【翻译】class torch.nn.ParameterDict(parameters=None)

参考链接: class torch.nn.ParameterDict(parameters=None)

说明:在PyTorch的1.2.0版本上这个方法有bug,用一个ParameterDict对象调用update()来更新另一个ParameterDict会报错,但在PyTorch的1.7.1版本上可以正常使用.

在这里插入图片描述

在这里插入图片描述

原文及翻译:

ParameterDict  ParameterDict章节

class torch.nn.ParameterDict(parameters=None)
类型: class torch.nn.ParameterDict(parameters=None)
    Holds parameters in a dictionary.
    该类可以以字典的方式来持有多个参数.
    ParameterDict can be indexed like a regular Python dictionary, but parameters it 
    contains are properly registered, and will be visible by all Module methods.
    ParameterDict 类型可以像普通Python字典一样用索引来访问,但是它和普通字典不同的是,它所包含的
    参数会被正确地登记注册,并且将会被所有Module方法可见.
    ParameterDict is an ordered dictionary that respects
    ParameterDict是一个有序字典,该字典遵循:
        the order of insertion, and
        插入地顺序,以及
        in update(), the order of the merged OrderedDict or another 
        ParameterDict (the argument to update()).
        在方法update(),遵循被合并地有序字典(OrderedDict )或另一个
        ParameterDict (即传递给update()方法的参数)的顺序.
        
    Note that update() with other unordered mapping types (e.g., Python’s plain dict) does 
    not preserve the order of the merged mapping.
    需要注意的是,如果传递给update()方法的参数是一个无序的映射类型(比如普通的Python字典),那么不会保持
    被合并映射类型的顺序.

    Parameters  参数

        parameters (iterable, optional) – a mapping (dictionary) of (string : Parameter) or 
        an iterable of key-value pairs of type (string, Parameter)
        parameters (可迭代类型, 可选) – 一个(字符串,Parameter参数的)映射类型(字典)或者
        由(字符串,Parameter参数)键值对类型构成可迭代类型.

    Example:  例子:

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.params = nn.ParameterDict({
    
    
                    'left': nn.Parameter(torch.randn(5, 10)),
                    'right': nn.Parameter(torch.randn(5, 10))
            })

        def forward(self, x, choice):
            x = self.params[choice].mm(x)
            return x

    clear()
    方法: clear()
        Remove all items from the ParameterDict.
        从ParameterDict类型中移除所有的项目.

    items()
    方法: items()
        Return an iterable of the ParameterDict key/value pairs.
        返回ParameterDict类型所包含的键值对构成的可迭代对象.

    keys()
    方法: keys()
        Return an iterable of the ParameterDict keys.
        返回ParameterDict类型的关键字构成的可迭代对象.

    pop(key)
    方法: pop(key)
        Remove key from the ParameterDict and return its parameter.
        从ParameterDict类型中移除关键字key,并且返回该关键字所对应的参数.
        Parameters  参数
            key (string) – key to pop from the ParameterDict
            key (字符串) – 从ParameterDict中所要弹出的关键字.

    update(parameters)
    方法: update(parameters)
        Update the ParameterDict with the key-value pairs from a mapping or an 
        iterable, overwriting existing keys.
        用一个包含键值对的映射类型或者用一个包含键值对的可迭代类型来更新ParameterDict类型,并且覆写
        其中已经存在的关键字.
        Note  注意:
        If parameters is an OrderedDict, a ParameterDict, or an iterable of key-value pairs, 
        the order of new elements in it is preserved.
        如果掺入的parameters参数是有序字典(OrderedDict)、ParameterDict或者是键值对构成的可迭代类型,
        那么新创建的元素也将继续保持原有的顺序.
        Parameters  参数
            parameters (iterable) – a mapping (dictionary) from string to Parameter, or 
            an iterable of key-value pairs of type (string, Parameter)
            parameters (iterable可迭代对象) – 字符串映射到Parameter参数的映射类型(字典),或者
            是由键值对类型(字符串,Parameter参数)构成的可迭代类型.

    values()
    方法: values()
        Return an iterable of the ParameterDict values.
        返回ParameterDict的所有键值对中的值构成的可迭代对象.

