【翻译】class torch.nn.ModuleList(modules=None)

参考链接: class torch.nn.ModuleList(modules=None)

在这里插入图片描述

原文及翻译:

ModuleList  ModuleList章节
class torch.nn.ModuleList(modules=None)
类型 torch.nn.ModuleList(modules=None)
    Holds submodules in a list.
    该类型能够持有一个子模块构成的列表.
    ModuleList can be indexed like a regular Python list, but modules it contains 
    are properly registered, and will be visible by all Module methods.
    ModuleList类型能够像一个普通Python列表一样索引,但是所不同的是该类型所包含的模块将会被正确
    地登记注册,并且这些子模块能够被所有的Module模块方法可见.

    Parameters  参数

        modules (iterable, optional) – an iterable of modules to add
        modules (可迭代类型, 可选) – 这是一个对需要添加的模块的可迭代类型

    Example:  例子:

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

        def forward(self, x):
            # ModuleList can act as an iterable, or be indexed using ints
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
            return x

    append(module)
    方法: append(module)
        Appends a given module to the end of the list.
        在列表的末尾追加一个给定的模块.
        Parameters  参数
            module (nn.Module) – module to append
            module (nn.Module类型) – 需要追加的模块

    extend(modules)
    方法: extend(modules)
        Appends modules from a Python iterable to the end of the list.
        将Python可迭代类型中的模块追加到列表的末尾.
        Parameters  参数
            modules (iterable) – iterable of modules to append
            modules (iterable可迭代类型) – 由需要追加的模块构成的可迭代类型

    insert(index, module)
    方法: insert(index, module)
        Insert a given module before a given index in the list.
        将一个给定的模块插入到列表的给定索引位置之前.
        Parameters  参数
                index (int) – index to insert.
                index (int整数) – 所要插入的索引位置
                module (nn.Module) – module to insert
                module (nn.Module类型) – 需要插入的模块

实验代码展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000022D335FD330>
>>>
>>> import torch.nn as nn
>>> linears = nn.ModuleList([nn.Linear(10+i, 10+i*2) for i in range(10)])
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
)
>>> type(linears)
<class 'torch.nn.modules.container.ModuleList'>
>>> linears[0]
Linear(in_features=10, out_features=10, bias=True)
>>> linears[9]
Linear(in_features=19, out_features=28, bias=True)
>>> linears[10]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\container.py", line 138, in __getitem__
    return self._modules[self._get_abs_string_index(idx)]
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\container.py", line 129, in _get_abs_string_index
    raise IndexError('index {} is out of range'.format(idx))
IndexError: index 10 is out of range
>>> linears[-1]
Linear(in_features=19, out_features=28, bias=True)
>>> linears[-10]
Linear(in_features=10, out_features=10, bias=True)
>>> linears[-11]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\container.py", line 138, in __getitem__
    return self._modules[self._get_abs_string_index(idx)]
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\container.py", line 129, in _get_abs_string_index
    raise IndexError('index {} is out of range'.format(idx))
