1、前言
由于transformer的成功,现已经在多目标跟踪领域(MOTR)广泛应用,例如TransTrack和TrackFormer,但这两个工作严格来说不能算作端到端的模型,而MOTR的出现弥补了上面两个工作的缺点。本文将对MOTR的网络结构简单说明,然后对推理代码进行解析,如有不对的地方,还请各位大佬指出。
paper : https://arxiv.org/abs/2105.03247
repo :https://github.com/megvii-research/MOTR
可以加v:Rex1586662742或者q群:468713665一起讨论
学习链接:MOTR、MOTR
MOTR中涉及到的Deformerable DRET 可参考:【BEV】学习笔记之 DeformableDETR(原理+代码解析)
2、网络结构
下面根据论文中的图示来分步解析MOTR的网络结构
如上图左边为DETR的解码过程,利用Object queries和Image featue进行Decoder获得目标框。MOTR中利用了这种方法,利用多帧的特征与Track queries进行Decoder从而获得被跟踪的目标, 从代码里面发现这里的track queries其实是包括300个初始化的 detect queries + n个track queries(下文中假定每一帧中有n个track queries),得到了当前帧的目标后,即更新track queries。
针对上图(b)中的Iterative Updata过程,MOTR论文中提供了下图进行说明。
左右两边是一个对应的过程,在t1时刻,首先初始化300个detect queries ,并检测到了两个目标,即Object1和Object2,由于者两个目标在之前从未出现过,因此将其划分为track queries。在t2时刻,再次初始化300个object queries 并检测到了新的目标Object3,由于Object1和object2已经被跟踪上,因此不会再次被检测到。于是在t3时刻得到了3个track queries,在t4时刻,由于Obeject2消失,因此,track queries相应的也要将object3对应的object query删除。
下面将对track query的加入与删除模块进行说明,文中将这个模块命名为 query interaction module(QIM)。
在上图中,通过预测可以得到当前帧的 detect query 和 track query。1、detect 进入到Object Enterance模块, 从detect queries筛选出符合条件的 query作为新的track query。2、track query进入到Object Exit中,筛选出目标消失的track queries,然后在track queries中做self_attn,最终将detect queries以及track queries筛选后的query作为当前帧的track queries。
通过上面几个步骤,最终的网络结构如下图所示:
3、代码解析
在代码开始之前需要对Instances类进行说明,后续的代码中track_instances是 Instances实例化的对象,用于储存每帧中的 track queries,以及每个track queries对应的类别、boxes等信息。
3.1、models/structures/instances.py
class Instances:
"""
This class represents a list of instances in an image.
It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
All fields must have the same ``__len__`` which is the number of instances.
储存每一帧的检测以及跟踪信息,包含detection query + track query
"""
def __init__(...):
self._fields: Dict[str, Any] = {
}
def set(self, name: str, value: Any) -> None:
"""
Set the field named `name` to `value`.
储存检测结果
"""
self._fields[name] = value
# 由于类初始的时候没有定义成员,后面给类添加成员变量的时候会默认调用这个函数
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self.set(name, val) # 将值存到 self._fields中
def cat(..)
"""合并两个 Instances"""
...
3.2、demo.py
class Detector(...):
def __init__(...):
self.tr_tracker = MOTR()
def run(...):
for _, cur_img, ori_img in tqdm(self.dataloader):
...
res = self.model.inference_single_image(...) # -> models/motr.py
track_instances = res['track_instances'] # 当前帧的 track_instances
# 筛选出track_instances 得分大于阈值的 track query
dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)
# 筛选出 dt_instances 中 box面积大于阈值的track query
dt_instances = self.filter_dt_by_area(dt_instances, area_threshold)
if dump:
tracker_outputs = self.tr_tracker.update(dt_instances) # [n, 6] 当前帧的 阈值与面积超过一定阈值的 [x1,y1,x2,y2,score,id]
3.3、models/motr.py
class MOTR(nn.Module):
def __init__(...):
self.post_process = TrackerPostProcess() # 将预测结果恢复到原图尺寸
self.track_base = RuntimeTrackerBase() # 进入到QIM模块之前给track queries进行标记
def _generate_empty_tracks(self):
...
def inference_single_image(...):
if track_instances is None:
track_instances = self._generate_empty_tracks() # 在第0帧初始化300个 detection queries
# track_instances:上一帧的跟踪结果 len(track_instances) = 300 + n 300为空的detect queries的个数, n为track queries的个数
# 利用上一帧的track qeries 和当前帧的图片特征进行decoder,来预测当前帧的检测结果,融合了时序信息。
res = self._forward_single_image(img,track_instances=track_instances) # res:{"pred_logits":"pred_boxes","","hs":"",...}当前帧的目标检测结果
res = self._post_process_single_image(res, track_instances, False) # 本帧检测到的目标,与上一帧的跟踪目标进行匹配
track_instances = self.post_process(track_instances, ori_img_size) # 将当前帧的跟踪结果缩放到图片尺寸大小
return ... # -> demo.py
def _forward_single_image(...):
features, pos = self.backbone(samples) # 提取图片特征,参考deformable detr
# track_instances.query_pos:[300 + n, 512]
# hs 每个decode的中间结果 len(hs)==6
# inter_references 每个decode的box
# enc_outputs_class 每个decode 的类别
hs, init_reference, inter_references, ... = self.transformer(...) # 2->models/deformable_transformer_plus.py
# pred_logits:每个框的类别、pred_boxes:每个框的box
out = {
'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'ref_pts': ref_pts_all[5]}
return out # -> inference_single_image
def _post_process_single_image(....):
with torch.no_grad():
if self.training:
track_scores = frame_res['pred_logits'][0, :].sigmoid().max(dim=-1).values
else:
track_scores = frame_res['pred_logits'][0, :, 0].sigmoid() # 当前帧检测到所有目标的得分
track_instances.scores = track_scores # [300 + n] detection queries + tracking queries
track_instances.pred_logits = frame_res['pred_logits'][0]
track_instances.pred_boxes = frame_res['pred_boxes'][0]
track_instances.output_embedding = frame_res['hs'][0]
if self.training:
track_instances = self.criterion.match_for_single_frame(frame_res)
else:
self.track_base.update(track_instances) # 根据当前帧的检测结果更新当前帧的track_instances的ID信息
tmp['init_track_instances'] = self._generate_empty_tracks() # 初始化300个detection query
tmp['track_instances'] = track_instances # 当前帧的 track_instances [300 + n]
if not is_last:
out_track_instances = self.track_embed(tmp) # QIM网络 -> models/qim.py
else:
...
return frame_res # -> inference_single_image
class RuntimeTrackerBase(...):
def __init__(...):
...
def update(...):
track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 # 得分大于阈值的设为0, 得分小于阈值超过一定次数后将被过滤
for i in range(len(track_instances)):
# 当前 query 没有目标 且 得分超过一定值
if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:
# 为这个query 设定一个ID ,说明开始跟踪一个目标
track_instances.obj_idxes[i] = self.max_obj_id
self.max_obj_id += 1 # 最大ID+1
# object_query存在目标,但是阈值小于一定值
elif track_instances.obj_idxes[i] >= 0 and track_instances.scores[i] < self.filter_score_thresh:
track_instances.disappear_time[i] += 1 # 目标消失的次数+1
# 当消失的次数大于阈值,则将query设为无目标
if track_instances.disappear_time[i] >= self.miss_tolerance:
track_instances.obj_idxes[i] = -1
3.4、models/deformable_transformer_plus.py
class DeformableTransformer(nn.Module):
def __init__(...):
...
def forward(...):
# encoder 提取图片特征 memory:[1, 23674, 256]
memory = self.encoder(...)
if self.two_stage:
...
else:
query_embed, tgt = torch.split(query_embed, c, dim=1) # query_embed:[300 + n, 256] tgt:[300 + n, 256]
if ref_pts is None:
...
else:
reference_points = ref_pts.unsqueeze(0).repeat(bs, 1, 1).sigmoid() # 归一化坐标
# decoder len(hs)==6,每个decode的中间结果,inter_references [6, 1, 300+n, 4] 每个decode的每个object_queries的bbox
hs, inter_references = self.decoder(...)
return hs, init_reference_out, inter_references_out, None, None # -> models/motr.py
class DeformableTransformerDecoder(nn.Module):
def __init__:
...
def forward(...):
"""
tgt: [1,300 + n , 256]
reference_points:[1,300+n,2]
src:[1, 23674, 256]
"""
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
...
else:
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] # [1,300,4] 每个object_queries 在每个特征层上找一个对应点
output = layer(...) # [1,300+n,256]
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points) #
class DeformableTransformerDecoderLayer(...):
def __init__(...):
...
def forward(...):
if self.self_cross:
return self._forward_self_cross(...)
def _forward_self_cross(...):
# self attention
tgt = self._forward_self_attn(tgt, query_pos, attn_mask) # [1,300+n,256]
# cross attention
# detect/track querise 与 图片特征图进行cross_attn 交叉注意力
tgt2 = self.cross_attn(...) # [1,300+n,256]
tgt = self.forward_ffn(tgt) # [1,300+n,256]
return tgt
def _forward_self_attn(...):
if self.extra_track_attn:
# tgt:[1,300+n,256] track queries 内部做self_attn
tgt = self._forward_track_attn(tgt, query_pos)
q = k = self.with_pos_embed(tgt, query_pos)
if attn_mask is not None:
...
else:
# detect queries 与 track queries 一起做self-attn
tgt2 = self.self_attn(...)
return self.norm2(tgt)
def _forward_track_attn(...):
q = k = self.with_pos_embed(tgt, query_pos) # [1,300+n,256]
if q.shape[1] > 300:
tgt2 = self.update_attn(...) # 即有目标的track_queries做self-attn,然后再将track queries与 detect queries
tgt = torch.cat([tgt[:, :300],self.norm4(tgt[:, 300:]+self.dropout5(tgt2))], dim=1)
return tgt
3.5、models/qim.py
class QueryInteractionModule(...):
def __init__(...):
...
def forward(...):
# data = track_instances
# data:[300 + n] 在进入 QIM网络之前,已经通过RuntimeTrackerBase更新了data中每个queries的属性
active_track_instances = self._select_active_tracks(data) # 挑选出 data里面跟踪到目标的queries [n]
active_track_instances = self._update_track_embedding(active_track_instances) # track queries 之间进行self_attn
merged_track_instances = Instances.cat([init_track_instances, active_track_instances]) # 将300个空的detetion_query 和当前帧最终的track query合并
return merged_track_instances
def _update_track_embedding(...):
# 当前帧跟踪到的 track query
tgt2 = self.self_attn(q[:, None], k[:, None], value=tgt[:, None])[0][:, 0]
if self.update_query_pos:
...
track_instances.query_pos[:, :dim // 2] = query_pos
track_instances.query_pos[:, dim//2:] = query_feat
4、学习小结
由于数据集的问题,本文只说明了推理过程的部分,按照这个,训练部分的代码可参考推理部分的代码,只是增加了最后计算loss的部分。通过本文,学习了如何利用transformer进行端到端的多目标跟踪,可以看出transformer已经在跟多的领域发挥作用,有机会分享更多的论文学习记录。