Théorie de DeformableConv (convolution déformable) et analyse de code

Théorie de DeformableConv (convolution déformable) et analyse de code

1. Analyse théorique de DeformableConv

Exemple de convolution courant
La figure ci-dessus est prise comme exemple pour expliquer, en convolution ordinaire :

  • La taille d'entrée est [batch_size, canal, H, W]
  • La taille de l'entrée de sortie est également [batch_size, canal, H, W]
  • La taille du noyau est [kernel_szie, kernel_szie]
    , donc tout point p0 dans la sortie correspond à la taille de la zone d'échantillonnage de convolution dans l'entrée comme kernel_szie x kernel_szie, et l'opération de convolution peut être exprimée comme suit : y ( p 0
    ) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R }} \mathbf{w}\left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf {p}_{n }\droite)oui( p.0)=pn∈R _w( p.n)X( p.0+pn)
    parmi eux,pn \mathbf{p}_{n}pnReprésente le décalage de chaque point du noyau de convolution par rapport au point central, qui peut être exprimé par la formule suivante (noyau de convolution 3 x 3 à titre d'exemple) : R = { ( − 1 , − 1 ) , ( − 1
    , 0 ) , … , ( 0 , 1 ) , ( 1 , 1 ) } \mathcal{R}=\{(-1,-1),(-1,0), \ldots,(0,1),( 1,1 )\}R.={( 1 ,1 ) ,( 1 ,0 ) ,,( 0 ,1 ) ,( 1 ,1 )}
    insérer la description de l'image ici

w ( pn ) \mathbf{w}\left(\mathbf{p}_{n}\right)w( p.n) représente le poids de la position correspondante sur le noyau de convolution. p 0 \mathbf{p}_{0}p0Il peut être considéré comme chaque point de la sortie, y ( p 0 ) \mathbf{y}\left(\mathbf{p}_{0}\right)oui( p.0) est la valeur spécifique de chaque point sur la sortie. x ( p 0 + pn ) \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}\right)X( p.0+pn) est la valeur spécifique de chaque point de la sortie correspondant à la zone d'échantillonnage par convolution de l'entrée. La signification globale de cette formule est l’opération de convolution.

L'étape de calcul de DeformableConv consiste à ajouter un décalage appris par le modèle lui-même sur la base d'une convolution ordinaire. La formule est :
y ( p 0 ) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn + Δ pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R}} \mathbf{w}\ left( \mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}+\Delta \mathbf{p }_ {n}\droite)oui( p.0)=pn∈R _w( p.n)X( p.0+pn+p _n)
dans la formule avecΔ pn \Delta \mathbf{p}_{n}p _nIndique le décalage. Il convient de noter que le décalage est pour x \mathbf{x}Le x , c'est-à-dire la convolution déformable n'est pas le noyau de convolution, mais l'entrée.
insérer la description de l'image ici
Comme le montre la figure ci-dessus, l'opération de convolution normale est effectuée sur l'entrée, et la zone d'échantillonnage de convolution correspondant à un point sur la sortie est un carré de la taille d'un noyau de convolution, tandis que la zone d'échantillonnage de convolution correspondant au la convolution déformable est Certains points représentés par des cases bleues, c'est la différence entre la convolution déformable et la convolution ordinaire.

Parlons ensuite des détails spécifiques de la convolution déformable, en prenant le noyau de convolution N x N comme exemple. Un point en sortie correspond à la zone d'échantillonnage par convolution en entrée. La taille de la zone d'échantillonnage par convolution est N x N. Selon le fonctionnement de la convolution déformable, chaque point d'échantillonnage par convolution dans cette zone N x N doit apprendre un décalage d'écart , et offset Il est représenté par des coordonnées, donc une sortie doit apprendre 2N x N paramètres. Une taille de sortie est H x L, donc un total de paramètres 2NxN x H x L doivent être appris. Autrement dit, le champ de décalage dans la figure ci-dessus a une dimension de B x 2NxN x H x W, où B représente batch_size (la figure ci-dessus se trouve sur Internet et N fait référence à la zone de convolution noyau). Deux détails à noter sont :

  • les cartes de caractéristiques d'entrée (en supposant que la dimension est B x C x H x L) (un total de C) sur tous les canaux d'un lot partagent un champ de décalage, c'est-à-dire que le décalage utilisé par chaque carte de caractéristiques dans un lot est le même que .
  • La convolution déformable ne modifie pas la taille de l'entrée, donc la sortie est également H x W.

Après avoir ajouté un décalage à un point, il y a une forte probabilité qu'un point de coordonnées entier ne soit pas obtenu. À ce stade, une opération d'interpolation linéaire doit être effectuée. La méthode spécifique consiste à trouver les quatre points les plus proches dont la distance par rapport au point de décalage est inférieur à 1 et calculez leurs valeurs. Multipliez par le poids (le poids est mesuré par la distance), puis effectuez l'opération de somme, qui est la valeur du point de décalage final.

2. Analyse du code DeformableConv

Selon la théorie ci-dessus, l'implémentation du code de DeformableConv peut être démontée dans l'organigramme suivant :
insérer la description de l'image ici

2.1 Opération d'initialisation

class DeformConv2d(nn.Module):
    def __init__(self, 
                 inc, 
                 outc, 
                 kernel_size=3, 
                 padding=1, 
                 stride=1, 
                 bias=None, 
                 modulation=False):
        """
        Args:
            modulation (bool, optional): If True, Modulated Defomable 
            Convolution (Deformable ConvNets v2).
        """
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, #该卷积用于最终的卷积
                              outc, 
                              kernel_size=kernel_size, 
                              stride=kernel_size, 
                              bias=bias)

        self.p_conv = nn.Conv2d(inc, #该卷积用于从input中学习offset
                                2*kernel_size*kernel_size, 
                                kernel_size=3, 
                                padding=1, 
                                stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)

        self.modulation = modulation #该部分是DeformableConv V2版本的,可以暂时不看
        if modulation:
            self.m_conv = nn.Conv2d(inc, 
                                    kernel_size*kernel_size, 
                                    kernel_size=3, 
                                    padding=1, 
                                    stride=stride)
            nn.init.constant_(self.m_conv.weight, 0)
            self.m_conv.register_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

2.2 Exécution

    def forward(self, x):
        offset = self.p_conv(x) #此处得到offset
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)

        return out
        
    def _get_p_n(self, N, dtype): #求
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset

Je suppose que tu aimes

Origine blog.csdn.net/weixin_45453121/article/details/129139347
conseillé
Classement