行人重识别0-07:DG-Net(ReID)-代码无死角解读(3)-网络总体前向传播

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_43013761/article/details/102488283

以下链接是个人关于DG-Net(行人重识别ReID)所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
行人重识别0-00:DG-Net(ReID)-目录-史上最新最全:https://blog.csdn.net/weixin_43013761/article/details/102364512

前言

通过上一篇博客,我相信大家对于生成模块部分,在总体框架上,应该有了深刻的认识。既然网络总体框架了解,我们就来了解细节部分了。我会为大家详细的解析(当然,如果要我详细到每个卷积,是怎么卷积,怎么来的,那就真的太为难我胖虎了!),玩笑到这里。愿大家愉快的阅读我的博客,那就开始吧!

源码引导

首先我们打开train.py文件:

if __name__=='__main__':
	# Setup model and data loader
    # 加载,训练数据,进行模型构建和加载
    if opts.trainer == 'DGNet':
        trainer = DGNet_Trainer(config, gpu_ids)
        trainer.cuda()

是的,是他,是他,就是他。我们还是要从这里开始,然后进入

class DGNet_Trainer(nn.Module):

在上篇博中,我们是介绍了其中得forward(self, x_a, x_b, xp_a, xp_b)函数,该小节我们先看看他的初始化函数,了解下在创建这个类得时候,其都做了什么(大家先浏览一下注释就行,函数内部不要去细究,因为不是有我在吗?我随后会逐步讲解)。

class DGNet_Trainer(nn.Module):

    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型和鉴别模型的学习了率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集
        ID_class = hyperparameters['ID_class']

        # 看是否设置使用float16,估计float16可以增加精确度
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']

        # Es
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了
        # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b




        #我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果
        #现在怀疑,该分类结果,可能就是行人重识别的结果
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器
        if hyperparameters['ID_style']=='PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style']=='AB':
            # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部
            self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 
        else:
            self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
        # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个
        self.id_b = self.id_a




        # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False)  # discriminator for domain a
        self.dis_b = self.dis_a # discriminator for domain b


        # 加载教师模型
        # load teachers

        # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 有这个操作,我怀疑是可以加载多个教师模型
            teacher_names = teacher_name.split(',')
            # 构建教师模型
            teacher_model = nn.ModuleList()
            teacher_count = 0

            # 默认只有一个teacher_name='teacher_name',所以其加载的模型为项目根目录models/best/opts.yaml模型
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)

                #  默认stride=1,是池化层的stride
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2

                # 网络搭建

                model_tmp = ft_net(ID_class, stride = stride)

                teacher_model_tmp = load_network(model_tmp, teacher_name)
                # 移除原本的全连接层
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                # 应该是进行网络搭建
                teacher_model_tmp = teacher_model_tmp.cuda()
                # summary(teacher_model_tmp, (3, 224, 224))
                # 使用浮点型
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1

            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)

        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)


        # RGB to one channel
        # 默认设置single = gray,前面提到,鉴别器使用灰度图,进行行人重识别
        if hyperparameters['single']=='edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)


        # Random Erasing when training
        # erasing_p表示随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        # 随机擦除矩形区域的一些像素,应该类似于数据增强
        self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])

        # 默认没有设置T_w。所以T_w被设置为1
        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1

        # Setup the optimizers,设置优化器的参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())

        # 使用Adam优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])


        # id params

        # 因为ID_style默认为AB,所以这里不执行
        if hyperparameters['ID_style']=='PCB':
            ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
                            +list(map(id, self.id_a.classifier1.parameters() ))
                            +list(map(id, self.id_a.classifier2.parameters() ))
                            +list(map(id, self.id_a.classifier3.parameters() ))
                            )
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)

        # 这里是我们执行的代码
        elif hyperparameters['ID_style']=='AB':
            # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的
            ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
                            + list(map(id, self.id_a.classifier2.parameters())))

            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters() ))
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)

        # 选择各个网络优化的策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)


        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        # 其实我不会告诉你我没有看懂,如果你看懂了要记得告诉我。不过这并不影响我们对代码的阅读,
        # 大概就是赋值保存,加初始化吧
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")

如果注意的朋友,应该发现了,我使用很多空格,分成了三个部分,分别对应Es(结构编码器,这里包含了解码和编码过程),Ea(外观编码器),以及鉴别器(行人重识别吧)。

下篇博客,我将详细的对他们一一进行介绍,当然也可能需要多篇博客!今天有学到新东西吗?可爱的你!

猜你喜欢

转载自blog.csdn.net/weixin_43013761/article/details/102488283
今日推荐