【MOT】多目标追踪学习笔记之MOTR

1、前言

由于transformer的成功,现已经在多目标跟踪领域(MOTR)广泛应用,例如TransTrack和TrackFormer,但这两个工作严格来说不能算作端到端的模型,而MOTR的出现弥补了上面两个工作的缺点。本文将对MOTR的网络结构简单说明,然后对推理代码进行解析,如有不对的地方,还请各位大佬指出。
paper : https://arxiv.org/abs/2105.03247
repo :https://github.com/megvii-research/MOTR
可以加v:Rex1586662742或者q群:468713665一起讨论
学习链接:MOTRMOTR

MOTR中涉及到的Deformerable DRET 可参考:【BEV】学习笔记之 DeformableDETR(原理+代码解析)

2、网络结构

下面根据论文中的图示来分步解析MOTR的网络结构
在这里插入图片描述
如上图左边为DETR的解码过程,利用Object queries和Image featue进行Decoder获得目标框。MOTR中利用了这种方法,利用多帧的特征与Track queries进行Decoder从而获得被跟踪的目标, 从代码里面发现这里的track queries其实是包括300个初始化的 detect queries + n个track queries(下文中假定每一帧中有n个track queries),得到了当前帧的目标后,即更新track queries。

针对上图(b)中的Iterative Updata过程,MOTR论文中提供了下图进行说明。
在这里插入图片描述
左右两边是一个对应的过程,在t1时刻,首先初始化300个detect queries ,并检测到了两个目标,即Object1和Object2,由于者两个目标在之前从未出现过,因此将其划分为track queries。在t2时刻,再次初始化300个object queries 并检测到了新的目标Object3,由于Object1和object2已经被跟踪上,因此不会再次被检测到。于是在t3时刻得到了3个track queries,在t4时刻,由于Obeject2消失,因此,track queries相应的也要将object3对应的object query删除。

下面将对track query的加入与删除模块进行说明,文中将这个模块命名为 query interaction module(QIM)。
在这里插入图片描述
在上图中,通过预测可以得到当前帧的 detect query 和 track query。1、detect 进入到Object Enterance模块, 从detect queries筛选出符合条件的 query作为新的track query。2、track query进入到Object Exit中,筛选出目标消失的track queries,然后在track queries中做self_attn,最终将detect queries以及track queries筛选后的query作为当前帧的track queries。

通过上面几个步骤,最终的网络结构如下图所示:
在这里插入图片描述

3、代码解析

在代码开始之前需要对Instances类进行说明,后续的代码中track_instances是 Instances实例化的对象,用于储存每帧中的 track queries,以及每个track queries对应的类别、boxes等信息。

3.1、models/structures/instances.py

class Instances:
    """
    This class represents a list of instances in an image.
    It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
    All fields must have the same ``__len__`` which is the number of instances.
    储存每一帧的检测以及跟踪信息,包含detection query + track query
    """
    def __init__(...):
        self._fields: Dict[str, Any] = {
    
    }
    
    def set(self, name: str, value: Any) -> None:
        """
        Set the field named `name` to `value`.
        储存检测结果
        """
        self._fields[name] = value
    
    # 由于类初始的时候没有定义成员,后面给类添加成员变量的时候会默认调用这个函数
    def __setattr__(self, name: str, val: Any) -> None:
        if name.startswith("_"):
            super().__setattr__(name, val)
        else:
            self.set(name, val)  # 将值存到 self._fields中
            
    def cat(..)
        """合并两个 Instances"""
        ...

3.2、demo.py

class Detector(...):
    def __init__(...):
        self.tr_tracker = MOTR()
    
    def run(...):
        for _, cur_img, ori_img in tqdm(self.dataloader):
            ...
            res = self.model.inference_single_image(...)  # -> models/motr.py
            
            track_instances = res['track_instances'] # 当前帧的 track_instances
            
            # 筛选出track_instances 得分大于阈值的 track query
            dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)
            # 筛选出 dt_instances 中 box面积大于阈值的track query
            dt_instances = self.filter_dt_by_area(dt_instances, area_threshold)
            
            if dump: 
                tracker_outputs = self.tr_tracker.update(dt_instances)  # [n, 6] 当前帧的 阈值与面积超过一定阈值的 [x1,y1,x2,y2,score,id]
                

