DeformableConv (deformable convolution) theory and code analysis

DeformableConv (deformable convolution) theory and code analysis

1. Theoretical analysis of DeformableConv

Common convolution example
The above figure is taken as an example to explain, in ordinary convolution:

  • The input size is [batch_size, channel, H, W]
  • The output input size is also [batch_size, channel, H, W]
  • The size of the kernel is [kernel_szie, kernel_szie]
    , so any point p0 in the output corresponds to the size of the convolution sampling area in the input as kernel_szie x kernel_szie, and the convolution operation can be expressed as:
    y ( p 0 ) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R }} \mathbf{w}\left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n }\right)y(p0)=pnRw(pn)x(p0+pn)
    among them,pn \mathbf{p}_{n}pnRepresents the offset of each point in the convolution kernel relative to the center point, which can be expressed by the following formula (3 x 3 convolution kernel as an example): R = {
    ( − 1 , − 1 ) , ( − 1 , 0 ) , … , ( 0 , 1 ) , ( 1 , 1 ) } \mathcal{R}=\{(-1,-1),(-1,0), \ldots,(0,1),(1,1 )\}R={(1,1),(1,0),,(0,1),(1,1)}
    insert image description here

w ( p n ) \mathbf{w}\left(\mathbf{p}_{n}\right) w(pn) represents the weight of the corresponding position on the convolution kernel. p 0 \mathbf{p}_{0}p0It can be regarded as each point on the output, y ( p 0 ) \mathbf{y}\left(\mathbf{p}_{0}\right)y(p0) is the specific value of each point on the output. x ( p 0 + pn ) \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}\right)x(p0+pn) is the specific value of each point on the output corresponding to the convolution sampling area on the input. The overall meaning of this formula is the convolution operation.

The calculation step of DeformableConv is to add an offset learned by the model itself on the basis of ordinary convolution. The formula is:
y ( p 0 ) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn + Δ pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R}} \mathbf{w}\ left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}+\Delta \mathbf{p }_{n}\right)y(p0)=pnRw(pn)x(p0+pn+p _n)
in the formula withΔ pn \Delta \mathbf{p}_{n}p _nIndicates the offset. It should be noted that the offset is for x \mathbf{x}The x , that is, the deformable convolution is not the convolution kernel, but the input.
insert image description here
As can be seen from the above figure, the normal convolution operation is performed on the input, and the convolution sampling area corresponding to a point on the output is a square with the size of a convolution kernel, while the convolution sampling area corresponding to the deformable convolution is Some points represented by blue boxes, this is the difference between deformable convolution and ordinary convolution.

Next, let's talk about the specific details of the deformable convolution, taking the N x N convolution kernel as an example. A point on the output corresponds to the convolution sampling area on the input. The size of the convolution sampling area is N x N. According to the operation of deformable convolution, each convolution sampling point in this N x N area must learn a deviation offset, and offset It is represented by coordinates, so an output needs to learn 2N x N parameters. An output size is H x W, so a total of 2NxN x H x W parameters need to be learned. That is, the offset field in the above figure has a dimension of B x 2NxN x H x W, where B represents batch_size, (the above figure is found on the Internet, and N in it refers to the area of ​​the convolution kernel). Two details worth noting are:

  • input (assuming the dimension is B x C x H x W) feature maps (a total of C) on all channels in a batch share an offset field, that is, the offset used by each feature map in a batch is the same of.
  • Deformable convolution does not change the size of the input, so the output is also H x W.

After adding offset to a point, there is a high probability that an integer coordinate point will not be obtained. At this time, a linear interpolation operation should be performed. The specific method is to find the four nearest points whose distance from the offset point is less than 1, and calculate their values Multiply by the weight (the weight is measured by distance) and then perform the sum operation, which is the value of the final offset point.

2. DeformableConv code analysis

According to the above theory, the code implementation of DeformableConv can be disassembled into the following flowchart:
insert image description here

2.1 Initialization operation

class DeformConv2d(nn.Module):
    def __init__(self, 
                 inc, 
                 outc, 
                 kernel_size=3, 
                 padding=1, 
                 stride=1, 
                 bias=None, 
                 modulation=False):
        """
        Args:
            modulation (bool, optional): If True, Modulated Defomable 
            Convolution (Deformable ConvNets v2).
        """
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, #该卷积用于最终的卷积
                              outc, 
                              kernel_size=kernel_size, 
                              stride=kernel_size, 
                              bias=bias)

        self.p_conv = nn.Conv2d(inc, #该卷积用于从input中学习offset
                                2*kernel_size*kernel_size, 
                                kernel_size=3, 
                                padding=1, 
                                stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)

        self.modulation = modulation #该部分是DeformableConv V2版本的,可以暂时不看
        if modulation:
            self.m_conv = nn.Conv2d(inc, 
                                    kernel_size*kernel_size, 
                                    kernel_size=3, 
                                    padding=1, 
                                    stride=stride)
            nn.init.constant_(self.m_conv.weight, 0)
            self.m_conv.register_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

2.2 Execution

    def forward(self, x):
        offset = self.p_conv(x) #此处得到offset
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)

        return out
        
    def _get_p_n(self, N, dtype): #求
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset

Supongo que te gusta

Origin blog.csdn.net/weixin_45453121/article/details/129139347
Recomendado
Clasificación