论文《Fast Online Object Tracking and Segmentation- A Unifying Approach》项目代码解读3

1.demo.py

项目主程序
参数设置
parser = argparse.ArgumentParser(description='PyTorch Tracking Demo')
parser.add_argument('--resume', default='', type=str, required=True,                   metavar='PATH',help='path to latest checkpoint (default: none)')
parser.add_argument('--config', dest='config', default='config_davis.json',help='hyper-parameter of SiamMask in json format')
parser.add_argument('--base_path', default='../../data/tennis', help='datasets')
args = parser.parse_args()
加载模型
cfg = load_config(args)    
from custom import Custom    
	siammask = Custom(anchors=cfg['anchors'])    
	if args.resume:        
		assert isfile(args.resume), '{} is not a valid file'.format(args.resume)        			siammask = load_pretrain(siammask, args.resume)
    siammask.eval().to(device)
目标跟踪与分割
for f, im in enumerate(ims):        
	tic = cv2.getTickCount()        
	if f == 0:  # init            
		target_pos = np.array([x + w / 2, y + h / 2])            
		target_sz = np.array([w, h])            
		state = siamese_init(im, target_pos, target_sz, siammask, cfg['hp'])  # init tracker        
	elif f > 0:  # tracking            
		state = siamese_track(state, im, mask_enable=True, refine_enable=True)  # track            
		location = state['ploygon'].flatten()            
		mask = state['mask'] > state['p'].seg_thr
		im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]            
		cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True, (0, 255, 0), 3)  			cv2.imshow('SiamMask', im)            
		key = cv2.waitKey(1)            
		if key > 0:                
			break

test.py