3.3、models/motr.py

class MOTR(nn.Module):
    def __init__(...):
        self.post_process = TrackerPostProcess()  # 将预测结果恢复到原图尺寸
        self.track_base = RuntimeTrackerBase()    # 进入到QIM模块之前给track queries进行标记
    
    def _generate_empty_tracks(self):
        ...
    def inference_single_image(...):
        if track_instances is None:
            track_instances = self._generate_empty_tracks()  # 在第0帧初始化300个 detection queries
     
        # track_instances:上一帧的跟踪结果  len(track_instances) = 300 + n   300为空的detect queries的个数, n为track queries的个数 
        # 利用上一帧的track qeries 和当前帧的图片特征进行decoder,来预测当前帧的检测结果,融合了时序信息。
        res = self._forward_single_image(img,track_instances=track_instances)  # res:{"pred_logits":"pred_boxes","","hs":"",...}当前帧的目标检测结果
        
        res = self._post_process_single_image(res, track_instances, False)   # 本帧检测到的目标,与上一帧的跟踪目标进行匹配
        
        track_instances = self.post_process(track_instances, ori_img_size)  # 将当前帧的跟踪结果缩放到图片尺寸大小
     
        return ... # -> demo.py
        
    def _forward_single_image(...):
        features, pos = self.backbone(samples) # 提取图片特征,参考deformable detr
        
        # track_instances.query_pos:[300 + n, 512]
    
        # hs 每个decode的中间结果 len(hs)==6 
        # inter_references 每个decode的box
        # enc_outputs_class 每个decode 的类别
        hs, init_reference, inter_references, ... = self.transformer(...)  # 2->models/deformable_transformer_plus.py
        
        
        #  pred_logits:每个框的类别、pred_boxes:每个框的box
        out = {
    
    'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'ref_pts': ref_pts_all[5]}
        return out # -> inference_single_image
        
    def _post_process_single_image(....):
        with torch.no_grad():
            if self.training:
                track_scores = frame_res['pred_logits'][0, :].sigmoid().max(dim=-1).values
            else:
                track_scores = frame_res['pred_logits'][0, :, 0].sigmoid()  # 当前帧检测到所有目标的得分
                
        track_instances.scores = track_scores  # [300 + n]   detection queries + tracking queries
        track_instances.pred_logits = frame_res['pred_logits'][0]
        track_instances.pred_boxes = frame_res['pred_boxes'][0]
        track_instances.output_embedding = frame_res['hs'][0]
        
        if self.training:
            track_instances = self.criterion.match_for_single_frame(frame_res)
        else: 
            self.track_base.update(track_instances)   # 根据当前帧的检测结果更新当前帧的track_instances的ID信息
        
            
        tmp['init_track_instances'] = self._generate_empty_tracks()  # 初始化300个detection query
        tmp['track_instances'] = track_instances  # 当前帧的 track_instances [300 + n]
        if not is_last:
            out_track_instances = self.track_embed(tmp)   # QIM网络  -> models/qim.py
        else:
            ...
        return frame_res  # -> inference_single_image
        
class RuntimeTrackerBase(...):
    def __init__(...):
        ...
    
    def update(...):
        track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 # 得分大于阈值的设为0, 得分小于阈值超过一定次数后将被过滤
        for i in range(len(track_instances)):
            # 当前 query 没有目标 且 得分超过一定值
            if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:
                # 为这个query 设定一个ID ,说明开始跟踪一个目标
                track_instances.obj_idxes[i] = self.max_obj_id
            self.max_obj_id += 1 # 最大ID+1
            # object_query存在目标,但是阈值小于一定值
            elif track_instances.obj_idxes[i] >= 0 and track_instances.scores[i] < self.filter_score_thresh:
                 track_instances.disappear_time[i] += 1 # 目标消失的次数+1 
                 # 当消失的次数大于阈值,则将query设为无目标
                 if track_instances.disappear_time[i] >= self.miss_tolerance:
                    track_instances.obj_idxes[i] = -1
                
                

3.4、models/deformable_transformer_plus.py

class DeformableTransformer(nn.Module):
    def __init__(...):
        ...
    
    def forward(...):
        # encoder 提取图片特征 memory:[1, 23674, 256]
        memory = self.encoder(...)
        if self.two_stage:
            ...
        else:
            query_embed, tgt = torch.split(query_embed, c, dim=1)  # query_embed:[300 + n, 256]  tgt:[300 + n, 256]        
        if ref_pts is None:
            ...
        else:
            reference_points = ref_pts.unsqueeze(0).repeat(bs, 1, 1).sigmoid()  # 归一化坐标
        
        # decoder len(hs)==6,每个decode的中间结果,inter_references [6, 1, 300+n, 4] 每个decode的每个object_queries的bbox
        hs, inter_references = self.decoder(...)
        
        return hs, init_reference_out, inter_references_out, None, None  # ->  models/motr.py
        
class DeformableTransformerDecoder(nn.Module):
    def __init__:
        ...
    
    def forward(...):
        """
        tgt: [1,300 + n , 256]
        reference_points:[1,300+n,2]
        src:[1, 23674, 256]
        
        """
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                ...
            else:
                reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]  # [1,300,4] 每个object_queries 在每个特征层上找一个对应点
            
            output = layer(...)  #  [1,300+n,256]
         
        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(intermediate_reference_points) # 
        
