pix2pixHD---損失---損失関数

まず、Pix2PixHDModel コードで損失を定義します。
ここに画像の説明を挿入
まず、最初のものを見てください。2 つの入力パラメータ use_gan_feat_loss と use_vgg_loss はデフォルトで false ですが、プレフィックスは false であるため、両方のパラメータは True です。

    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        return loss_filter

次に、フラグには 5 つの True があり、zip 関数は各値と True をタプルに結合します。全部で5つあります。
次に、2 番目のものを見てください。
ここに画像の説明を挿入

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real)
                loss += self.loss(pred, target_tensor)
            return loss
        else:            
            target_tensor = self.get_target_tensor(input[-1], target_is_real)
            return self.loss(input[-1], target_tensor)

call 関数を通じて呼び出される場合、入力は 2 つあり、識別子には出力リストに 5 つの値があるため、ここでは for ループが使用されます。get_target_tensor に pred 値と target 値を入力して target_tensor を取得します。

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

ここは入力と同じサイズの行列を取得するもので、行列の値は1または0で構成されます。
ここに画像の説明を挿入
最後に、pred と target を損失計算に使用し、計算結果を累積します: MSELoss は損失関数で使用され、
ここに画像の説明を挿入
ここに画像の説明を挿入
GANloss に加えて、L1loss と VGGloss もあります。
ここに画像の説明を挿入
VGGloss または機能マッチング損失を使用する場合:

class VGGLoss(nn.Module):
    def __init__(self, gpu_ids):
        super(VGGLoss, self).__init__()        
        self.vgg = Vgg19().cuda()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        
    def forward(self, x, y):              
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())        
        return loss

実画像と生成画像をVGG19に入力し、得られた値に対してL1ロス計算を行い、それぞれの値に重みを付けます。
ここに画像の説明を挿入
モデル内の損失を使用して次を計算します。

        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss        
        pred_real = self.discriminate(input_label, real_image)
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Posibility Loss)
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)   

弁別器の場合、入力が偽の画像、つまり生成された画像の場合、その出力が 0 であることを望み、入力が本物の画像の場合、弁別器の出力が 1 であることを望みます。 、デバイスを予測できない、つまり、生成されるものはすべて 1 であることを区別したいと考えています。これらは 3 つの GANLoss です。
次に、l1 損失計算のために、本物の画像を弁別器の出力に入力し、偽の画像を弁別器の出力に入力します。
ここに画像の説明を挿入
最後に、加算ピクチャとジェネレータで生成した実ピクチャを VGG に入力し、VGGloss 計算を実行します。損失の計算過程を描画します。
ここに画像の説明を挿入
最後に、すべての損失は loss_filter に入力され、それぞれが true でタプルを形成し、大きなリストを返します。同時に、モデルはトレーニング中に別の出力 None を出力し、推論中に偽の画像を出力します。
電車に戻ると、モデル全体が構築されます。

おすすめ

転載: blog.csdn.net/qq_43733107/article/details/130969740