IndexError: index -11 is out of range
>>>
>>> linears.append(nn.Linear(4, 3))
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
  (10): Linear(in_features=4, out_features=3, bias=True)
)
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
  (10): Linear(in_features=4, out_features=3, bias=True)
)
>>> module = nn.Linear(2, 2)
>>> linears[10] = module
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
  (10): Linear(in_features=2, out_features=2, bias=True)
)
>>> modules = [nn.Linear(30+i, 30+i*2) for i in range(5)]
>>> modules
[Linear(in_features=30, out_features=30, bias=True), Linear(in_features=31, out_features=32, bias=True), Linear(in_features=32, out_features=34, bias=True), Linear(in_features=33, out_features=36, bias=True), Linear(in_features=34, out_features=38, bias=True)]
>>> linears.extend(modules)
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
  (10): Linear(in_features=2, out_features=2, bias=True)
  (11): Linear(in_features=30, out_features=30, bias=True)
  (12): Linear(in_features=31, out_features=32, bias=True)
  (13): Linear(in_features=32, out_features=34, bias=True)
  (14): Linear(in_features=33, out_features=36, bias=True)
  (15): Linear(in_features=34, out_features=38, bias=True)
)
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=17, out_features=24, bias=True)
  (8): Linear(in_features=18, out_features=26, bias=True)
  (9): Linear(in_features=19, out_features=28, bias=True)
  (10): Linear(in_features=2, out_features=2, bias=True)
  (11): Linear(in_features=30, out_features=30, bias=True)
  (12): Linear(in_features=31, out_features=32, bias=True)
  (13): Linear(in_features=32, out_features=34, bias=True)
  (14): Linear(in_features=33, out_features=36, bias=True)
  (15): Linear(in_features=34, out_features=38, bias=True)
)
>>> index = 7
>>> module = nn.Linear(99, 88)
>>> index
7
>>> module
Linear(in_features=99, out_features=88, bias=True)
>>> # linears.insert(index, module)
>>> len(linears)
16
>>> linears.insert(index, module)
>>> len(linears)
17
>>> linears
ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=11, out_features=12, bias=True)
  (2): Linear(in_features=12, out_features=14, bias=True)
  (3): Linear(in_features=13, out_features=16, bias=True)
  (4): Linear(in_features=14, out_features=18, bias=True)
  (5): Linear(in_features=15, out_features=20, bias=True)
  (6): Linear(in_features=16, out_features=22, bias=True)
  (7): Linear(in_features=99, out_features=88, bias=True)
  (8): Linear(in_features=17, out_features=24, bias=True)
  (9): Linear(in_features=18, out_features=26, bias=True)
  (10): Linear(in_features=19, out_features=28, bias=True)
  (11): Linear(in_features=2, out_features=2, bias=True)
  (12): Linear(in_features=30, out_features=30, bias=True)
  (13): Linear(in_features=31, out_features=32, bias=True)
  (14): Linear(in_features=32, out_features=34, bias=True)
  (15): Linear(in_features=33, out_features=36, bias=True)
  (16): Linear(in_features=34, out_features=38, bias=True)
)
>>>
>>> for lin in linears:
...     print(lin.weight.shape)
...
torch.Size([10, 10])
torch.Size([12, 11])
torch.Size([14, 12])
torch.Size([16, 13])
torch.Size([18, 14])
torch.Size([20, 15])
torch.Size([22, 16])
torch.Size([88, 99])
torch.Size([24, 17])
torch.Size([26, 18])
torch.Size([28, 19])
torch.Size([2, 2])
torch.Size([30, 30])
torch.Size([32, 31])
torch.Size([34, 32])
torch.Size([36, 33])
torch.Size([38, 34])
>>>
>>>

代码实验展示:普通的Python列表不会被正确地登记注册

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linearsModuleList_in = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
        self.linearsPythonList_in = [nn.Linear(30, 30) for i in range(10)]

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)

model = Model() #.cuda()
print('普通Python列表不会被正确登记注册'.center(100,"-"))
print("打印模型".center(100,"-"))
for name, param in model.named_parameters(prefix='', recurse=True):
    print('参数名字是:', name, '参数形状是:', param.shape)

model.linearsModuleList_out = nn.ModuleList([nn.Linear(20, 20) for i in range(10)])
model.linearsPythonList_out = [nn.Linear(40, 40) for i in range(10)]
print('普通Python列表不会被正确登记注册'.center(100,"-"))
print("打印模型".center(100,"-"))
for name, param in model.named_parameters(prefix='', recurse=True):
    print('参数名字是:', name, '参数形状是:', param.shape)

控制台输出结果展示:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 937 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '54026' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
----------------------------------------普通Python列表不会被正确登记注册-----------------------------------------
------------------------------------------------打印模型------------------------------------------------
参数名字是: linearsModuleList_in.0.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.0.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.1.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.1.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.2.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.2.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.3.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.3.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.4.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.4.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.5.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.5.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.6.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.6.bias 参数形状是: torch.Size([10])      
参数名字是: linearsModuleList_in.7.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.7.bias 参数形状是: torch.Size([10])      
参数名字是: linearsModuleList_in.8.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.8.bias 参数形状是: torch.Size([10])      
参数名字是: linearsModuleList_in.9.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.9.bias 参数形状是: torch.Size([10])      
----------------------------------------普通Python列表不会被正确登记注册-----------------------------------------
------------------------------------------------打印模型------------------------------------------------
参数名字是: linearsModuleList_in.0.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.0.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.1.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.1.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.2.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.2.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.3.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.3.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.4.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.4.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.5.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.5.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.6.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.6.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.7.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.7.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.8.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.8.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_in.9.weight 参数形状是: torch.Size([10, 10])
参数名字是: linearsModuleList_in.9.bias 参数形状是: torch.Size([10])
参数名字是: linearsModuleList_out.0.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.0.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.1.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.1.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.2.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.2.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.3.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.3.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.4.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.4.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.5.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.5.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.6.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.6.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.7.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.7.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.8.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.8.bias 参数形状是: torch.Size([20])
参数名字是: linearsModuleList_out.9.weight 参数形状是: torch.Size([20, 20])
参数名字是: linearsModuleList_out.9.bias 参数形状是: torch.Size([20])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

猜你喜欢

转载自blog.csdn.net/m0_46653437/article/details/112760242
今日推荐