ConditionalDetr paper interpretation + core source code interpretation


foreword

 This article mainly introduces the basic idea of ​​the ConditionalDetr paper and the implementation of the code. First, I paste the link of the big brother’s Zhihu interpretation. In addition, I just briefly introduce my opinion on the basis of it, which is far from the thorough interpretation of the author. (Looking up at the big guy)
 1. Zhihu interpretation
 2. Code address
 3. Paper address
 In addition, if you are interested, you can read other articles about detr written by me:
 1. Using nn.Transformer
 2. mmdet to interpret Detr
 3. DeformableDetr

1. Introduction to the paper

1.1. Research questions

 This article mainly solves the reasons for the slow convergence of Detr, so the author first analyzes the possible reasons for its slow convergence: the encoder only involves image feature vector extraction; the self-attn in the decoder only involves the interactive deduplication between queries; and the last It is possible to happen in cross attn.In the original Detr paper, query=content query + object query, and the original paper found that removing the object query in the second layer basically does not lose points, so the slow convergence is caused by the content query.

1.2. Visualizing Spatial Attention Heatmaps

 The author visualizes the spatial attention heatmap of decoder cross-attention in Detr:(content query+object query) * pk. It is found that at 50epoch, detr cannot predict the boundary of the object very well, which is the reason for the slow convergence.
 Here I wrote one myselfVisualize the spatial attention heatmap code of each head, If you are interested, you can take a look at: Visualization of Detr Spatial Attention.
insert image description here

1.3. Causes

 First analyze the cross-attention calculation method of the original Detr, pay attention to the use of addition, that is, cq interacts with ck and pk at the same time, which will easily confuse the network, so consider c and pdecouplingThat's it.
insert image description here

1.4. Conditional Cross Attn

 The strategy adopted by the author is very simple, just decoupling:
insert image description here

1.5. Structure diagram

 First, the object query: [N, 256] is mapped to a 2d reference point s: [N, 2], and then the following formula is used to map s to a sin code consistent with pk to obtain Ps.
insert image description here
After having Ps, the author considers that cq contains the boundary information of the object, so he passes cq through FFN to get T, and makes a dot product with Ps to get Pq.

insert image description here
 Then and cq through self-attnstitchingJust send it to cross-attn.

insert image description here
 In the final prediction stage, it is enough to use the reference point s and predict the offset.
insert image description here

2. Code explanation

 Since the author's code only changes the cross-attention part of Detr, I only introduce the changed part. For the rest, please refer to: mmdet Interpretation of Detr

2.1. Core code

#-------------------#
# 对应结构图中FFN
#-------------------#
class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
#---------------------------------#
# 将参考点s变成256维度的sincos的编码
#---------------------------------#  
def gen_sineembed_for_position(pos_tensor):
    # n_query, bs, _ = pos_tensor.size()
    # sineembed_tensor = torch.zeros(n_query, bs, 256)
    scale = 2 * math.pi
    dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000 ** (2 * (dim_t // 2) / 128)
    x_embed = pos_tensor[:, :, 0] * scale
    y_embed = pos_tensor[:, :, 1] * scale
    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
    pos = torch.cat((pos_y, pos_x), dim=2)
    return pos

#-------------------#
# pq的生成过程
#-------------------#
class TransformerDecoder(nn.Module):

	def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256):
	    super().__init__()
	    self.layers = _get_clones(decoder_layer, num_layers)
	    self.num_layers = num_layers
	    self.norm = norm
	    self.return_intermediate = return_intermediate
	    self.query_scale = MLP(d_model, d_model, d_model, 2)
	    self.ref_point_head = MLP(d_model, d_model, 2, 2)
	    for layer_id in range(num_layers - 1):
	        self.layers[layer_id + 1].ca_qpos_proj = None

	def forward(self, tgt, memory,
	            tgt_mask: Optional[Tensor] = None,
	            memory_mask: Optional[Tensor] = None,
	            tgt_key_padding_mask: Optional[Tensor] = None,
	            memory_key_padding_mask: Optional[Tensor] = None,
	            pos: Optional[Tensor] = None,
	            query_pos: Optional[Tensor] = None):
	    output = tgt
	
	    intermediate = []
	    # 得到映射参考点s:# [num_queries, batch_size, 2]
	    reference_points_before_sigmoid = self.ref_point_head(query_pos)    
	    # 经过sigmoid
	    reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1) 	
	    # 开始遍历6次decoder layer
	    for layer_id, layer in enumerate(self.layers):
	    	# 记录参考点s,也即后续box预测时用到
	        obj_center = reference_points[..., :2].transpose(0, 1)      
	        # For the first decoder layer, we do not apply transformation over p_s
	        if layer_id == 0:
	            pos_transformation = 1
	        else:
	        	# 论文中T,将cq经过FFN变换映射
	            pos_transformation = self.query_scale(output)          
	        #将参考点经过sin编码得到ps
	        query_sine_embed = gen_sineembed_for_position(obj_center)    
	        #对应元素相乘得到pq,二者维度[300,2,256]
	        query_sine_embed = query_sine_embed * pos_transformation     
	        # 遍历decoder layer
	        # 拆成多头,每个头均是 content + pos_embed --> (32d+32d)
	        q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
	        query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)  
	        query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead) 
	        # 将cq和pq进行拼接
	        q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
	        k = k.view(hw, bs, self.nhead, n_model//self.nhead)
	        k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
	        # 将ck和pk进行拼接
	        k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)
			# 送入nn.MultiHeadAttn()模块完成交叉注意力计算
	        tgt2 = self.cross_attn(query=q,
	                                   key=k,
	                                   value=v, attn_mask=memory_mask,
		                                   key_padding_mask=memory_key_padding_mask)[0]   

Summarize

 The structure of this paper is simple and effective. Only one decoupling strategy is used, and the training speed is increased by 10 times. DAB-Detr and DN-Detr will be explained later, so stay tuned. If you have any questions, please welcome +vx: wulele2541612007, and pull you into the group for discussion and exchange.

Guess you like

Origin blog.csdn.net/wulele2/article/details/123993727