【pytorch】自定义层

深度学习的一个魅力在于神经网络中各式各样的层,例如 全连接层、卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。

这篇文章介绍如何使用Module来自定义层,从而可以被重复调用。


不含模型参数的自定义层

我们先介绍如何定义一个不含参数的自定义层。事实上,创建自定义层 与 使用 Module类 构造模型类似。

下面的 CenteredLayer 类通过继承 Module类 自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。(这个层里不含模型参数)

import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self):
        super(CenteredLayer, self).__init__()

    def forward(self, x):
        x -= torch.mean(x, dim=0)
        return x - x.mean()


net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
print(y.shape)   # torch.Size([4, 128])
print(y.mean().item())   # 1.2655618775170296e-09

含模型参数的自定义层

之前的文章已经介绍了模型参数。 Parameter类 其实是Tensor的子类,如果一个 Tensor 是 Parameter,那么它会自动被添加到模型的参数列表里。

所以在自定义含模型参数的层时,我们需要将参数定义成Parameter。
除了直接定义成Parameter类外,还可以使用 ParameterListParameterDict 分别定义参数的列表和字典。

1)ParameterList

ParameterList : 能接收一个 Parameter 实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用 append 和 extend 在列表后面新增参数。

import torch
from torch import nn

class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
        
net = MyDense()
print(net)

# MyDense(
#   (params): ParameterList(
#       (0): Parameter containing: [torch.FloatTensor of size 4x4]
#       (1): Parameter containing: [torch.FloatTensor of size 4x4]
#       (2): Parameter containing: [torch.FloatTensor of size 4x4]
#       (3): Parameter containing: [torch.FloatTensor of size 4x1]
#   )
# )

2)ParameterDict

import torch
from torch import nn

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
    
    
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({
    
    'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

# MyDictDense(
#   (params): ParameterDict(
#       (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
#       (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
#       (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
#   )
# )

猜你喜欢

转载自blog.csdn.net/weixin_37804469/article/details/129133324