Prefacio
A continuación, echemos un vistazo a uno de los métodos de agregación temporal del modelo de entrenamiento de re-reconocimiento de peatones por video: agrupación temporal.
Esta es una forma relativamente simple y el efecto es bueno. Utiliza la agrupación promedio para fusionar las características de cada clip en las características de cada clip de acuerdo con seq_len.
Como la parte A:
Entrada de modelo
- imgs
- imgs.size () = [b, s, c, h, w]
- En el nivel de entrenamiento, b es por lotes y generalmente se establece en 32, seq_len se establece en 4, c es el número de canales es 3, h es la altura de la imagen, w es la anchura de la imagen
Parámetros de inicialización del modelo
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={
'xent', 'htri'})
- name El nombre del modelo utilizado
- dataset.num_train_pids El número de clasificaciones durante la clasificación
- pérdida xent = pérdida de entropía cruzada htri = triple pérdida
Realización de modelos
class ResNet50TP(nn.Module):
def __init__(self, num_classes, loss={
'xent'}, **kwargs):
# 继承的是ResNet50TP父类的初始化方法
super(ResNet50TP, self).__init__()
# 设置loss总类
self.loss = loss
# 使用resnet501模型
resnet50 = torchvision.models.resnet50(pretrained=True)
# 使用resnet501模型 (除了最后两层)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
# 特征维数为2048
self.feat_dim = 2048
self.classifier = nn.Linear(self.feat_dim, num_classes)
# 前向传播 x=imgs=[32,4,3,224,112] [b,t,c,h,w]
def forward(self, x):
# b = 32 batch
b = x.size(0)
# t = 4 seq——len
t = x.size(1)
# x = [128,3,224,112]
x = x.view(b*t,x.size(2), x.size(3), x.size(4))
# resnet
# x = [128,2048,7,4] CNN提取features
x = self.base(x)
# 平均池化 这里得到的是每一帧的features
# x = [128,2048,1,1]
x = F.avg_pool2d(x, x.size()[2:])
# x = [32,4,2048]
x = x.view(b,t,-1)
# x= [32,2048,4]
x=x.permute(0,2,1)
# 这里对得到的features进行平均池化 得到每个clips的feature
f = F.avg_pool1d(x,t)
# f= [32,2048]
f = f.view(b, self.feat_dim)
# embed()
# 不是训练阶段的化 直接使用得到的features
if not self.training:
return f
# 将特征放入全链接层
y = self.classifier(f)
# 根据计算loss的方法不同,返回不同的参数
if self.loss == {
'xent'}:
return y
elif self.loss == {
'xent', 'htri'}:
return y, f
elif self.loss == {
'cent'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))