In the DINO code study notes (1), the parameter processing before inputting the transformer has been sorted out, and the next step is to pass these parameters to the transformer.
DINO's transformer uses the variability transformer in Deformable-DETR (also used in their previous work)
Some of the previous settings are still used here. For the sake of coherence, here is a statement in advance:
1. Input size [2,3,640,701],
2. src is [[N,256,80,88],[N,256,40,44],[N,256,20,22],[N,256,10,11]], where N=2 ,
3, possible [[N,256,80,88],[N,256,40,44],[N,256,20,22],[N,256,10,11]]
4. The mask is [[N,80,88],[N,40,44],[N,20,22],[N,10,11]]
5. input_query_bbox [N, single_pad * 2 * dn_number, 256] ([N, 200, 256] in the batch);
6、input_query_label [N,200,4];
7、attn_mask [single_pad * 2 * dn_number + 900,single_pad * 2 * dn_number + 900](该batch中[1100,1100])
Here first paste the code of the main function
class DeformableTransformer(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False,
return_intermediate_dec=False, query_dim=4,
num_patterns=0,
modulate_hw_attn=False,
# for deformable encoder
deformable_encoder=False,
deformable_decoder=False,
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
use_deformable_box_attn=False,
box_attn_type='roi_align',
# init query
learnable_tgt_init=False,
decoder_query_perturber=None,
add_channel_attention=False,
add_pos_value=False,
random_refpoints_xy=False,
# two stage
two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
two_stage_pat_embed=0,
two_stage_add_query_num=0,
two_stage_learn_wh=False,
two_stage_keep_all_tokens=False,
# evo of #anchors
dec_layer_number=None,
rm_enc_query_scale=True,
rm_dec_query_scale=True,
rm_self_attn_layers=None,
key_aware_type=None,
# layer share
layer_share_type=None,
# for detach
rm_detach=None,
decoder_sa_type='ca',
module_seq=['sa', 'ca', 'ffn'],
# for dn
embed_init_tgt=False,
use_detached_boxes_dec_out=False,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.deformable_encoder = deformable_encoder
self.deformable_decoder = deformable_decoder
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
self.num_queries = num_queries
self.random_refpoints_xy = random_refpoints_xy
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
assert query_dim == 4
if num_feature_levels > 1:
assert deformable_encoder, "only support deformable_encoder for num_feature_levels > 1"
if use_deformable_box_attn:
assert deformable_encoder or deformable_encoder
assert layer_share_type in [None, 'encoder', 'decoder', 'both']
if layer_share_type in ['encoder', 'both']:
enc_layer_share = True
else:
enc_layer_share = False
if layer_share_type in ['decoder', 'both']:
dec_layer_share = True
else:
dec_layer_share = False
assert layer_share_type is None
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# choose encoder layer type
if deformable_encoder:
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, enc_n_points, add_channel_attention=add_channel_attention, use_deformable_box_attn=use_deformable_box_attn, box_attn_type=box_attn_type)
else:
raise NotImplementedError
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers,
encoder_norm, d_model=d_model,
num_queries=num_queries,
deformable_encoder=deformable_encoder,
enc_layer_share=enc_layer_share,
two_stage_type=two_stage_type
)
# choose decoder layer type
if deformable_decoder:
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
dropout, activation,
num_feature_levels, nhead, dec_n_points, use_deformable_box_attn=use_deformable_box_attn, box_attn_type=box_attn_type,
key_aware_type=key_aware_type,
decoder_sa_type=decoder_sa_type,
module_seq=module_seq)
else:
raise NotImplementedError
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model, query_dim=query_dim,
modulate_hw_attn=modulate_hw_attn,
num_feature_levels=num_feature_levels,
deformable_decoder=deformable_decoder,
decoder_query_perturber=decoder_query_perturber,
dec_layer_number=dec_layer_number, rm_dec_query_scale=rm_dec_query_scale,
dec_layer_share=dec_layer_share,
use_detached_boxes_dec_out=use_detached_boxes_dec_out
)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning("num_patterns should be int but {}".format(type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
self.two_stage_pat_embed = two_stage_pat_embed
self.two_stage_add_query_num = two_stage_add_query_num
self.two_stage_learn_wh = two_stage_learn_wh
assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type =='standard':
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
if two_stage_pat_embed > 0:
self.pat_embed_for_2stage = nn.Parameter(torch.Tensor(two_stage_pat_embed, d_model))
nn.init.normal_(self.pat_embed_for_2stage)
if two_stage_add_query_num > 0:
self.tgt_embed = nn.Embedding(self.two_stage_add_query_num, d_model)
if two_stage_learn_wh:
self.two_stage_wh_embedding = nn.Embedding(1, 2)
else:
self.two_stage_wh_embedding = None
if two_stage_type == 'no':
self.init_ref_points(num_queries) # init self.refpoint_embed
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
# evolution of anchors
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
if self.two_stage_type != 'no' or num_patterns == 0:
assert dec_layer_number[0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})"
else:
assert dec_layer_number[0] == num_queries * num_patterns, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})"
self._reset_parameters()
self.rm_self_attn_layers = rm_self_attn_layers
if rm_self_attn_layers is not None:
print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers))
for lid, dec_layer in enumerate(self.decoder.layers):
if lid in rm_self_attn_layers:
dec_layer.rm_self_attn_modules()
self.rm_detach = rm_detach
if self.rm_detach:
assert isinstance(rm_detach, list)
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
self.decoder.rm_detach = rm_detach
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
if self.two_stage_learn_wh:
nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)))
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1) # 取feature map中非padding部分的H (即feature map的实际大小)
valid_W = torch.sum(~mask[:, 0, :], 1) # 取feature map中非padding部分的W
valid_ratio_h = valid_H.float() / H # 计算feature map中非padding部分的H在当前batch下feature map中的H所占的比例
valid_ratio_w = valid_W.float() / W # 计算feature map中非padding部分的W在当前batch下feature map中的W所占的比例
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
if self.random_refpoints_xy:
self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer # 即input_query_bbox
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer # 即input_query_label
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c # 将H和W打平 [N,256,H,W] -> [N,H*W,256]
mask = mask.flatten(1) # bs, hw # [N,H,W] -> [N,H*W]
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c # 同样将H和W打平 [N,256,H,W] -> [N,H*W,256]
if self.num_feature_levels > 1 and self.level_embed is not None: # self.level_embed是一个[4,256]的tensor
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # 加上层数的embed
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c # 将打平后的tensor cat在一起,该batch中[N,9350,256]
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} 该batch中[N,9350]
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c 该batch中[N,9350,256]
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) # 存放着每一层feature map的[H,W],维度为[4,2]
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) # cat在一起后feature map的起始索引,如:第一层是0,第二层是H1*W1+0,第三层是H2*W2+H1*W1+0,最后一层H3*W3+H2*W2+H1*W1+0 共4维 如level_start_index = tensor([ 0, 7040, 8800, 9240], device='cuda:0')
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # 输出一个[N,4,2]的tensor,表示每一层的feature map中对应的非padding部分有效长宽与该层feature map长宽的比值
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
ref_token_index=enc_topk_proposals, # bs, nq
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
) # memory [N,9350,256];enc_intermediate_output=Nonw;enc_intermediate_refpoints=None
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
if self.two_stage_type =='standard':
if self.two_stage_learn_wh:
input_hw = self.two_stage_wh_embedding.weight[0]
else:
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes, input_hw)
output_memory = self.enc_output_norm(self.enc_output(output_memory)) # Linear(256,256) + Layer Norm
if self.two_stage_pat_embed > 0:
bs, nhw, _ = output_memory.shape
# output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
output_memory = output_memory.repeat(1, self.two_stage_pat_embed, 1)
_pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
output_memory = output_memory + _pats
output_proposals = output_proposals.repeat(1, self.two_stage_pat_embed, 1)
if self.two_stage_add_query_num > 0:
assert refpoint_embed is not None
output_memory = torch.cat((output_memory, tgt), dim=1)
output_proposals = torch.cat((output_proposals, refpoint_embed), dim=1)
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) # Linear(256,91) [N,9350,91]
enc_outputs_coord_unselected = self.enc_out_bbox_embed(output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid [N,9350,4]
topk = self.num_queries # 900
topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq top900索引[N,900]
# gather boxes
refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid 横向根据topk_proposals取值 [N,900,4]
refpoint_embed_ = refpoint_embed_undetach.detach() # refpoint_embed_ [N,900,4]
init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid init_box_proposal [N,900,4]
# gather tgt
tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model [N,900,256]
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed=torch.cat([refpoint_embed,refpoint_embed_],dim=1) # [N,1100,4]
tgt=torch.cat([tgt,tgt_],dim=1) # [N,1100,256]
else:
refpoint_embed,tgt=refpoint_embed_,tgt_
elif self.two_stage_type == 'no':
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed=torch.cat([refpoint_embed,refpoint_embed_],dim=1)
tgt=torch.cat([tgt,tgt_],dim=1)
else:
refpoint_embed,tgt=refpoint_embed_,tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries, 1) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1), # [1100,N,256]
memory=memory.transpose(0, 1), # [9350,N,256]
memory_key_padding_mask=mask_flatten, # [N,9350]
pos=lvl_pos_embed_flatten.transpose(0, 1), # [9350,N,256]
refpoints_unsigmoid=refpoint_embed.transpose(0, 1), # [1100,N,4]
level_start_index=level_start_index, # [4]
spatial_shapes=spatial_shapes, # [4,2]
valid_ratios=valid_ratios,tgt_mask=attn_mask) # valid_ratios [2,4,2],attn_mask [1100,1100]
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model [N,1100,256] * 6
# references: n_dec+1, bs, nq, query_dim [N,1100,4] * 7
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_keep_all_tokens:
hs_enc = output_memory.unsqueeze(0)
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
init_box_proposal = output_proposals
else:
hs_enc = tgt_undetach.unsqueeze(0) # [1,N,900,256]
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) # [1,N,900,4]
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
As you can see from the main function, some preprocessing of the parameters is required before entering the encoder
# prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c # flatten H and W [N,256,H,W] -> [N,H*W,256] mask = mask.flatten(1) # bs, hw # [N,H,W] -> [N,H*W] pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c # also flatten H and W [N,256,H,W] -> [N,H*W,256] if self.num_feature_levels > 1 and self.level_embed is not None: # self.level_embed是一个[4,256]的tensor lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # embed with layers else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c # Cat the flattened tensor together, [N,9350,256] in this batch mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} [N,9350] in the batch lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c 该batch中[N,9350,256] # Store the [H, W] of each layer of feature map, and the dimension is [4, 2] spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) # The starting index of the feature map after cat together, such as: the first layer is 0, the second layer is H1*W1+0, the third layer is H2*W2+H1*W1+0, and the last layer is H3*W3 +H2*W2+H1*W1+0 total 4 dimensions such as level_start_index = tensor([ 0, 7040, 8800, 9240], device='cuda:0') level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) # Output a tensor of [N,4,2], indicating the ratio of the effective length and width of the corresponding non-padding part in the feature map of each layer to the length and width of the feature map of this layer valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
1. Integrate the four-layer feature map into a query, assuming that the size of C2 is [H, W], then its dimension is len_q = H*W + H//2*W//2 + H//4*W/ /4 + H//8*W//8, the final dimension is [N, len_q, 256], where N is the batch size, and the number of layers embed will be added in the process, here is [N, 9350, 256], The dimensions of the corresponding position codes are also the same
2. The dimension of the mask is aligned with the query, which is [N,len_q]([N,9350])
3. spatial_shapes records the size of the four-layer feature map [4,2]([[80, 88],[40, 44],[20, 22],[10, 11]])
4. level_start_index records the starting index of the feature map after cat together, such as: the first layer is 0, the second layer is H1*W1+0, the third layer is H2*W2+H1*W1+0, and the last layer H3*W3+H2*W2+H1*W1+0 total 4 dimensions
5. valid_ratios outputs a tensor of [N, 4, 2], indicating the ratio of the effective length and width of the corresponding non-padding part (actual effective feature map) in the feature map of each layer to the length and width of the feature map of this layer
1. Encoder
######################################################### # Begin Encoder ######################################################### memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder( src_flatten, pos=lvl_pos_embed_flatten, level_start_index=level_start_index, spatial_shapes=spatial_shapes, valid_ratios=valid_ratios, key_padding_mask=mask_flatten, ref_token_index=enc_topk_proposals, # bs, nq ref_token_coord=enc_refpoint_embed, # bs, nq, 4 ) # memory [N,9350,256];enc_intermediate_output=Nonw;enc_intermediate_refpoints=None ######################################################### # End Encoder # - memory: bs, \sum{hw}, c # - mask_flatten: bs, \sum{hw} # - lvl_pos_embed_flatten: bs, \sum{hw}, c # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c) #########################################################
class TransformerEncoder(nn.Module):
def __init__(self,
encoder_layer, num_layers, norm=None, d_model=256,
num_queries=300,
deformable_encoder=False,
enc_layer_share=False, enc_layer_dropout_prob=None,
two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
):
super().__init__()
# prepare layers
if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
else:
self.layers = []
del encoder_layer
self.query_scale = None
self.num_queries = num_queries
self.deformable_encoder = deformable_encoder
self.num_layers = num_layers
self.norm = norm
self.d_model = d_model
self.enc_layer_dropout_prob = enc_layer_dropout_prob
if enc_layer_dropout_prob is not None:
assert isinstance(enc_layer_dropout_prob, list)
assert len(enc_layer_dropout_prob) == num_layers
for i in enc_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.two_stage_type = two_stage_type
if two_stage_type in ['enceachlayer', 'enclayer1']:
_proj_layer = nn.Linear(d_model, d_model)
_norm_layer = nn.LayerNorm(d_model)
if two_stage_type == 'enclayer1':
self.enc_norm = nn.ModuleList([_norm_layer])
self.enc_proj = nn.ModuleList([_proj_layer])
else:
self.enc_norm = nn.ModuleList([copy.deepcopy(_norm_layer) for i in range(num_layers - 1) ])
self.enc_proj = nn.ModuleList([copy.deepcopy(_proj_layer) for i in range(num_layers - 1) ])
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes): # 遍历feature map,第0层是尺寸最大的feature map H_=80,W_=88
# 根据feature map的尺寸生成网格,生成每个像素点的中心点归一化后的x,y坐标
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1) # [N,7040,2]
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1) # 再将所有的归一化后的中心点坐标cat在一起 [N,9350,2]
reference_points = reference_points[:, :, None] * valid_ratios[:, None] # 归一化的x,y坐标乘实际feature map有效区域的比值,得到每个中心点在实际feature map上归一化的坐标 [N,9350,4,2]
return reference_points
def forward(self,
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
ref_token_index: Optional[Tensor]=None,
ref_token_coord: Optional[Tensor]=None
):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- ref_token_index: bs, nq
- ref_token_coord: bs, nq, 4
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
if self.two_stage_type in ['no', 'standard', 'enceachlayer', 'enclayer1']:
assert ref_token_index is None
output = src
# preparation and reshape
if self.num_layers > 0:
if self.deformable_encoder:
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) # [N,9350,4,2]
intermediate_output = []
intermediate_ref = []
if ref_token_index is not None:
out_i = torch.gather(output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
# main process
for layer_id, layer in enumerate(self.layers):
# main process
dropflag = False
if self.enc_layer_dropout_prob is not None:
prob = random.random()
if prob < self.enc_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
if self.deformable_encoder:
output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask)
else:
output = layer(src=output.transpose(0, 1), pos=pos.transpose(0, 1), key_padding_mask=key_padding_mask).transpose(0, 1)
if ((layer_id == 0 and self.two_stage_type in ['enceachlayer', 'enclayer1']) \
or (self.two_stage_type == 'enceachlayer')) \
and (layer_id != self.num_layers - 1):
output_memory, output_proposals = gen_encoder_output_proposals(output, key_padding_mask, spatial_shapes)
output_memory = self.enc_norm[layer_id](self.enc_proj[layer_id](output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.class_embed[layer_id](output_memory)
ref_token_index = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq
ref_token_coord = torch.gather(output_proposals, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, 4))
output = output_memory
# aux loss
if (layer_id != self.num_layers - 1) and ref_token_index is not None:
out_i = torch.gather(output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
if self.norm is not None:
output = self.norm(output)
if ref_token_index is not None:
intermediate_output = torch.stack(intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model
intermediate_ref = torch.stack(intermediate_ref)
else:
intermediate_output = intermediate_ref = None
return output, intermediate_output, intermediate_ref
The shape of reference_points is [N, len_q, 4, 2] ([N, 9350, 4, 2]), and the relative position in the feature map of each layer is obtained.
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(self,
d_model=256, d_ffn=1024,
dropout=0.1, activation="relu",
n_levels=4, n_heads=8, n_points=4,
add_channel_attention=False,
use_deformable_box_attn=False,
box_attn_type='roi_align',
):
super().__init__()
# self attention
if use_deformable_box_attn:
self.self_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
else:
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# channel attention
self.add_channel_attention = add_channel_attention
if add_channel_attention:
self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
self.norm_channel = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
# self attention
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
# channel attn
if self.add_channel_attention:
src = self.norm_channel(src + self.activ_channel(src))
return src
The diagram of the encoder:
MSDeformAttn:
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
"""
Multi-Scale Deformable Attention Module
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation.")
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.)
constant_(self.attention_weights.bias.data, 0.)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.)
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
N, Len_q, _ = query.shape # Len_q9350/1100
N, Len_in, _ = input_flatten.shape # Len_in9350
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten) # 输入经过一个Linear层,维度由[N,Len_in,256] -> [N,Len_in,256],得到value
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0)) # 在value中,mask中对应元素为True的位置都用0填充
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) # value的shape由[N,Len_in,256] -> [N,Len_in,8,32]
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) # 每个query产生对应不同head不同level的偏置,sampling_offsets的shape由[N,Len_q,256] -> [N,Len_q,8,4,4,2]
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) # 每个偏置向量的权重,经过Linear(256,128),attention_weights的shape由[N,Len_q,256] -> [N,Len_q,8,16]
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # 对属于同一个query的来自与不同level的offset后向量权重在每个head分别归一化,softmax后attention_weights的shape由[N,Len_q,8,16] -> [N,Len_q,8,4,4]
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) # offset_normalizer 将input_spatial_shapes中[H,W]的形式转化为[W,H],input_spatial_shapes的shape还是[4,2]
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] # 采样点的坐标[N,Len_q,8,4,4,2]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
# for amp
if value.dtype == torch.float16:
# for mixed precision
output = MSDeformAttnFunction.apply(
value.to(torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations.to(torch.float32), attention_weights, self.im2col_step)
output = output.to(torch.float16)
output = self.output_proj(output)
return output
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
output = self.output_proj(output) # 输出经过一个Linear层,维度由[N,Len_q,256] -> [N,Len_q,256]
return output
In the source code, set n_head to 8, d_model to 256, n_levels to 4, and n_points to 4.
The MSDeformAttn function is to pass in the srcs with pos_embeds as a query. Each query corresponds to a reference_point on the feature map. Based on each reference_point, n = 4 keys are selected, and feature fusion is performed according to the attention_weights generated by Linear (the attention weight is not Q * k is calculated, but obtained directly from the query Linear). The specific information of sampling_offsets and attention_weights is marked in the above code segment, so I won't say much here.
Diagram of deformable transformer (from Deformable-DETR):
The corresponding formula:
MSDeformAttnFunction calls cuda programming, but there is a pytorch implementation in the code:
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape # value shpae [N,len_q,8,32]
_, Lq_, M_, L_, P_, _ = sampling_locations.shape # shape [N,len_q,8,4,4,2]
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) # 区分每个feature map level
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # [N,H_*W_,8,32] -> [N*8,32,H_,W_]
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
# F.grid_sample这个函数的作用就是给定输入input和网格grid,根据grid中的像素位置从input中取出对应位置的值(可能需要插值)得到输出output。
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) # shape [N,len_q,8,4,4] -> [N*8,1,len_q,16]
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) # 对应上论文中的公式
return output.transpose(1, 2).contiguous()
encoder output:
1、memory[N,9350,256];
2、enc_intermediate_output=None;
3、enc_intermediate_refpoints=None;
Go here and go through the encoder part first, and then update the decoder and loss parts