含并行连结的网络 GoogLeNet

在GoogleNet出现值前,流行的网络结构使用的卷积核从1×1到11×11,卷积核的选择并没有太多的原因。GoogLeNet的提出,说明有时候使用多个不同大小的卷积核组合是有利的。

import torch
from torch import nn
from torch.nn import functional as F

1. Inception块

Inception块是 GoogLeNet 的基本组成单元。Inception 块由四条并行的路径组成,每个路径使用不同大小的卷积核:

路径1:使用 1×1 卷积层;

路径2:先对输出执行 1×1 卷积层,来减少通道数,降低模型复杂性,然后接 3×3 卷积层;

路径3:先对输出执行 1×1 卷积层,然后接 5×5 卷积层;

路径4:使用 3×3 最大汇聚层,然后使用 1×1 卷积层;

在各自路径中使用合适的 padding ,使得各个路径的输出拥有相同的高和宽,然后将每条路径的输出在通道维度上做连结,作为 Inception 块的最终输出.

class Inception(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Inception, self).__init__()
        # 路径1
        c1, c2, c3, c4 = out_channels
        self.route1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        # 路径2
        self.route2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
        self.route2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 路径3
        self.route3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        self.route3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 路径4
        self.route4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.route4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)
    def forward(self, x):
        x1 = F.relu(self.route1_1(x))
        x2 = F.relu(self.route2_2(F.relu(self.route2_1(x))))
        x3 = F.relu(self.route3_2(F.relu(self.route3_1(x))))
        x4 = F.relu(self.route4_2(self.route4_1(x)))
        return torch.cat((x1, x2, x3, x4), dim=1) 

2. 构造 GoogLeNet 网络

顺序定义 GoogLeNet 的模块。

第一个模块,顺序使用三个卷积层。

# 模型的第一个模块
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
    nn.Conv2d(64, 64, kernel_size=1),
    nn.ReLU(),
    nn.Conv2d(64, 192, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                   )

第二个模块,使用两个Inception模块。

# Inception组成的第二个模块
b2 = nn.Sequential(
    Inception(192, (64, (96, 128), (16, 32), 32)),
    Inception(256, (128, (128, 192), (32, 96), 64)),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    )

第三个模块,串联五个Inception模块。

# Inception组成的第三个模块
b3 = nn.Sequential(
    Inception(480, (192, (96, 208), (16, 48), 64)),
    Inception(512, (160, (112, 224), (24, 64), 64)),
    Inception(512, (128, (128, 256), (24, 64), 64)),
    Inception(512, (112, (144, 288), (32, 64), 64)),
    Inception(528, (256, (160, 320), (32, 128), 128)),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    )

第四个模块,传来两个Inception模块。

GoogLeNet使用 avg pooling layer 代替了 fully-connected layer。一方面降低了维度,另一方面也可以视为对低层特征的组合。

# Inception组成的第四个模块
b4 = nn.Sequential(
    Inception(832, (256, (160, 320), (32, 128), 128)),
    Inception(832, (384, (192, 384), (48, 128), 128)),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
                    )
net = nn.Sequential(b1, b2, b3, b4, nn.Linear(1024, 10))
x = torch.randn(1, 1, 96, 96)
for layer in net:
    x = layer(x)
    print(layer.__class__.__name__, "output shape: ", x.shape)

输出:

Sequential output shape:  torch.Size([1, 192, 28, 28])
Sequential output shape:  torch.Size([1, 480, 14, 14])
Sequential output shape:  torch.Size([1, 832, 7, 7])
Sequential output shape:  torch.Size([1, 1024])
Linear output shape:  torch.Size([1, 10])

3. FashionMNIST训练测试

def load_datasets_Cifar10(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)
    print("Cifar10 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    print("FashionMNIST 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets(dataset, batch_size, resize):
    if dataset == "Cifar10":
        return load_datasets_Cifar10(batch_size, resize=resize)
    else:
        return load_datasets_FashionMNIST(batch_size, resize=resize)
train_iter, test_iter = load_datasets("", 128, 96) # Cifar10

训练结果:

到此这篇关于PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程的文章就介绍到这了,更多相关PyTorch GoogLeNet内容请搜索Devmax以前的文章或继续浏览下面的相关文章希望大家以后多多支持Devmax!

PyTorch详解经典网络种含并行连结的网络GoogLeNet实现流程的更多相关文章

  1. Python使用pytorch动手实现LSTM模块

    这篇文章主要介绍了Python使用pytorch动手实现LSTM模块,LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出

  2. Pytorch搭建yolo3目标检测平台实现源码

    这篇文章主要为大家介绍了Pytorch搭建yolo3目标检测平台实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  3. PyTorch搭建双向LSTM实现时间序列负荷预测

    这篇文章主要为大家介绍了PyTorch搭建双向LSTM实现时间序列负荷预测,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  4. pytorch使用nn.Moudle实现逻辑回归

    这篇文章主要为大家详细介绍了pytorch使用nn.Moudle实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  5. pytorch加载自己的图片数据集的2种方法详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力,下面这篇文章主要给大家介绍了关于pytorch加载自己的图片数据集的2种方法,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  6. PyTorch实现手写数字的识别入门小白教程

    这篇文章主要介绍了python实现手写数字识别,非常适合小白入门学习,本文通过实例图文相结合给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  7. pytorch人工智能之torch.gather算子用法示例

    这篇文章主要介绍了pytorch人工智能之torch.gather算子用法示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  8. Pytorch深度学习addmm()和addmm_()函数用法解析

    这篇文章主要为大家介绍了Pytorch中addmm()和addmm_()函数用法解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  9. 基于Pytorch实现逻辑回归

    这篇文章主要为大家详细介绍了基于Pytorch实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  10. pytorch关于Tensor的数据类型说明

    这篇文章主要介绍了pytorch关于Tensor的数据类型说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部