【翻译】class torch.nn.ModuleDict(modules=None)

参考链接: class torch.nn.ModuleDict(modules=None)

说明:在PyTorch的1.2.0版本上这个方法有bug,用一个ModuleDict对象调用update()来更新另一个ModuleDict会报错,但在PyTorch的1.7.1版本上可以正常使用.

在这里插入图片描述

在这里插入图片描述

原文及翻译:

ModuleDict  ModuleDict章节

class torch.nn.ModuleDict(modules=None)
类型 class torch.nn.ModuleDict(modules=None)
    Holds submodules in a dictionary.
    该类能够以字典的方式持有子模块.
    ModuleDict can be indexed like a regular Python dictionary, but modules it contains 
    are properly registered, and will be visible by all Module methods.
    ModuleDict 类型能够像普通Python字典一样被索引访问,但是它和普通Python字典不同的是,该类型所
    包含的模块会被正确地注册登记,并且这些模块能被所有地Module模块方法可见.
    ModuleDict is an ordered dictionary that respects
    ModuleDict 类型是一个有序字典,它遵循:
        the order of insertion, and
        插入地先后顺序,并且
        in update(), the order of the merged OrderedDict or another ModuleDict 
        (the argument to update()).
		在方法update(),遵循被合并的有序字典OrderedDict的顺序或者
		另一个ModuleDict(,传递给方法update()的参数)的顺序.
    Note that update() with other unordered mapping types (e.g., Python’s plain dict) does 
    not preserve the order of the merged mapping.
    值得注意的是,在这个update()方法中如果传递了一个无序的映射类型(比如,Python的普通字典),那么不会
    保持被合并的这个映射类型的顺序.

    Parameters  参数
        modules (iterable, optional) – a mapping (dictionary) of (string: module) or 
        an iterable of key-value pairs of type (string, module)
		modules (iterable可迭代类型, 可选) – 一个映射(字符串:模块)类型(字典)或者
		一个键值对(字符串,模块)类型的可迭代对象.


    Example:  例子:

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.choices = nn.ModuleDict({
    
    
                    'conv': nn.Conv2d(10, 10, 3),
                    'pool': nn.MaxPool2d(3)
            })
            self.activations = nn.ModuleDict([
                    ['lrelu', nn.LeakyReLU()],
                    ['prelu', nn.PReLU()]
            ])

        def forward(self, x, choice, act):
            x = self.choices[choice](x)
            x = self.activations[act](x)
            return x

    clear()
    方法: clear()
        Remove all items from the ModuleDict.
        移除ModuleDict中的所有项目.

    items()
    方法: items()
        Return an iterable of the ModuleDict key/value pairs.
        返回一个ModuleDict的键/值对的可迭代对象.

    keys()
    方法: keys()
        Return an iterable of the ModuleDict keys.
        返回ModuleDict关键字的可迭代对象.

    pop(key)
    方法: pop(key)
        Remove key from the ModuleDict and return its module.
        在ModuleDict中移除关键字key.并且返回这个关键字对应的模块.
        Parameters  参数
            key (string) – key to pop from the ModuleDict
            
    update(modules)
    方法: update(modules)
        Update the ModuleDict with the key-value pairs from a mapping or an iterable, 
        overwriting existing keys.
        用一个键值对的映射类型或者可迭代对象来更新ModuleDict,覆写已经存在的关键字.
        Note  注意:
        If modules is an OrderedDict, a ModuleDict, or an iterable of key-value pairs, 
        the order of new elements in it is preserved.
        如果modules参数是一个有序字典OrderedDict或者ModuleDict或者键值对的可迭代对象,
        那么新元素的顺序也同样被维持.
        Parameters  参数
            modules (iterable) – a mapping (dictionary) from string to Module, or an 
            iterable of key-value pairs of type (string, Module)
            modules (iterable可迭代对象) – 字符串映射到Module模块的映射类型(字典),或者
            是键值对(字符串,Module模块)类型的可迭代对象.

    values()
    方法: values()
        Return an iterable of the ModuleDict values.
        返回一个ModuleDict键值对中的值的可迭代对象.

代码实验展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000021D828AD330>
>>> import torch.nn as nn
>>> layers_1 = nn.ModuleDict({
    
    
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1['conv']
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
>>> layers_1['pool']
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>>
>>> for item in layers_1.items():
...     print(item)
...
('conv', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)))
('pool', MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))
>>>
>>> for key in layers_1.keys():
...     print(key)
...
conv
pool
>>> for value in layers_1.values():
...     print(value)
...
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>> layers_1.pop()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: pop() missing 1 required positional argument: 'key'
>>> layers_1.pop('conv')
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
>>> layers_1
ModuleDict(
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.pop('pool')
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>> layers_1
ModuleDict()
>>>
>>> layers_1 = nn.ModuleDict({
    
    
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.clear()
>>> layers_1
ModuleDict()
>>> layers_1 = nn.ModuleDict({
    
    
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>>
>>> layers_2 = nn.ModuleDict({
    
    
...             'conv': nn.Conv2d(5, 5, 5),
...             'pool2': nn.MaxPool2d(7)
...     })
>>>

代码实验:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000002C864EB7870>
>>>
>>> print(torch.__version__)
1.7.1
>>>
>>> layers_1 = nn.ModuleDict({
    
    
...     'conv': nn.Conv2d(10, 10, 3),
...     'pool': nn.MaxPool2d(3)
... })
>>>
>>>
>>> layers_2 = nn.ModuleDict([
...     ['lrelu', nn.LeakyReLU()],
...     ['prelu', nn.PReLU()]
... ])
>>>
>>> layers_2
ModuleDict(
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.update(layers_2)
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>> layers_2
ModuleDict(
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>>
>>>

猜你喜欢

转载自blog.csdn.net/m0_46653437/article/details/112760256