行人重识别0-11:DG-Net(ReID)-代码无死角解读(7)-全局loss讲解-注释篇

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_43013761/article/details/102524473

以下链接是个人关于DG-Net(行人重识别ReID)所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
行人重识别0-00:DG-GAN(行人重识别ReID)-目录-史上最新最全:https://blog.csdn.net/weixin_43013761/article/details/102364512

前言

该篇博客,为大家详细的分析所有的loss(前面生成和真假鉴别模式的loss已经分析过,所以除外),这个地方,应该也是DG-Net最重要,同时又是最难的地方。请大家买瓶可乐,慢慢的,细心的品味这篇博客。那么我们就开始吧!

代码引导

再train.py文件种,我们可以找到如下代码:

with Timer("Elapsed time in update: %f"):
     # Main training code
     # 进行前向传播
     # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
     # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
     # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
     # p_a[输入图片images_a经过Ea编码进行身份ID的预测]    p_b[输入图片images_b经过Ea编码进行身份ID的预测]
     # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]      pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
     # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
     # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
     # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
     # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]
     # 朋友啊!这波注释不容易啊
     x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p = \
	......
                                                                                      trainer.forward(images_a, images_b, pos_a, pos_b)
if num_gpu>1:
    trainer.module.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu)
    #
    trainer.module.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu)
else:
    trainer.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu=1)
    trainer.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu=1)

其中trainer.dis_update与trainer.gen_update函数的实现都在trainer.py文件种,dis_update实现如下:

    def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu):
        self.dis_opt.zero_grad()
        # D loss
        # calc_dis_loss的注释有详细注解,这里就不注释了
        if num_gpu>1:
            self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(self.dis_b, x_ab.detach(), x_b)
        else:
            self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(self.dis_a, x_ba.detach(), x_a)
            self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(self.dis_b, x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        print("DLoss: %.4f"%self.loss_dis_total, "Reg: %.4f"%(reg_a+reg_b) )
        if self.fp16:
            with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            self.loss_dis_total.backward()
        self.dis_opt.step()

之前我们已经介绍calc_dis_loss,所以这里还是很容易理解的,就是通过calc_dis_loss计算去鉴别器的loss,然后进行反向传播。下面来看看重点gen_update,不要被吓到了(大家好好过一遍,后面还有白语话解释)

gen_update()注释

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        """
        # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
        # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
        # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
        # p_a[输入图片images_a经过Ea编码进行身份ID的预测]    p_b[输入图片images_b经过Ea编码进行身份ID的预测]
        # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]      pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
        # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]
        """

        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
 
        # no gradient
        # 对合成x_ba与x_ab分别进行一份拷贝
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0,1)
        #################################
        # encode structure
        # enc_content是类ContentEncoder对象
        if hyperparameters['use_encoder_again']>=rand_num:
            # encode again (encoder is tuned, input is fixed)
            # Es编码得到s_a_recon与s_b_recon即st code
            # 如果是理想模型,s_a_recon=s_a, s_b_recon=s_b
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            # 这里的是深拷贝
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()

            # 和上述操作一样
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))


        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()

        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy

        # encode again (encoder is fixed, input is tuned)
        # 对混合生成的图片x_ba,x_ab记行Es编码操作,同时对身份进行鉴别#
        # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        # 如果使用了教师网络
        if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB)
                # 合成的图片经过身份鉴别器,得到每个ID可能性的概率,注意这里去的是p_ba_student[0],我们知有两个身份预测结果,这里只取了一个
                # 并且赋值给了p_a_student,用于和教师模型结合的,共同计算损失
                _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b
                p_a_student = log_sm(p_ba_student[0])

                with torch.no_grad():
                    p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                # criterion_teacher = nn.KLDivLoss(size_average=False)
                # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离和,然后除以p_a_student.size(0)取平均值
                # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)


                # 同样另外的合成图像,进行相同操作
                _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():
                    p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)


                # branch b loss
                # here we give different label
                # 求得学生网络(Ea)预测ID结果,与实际ID之间的损失,注意这里取的是p_ba_student[1],
                # 大家要细心的去体会p_ba_student[0],p_ba_student[1]
                loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a)

                # 网络中T_w=1, B_w=0.2
                # 其实我很纠结,这个损失由什么用,为什么非要分开两个部分,我感觉一个loss_teacher,或者一个loss_B就够了
                #
                self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B

        else:
            self.loss_teacher = 0.0

        # 剩下的就是重构图像之间的损失了
        # 前面提到,重构和合成是不一样的,重构是构建出来和原来图片一样的图片
        # 所以也就是可以把重构的图片和原来的图像直接计算像素直接的插值
        # 但是合成的图片是没有办法的,因为训练数据集是没有合成图片的,所以,没有办法计算像素之间的损失
        # auto-encoder image reconstruction
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        # 重构的st code 与 ap code 分别计算相对应的损失
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0


        # 把重构的st code和ap code,进行图片合成(解码操作,论文中的G)
        x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # 咱不看了把,我博客给大家总结一下算了,重复的没意思
        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style']=='PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style']=='AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        if num_gpu>1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)

        # domain-invariant perceptual loss
        # 这里也是人才(膜拜下,可以改成自己的网络,融入这个系统,可以加载任何分类(行人重识别)模型,让合成的数据帮你训练网络,是不是很过分),
        # 使用vgg,对合成图片和真实图片进行特征提取,然后计算两个特征loss
        # VGG网络去学习提取st的特征向量,真的太过分了,这样就能通过特征向量对行人进行对比了
        # 不过作者的源码没有使用,移植的事情就靠你自己了
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0


        # 下面的字典的W表示的是权重,这里设置的是每个loss所占的权重
        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        # 计算所有loss的总和
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss:
            scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()

别慌,镇定!我们慢慢来,看我下篇博客,用白话为大家解释!

猜你喜欢

转载自blog.csdn.net/weixin_43013761/article/details/102524473
今日推荐