Article directory
foreword
This article mainly introduces the basic idea of the ConditionalDetr paper and the implementation of the code. First, I paste the link of the big brother’s Zhihu interpretation. In addition, I just briefly introduce my opinion on the basis of it, which is far from the thorough interpretation of the author. (Looking up at the big guy)
1. Zhihu interpretation
2. Code address
3. Paper address
In addition, if you are interested, you can read other articles about detr written by me:
1. Using nn.Transformer
2. mmdet to interpret Detr
3. DeformableDetr
1. Introduction to the paper
1.1. Research questions
This article mainly solves the reasons for the slow convergence of Detr, so the author first analyzes the possible reasons for its slow convergence: the encoder only involves image feature vector extraction; the self-attn in the decoder only involves the interactive deduplication between queries; and the last It is possible to happen in cross attn.In the original Detr paper, query=content query + object query, and the original paper found that removing the object query in the second layer basically does not lose points, so the slow convergence is caused by the content query.
1.2. Visualizing Spatial Attention Heatmaps
The author visualizes the spatial attention heatmap of decoder cross-attention in Detr:(content query+object query) * pk. It is found that at 50epoch, detr cannot predict the boundary of the object very well, which is the reason for the slow convergence.
Here I wrote one myselfVisualize the spatial attention heatmap code of each head, If you are interested, you can take a look at: Visualization of Detr Spatial Attention.
1.3. Causes
First analyze the cross-attention calculation method of the original Detr, pay attention to the use of addition, that is, cq interacts with ck and pk at the same time, which will easily confuse the network, so consider c and pdecouplingThat's it.
1.4. Conditional Cross Attn
The strategy adopted by the author is very simple, just decoupling:
1.5. Structure diagram
First, the object query: [N, 256] is mapped to a 2d reference point s: [N, 2], and then the following formula is used to map s to a sin code consistent with pk to obtain Ps.
After having Ps, the author considers that cq contains the boundary information of the object, so he passes cq through FFN to get T, and makes a dot product with Ps to get Pq.
Then and cq through self-attnstitchingJust send it to cross-attn.
In the final prediction stage, it is enough to use the reference point s and predict the offset.
2. Code explanation
Since the author's code only changes the cross-attention part of Detr, I only introduce the changed part. For the rest, please refer to: mmdet Interpretation of Detr
2.1. Core code
#-------------------#
# 对应结构图中FFN
#-------------------#
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
#---------------------------------#
# 将参考点s变成256维度的sincos的编码
#---------------------------------#
def gen_sineembed_for_position(pos_tensor):
# n_query, bs, _ = pos_tensor.size()
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
pos = torch.cat((pos_y, pos_x), dim=2)
return pos
#-------------------#
# pq的生成过程
#-------------------#
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
self.query_scale = MLP(d_model, d_model, d_model, 2)
self.ref_point_head = MLP(d_model, d_model, 2, 2)
for layer_id in range(num_layers - 1):
self.layers[layer_id + 1].ca_qpos_proj = None
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
# 得到映射参考点s:# [num_queries, batch_size, 2]
reference_points_before_sigmoid = self.ref_point_head(query_pos)
# 经过sigmoid
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
# 开始遍历6次decoder layer
for layer_id, layer in enumerate(self.layers):
# 记录参考点s,也即后续box预测时用到
obj_center = reference_points[..., :2].transpose(0, 1)
# For the first decoder layer, we do not apply transformation over p_s
if layer_id == 0:
pos_transformation = 1
else:
# 论文中T,将cq经过FFN变换映射
pos_transformation = self.query_scale(output)
#将参考点经过sin编码得到ps
query_sine_embed = gen_sineembed_for_position(obj_center)
#对应元素相乘得到pq,二者维度[300,2,256]
query_sine_embed = query_sine_embed * pos_transformation
# 遍历decoder layer
# 拆成多头,每个头均是 content + pos_embed --> (32d+32d)
q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# 将cq和pq进行拼接
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
# 将ck和pk进行拼接
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
# 送入nn.MultiHeadAttn()模块完成交叉注意力计算
tgt2 = self.cross_attn(query=q,
key=k,
value=v, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
Summarize
The structure of this paper is simple and effective. Only one decoupling strategy is used, and the training speed is increased by 10 times. DAB-Detr and DN-Detr will be explained later, so stay tuned. If you have any questions, please welcome +vx: wulele2541612007, and pull you into the group for discussion and exchange.