ModuleList and ModuleDict
1、ModuleList
1) ModuleList
Receive a list of submodules as input, and then perform append and extend operations similar to List:
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1]) # 可使用类似List的索引访问
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
# 输出:
# Linear(in_features=256, out_features=10, bias=True)
# ModuleList(
# (0): Linear(in_features=784, out_features=256, bias=True)
# (1): ReLU()
# (2): Linear(in_features=256, out_features=10, bias=True)
# )
\quad
2) nn.Sequential
and the difference between nn.ModuleList
the two :
nn.ModuleList
It is just a list that stores various modules, there is no connection between these modules (so there is no need to ensure that the input and output dimensions of adjacent layers match), andnn.Sequential
the modules in need to be arranged in order to ensure that the input and output sizes of adjacent layers matchnn.ModuleList
If the forward function is not implemented, you need to implement it yourself, so the above executionnet(torch.zeros(1, 784))
will report NotImplementedError; andnn.Sequential
the internal forward function has been implemented.
The emergence of ModuleList just makes the network definition forward propagation more flexible, see the example on the official website below:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
\quad
3) In addition, nn.ModuleList
unlike the general Python list, the parameters of all modules added to nn.ModuleList
it will be automatically added to the entire network. Let's see an example for comparison.
import torch
import torch.nn as nn
class Module_ModuleList(nn.Module):
def __init__(self):
super(Module_ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10)])
class Module_List(nn.Module):
def __init__(self):
super(Module_List, self).__init__()
self.linears = [nn.Linear(10, 10)]
net1 = Module_ModuleList()
net2 = Module_List()
print(net1)
for p in net1.parameters():
print(p.size())
print('*'*20)
print(net2)
for p in net2.parameters():
print(p)
output
Module_ModuleList(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
)
)
torch.Size([10, 10])
torch.Size([10])
********************
Module_List()
2、ModuleDict
ModuleDict accepts a dictionary of submodules as input, and can then be added and accessed like a dictionary:
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
# 输出:
# Linear(in_features=784, out_features=256, bias=True)
# Linear(in_features=256, out_features=10, bias=True)
# ModuleDict(
# (act): ReLU()
# (linear): Linear(in_features=784, out_features=256, bias=True)
# (output): Linear(in_features=256, out_features=10, bias=True)
# )
(1) nn.ModuleList
Same as , nn.ModuleDict
the instance is only a dictionary that stores some modules, and does not define the forward function, which needs to be defined by itself.
(2) Similarly, nn.ModuleDict
different from Python's Dict, nn.ModuleDict
the parameters of all modules in it will be automatically added to the entire network.
3. Summary
Sequential
,ModuleList
, andModuleDict
classes all inherit from the Module class.Sequential
Unlike , andModuleList
do not define a complete network, they just store different modules together, and need to define the forward function by themselves.ModuleDict
- Although classes
Sequential
such as can make model construction easier, directly inheriting the Module class can greatly expand the flexibility of model construction.