行人重识别0-10:DG-Net(ReID)-代码无死角解读(6)-lsgan损失及教师网络

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_43013761/article/details/102509237

以下链接是个人关于DG-Net(行人重识别ReID)所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
行人重识别0-00:DG-GAN(ReID)-目录-史上最新最全:https://blog.csdn.net/weixin_43013761/article/details/102364512

代码引导

首先根据上篇博客,说一下代码位置networks.py:

    def calc_dis_loss(self, model, input_fake, input_real):
        """
        该loss为了训练D,即鉴别器本身
        :param model: 为自己本身MsImageDis
        :param input_fake: 输入假图片,也就是合成的图片
        :param input_real: 输入真图片,训练集里面的图片
        :return:
        """

        # calculate the loss to train D
        input_real.requires_grad_()
        # 这里一起3个元素,分别大小为[batch_size, 1,64,32], [batch_size, 1,32,16], [batch_size, 1,16,8]
        outs0 = model.forward(input_fake)
        # 这里一起3个元素,分别大小为[batch_size, 1,64,32], [batch_size, 1,32,16], [batch_size, 1,16,8]
        outs1 = model.forward(input_real)
        loss = 0
        reg = 0
        Drift = 0.001
        LAMBDA = self.LAMBDA

        # 默认gan_type = 'lsgan',即没有执行这里
        if self.gan_type == 'wgan':
            loss += torch.mean(outs0) - torch.mean(outs1)
            # progressive gan
            loss += Drift*( torch.sum(outs0**2) + torch.sum(outs1**2))
            #alpha = torch.FloatTensor(input_fake.shape).uniform_(0., 1.)
            #alpha = alpha.cuda()
            #differences = input_fake - input_real
            #interpolates =  Variable(input_real + (alpha*differences), requires_grad=True)
            #dis_interpolates = self.forward(interpolates) 
            #gradient_penalty = self.compute_grad2(dis_interpolates, interpolates).mean()
            #reg += LAMBDA*gradient_penalty 
            reg += LAMBDA* self.compute_grad2(outs1, input_real).mean() # I suggest Lambda=0.1 for wgan
            loss = loss + reg
            return loss, reg


        for it, (out0, out1) in enumerate(zip(outs0, outs1)):
            # 默认gan_type == 'lsgan',最小二乘损失方式,主要解决生成图像不稳定的问题
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
                # regularization
                reg += LAMBDA* self.compute_grad2(out1, input_real).mean()
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
                all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
                                   F.binary_cross_entropy(F.sigmoid(out1), all1))
                reg += LAMBDA* self.compute_grad2(F.sigmoid(out1), input_real).mean()
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)

        loss = loss+reg
        return loss, reg

    def calc_gen_loss(self, model, input_fake):
        """
        :param model: 为自己本身MsImageDis
        :param input_fake: 输入假的图片
        :return:
        """
        # calculate the loss to train G
        # 生成图片,初一这里的输出还是有3个尺寸
        outs0 = model.forward(input_fake)
        loss = 0
        Drift = 0.001

        # 该处不执行,因为gan_type == 'lsgan'
        if self.gan_type == 'wgan':
            loss += -torch.mean(outs0)
            # progressive gan
            loss += Drift*torch.sum(outs0**2)
            return loss

        # 同理我们使用的是gan_type == 'lsgan'
        for it, (out0) in enumerate(outs0):
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 1)**2) * 2  # LSGAN
            elif self.gan_type == 'nsgan':
                all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    # 计算梯度,大概反向传播使用,了解了细节的朋友麻烦告诉我下
    def compute_grad2(self, d_out, x_in):
        batch_size = x_in.size(0)
         # 这是一个对输出自动求导数的函数,这里表示对outputs=d_out.sum()求inputs=x_in的导数
        grad_dout = torch.autograd.grad(
            outputs=d_out.sum(), inputs=x_in,
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]
        grad_dout2 = grad_dout.pow(2)
        assert(grad_dout2.size() == x_in.size())
        reg = grad_dout2.view(batch_size, -1).sum(1)
        return reg

主要讲解这三个函数,需要要注意的是,我们这里讲解的是鉴别器(MsImageDis)相关的损失函数,这个鉴别器,是鉴别图片的真假,不是对身份的鉴别。MsImageDis中MS表示的应该是对尺度的意思,dis表示鉴别。Image自己体会下。

