pytorch custom network layer

Sometimes it is necessary to customize the weight of the convolution kernel, but the convolution kernel parameters in torch.nn.Conv2d and torch.nn.Conv3d do not allow customization, so you need to use Conv2d or nn.Conv3d in torch.nn.functional

Take the custom SRM layer as an example

The spatial rich model (SRM, Spatial Rich Model) filter layer uses three convolution kernels with different parameters to extract three different high-frequency residual signals, and sets its parameters as non-trainable . The three convolution kernel parameters used by the SRM filter are:

bd80ad091a254d9b94b92c5aff88fac3.png

In fact, a convolutional layer with parameters that cannot be trained is customized. The custom layer needs to inherit the nn.Module  class and rewrite  __init__ and forward  two methods.

Conv2d

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

class SRM2D(nn.Module):
    def __init__(self):
        super().__init__()

        q = [4.0, 12.0, 4.0]
        filter1 = [[0,  0,  0,  0, 0],
                   [0, -1,  2, -1, 0],
                   [0,  2, -4,  2, 0],
                   [0, -1,  2, -1, 0],
                   [0,  0,  0,  0, 0]]

        filter2 = [[-1, 2, -2,  2,-1],
                   [ 2,-6,  8, -6, 2],
                   [-2, 8, -12, 8,-2],
                   [ 2,-6,  8, -6, 2],
                   [-1, 2, -2,  2,-1]]

        filter3 = [[0,  0,  0,  0, 0],
                   [0,  0, -1,  0, 0],
                   [0, -1, +4, -1, 0],
                   [0,  0, -1,  0, 0],
                   [0,  0,  0,  0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / q[0]
        filter2 = np.asarray(filter2, dtype=float) / q[1]
        filter3 = np.asarray(filter3, dtype=float) / q[2]
        
        # 自定义卷积核权重
        self.filter = torch.tensor([[filter1, filter1, filter1], 
                                    [filter2, filter2, filter2], 
                                    [filter3, filter3, filter3]],
                                    dtype=torch.float32)

    def forward(self, input):
        
        def truncate(x):
            neg = ((x + 2) + abs(x + 2)) / 2 - 2
            return -(-neg + 2 + abs(- neg + 2)) / 2 + 2

        result = F.conv2d(input,
                          weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
                          stride=(1, 1, 1),
                          # 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为2
                          padding=(2, 2, 2)) 

        result = truncate(result)
        return result

Conv3d

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np

class SRM3D(nn.Module):
    def __init__(self):
        super().__init__()

        q = [4.0, 12.0, 4.0]
        filter1 = [[0,  0,  0,  0, 0],
                   [0, -1,  2, -1, 0],
                   [0,  2, -4,  2, 0],
                   [0, -1,  2, -1, 0],
                   [0,  0,  0,  0, 0]]

        filter2 = [[-1, 2, -2,  2,-1],
                   [ 2,-6,  8, -6, 2],
                   [-2, 8, -12, 8,-2],
                   [ 2,-6,  8, -6, 2],
                   [-1, 2, -2,  2,-1]]

        filter3 = [[0,  0,  0,  0, 0],
                   [0,  0, -1,  0, 0],
                   [0, -1, +4, -1, 0],
                   [0,  0, -1,  0, 0],
                   [0,  0,  0,  0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / q[0]
        filter2 = np.asarray(filter2, dtype=float) / q[1]
        filter3 = np.asarray(filter3, dtype=float) / q[2]
        
        # 自定义卷积核权重
        filter = torch.tensor([[filter1, filter1, filter1], 
                               [filter2, filter2, filter2], 
                               [filter3, filter3, filter3]],
                               dtype=torch.float32)
        
        # 因为是3D卷积,所以需要扩充维度
        self.filter = torch.unsqueeze(filter, 2)

    def forward(self, input):
        
        def truncate(x):
            neg = ((x + 2) + abs(x + 2)) / 2 - 2
            return -(-neg + 2 + abs(- neg + 2)) / 2 + 2

        result = F.conv3d(input,
                          weight=nn.Parameter(self.filter, requires_grad=False), # 设置为参数不可训练
                          stride=(1, 1, 1),
                          # 因为卷积核大小为5×5,步长为1,若想保持输出和输入大小相等,需设置padding为(0, 2, 2)
                          padding=(0, 2, 2))

        result = truncate(result)
        return result

Guess you like

Origin blog.csdn.net/weixin_46566663/article/details/127838652
Recommended