Deep Learning (33) - CycleGAN (2)
The complete project is here: welcome to visit
Article directory
Data Format:
1. Generator
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
The backbone used in the feature extraction part is resnet (optional, you can use other models as the backbone)
- Upsampling a total of 9 ResNet Blocks
- downsampling part
2. Discriminator
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
As mentioned in the previous section, the discriminator is a two-category model that distinguishes true from false. The input is still a three-channel image, and finally judges whether the image is true or false.
3. fake pool
Used to save the generated fake image
self.fake_A_pool = ImagePool(opt.pool_size)
4. Definition of loss
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.MSE
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
- GANLoss is defined according to gan_mode, here is MSELoss
5. The amount of model parameters
6. Debug record
-
set_input(input): get real_A, real_B
-
optimize_parameters(): calculate loss for backpropagation
-
forward(): 生成fake_A,fake_B,rec_A,rec_B
- generatorA first generates fake_B based on real_A
- generatorB uses fake_B to generate rec_A
- generatorB generates fake_A based on real_B
- generatorA uses fake_A to generate rec_B
-
backward_G(): Backpropagation
- Calculate identity_loss: generatorA inputs real_A to get fake_B, so now inputting real_B can also generate fake_B, and name this generated idt_A, there will be identity_loss between idt_A and real_B, and between idt_B and real_A similarly identity_loss
- Calculate generator_loss: the loss of feak_B generated by generatorA. We hope that feak_B has deceived discriminatorA, so we hope that discriminatorA thinks it is true A, so here we use fake_B and True as MSEloss. Similarly, we hope that discriminatorB thinks fake_A is true B
- Calculate cycle_loss: real_A generates fake_B through generatorA, fake_B returns through generatorB to generate rec_A, and calculates the loss between A generated by such a cycle and real A, and B is the same.
- The final generator_loss is the sum of the above three, because there are AB points, so there are 6 items in total
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
-
backward_D_A(): Calculate the loss of discriminator_A
-
backward_D_B(): Calculate the loss of discriminator_B
-
注
- When the optimizer generator is used, the discriminator is set to no gradient and no backpropagation.
self.set_requires_grad([self.netD_A, self.netD_B], False)
Regarding sauce, welcome to ask questions and discuss, 886~