代码实验展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000182130FD330>
>>>
>>> params = nn.ParameterDict({
    
    
...     'left': nn.Parameter(torch.randn(2, 4)),
...     'right': nn.Parameter(torch.randn(2, 4))
... })
>>>
>>> params
ParameterDict(
    (left): Parameter containing: [torch.FloatTensor of size 2x4]
    (right): Parameter containing: [torch.FloatTensor of size 2x4]
)
>>> params['left']
Parameter containing:
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651]], requires_grad=True)
>>> params['right']
Parameter containing:
tensor([[ 1.1216,  0.8440,  0.1783,  0.6859],
        [-1.5942, -0.2006, -0.4050, -0.5556]], requires_grad=True)
>>> for item in params.items():
...     print(item[0],'  ',item[1])
...
left    Parameter containing:
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651]], requires_grad=True)
right    Parameter containing:
tensor([[ 1.1216,  0.8440,  0.1783,  0.6859],
        [-1.5942, -0.2006, -0.4050, -0.5556]], requires_grad=True)
>>> for key in params.keys():
...     print(key)
...
left
right
>>> for value in params.values():
...     print(value)
...
Parameter containing:
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651]], requires_grad=True)
Parameter containing:
tensor([[ 1.1216,  0.8440,  0.1783,  0.6859],
        [-1.5942, -0.2006, -0.4050, -0.5556]], requires_grad=True)
>>>
>>> params
ParameterDict(
    (left): Parameter containing: [torch.FloatTensor of size 2x4]
    (right): Parameter containing: [torch.FloatTensor of size 2x4]
)
>>> params.pop('left')
Parameter containing:
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651]], requires_grad=True)
>>> params
ParameterDict(  (right): Parameter containing: [torch.FloatTensor of size 2x4])
>>> params.clear()
>>> params
ParameterDict()
>>>
>>> params_1 = nn.ParameterDict([
...     ('a',nn.Parameter(torch.randn(2, 4))),
...     ('b',nn.Parameter(torch.randn(1, 4))),
...     ('c',nn.Parameter(torch.randn(1, 3)))
... ])
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 1x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>> params_1.clear()
>>> params_1
ParameterDict()
>>> params_1 = nn.ParameterDict([
...     ('a',nn.Parameter(torch.randn(2, 4))),
...     ('b',nn.Parameter(torch.randn(1, 4))),
...     ('c',nn.Parameter(torch.randn(1, 3)))
... ])
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 1x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>> params_1.pop('b')
Parameter containing:
tensor([[ 1.0600, -0.4584, -0.3792,  0.1137]], requires_grad=True)
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>>
>>>
>>> params_1 = nn.ParameterDict([
...     ('a',nn.Parameter(torch.randn(2, 4))),
...     ('b',nn.Parameter(torch.randn(1, 4))),
...     ('c',nn.Parameter(torch.randn(1, 3)))
... ])
>>>
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 1x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>> params_2 = nn.ParameterDict([
...     ('d',nn.Parameter(torch.randn(1, 2))),
...     ('e',nn.Parameter(torch.randn(2, 3))),
...     ('f',nn.Parameter(torch.randn(1, 5))),
...     ('b',nn.Parameter(torch.randn(2, 6))),
... ])
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 1x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>> params_2
ParameterDict(
    (d): Parameter containing: [torch.FloatTensor of size 1x2]
    (e): Parameter containing: [torch.FloatTensor of size 2x3]
    (f): Parameter containing: [torch.FloatTensor of size 1x5]
    (b): Parameter containing: [torch.FloatTensor of size 2x6]
)
>>> params_1.update(params_2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\container.py", line 558, in update
    "; 2 is required")
ValueError: ParameterDict update sequence element #0 has length 1; 2 is required
>>> ^Z


