论文《Fast Online Object Tracking and Segmentation- A Unifying Approach》项目代码解读1:Custom

论文解析 https://blog.csdn.net/CsdnWujinming/article/details/88895146
项目地址:https://github.com/foolwood/SiamMask

网络结构

在这里插入图片描述
在这里插入图片描述

代码解读

1. Custom.py SiamMask网络具体实现

ResDownS

ResNet 特征提取后,下采样输入adjust层,输入1024通道,输出256。对应ResNet-50后一个操作。在ResDowno中调用。

class ResDownS(nn.Module): 
   #inplane:输入通道数
   #outplane:输出通道数
	def __init__(self, inplane, outplane):        
		super(ResDownS, self).__init__()        
		self.downsample = nn.Sequential(nn.Conv2d(inplane, outplane, kernel_size=1, bias=False),                					nn.BatchNorm2d(outplane))
   	def forward(self, x):        
   	 	x = self.downsample(x)        
   	 	if x.size(3) < 20:            
   	 		l, r = 4, -4            
   	 		x = x[:, :, l:r, l:r]       
   	 	return x

ResDown

孪生网络特征提取层,对应图2中ResNet-50

class ResDown(Features):    
	def __init__(self, pretrain=False):        
		super(ResDown, self).__init__()        
		self.features = resnet50(layer3=True, layer4=False)        
		if pretrain:            
			load_pretrain(self.features, 'resnet.model')
        self.downsample = ResDownS(1024, 256)
    def forward(self, x):       
		output = self.features(x)        
		p3 = self.downsample(output[-1])        
		return p3
    def forward_all(self, x):        
    		output = self.features(x)        
    		p3 = self.downsample(output[-1])        
    		return output, p3

UP(RPN)

边框回归和分类网络,实现过程调用DepthCorr对象。

class UP(RPN):    
	def __init__(self, anchor_num=5, feature_in=256, feature_out=256):        
		super(UP, self).__init__()
		self.anchor_num = anchor_num        
		self.feature_in = feature_in        	
		self.feature_out = feature_out
		self.cls_output = 2 * self.anchor_num        
		self.loc_output = 4 * self.anchor_num
		#feature_in:cls网络输入通道数
		#feature_out:cls网络隐藏层通道数
		#cls_output:cls输出通道数
		self.cls = DepthCorr(feature_in, feature_out, self.cls_output)        
		self.loc = DepthCorr(feature_in, feature_out, self.loc_output)
    def forward(self, z_f, x_f):        
    		cls = self.cls(z_f, x_f)        
    		loc = self.loc(z_f, x_f)        
    		return cls, loc

MaskCorr

mask分支网络,同样调用DepthCorr对象,输入为256,输出为63*63通道数

在这里插入图片描述

class MaskCorr(Mask):    
	def __init__(self, oSz=63):        
		super(MaskCorr, self).__init__()        
		self.oSz = oSz        
		self.mask = DepthCorr(256, 256, self.oSz**2)
    def forward(self, z, x):        
    		return self.mask(z, x)

Refine

网络结构
图2上半部分

在这里插入图片描述
三个post属性分别对应U2,U3,U4

class Refine(nn.Module):    
	def __init__(self):        
	"""        
	Mask refinement module        
	Please refer SiamMask (Appendix A)        
	https://arxiv.org/abs/1812.05050        
	"""        
		super(Refine, self).__init__()        
		self.v0 = nn.Sequential(nn.Conv2d(64, 16, 3, padding=1), nn.ReLU(),          					nn.Conv2d(16, 4, 3, padding=1), nn.ReLU())	
		
		self.v1 = nn.Sequential(nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(),              				nn.Conv2d(64, 16, 3, padding=1), nn.ReLU())
		self.v2 = nn.Sequential(nn.Conv2d(512, 128, 3, padding=1), nn.ReLU(),             				nn.Conv2d(128, 32, 3, padding=1), nn.ReLU())
		self.h2 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
			nn.Conv2d(32, 32, 3, padding=1), nn.ReLU())
		self.h1 = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(),	
			nn.Conv2d(16, 16, 3, padding=1), nn.ReLU())
		self.h0 = nn.Sequential(nn.Conv2d(4, 4, 3, padding=1), nn.ReLU(), 
        		nn.Conv2d(4, 4, 3, padding=1), nn.ReLU())
		self.deconv = nn.ConvTranspose2d(256, 32, 15, 15)
		self.post0 = nn.Conv2d(32, 16, 3, padding=1)        
		self.post1 = nn.Conv2d(16, 4, 3, padding=1)        
		self.post2 = nn.Conv2d(4, 1, 3, padding=1)
	def forward(self, f, corr_feature, pos=None):        
		p0 = torch.nn.functional.pad(f[0], [16,16,16,16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]        
		p1 = torch.nn.functional.pad(f[1], [8,8,8,8])[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31]        
		p2 = torch.nn.functional.pad(f[2], [4,4,4,4])[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15]
		
		p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1)
		#填充
		out = self.deconv(p3)        
		out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31)))        
		out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61)))        
		out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127)))        			out = out.view(-1, 127*127)        
		return out

Custom

网络具体实现,继承自SiamMask类,方法未贴出

class Custom(SiamMask):    
	def __init__(self, pretrain=False, **kwargs):        
		super(Custom, self).__init__(**kwargs)        
		self.features = ResDown(pretrain=pretrain)        
		self.rpn_model = UP(anchor_num=self.anchor_num, feature_in=256, feature_out=256)        
		self.mask_model = MaskCorr()        
		self.refine_model = Refine()

猜你喜欢

转载自blog.csdn.net/CsdnWujinming/article/details/88943591
今日推荐