siamse_track方法
def siamese_track(state, im, mask_enable=False, refine_enable=False):    
	p = state['p']    
	net = state['net']    
	avg_chans = state['avg_chans']    
	window = state['window']    
	target_pos = state['target_pos']    
	target_sz = state['target_sz']
	wc_x = target_sz[1] + p.context_amount * sum(target_sz)    
	hc_x = target_sz[0] + p.context_amount * sum(target_sz)    
	s_x = np.sqrt(wc_x * hc_x)    
	scale_x = p.exemplar_size / s_x    
	d_search = (p.instance_size - p.exemplar_size) / 2    
	pad = d_search / scale_x    
	s_x = s_x + 2 * pad    
	#读取目标状态参数
	crop_box = [target_pos[0] - round(s_x) / 2, target_pos[1] - round(s_x) / 2, round(s_x), round(s_x)]
	#目标缩放后的box
    # extract scaled crops for search region x at previous target position    
	x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0))
    if mask_enable:        
    		score, delta, mask = net.track_mask(x_crop.cuda())    
    	else:        
    		score, delta = net.track(x_crop.cuda())
    	#分支网络输出,mask分支网络还需一步操作
    delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy()    
    score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1).permute(1, 0), dim=1).data[:,1].cpu().numpy()
	delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0]    
	delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1]    
	delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2]    
	delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3]
	#边框回归结果
    	def change(r):        
		return np.maximum(r, 1. / r)
    	def sz(w, h):        
    		pad = (w + h) * 0.5        
    		sz2 = (w + pad) * (h + pad)       	 
		return np.sqrt(sz2)
    	def sz_wh(wh):        
    		pad = (wh[0] + wh[1]) * 0.5        
    		sz2 = (wh[0] + pad) * (wh[1] + pad)        
    		return np.sqrt(sz2)
    # size penalty    
    	target_sz_in_crop = target_sz*scale_x    
    	s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz_in_crop)))  # scale penalty    
    	r_c = change((target_sz_in_crop[0] / target_sz_in_crop[1]) / (delta[2, :] / delta[3, :]))  # ratio penalty
    	penalty = np.exp(-(r_c * s_c - 1) * p.penalty_k)    
    	pscore = penalty * score
    # cos window (motion model)    
    	pscore = pscore * (1 - p.window_influence) + window * p.window_influence    
    	best_pscore_id = np.argmax(pscore)#对score最高的边框对应特征进行mask_refine
    	pred_in_crop = delta[:, best_pscore_id] / scale_x    
    	lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr  # lr for OTB
    	res_x = pred_in_crop[0] + target_pos[0]    
    	res_y = pred_in_crop[1] + target_pos[1]
    	res_w = target_sz[0] * (1 - lr) + pred_in_crop[2] * lr    res_h = target_sz[1] * (1 - lr) + pred_in_crop[3] * lr
    	target_pos = np.array([res_x, res_y])    
    	target_sz = np.array([res_w, res_h])
    # for Mask Branch    
    	if mask_enable:        
    		best_pscore_id_mask = np.unravel_index(best_pscore_id, (5, p.score_size, p.score_size))        
    		delta_x, delta_y = best_pscore_id_mask[2], best_pscore_id_mask[1]
        if refine_enable:            
        	mask = net.track_refine((delta_y, delta_x)).cuda().sigmoid().squeeze().view(p.out_size, p.out_size).cpu().data.numpy()        
        else:            
        	mask = mask[0, :, delta_y, delta_x].sigmoid(). \
        		squeeze().view(p.out_size, p.out_size).cpu().data.numpy()
        def crop_back(image, bbox, out_sz, padding=-1):            
        	a = (out_sz[0] - 1) / bbox[2]            
        	b = (out_sz[1] - 1) / bbox[3]            
        	c = -a * bbox[0]            
        	d = -b * bbox[1]            
        	mapping = np.array([[a, 0, c], [0, b, d]]).astype(np.float)            
        	crop = cv2.warpAffine(image, mapping, (out_sz[0], out_szz[1]),flags=cv2.INTER_LINEAR,                                  	borderMode=cv2.BORDER_CONSTANT,                                  	borderValue=padding)            
        	return crop
        	
	s = crop_box[2] / p.instance_size        
	sub_box = [crop_box[0] + (delta_x - p.base_size / 2) * p.total_stride *s,crop_box[1] + (delta_y - p.base_size / 2) * p.total_stride * s,s * p.exemplar_size, s *p.exemplar_size]        
	s = p.out_size / sub_box[2]        
	back_box = [-sub_box[0] * s, -sub_box[1] * s, state['im_w'] * s, state['im_h'] * s]        
	mask_in_img = crop_back(mask, back_box, (state['im_w'], state['im_h'])
	target_mask = (mask_in_img > p.seg_thr).astype(np.uint8)       
	if cv2.__version__[-5] == '4':            
		contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)        
	else:            
		_, contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)        
	cnt_area = [cv2.contourArea(cnt) for cnt in contours]        
	if len(contours) != 0 and np.max(cnt_area) > 100:            
		contour = contours[np.argmax(cnt_area)]  # use max area polygon            
		polygon = contour.reshape(-1, 2)            # pbox = cv2.boundingRect(polygon)  # Min Max Rectangle            
		prbox = cv2.boxPoints(cv2.minAreaRect(polygon))  # Rotated Rectangle
            # box_in_img = pbox            
      		rbox_in_img = prbox        
      	else:  # empty mask            
      		location = cxy_wh_2_rect(target_pos, target_sz)            
      		rbox_in_img = np.array([[location[0], location[1]],                                    [location[0] + location[2], location[1]],                                    [location[0] + location[2], location[1] + location[3]],                                    [location[0], location[1] + location[3]]])
    	target_pos[0] = max(0, min(state['im_w'], target_pos[0]))    
    	target_pos[1] = max(0, min(state['im_h'], target_pos[1]))    
    	target_sz[0] = max(10, min(state['im_w'], target_sz[0]))    
    	target_sz[1] = max(10, min(state['im_h'], target_sz[1]))
    	state['target_pos'] = target_pos    
    	state['target_sz'] = target_sz    
    	state['score'] = score    
    	state['mask'] = mask_in_img if mask_enable else []    
    	state['ploygon'] = rbox_in_img if mask_enable else []    
    	return state
训练部分代码作者未给出

猜你喜欢

转载自blog.csdn.net/CsdnWujinming/article/details/88950600
今日推荐