まず、Pix2PixHDModel コードで損失を定義します。
まず、最初のものを見てください。2 つの入力パラメータ use_gan_feat_loss と use_vgg_loss はデフォルトで false ですが、プレフィックスは false であるため、両方のパラメータは True です。
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
return loss_filter
次に、フラグには 5 つの True があり、zip 関数は各値と True をタプルに結合します。全部で5つあります。
次に、2 番目のものを見てください。
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
call 関数を通じて呼び出される場合、入力は 2 つあり、識別子には出力リストに 5 つの値があるため、ここでは for ループが使用されます。get_target_tensor に pred 値と target 値を入力して target_tensor を取得します。
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
ここは入力と同じサイズの行列を取得するもので、行列の値は1または0で構成されます。
最後に、pred と target を損失計算に使用し、計算結果を累積します: MSELoss は損失関数で使用され、
GANloss に加えて、L1loss と VGGloss もあります。
VGGloss または機能マッチング損失を使用する場合:
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
実画像と生成画像をVGG19に入力し、得られた値に対してL1ロス計算を行い、それぞれの値に重みを付けます。
モデル内の損失を使用して次を計算します。
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Posibility Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
弁別器の場合、入力が偽の画像、つまり生成された画像の場合、その出力が 0 であることを望み、入力が本物の画像の場合、弁別器の出力が 1 であることを望みます。 、デバイスを予測できない、つまり、生成されるものはすべて 1 であることを区別したいと考えています。これらは 3 つの GANLoss です。
次に、l1 損失計算のために、本物の画像を弁別器の出力に入力し、偽の画像を弁別器の出力に入力します。
最後に、加算ピクチャとジェネレータで生成した実ピクチャを VGG に入力し、VGGloss 計算を実行します。損失の計算過程を描画します。
最後に、すべての損失は loss_filter に入力され、それぞれが true でタプルを形成し、大きなリストを返します。同時に、モデルはトレーニング中に別の出力 None を出力し、推論中に偽の画像を出力します。
電車に戻ると、モデル全体が構築されます。