Inspiration : Recently, I was also looking at the code about GAN, and I also saw a lot of blogs, all of which were well written, which made me enlightened and understood the principle of GAN and the implementation of the code. So write it down and record it, and there are links to other good articles at the end.
Source of inspiration: pix2pixGAN theory and code implementation
Table of contents
2. Design of pix2pixGAN generator
3. Design of pix2pixGAN discriminator
1. What is pix2pix GAN
It is actually a CGAN , a conditional GAN, but changes the output of the discriminator of the general GAN. All others output a probability, and pix2pixGAN or patchgan, its final output is a matrix, and each block represents the probability of a patch. The knowledge about patch can be supplemented in other places, and there is also an entry at the end of the article.
The picture x is used as the condition of this cGAN and needs to be input into G and D. The input of G is x (x is the picture to be converted), and the output is the generated picture G(x). D needs to distinguish between (x, G(x)) and (x, y)
pix2pixGAN is mainly used for conversion between images, also known as image translation.
2. Design of pix2pixGAN generator
For image translation tasks, a lot of information is shared between input and output. For example, contour information is shared. How to solve the sharing problem? We need to think from the design of the loss function.
If you use a normal convolutional neural network, it will cause each layer to carry all the information. This neural network is very error-prone (easy to lose some information)
So, we use the UNet model as the generator
3. Design of pix2pixGAN discriminator
D is to input pairs of images. This is similar to cGAN, if G(x) and x are corresponding, it is hoped that the generator will be 1;
If G(x) does not correspond to x, for the generator, it is hoped that the discriminator will judge it as 0
D in pix2pixGAN is implemented as patch_D in the paper. The so-called patch means that no matter how big the generated picture is, it is divided into multiple fixed-size patches and input into D for judgment. As shown in FIG.
The advantage of this design is: the input of D becomes smaller, the amount of calculation is small, and the training speed is fast
4. Loss function
D network loss function: Input the real paired image and hope to judge it as 1; input the generated image and the original image and hope to judge it as 0
G network loss function: Input the generated image and the original image and hope to judge it as 1
For image translation tasks, there is actually a lot of information shared between the input and output of G. Therefore, in order to ensure the similarity between the input image and the output image, L1loss is also added, and the formula is as follows:
5. Code implementation
If the code is implemented, there are official and other people's implementations, but I don't understand it a bit. Then see the code of this link to understand.
All the code is here: pix2pixGAN theory and code implementation
I record it as a note and write about the code understanding that I think is the key.
for step,(annos,imgs) in enumerate(dataloader):
imgs = imgs.to(device) #imgs 输入的图像
annos = annos.to(device) #标签,真实的应该生成的图片
#定义判别器的损失计算以及优化的过程
d_optimizer.zero_grad()
disc_real_output = dis(annos,imgs) #输入真实成对图片
d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
device=device))
#上面是为了将我们输入的真实图像对都标为1,希望他接近1,因为真实嘛
d_real_loss.backward() #求梯度
gen_output = gen(annos) #通过输入图像生成图片
disc_gen_output = dis(annos,gen_output.detach()) #将我们输入的和生成的图片输入辨别器
d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
device=device)) #辨别器希望生成的和我们输入的图像最终的判断为0,也就是假的嘛
d_fack_loss.backward()
disc_loss = d_real_loss+d_fack_loss#判别器的损失计算,由两个之和
d_optimizer.step() #梯度更新
#定义生成器的损失计算以及优化的过程
g_optimizer.zero_grad()
disc_gen_out = dis(annos,gen_output) #辨别器辨别输入图像和生成图像的匹配度
gen_loss_crossentropyloss = loss_fn(disc_gen_out,
torch.ones_like(disc_gen_out,
device=device)) #生成器和辨别器相反,他希望生成的图像和输入的图像匹配为真实,也就是造假嘛
gen_l1_loss = torch.mean(torch.abs(gen_output-imgs)) #L1损失
gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
gen_loss.backward() #反向传播
g_optimizer.step() #优化
#累计每一个批次的loss
with torch.no_grad():
D_epoch_loss +=disc_loss.item()
G_epoch_loss +=gen_loss.item()
上面用到的loss_fn是BCE损失。因为我们的辨别器输出值为概率嘛,0到1,所以算得上是二分类,可以使用BCE。