【PyTorch】教程:torch.nn.ModuleDict

Containers-ModuleList

CLASS torch.nn.ModuleDict(modules=None)

将所有的子模块放到一个字典中。

ModuleDict 可以像常规 Python 字典一样进行索引,但它包含的模块已正确注册,所有 Module 方法都可以看到。

ModuleDict 是一个有序字典。

  • Parameters

modules (iterable, optional) – 一个(string: module)映射(字典)或者可迭代的键值对。

  • Example
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
    
    
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x
  • method

  • clear

移除 ModuleDict 里所有的子模块。

  • items

返回 ModuleDict 键值对的迭代器

  • keys

返回 ModuleDict 键的可迭代项。
返回: Iterable[str]

  • pop

Remove key from the ModuleDict and return its module.
参数:key (str) – 从 ModuleDict 弹出的键
返回: Module

  • update

使用映射或可迭代的键值对更新 ModuleDict, 覆盖现有键;
如果模块是 OrderedDict, ModuleDict, 或可迭代的键值对,则保留其中新元素的顺序;
参数:modules (iterable) - 从字符串到模块的映射(字典),或类型(string,Module)的键值对的可迭代

  • values

返回 ModuleDict 值的可迭代值。Iterable[Module]

【参考】

ModuleDict — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/zhoujinwang/article/details/129052373