[pytorch] ModuleList and ModuleDict

ModuleList and ModuleDict


1、ModuleList

1) ModuleListReceive a list of submodules as input, and then perform append and extend operations similar to List:

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 可使用类似List的索引访问
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError


# 输出:
# Linear(in_features=256, out_features=10, bias=True)
# ModuleList(
#   (0): Linear(in_features=784, out_features=256, bias=True)
#   (1): ReLU()
#   (2): Linear(in_features=256, out_features=10, bias=True)
# )

\quad
2) nn.Sequentialand the difference between nn.ModuleListthe two :

  • nn.ModuleListIt is just a list that stores various modules, there is no connection between these modules (so there is no need to ensure that the input and output dimensions of adjacent layers match), and nn.Sequentialthe modules in need to be arranged in order to ensure that the input and output sizes of adjacent layers match
  • nn.ModuleListIf the forward function is not implemented, you need to implement it yourself, so the above execution net(torch.zeros(1, 784))will report NotImplementedError; and nn.Sequentialthe internal forward function has been implemented.

The emergence of ModuleList just makes the network definition forward propagation more flexible, see the example on the official website below:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

\quad
3) In addition, nn.ModuleListunlike the general Python list, the parameters of all modules added to nn.ModuleListit will be automatically added to the entire network. Let's see an example for comparison.

import torch
import torch.nn as nn

class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10)])

class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.linears = [nn.Linear(10, 10)]

net1 = Module_ModuleList()
net2 = Module_List()

print(net1)
for p in net1.parameters():
    print(p.size())

print('*'*20)
print(net2)
for p in net2.parameters():
    print(p)

output

Module_ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
  )
)
torch.Size([10, 10])
torch.Size([10])
********************
Module_List()

2、ModuleDict

ModuleDict accepts a dictionary of submodules as input, and can then be added and accessed like a dictionary:

net = nn.ModuleDict({
    
    
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError


# 输出:
# Linear(in_features=784, out_features=256, bias=True)
# Linear(in_features=256, out_features=10, bias=True)
# ModuleDict(
#   (act): ReLU()
#   (linear): Linear(in_features=784, out_features=256, bias=True)
#   (output): Linear(in_features=256, out_features=10, bias=True)
# )

(1) nn.ModuleListSame as , nn.ModuleDictthe instance is only a dictionary that stores some modules, and does not define the forward function, which needs to be defined by itself.
(2) Similarly, nn.ModuleDictdifferent from Python's Dict, nn.ModuleDictthe parameters of all modules in it will be automatically added to the entire network.


3. Summary

  1. Sequential, ModuleList, and ModuleDictclasses all inherit from the Module class.
  2. SequentialUnlike , andModuleList do not define a complete network, they just store different modules together, and need to define the forward function by themselves.ModuleDict
  3. Although classes Sequentialsuch as can make model construction easier, directly inheriting the Module class can greatly expand the flexibility of model construction.

Guess you like

Origin blog.csdn.net/weixin_37804469/article/details/129126824