pytorch中的顺序容器——torch.nn.Sequential

1.torch.nn.Sequential概要

pytorch官网对torch.nn.Sequential的描述如下。

使用方式:

# 写法一
net = nn.Sequential(
    nn.Linear(num_inputs, 1)
    # 此处还可以传入其他层
    )

# 写法二
net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))
# net.add_module ......

# 写法三
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
          ('linear', nn.Linear(num_inputs, 1))
          # ......
        ]))

方式一:
这是一个有顺序的容器,将特定神经网络模块按照在传入构造器的顺序依次被添加到计算图中执行。
方式二:
也可以将以特定神经网络模块为元素的有序字典(OrderedDict)为参数传入。
方式三:
也可以利用add_module函数将特定的神经网络模块插入到计算图中。add_module函数是神经网络模块的基础类(torch.nn.Module)中的方法,如下描述所示用于将子模块添加到现有模块中。

2.Sequential源码分析

先看一下初始化函数__init__,在初始化函数中,首先是if条件判断,如果传入的参数为1个,并且类型为OrderedDict,通过字典索引的方式利用add_module函数将子模块添加到现有模块中,否则,通过for循环遍历参数,将所有的子模块添加到现有中。

由于每一个神经网络模块都继承于nn.Module,因此都会实现__call__forward函数,所以forward函数中通过for循环依次调用添加到现有模块中的子模块,最后输出经过所有神经网络层的结果。

参考文献:
https://blog.csdn.net/dss_dssssd/article/details/82980222

发布了233 篇原创文章 · 获赞 187 · 访问量 40万+

猜你喜欢

转载自blog.csdn.net/qiu931110/article/details/104254752
今日推荐