首先,来说一个概念。GANS网路一般都会有两个模块,即生成模块和鉴别模块。在生成模块计算损失的时候,我们是不需要真实图片的。直白的说,就是把生成模块生成的图片,丢给鉴别器就可以了,通过鉴别器告诉生成模块是真的还是假的,如果是假的,生成模块就会继续优化。所以生成模块损失计算函数def calc_gen_loss(self, model, input_fake)只有一个,只需要一张假的图片就能够计算损失了。

但是鉴别模块的损失计算是不一样的,如def calc_dis_loss(self, model, input_fake, input_real),这里可以看到,其是有两个参数的,一个为假的照片,一个为真的照片。为什么呢?因为鉴别模块,不仅要认出假冒的图片,还要认出真实的图片。所以假的图片和真的图片都要让他学习(以后就不再提及这个概念了)。

LSGAN

LSGAN是一篇关于GAN网络论文提出的,如果我先去阅读这篇论文,然后再为大家详细介绍,就没有必要了,我相信你也不想听我啰嗦,同样我也不想去看那个论文,毕竟时间是宝贵的,我这里给大家简单介绍一下就行了:
首先是针对生成器G优化的LOSS计算(论文中得公式):
L ( G ) = ξ x p x ( D ( G ( x ) ) c ) 2 L(G)=\xi_{x-p_x} (D(G(x))-c)^2
给大家提一下,以后看这样的公式,不要想得太复杂了,如下面我们这样分析,找到我们源码中代码如下:

       for it, (out0) in enumerate(outs0):
            if self.gan_type == 'lsgan':
                # 这里我们可以看到,当out0=1的时候,其损失是最小的,同时out0是假的图片经过odel.forward(input_fake)鉴别之后的输出
                loss += torch.mean((out0 - 1)**2) * 2  # LSGAN

可以看到,和上面得公式就一一对应起来了。公式中的C等于源码中的1。所以我们从源码可以知道当(out0 - 1)=0的时候,loss是最小的,也就是ut0 = 1,我们从代码可以知道ut0是生成图片经过鉴别器的结果。结果为1,这就表示,鉴别器被生成器欺骗过去了,鉴别器把假冒生成的图片鉴别成了真实图片。这个时候生成器的模型的最好的。

所以同过上面loss的计算,对生成器进行优化了。

既然生成器G讲解完成了,我们来看看鉴别器又是怎么优化的。首先看论文公式:
L ( D ) = 1 2 ξ z p z ( D ( G ( z ) ) b ) 2 + 1 2 ξ x p x ( D ( x ) a ) 2 L(D)= \frac{1}{2} \xi_{z-p_z} (D(G(z))-b)^2 + \frac{1}{2} \xi_{x-p_x} (D(x)-a)^2
然后在找到对应的代码:

        for it, (out0, out1) in enumerate(zip(outs0, outs1)):
            # 默认gan_type == 'lsgan',最小二乘损失方式,主要解决生成图像不稳定的问题
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)

可以看到loss的计算和上面的公式一一对应起来了,out0表示假的照片经过鉴别器得到的结果,,即out0= D ( G ( z ) ) D(G(z)) ,out1表示真实的经过鉴别器得到的结果,也就是 o u t 1 = D ( x ) out1=D(x) ,可以看到,我们如果想loss最小的,那么就是out0=0(假冒的图片),out1=1(真实的图片)。

这样就能通过loss的计算进行网络迭代,对图片鉴别器进行优化了。当真实图片输入为1,合成图片输出0的时候,测试鉴别效果是最好的。

教师网络

到这里为止,我们基本把图片真假的鉴别器已经讲解完成了,下面我们回到trainer.py文件,然后找到教师网络的相关部分:

       # 加载教师模型
        # load teachers
        # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 有这个操作,我怀疑是可以加载多个教师模型
            teacher_names = teacher_name.split(',')
            # 构建教师模型
            teacher_model = nn.ModuleList()
            teacher_count = 0

            # 默认只有一个teacher_name='teacher_name',所以其加载的模型为项目根目录models/best/opts.yaml模型
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)

                #  默认stride=1,是池化层的stride
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2

                # 网络搭建

                model_tmp = ft_net(ID_class, stride = stride)

                teacher_model_tmp = load_network(model_tmp, teacher_name)
                # 移除原本的全连接层
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                # 应该是进行网络搭建
                teacher_model_tmp = teacher_model_tmp.cuda()
                #summary(teacher_model_tmp, (3, 224, 224))
                # 使用浮点型
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1

            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)

