Deep Learning (33) - CycleGAN (2)

Deep Learning (33) - CycleGAN (2)

The complete project is here: welcome to visit


Data Format:
insert image description here
insert image description here
insert image description here

1. Generator

self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
The backbone used in the feature extraction part is resnet (optional, you can use other models as the backbone)

  • Upsampling a total of 9 ResNet Blocks
    insert image description here
  • downsampling part
    insert image description here

2. Discriminator

self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
As mentioned in the previous section, the discriminator is a two-category model that distinguishes true from false. The input is still a three-channel image, and finally judges whether the image is true or false.
insert image description here

3. fake pool

Used to save the generated fake image
self.fake_A_pool = ImagePool(opt.pool_size)

4. Definition of loss

self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.MSE
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
  • GANLoss is defined according to gan_mode, here is MSELoss

5. The amount of model parameters

insert image description here

6. Debug record

  • set_input(input): get real_A, real_B

  • optimize_parameters(): calculate loss for backpropagation

    • forward(): 生成fake_A,fake_B,rec_A,rec_B

      • generatorA first generates fake_B based on real_A
      • generatorB uses fake_B to generate rec_A
      • generatorB generates fake_A based on real_B
      • generatorA uses fake_A to generate rec_B
    • backward_G(): Backpropagation

      • Calculate identity_loss: generatorA inputs real_A to get fake_B, so now inputting real_B can also generate fake_B, and name this generated idt_A, there will be identity_loss between idt_A and real_B, and between idt_B and real_A similarly identity_loss
      • Calculate generator_loss: the loss of feak_B generated by generatorA. We hope that feak_B has deceived discriminatorA, so we hope that discriminatorA thinks it is true A, so here we use fake_B and True as MSEloss. Similarly, we hope that discriminatorB thinks fake_A is true B
      • Calculate cycle_loss: real_A generates fake_B through generatorA, fake_B returns through generatorB to generate rec_A, and calculates the loss between A generated by such a cycle and real A, and B is the same.
      • The final generator_loss is the sum of the above three, because there are AB points, so there are 6 items in totalself.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
    • backward_D_A(): Calculate the loss of discriminator_A

    • backward_D_B(): Calculate the loss of discriminator_B

  1. When the optimizer generator is used, the discriminator is set to no gradient and no backpropagation. self.set_requires_grad([self.netD_A, self.netD_B], False)

Regarding sauce, welcome to ask questions and discuss, 886~

Supongo que te gusta

Origin blog.csdn.net/qq_43368987/article/details/132035216
Recomendado
Clasificación