论文:Learning Discriminative Features with Multiple Granularities for Person Re-Identification
why need local information?
The intuitive approach of pedestrian representations is to extract discriminative features from the whole body on images. The aim of global feature learning is to capture the most salient clues of appearance to represent identities of different pedestrians. However, high complexities for images captured in surveillance scenes usually restrict the accuracy for feature learning in large scale Re-ID scenarios.Due to the limited scale and weak diversity of person Re-ID training datasets, some non-salient or infrequent detailed information can be easily ignored and make no contribution for better discrimination during global feature learning procedure, which makes global features hard to adapt similar inter-class common properties or large intra-class differences.
Part-based methods
- Locating part regions with strong structural information such as empirical knowledge about human bodies or strong learning-based pose information
- Locating part regions by region proposal methods
- Enhancing features by middle-level attention on salient partitions
First, pose or occlusion variations can affect the reliability of local representation. Second, these methods almost only focus on specific parts with fixed semantics, but cannot cover all the discriminative information. Last but not least, most of these methods are not end-to-end learning process, which increases the complexity and difficulty of feature learning.
Notice that these part regions are not necessary to be located partitions with specific semantics, but only a piece of equally-split stripe on the original images.
MGN
mgn net
import copy
import torch
from torch import nn
from torchvision.models.resnet import resnet50, Bottleneck
class MGN(nn.Module):
def __init__(self, num_classes, pool = 'avg', feats = 256 ):
super(MGN, self).__init__()
# 使用resnet的的前面层作为基础特征特征提取结构,分支结构共享部分
resnet = resnet50(pretrained=True)
self.backone = nn.Sequential(resnet.conv1,resnet.bn1,resnet.relu,resnet.maxpool,resnet.layer1,resnet.layer2,resnet.layer3[0])
# 使用conv4和conv5实现3个branch
res_conv4 = nn.Sequential(resnet.layer3[1],resnet.layer3[2],resnet.layer3[3],resnet.layer3[4],resnet.layer3[5])
res_g_conv5 = resnet.layer4
res_p_conv5 = nn.Sequential(
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
Bottleneck(2048, 512),
Bottleneck(2048, 512))
res_p_conv5.load_state_dict(resnet.layer4.state_dict())
# 3个branch各自结构融合
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
if pool == 'max':
pool2d = nn.MaxPool2d
elif pool == 'avg':
pool2d = nn.AvgPool2d
else:
raise Exception()
self.maxpool_zg_p1 = pool2d(kernel_size=(12, 4))
self.maxpool_zg_p2 = pool2d(kernel_size=(24, 8))
self.maxpool_zg_p3 = pool2d(kernel_size=(24, 8))
self.maxpool_zp2 = pool2d(kernel_size=(12, 8))
self.maxpool_zp3 = pool2d(kernel_size=(8, 8))
reduction = nn.Sequential(nn.Conv2d(2048, feats, 1, bias=False), nn.BatchNorm2d(feats), nn.ReLU())
self._init_reduction(reduction)
self.reduction_0 = copy.deepcopy(reduction)
self.reduction_1 = copy.deepcopy(reduction)
self.reduction_2 = copy.deepcopy(reduction)
self.reduction_3 = copy.deepcopy(reduction)
self.reduction_4 = copy.deepcopy(reduction)
self.reduction_5 = copy.deepcopy(reduction)
self.reduction_6 = copy.deepcopy(reduction)
self.reduction_7 = copy.deepcopy(reduction)
# 实现8个fc结构,每个结构有单独的模块 分支1:1 分支2:3 分支3:4
self.fc_id_2048_0 = nn.Linear(feats, num_classes)
self.fc_id_2048_1 = nn.Linear(feats, num_classes)
self.fc_id_2048_2 = nn.Linear(feats, num_classes)
self.fc_id_256_1_0 = nn.Linear(feats, num_classes)
self.fc_id_256_1_1 = nn.Linear(feats, num_classes)
self.fc_id_256_2_0 = nn.Linear(feats, num_classes)
self.fc_id_256_2_1 = nn.Linear(feats, num_classes)
self.fc_id_256_2_2 = nn.Linear(feats, num_classes)
self._init_fc(self.fc_id_2048_0)
self._init_fc(self.fc_id_2048_1)
self._init_fc(self.fc_id_2048_2)
self._init_fc(self.fc_id_256_1_0)
self._init_fc(self.fc_id_256_1_1)
self._init_fc(self.fc_id_256_2_0)
self._init_fc(self.fc_id_256_2_1)
self._init_fc(self.fc_id_256_2_2)
@staticmethod
def _init_reduction(reduction):
# conv
nn.init.kaiming_normal(reduction[0].weight, mode='fan_in')
# bn
nn.init.normal(reduction[1].weight, mean=1., std=0.02)
nn.init.constant(reduction[1].bias, 0.)
@staticmethod
def _init_fc(fc):
nn.init.kaiming_normal(fc.weight, mode='fan_out')
nn.init.constant(fc.bias, 0.)
def forward(self, x):
# 基础模型特征提取部分,resnet50 conv4_2之前的结构
x = self.backone(x)
# 构建三个分支结构并进行第一次特征pool,分支1特征图缩小,分支2分支3特征图大小不变
p1 = self.p1(x)
p2 = self.p2(x)
p3 = self.p3(x)
zg_p1 = self.maxpool_zg_p1(p1)
zg_p2 = self.maxpool_zg_p2(p2)
zg_p3 = self.maxpool_zg_p3(p3)
# 继续对分支2进行处理,pool获取一个feature_map,对分支二的特征进行分割获取2个feature_map,最终德大3个feature_map
zp2 = self.maxpool_zp2(p2)
z0_p2 = zp2[:, :, 0:1, :]
z1_p2 = zp2[:, :, 1:2, :]
# 继续对分支3进行处理,pool获取一个feature_map,对分支3的特征进行分割获取3个feature_map,最后只能怪得到4个feature_map
zp3 = self.maxpool_zp3(p3)
z0_p3 = zp3[:, :, 0:1, :]
z1_p3 = zp3[:, :, 1:2, :]
z2_p3 = zp3[:, :, 2:3, :]
# 对分支1 分支2 分支3获取到的8个feature_map进行压缩处理
fg_p1 = self.reduction_0(zg_p1).squeeze(dim=3).squeeze(dim=2)
fg_p2 = self.reduction_1(zg_p2).squeeze(dim=3).squeeze(dim=2)
fg_p3 = self.reduction_2(zg_p3).squeeze(dim=3).squeeze(dim=2)
f0_p2 = self.reduction_3(z0_p2).squeeze(dim=3).squeeze(dim=2)
f1_p2 = self.reduction_4(z1_p2).squeeze(dim=3).squeeze(dim=2)
f0_p3 = self.reduction_5(z0_p3).squeeze(dim=3).squeeze(dim=2)
f1_p3 = self.reduction_6(z1_p3).squeeze(dim=3).squeeze(dim=2)
f2_p3 = self.reduction_7(z2_p3).squeeze(dim=3).squeeze(dim=2)
# 对3个pool获取到的特征map使用2018的全连层连接
l_p1 = self.fc_id_2048_0(fg_p1)
l_p2 = self.fc_id_2048_1(fg_p2)
l_p3 = self.fc_id_2048_2(fg_p3)
# 多5个压缩得到的特征使用512的全连接处连接
l0_p2 = self.fc_id_256_1_0(f0_p2)
l1_p2 = self.fc_id_256_1_1(f1_p2)
l0_p3 = self.fc_id_256_2_0(f0_p3)
l1_p3 = self.fc_id_256_2_1(f1_p3)
l2_p3 = self.fc_id_256_2_2(f2_p3)
predict = torch.cat([fg_p1, fg_p2, fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
return predict, fg_p1, fg_p2, fg_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
# triplet : fg_p1, fg_p2, fg_p3
# cross entropy : l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
# predict: 2048维