ビデオベース-ReID_TP

序文

次に、ビデオ歩行者再認識トレーニングモデルの時間的集約方法の1つである時間的プーリングを見てみましょう。
これは比較的簡単な方法であり、効果は良好です。平均プーリングを使用して、seq_lenに従って各クリップの機能を各クリップの機能にマージします。
パートAなど:
ここに画像の説明を挿入

モデル入力

  • imgs
    • imgs.size()= [b、s、c、h、w]
    • トレーニングレベルでは、bはバッチで、通常は32に設定され、seq_lenは4に設定され、cはチャネル数が3、hは画像の高さ、wは画像の幅です。

モデル初期化パラメータ

        model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={
    
    'xent', 'htri'})
  • name使用されているモデルの名前
  • dataset.num_train_pids分類中の分類の数
  • 損失xent =クロスエントロピー損失htri = Tripletloss

モデルの実現

class ResNet50TP(nn.Module):
    def __init__(self, num_classes, loss={
    
    'xent'}, **kwargs):
    	# 继承的是ResNet50TP父类的初始化方法
        super(ResNet50TP, self).__init__()
        # 设置loss总类
        self.loss = loss
        # 使用resnet501模型
        resnet50 = torchvision.models.resnet50(pretrained=True)
        # 使用resnet501模型 (除了最后两层)
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        # 特征维数为2048
        self.feat_dim = 2048
        self.classifier = nn.Linear(self.feat_dim, num_classes)

    # 前向传播 x=imgs=[32,4,3,224,112]  [b,t,c,h,w]
    def forward(self, x):
        # b = 32 batch
        b = x.size(0)
        # t = 4 seq——len
        t = x.size(1)
        # x = [128,3,224,112]
        x = x.view(b*t,x.size(2), x.size(3), x.size(4))

        # resnet
        # x = [128,2048,7,4]  CNN提取features
        x = self.base(x)
        # 平均池化   这里得到的是每一帧的features
        # x = [128,2048,1,1]
        x = F.avg_pool2d(x, x.size()[2:])
        # x = [32,4,2048]  
        x = x.view(b,t,-1)
        # x= [32,2048,4]
        x=x.permute(0,2,1) 
        # 这里对得到的features进行平均池化 得到每个clips的feature
        f = F.avg_pool1d(x,t)
        # f= [32,2048]
        f = f.view(b, self.feat_dim)
        # embed()
        # 不是训练阶段的化 直接使用得到的features
        if not self.training:
            return f
        # 将特征放入全链接层
        y = self.classifier(f)
		# 根据计算loss的方法不同,返回不同的参数
        if self.loss == {
    
    'xent'}:
            return y
        elif self.loss == {
    
    'xent', 'htri'}:
            return y, f
        elif self.loss == {
    
    'cent'}:
            return y, f
        else:
            raise KeyError("Unsupported loss: {}".format(self.loss))

おすすめ

転載: blog.csdn.net/qq_37747189/article/details/114729566