深度学习的一个魅力在于神经网络中各式各样的层,例如 全连接层、卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。
这篇文章介绍如何使用Module来自定义层,从而可以被重复调用。
不含模型参数的自定义层
我们先介绍如何定义一个不含参数的自定义层。事实上,创建自定义层 与 使用 Module类 构造模型类似。
下面的 CenteredLayer 类通过继承 Module类 自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。(这个层里不含模型参数)
import torch
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super(CenteredLayer, self).__init__()
def forward(self, x):
x -= torch.mean(x, dim=0)
return x - x.mean()
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
print(y.shape) # torch.Size([4, 128])
print(y.mean().item()) # 1.2655618775170296e-09
含模型参数的自定义层
之前的文章已经介绍了模型参数。 Parameter类 其实是Tensor的子类,如果一个 Tensor 是 Parameter,那么它会自动被添加到模型的参数列表里。
所以在自定义含模型参数的层时,我们需要将参数定义成Parameter。
除了直接定义成Parameter类外,还可以使用 ParameterList
和 ParameterDict
分别定义参数的列表和字典。
1)ParameterList
ParameterList
: 能接收一个 Parameter 实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用 append 和 extend 在列表后面新增参数。
import torch
from torch import nn
class MyDense(nn.Module):
def __init__(self):
super(MyDense, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4, 1)))
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = MyDense()
print(net)
# MyDense(
# (params): ParameterList(
# (0): Parameter containing: [torch.FloatTensor of size 4x4]
# (1): Parameter containing: [torch.FloatTensor of size 4x4]
# (2): Parameter containing: [torch.FloatTensor of size 4x4]
# (3): Parameter containing: [torch.FloatTensor of size 4x1]
# )
# )
2)ParameterDict
import torch
from torch import nn
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({
'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
# MyDictDense(
# (params): ParameterDict(
# (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
# (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
# (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
# )
# )