Théorie de DeformableConv (convolution déformable) et analyse de code
- Référence du code : code DeformableConv
- Référence théorique : explication vidéo bilibili
1. Analyse théorique de DeformableConv
La figure ci-dessus est prise comme exemple pour expliquer, en convolution ordinaire :
- La taille d'entrée est [batch_size, canal, H, W]
- La taille de l'entrée de sortie est également [batch_size, canal, H, W]
- La taille du noyau est [kernel_szie, kernel_szie]
, donc tout point p0 dans la sortie correspond à la taille de la zone d'échantillonnage de convolution dans l'entrée comme kernel_szie x kernel_szie, et l'opération de convolution peut être exprimée comme suit : y ( p 0
) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R }} \mathbf{w}\left(\mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf {p}_{n }\droite)oui( p.0)=pn∈R _∑w( p.n)⋅X( p.0+pn)
parmi eux,pn \mathbf{p}_{n}pnReprésente le décalage de chaque point du noyau de convolution par rapport au point central, qui peut être exprimé par la formule suivante (noyau de convolution 3 x 3 à titre d'exemple) : R = { ( − 1 , − 1 ) , ( − 1
, 0 ) , … , ( 0 , 1 ) , ( 1 , 1 ) } \mathcal{R}=\{(-1,-1),(-1,0), \ldots,(0,1),( 1,1 )\}R.={( − 1 ,− 1 ) ,( − 1 ,0 ) ,…,( 0 ,1 ) ,( 1 ,1 )}
w ( pn ) \mathbf{w}\left(\mathbf{p}_{n}\right)w( p.n) représente le poids de la position correspondante sur le noyau de convolution. p 0 \mathbf{p}_{0}p0Il peut être considéré comme chaque point de la sortie, y ( p 0 ) \mathbf{y}\left(\mathbf{p}_{0}\right)oui( p.0) est la valeur spécifique de chaque point sur la sortie. x ( p 0 + pn ) \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}\right)X( p.0+pn) est la valeur spécifique de chaque point de la sortie correspondant à la zone d'échantillonnage par convolution de l'entrée. La signification globale de cette formule est l’opération de convolution.
L'étape de calcul de DeformableConv consiste à ajouter un décalage appris par le modèle lui-même sur la base d'une convolution ordinaire. La formule est :
y ( p 0 ) = ∑ pn ∈ R w ( pn ) ⋅ x ( p 0 + pn + Δ pn ) \mathbf{y}\left(\mathbf{p}_{0}\right)=\sum_{\mathbf{p}_{n} \in \mathcal{R}} \mathbf{w}\ left( \mathbf{p}_{n}\right) \cdot \mathbf{x}\left(\mathbf{p}_{0}+\mathbf{p}_{n}+\Delta \mathbf{p }_ {n}\droite)oui( p.0)=pn∈R _∑w( p.n)⋅X( p.0+pn+p _n)
dans la formule avecΔ pn \Delta \mathbf{p}_{n}p _nIndique le décalage. Il convient de noter que le décalage est pour x \mathbf{x}Le x , c'est-à-dire la convolution déformable n'est pas le noyau de convolution, mais l'entrée.
Comme le montre la figure ci-dessus, l'opération de convolution normale est effectuée sur l'entrée, et la zone d'échantillonnage de convolution correspondant à un point sur la sortie est un carré de la taille d'un noyau de convolution, tandis que la zone d'échantillonnage de convolution correspondant au la convolution déformable est Certains points représentés par des cases bleues, c'est la différence entre la convolution déformable et la convolution ordinaire.
Parlons ensuite des détails spécifiques de la convolution déformable, en prenant le noyau de convolution N x N comme exemple. Un point en sortie correspond à la zone d'échantillonnage par convolution en entrée. La taille de la zone d'échantillonnage par convolution est N x N. Selon le fonctionnement de la convolution déformable, chaque point d'échantillonnage par convolution dans cette zone N x N doit apprendre un décalage d'écart , et offset Il est représenté par des coordonnées, donc une sortie doit apprendre 2N x N paramètres. Une taille de sortie est H x L, donc un total de paramètres 2NxN x H x L doivent être appris. Autrement dit, le champ de décalage dans la figure ci-dessus a une dimension de B x 2NxN x H x W, où B représente batch_size (la figure ci-dessus se trouve sur Internet et N fait référence à la zone de convolution noyau). Deux détails à noter sont :
- les cartes de caractéristiques d'entrée (en supposant que la dimension est B x C x H x L) (un total de C) sur tous les canaux d'un lot partagent un champ de décalage, c'est-à-dire que le décalage utilisé par chaque carte de caractéristiques dans un lot est le même que .
- La convolution déformable ne modifie pas la taille de l'entrée, donc la sortie est également H x W.
Après avoir ajouté un décalage à un point, il y a une forte probabilité qu'un point de coordonnées entier ne soit pas obtenu. À ce stade, une opération d'interpolation linéaire doit être effectuée. La méthode spécifique consiste à trouver les quatre points les plus proches dont la distance par rapport au point de décalage est inférieur à 1 et calculez leurs valeurs. Multipliez par le poids (le poids est mesuré par la distance), puis effectuez l'opération de somme, qui est la valeur du point de décalage final.
2. Analyse du code DeformableConv
Selon la théorie ci-dessus, l'implémentation du code de DeformableConv peut être démontée dans l'organigramme suivant :
2.1 Opération d'initialisation
class DeformConv2d(nn.Module):
def __init__(self,
inc,
outc,
kernel_size=3,
padding=1,
stride=1,
bias=None,
modulation=False):
"""
Args:
modulation (bool, optional): If True, Modulated Defomable
Convolution (Deformable ConvNets v2).
"""
super(DeformConv2d, self).__init__()
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.zero_padding = nn.ZeroPad2d(padding)
self.conv = nn.Conv2d(inc, #该卷积用于最终的卷积
outc,
kernel_size=kernel_size,
stride=kernel_size,
bias=bias)
self.p_conv = nn.Conv2d(inc, #该卷积用于从input中学习offset
2*kernel_size*kernel_size,
kernel_size=3,
padding=1,
stride=stride)
nn.init.constant_(self.p_conv.weight, 0)
self.p_conv.register_backward_hook(self._set_lr)
self.modulation = modulation #该部分是DeformableConv V2版本的,可以暂时不看
if modulation:
self.m_conv = nn.Conv2d(inc,
kernel_size*kernel_size,
kernel_size=3,
padding=1,
stride=stride)
nn.init.constant_(self.m_conv.weight, 0)
self.m_conv.register_backward_hook(self._set_lr)
@staticmethod
def _set_lr(module, grad_input, grad_output):
grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
2.2 Exécution
def forward(self, x):
offset = self.p_conv(x) #此处得到offset
if self.modulation:
m = torch.sigmoid(self.m_conv(x))
dtype = offset.data.type()
ks = self.kernel_size
N = offset.size(1) // 2
if self.padding:
x = self.zero_padding(x)
# (b, 2N, h, w)
p = self._get_p(offset, dtype)
# (b, h, w, 2N)
p = p.contiguous().permute(0, 2, 3, 1)
q_lt = p.detach().floor()
q_rb = q_lt + 1
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
# clip p
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
# bilinear kernel (b, h, w, N)
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
# (b, c, h, w, N)
x_q_lt = self._get_x_q(x, q_lt, N)
x_q_rb = self._get_x_q(x, q_rb, N)
x_q_lb = self._get_x_q(x, q_lb, N)
x_q_rt = self._get_x_q(x, q_rt, N)
# (b, c, h, w, N)
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
g_rb.unsqueeze(dim=1) * x_q_rb + \
g_lb.unsqueeze(dim=1) * x_q_lb + \
g_rt.unsqueeze(dim=1) * x_q_rt
# modulation
if self.modulation:
m = m.contiguous().permute(0, 2, 3, 1)
m = m.unsqueeze(dim=1)
m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
x_offset *= m
x_offset = self._reshape_x_offset(x_offset, ks)
out = self.conv(x_offset)
return out
def _get_p_n(self, N, dtype): #求
p_n_x, p_n_y = torch.meshgrid(
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
# (2N, 1)
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
return p_n
def _get_p_0(self, h, w, N, dtype):
p_0_x, p_0_y = torch.meshgrid(
torch.arange(1, h*self.stride+1, self.stride),
torch.arange(1, w*self.stride+1, self.stride))
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
return p_0
def _get_p(self, offset, dtype):
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
# (1, 2N, 1, 1)
p_n = self._get_p_n(N, dtype)
# (1, 2N, h, w)
p_0 = self._get_p_0(h, w, N, dtype)
p = p_0 + p_n + offset
return p
def _get_x_q(self, x, q, N):
b, h, w, _ = q.size()
padded_w = x.size(3)
c = x.size(1)
# (b, c, h*w)
x = x.contiguous().view(b, c, -1)
# (b, h, w, N)
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
# (b, c, h*w*N)
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
return x_offset
@staticmethod
def _reshape_x_offset(x_offset, ks):
b, c, h, w, N = x_offset.size()
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
return x_offset