下面是网络结构的打印:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256, 56, 56]             512
             ReLU-15          [-1, 256, 56, 56]               0
       Bottleneck-16          [-1, 256, 56, 56]               0
           Conv2d-17           [-1, 64, 56, 56]          16,384
      BatchNorm2d-18           [-1, 64, 56, 56]             128
             ReLU-19           [-1, 64, 56, 56]               0
           Conv2d-20           [-1, 64, 56, 56]          36,864
      BatchNorm2d-21           [-1, 64, 56, 56]             128
             ReLU-22           [-1, 64, 56, 56]               0
           Conv2d-23          [-1, 256, 56, 56]          16,384
      BatchNorm2d-24          [-1, 256, 56, 56]             512
             ReLU-25          [-1, 256, 56, 56]               0
       Bottleneck-26          [-1, 256, 56, 56]               0
           Conv2d-27           [-1, 64, 56, 56]          16,384
      BatchNorm2d-28           [-1, 64, 56, 56]             128
             ReLU-29           [-1, 64, 56, 56]               0
           Conv2d-30           [-1, 64, 56, 56]          36,864
      BatchNorm2d-31           [-1, 64, 56, 56]             128
             ReLU-32           [-1, 64, 56, 56]               0
           Conv2d-33          [-1, 256, 56, 56]          16,384
      BatchNorm2d-34          [-1, 256, 56, 56]             512
             ReLU-35          [-1, 256, 56, 56]               0
       Bottleneck-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          32,768
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40          [-1, 128, 28, 28]         147,456
      BatchNorm2d-41          [-1, 128, 28, 28]             256
             ReLU-42          [-1, 128, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]          65,536
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
             ReLU-47          [-1, 512, 28, 28]               0
       Bottleneck-48          [-1, 512, 28, 28]               0
           Conv2d-49          [-1, 128, 28, 28]          65,536
      BatchNorm2d-50          [-1, 128, 28, 28]             256
             ReLU-51          [-1, 128, 28, 28]               0
           Conv2d-52          [-1, 128, 28, 28]         147,456
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 512, 28, 28]          65,536
      BatchNorm2d-56          [-1, 512, 28, 28]           1,024
             ReLU-57          [-1, 512, 28, 28]               0
       Bottleneck-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          65,536
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]         147,456
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]          65,536
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
             ReLU-67          [-1, 512, 28, 28]               0
       Bottleneck-68          [-1, 512, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          65,536
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72          [-1, 128, 28, 28]         147,456
      BatchNorm2d-73          [-1, 128, 28, 28]             256
             ReLU-74          [-1, 128, 28, 28]               0
           Conv2d-75          [-1, 512, 28, 28]          65,536
      BatchNorm2d-76          [-1, 512, 28, 28]           1,024
             ReLU-77          [-1, 512, 28, 28]               0
       Bottleneck-78          [-1, 512, 28, 28]               0
           Conv2d-79          [-1, 256, 28, 28]         131,072
      BatchNorm2d-80          [-1, 256, 28, 28]             512
             ReLU-81          [-1, 256, 28, 28]               0
           Conv2d-82          [-1, 256, 14, 14]         589,824
      BatchNorm2d-83          [-1, 256, 14, 14]             512
             ReLU-84          [-1, 256, 14, 14]               0
           Conv2d-85         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-86         [-1, 1024, 14, 14]           2,048
           Conv2d-87         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
             ReLU-89         [-1, 1024, 14, 14]               0
       Bottleneck-90         [-1, 1024, 14, 14]               0
           Conv2d-91          [-1, 256, 14, 14]         262,144
      BatchNorm2d-92          [-1, 256, 14, 14]             512
             ReLU-93          [-1, 256, 14, 14]               0
           Conv2d-94          [-1, 256, 14, 14]         589,824
      BatchNorm2d-95          [-1, 256, 14, 14]             512
             ReLU-96          [-1, 256, 14, 14]               0
           Conv2d-97         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-98         [-1, 1024, 14, 14]           2,048
             ReLU-99         [-1, 1024, 14, 14]               0
      Bottleneck-100         [-1, 1024, 14, 14]               0
          Conv2d-101          [-1, 256, 14, 14]         262,144
     BatchNorm2d-102          [-1, 256, 14, 14]             512
            ReLU-103          [-1, 256, 14, 14]               0
          Conv2d-104          [-1, 256, 14, 14]         589,824
     BatchNorm2d-105          [-1, 256, 14, 14]             512
            ReLU-106          [-1, 256, 14, 14]               0
          Conv2d-107         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-108         [-1, 1024, 14, 14]           2,048
            ReLU-109         [-1, 1024, 14, 14]               0
      Bottleneck-110         [-1, 1024, 14, 14]               0
          Conv2d-111          [-1, 256, 14, 14]         262,144
     BatchNorm2d-112          [-1, 256, 14, 14]             512
            ReLU-113          [-1, 256, 14, 14]               0
          Conv2d-114          [-1, 256, 14, 14]         589,824
     BatchNorm2d-115          [-1, 256, 14, 14]             512
            ReLU-116          [-1, 256, 14, 14]               0
          Conv2d-117         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-118         [-1, 1024, 14, 14]           2,048
            ReLU-119         [-1, 1024, 14, 14]               0
      Bottleneck-120         [-1, 1024, 14, 14]               0
          Conv2d-121          [-1, 256, 14, 14]         262,144
     BatchNorm2d-122          [-1, 256, 14, 14]             512
            ReLU-123          [-1, 256, 14, 14]               0
          Conv2d-124          [-1, 256, 14, 14]         589,824
     BatchNorm2d-125          [-1, 256, 14, 14]             512
            ReLU-126          [-1, 256, 14, 14]               0
          Conv2d-127         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-128         [-1, 1024, 14, 14]           2,048
            ReLU-129         [-1, 1024, 14, 14]               0
      Bottleneck-130         [-1, 1024, 14, 14]               0
          Conv2d-131          [-1, 256, 14, 14]         262,144
     BatchNorm2d-132          [-1, 256, 14, 14]             512
            ReLU-133          [-1, 256, 14, 14]               0
          Conv2d-134          [-1, 256, 14, 14]         589,824
     BatchNorm2d-135          [-1, 256, 14, 14]             512
            ReLU-136          [-1, 256, 14, 14]               0
          Conv2d-137         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-138         [-1, 1024, 14, 14]           2,048
            ReLU-139         [-1, 1024, 14, 14]               0
      Bottleneck-140         [-1, 1024, 14, 14]               0
          Conv2d-141          [-1, 512, 14, 14]         524,288
     BatchNorm2d-142          [-1, 512, 14, 14]           1,024
            ReLU-143          [-1, 512, 14, 14]               0
          Conv2d-144          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-145          [-1, 512, 14, 14]           1,024
            ReLU-146          [-1, 512, 14, 14]               0
          Conv2d-147         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-148         [-1, 2048, 14, 14]           4,096
          Conv2d-149         [-1, 2048, 14, 14]       2,097,152
     BatchNorm2d-150         [-1, 2048, 14, 14]           4,096
            ReLU-151         [-1, 2048, 14, 14]               0
      Bottleneck-152         [-1, 2048, 14, 14]               0
          Conv2d-153          [-1, 512, 14, 14]       1,048,576
     BatchNorm2d-154          [-1, 512, 14, 14]           1,024
            ReLU-155          [-1, 512, 14, 14]               0
          Conv2d-156          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-157          [-1, 512, 14, 14]           1,024
            ReLU-158          [-1, 512, 14, 14]               0
          Conv2d-159         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-160         [-1, 2048, 14, 14]           4,096
            ReLU-161         [-1, 2048, 14, 14]               0
      Bottleneck-162         [-1, 2048, 14, 14]               0
          Conv2d-163          [-1, 512, 14, 14]       1,048,576
     BatchNorm2d-164          [-1, 512, 14, 14]           1,024
            ReLU-165          [-1, 512, 14, 14]               0
          Conv2d-166          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-167          [-1, 512, 14, 14]           1,024
            ReLU-168          [-1, 512, 14, 14]               0
          Conv2d-169         [-1, 2048, 14, 14]       1,048,576
     BatchNorm2d-170         [-1, 2048, 14, 14]           4,096
            ReLU-171         [-1, 2048, 14, 14]               0
      Bottleneck-172         [-1, 2048, 14, 14]               0
AdaptiveAvgPool2d-173           [-1, 2048, 4, 1]               0
AdaptiveAvgPool2d-174           [-1, 2048, 1, 1]               0
          Linear-175                  [-1, 512]       1,049,088
     BatchNorm1d-176                  [-1, 512]           1,024
         Dropout-177                  [-1, 512]               0
          Linear-178                  [-1, 751]         385,263
      ClassBlock-179                  [-1, 751]               0
================================================================

比较尴尬啊,又大又长。反正我们只要暂时知道,该网络就是输入一张图片,然后其给出这个图片属于的类别,或者ID编号就可以了。估计后续带大家把这个教师模型训练一篇的命运是逃脱不了了,因为感觉项目的落实,非需要他不可。

小结

这篇博客有点水了啊,都没有什么东西讲,就完了,都不好意思叫大叫点赞了。表示很尴尬,不怪我啊,怪公式与网络太简单了!

猜你喜欢

转载自blog.csdn.net/weixin_43013761/article/details/102509237