class DeformableTransformerDecoderLayer(...):
    def __init__(...):
		...
    
    def forward(...):
        if self.self_cross:
            return self._forward_self_cross(...)
            

    def _forward_self_cross(...):
        # self attention
        tgt = self._forward_self_attn(tgt, query_pos, attn_mask)  # [1,300+n,256]
        
        # cross attention
        # detect/track querise 与 图片特征图进行cross_attn 交叉注意力
        tgt2 = self.cross_attn(...)   # [1,300+n,256]
        tgt = self.forward_ffn(tgt) # [1,300+n,256]
        return tgt 
        
    def _forward_self_attn(...):
        if self.extra_track_attn:
            # tgt:[1,300+n,256]  track queries 内部做self_attn
            tgt = self._forward_track_attn(tgt, query_pos) 
        q = k = self.with_pos_embed(tgt, query_pos)
        if attn_mask is not None:
            ...
        else:
            # detect queries 与 track queries 一起做self-attn
            tgt2 = self.self_attn(...)
        return self.norm2(tgt)
    def _forward_track_attn(...):
        q = k = self.with_pos_embed(tgt, query_pos)  # [1,300+n,256]
        if q.shape[1] > 300:
            tgt2 = self.update_attn(...)  # 即有目标的track_queries做self-attn,然后再将track queries与 detect queries
            tgt = torch.cat([tgt[:, :300],self.norm4(tgt[:, 300:]+self.dropout5(tgt2))], dim=1)
        return tgt

        

3.5、models/qim.py

class QueryInteractionModule(...):
    def __init__(...):
        ...
    
    def forward(...):
        # data = track_instances
        # data:[300 + n] 在进入 QIM网络之前,已经通过RuntimeTrackerBase更新了data中每个queries的属性
        active_track_instances = self._select_active_tracks(data) # 挑选出 data里面跟踪到目标的queries  [n]
        active_track_instances = self._update_track_embedding(active_track_instances) # track queries 之间进行self_attn
        merged_track_instances = Instances.cat([init_track_instances, active_track_instances])  # 将300个空的detetion_query  和当前帧最终的track query合并
        return merged_track_instances
    
    def _update_track_embedding(...):
        # 当前帧跟踪到的 track query
        tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0]
        if self.update_query_pos:
            ...
            track_instances.query_pos[:, :dim // 2] = query_pos
        track_instances.query_pos[:, dim//2:] = query_feat

4、学习小结

由于数据集的问题,本文只说明了推理过程的部分,按照这个,训练部分的代码可参考推理部分的代码,只是增加了最后计算loss的部分。通过本文,学习了如何利用transformer进行端到端的多目标跟踪,可以看出transformer已经在跟多的领域发挥作用,有机会分享更多的论文学习记录。

猜你喜欢

转载自blog.csdn.net/weixin_42108183/article/details/129000574