Pytorch 用cfg构建网络结构

动机:使用cfg构建网络是很方便的,有必要了解一下
  • 本篇是通过参考Pytorch源码,来进行相关学习的。一直以来觉得这种方法比较抽象,但其实看一看,还是很值得的,毕竟我们生活在21世纪,接下来做研究可以借助强大的GPU算力(啥时候有钱了)构造更多深层网络,只用一个列表写出来结构显然会比一层一层手动写出来灵活性更高,修改也比较省时省力。

  • 通过阅读本文,可以学到用for循环迭代cfg列表的形式来构建网络

  • 首先给出我参考源码写的


print("Testing addLayer...")

cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']


batch_norm = 1
layers = []
in_channel = 3 # VGG第一层输入是RGB三通道的图像
for i in cfg:
    if i == "M": # 当遇到M
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

    else: # 如果遇到的不是M 那么就创建 多少入多少出的conv2d
        conv2d = nn.Conv2d(in_channel, i, kernel_size=3, padding=1)
        if batch_norm:
            layers += [conv2d, nn.BatchNorm2d(i), nn.ReLU(inplace=True)]
        else:
            layers += [conv2d, nn.ReLU(inplace=True)]
        in_channel = i  # VGG 里面,输入输出通道除了第一层以外都是相等的

a = nn.Sequential(* layers)

print(a)

  • 参考
  1. inplace的作用
  2. Pytorch写好的几个Network
发布了51 篇原创文章 · 获赞 1 · 访问量 3075

猜你喜欢

转载自blog.csdn.net/m0_38139098/article/details/105452659