(ssd4pytorch1_2_0) C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000018EB4EE7870>
>>> params_1 = nn.ParameterDict([
...     ('a',nn.Parameter(torch.randn(2, 4))),
...     ('b',nn.Parameter(torch.randn(1, 4))),
...     ('c',nn.Parameter(torch.randn(1, 3)))
... ])
>>>
>>> params_2 = nn.ParameterDict([
...     ('d',nn.Parameter(torch.randn(1, 2))),
...     ('e',nn.Parameter(torch.randn(2, 3))),
...     ('f',nn.Parameter(torch.randn(1, 5))),
...     ('b',nn.Parameter(torch.randn(2, 6))),
... ])
>>>
>>>
>>> print(torch.__version__)
1.7.1
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 1x4]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
)
>>> params_2
ParameterDict(
    (d): Parameter containing: [torch.FloatTensor of size 1x2]
    (e): Parameter containing: [torch.FloatTensor of size 2x3]
    (f): Parameter containing: [torch.FloatTensor of size 1x5]
    (b): Parameter containing: [torch.FloatTensor of size 2x6]
)
>>> params_1.update(params_2)
>>> params_1
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x4]
    (b): Parameter containing: [torch.FloatTensor of size 2x6]
    (c): Parameter containing: [torch.FloatTensor of size 1x3]
    (d): Parameter containing: [torch.FloatTensor of size 1x2]
    (e): Parameter containing: [torch.FloatTensor of size 2x3]
    (f): Parameter containing: [torch.FloatTensor of size 1x5]
)
>>> params_2
ParameterDict(
    (d): Parameter containing: [torch.FloatTensor of size 1x2]
    (e): Parameter containing: [torch.FloatTensor of size 2x3]
    (f): Parameter containing: [torch.FloatTensor of size 1x5]
    (b): Parameter containing: [torch.FloatTensor of size 2x6]
)
>>>
>>>
>>> ^Z


(pytorch_1.7.1_cu102) C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
1.2.0+cu92
>>>
>>> ^Z


(ssd4pytorch1_2_0) C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
1.7.1
>>>
>>> ^Z


(pytorch_1.7.1_cu102) C:\Users\chenxuqi>

代码实验展示:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.params_ParameterDict_in = nn.ParameterDict([
            ('a',nn.Parameter(torch.randn(2, 4))),
            ('b',nn.Parameter(torch.randn(1, 4))),
            ('c',nn.Parameter(torch.randn(1, 3)))
        ])

        self.params_PythonDict_in = {
    
    
            'top': nn.Parameter(torch.randn(88, 77)),
            'bottom': nn.Parameter(torch.randn(66, 55)),
        }

    def forward(self,x): 
        pass

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)

model = Model() #.cuda()
print('普通Python字典不会被正确登记注册'.center(100,"-"))
print("打印模型".center(100,"-"))
for name, param in model.named_parameters(prefix='', recurse=True):
    print('参数名字是:', name, '参数形状是:', param.shape)

model.params_ParameterList_out = nn.ParameterDict([
    ('d',nn.Parameter(torch.randn(1, 2))),
    ('e',nn.Parameter(torch.randn(2, 3))),
    ('f',nn.Parameter(torch.randn(1, 5))),
    ('b',nn.Parameter(torch.randn(2, 6))),
])


model.params_PythonList_out = {
    
    
    'left': nn.Parameter(torch.randn(2, 4)),
    'right': nn.Parameter(torch.randn(2, 4))
}


print('普通Python字典不会被正确登记注册'.center(100,"-"))
print("打印模型".center(100,"-"))
for name, param in model.named_parameters(prefix='', recurse=True):
    print('参数名字是:', name, '参数形状是:', param.shape)

控制台输出结果:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 865 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '55100' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test2.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
----------------------------------------普通Python字典不会被正确登记注册-----------------------------------------
------------------------------------------------打印模型------------------------------------------------
参数名字是: params_ParameterDict_in.a 参数形状是: torch.Size([2, 4])
参数名字是: params_ParameterDict_in.b 参数形状是: torch.Size([1, 4])
参数名字是: params_ParameterDict_in.c 参数形状是: torch.Size([1, 3])
----------------------------------------普通Python字典不会被正确登记注册-----------------------------------------
------------------------------------------------打印模型------------------------------------------------
参数名字是: params_ParameterDict_in.a 参数形状是: torch.Size([2, 4])
参数名字是: params_ParameterDict_in.b 参数形状是: torch.Size([1, 4])
参数名字是: params_ParameterDict_in.c 参数形状是: torch.Size([1, 3])
参数名字是: params_ParameterList_out.d 参数形状是: torch.Size([1, 2])
参数名字是: params_ParameterList_out.e 参数形状是: torch.Size([2, 3])
参数名字是: params_ParameterList_out.f 参数形状是: torch.Size([1, 5])
参数名字是: params_ParameterList_out.b 参数形状是: torch.Size([2, 6])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 

猜你喜欢

转载自blog.csdn.net/m0_46653437/article/details/112760276