Tribute to GAN and my favorite framework pytorch

The editor started to study in the laboratory from the summer vacation of 17 years, and taught herself deep learning and machine learning, but the understanding was not deep; in the summer vacation of 18 years, I started to take over the topic "line simplification" from a senior sister, but in fact, the work I did was mainly Data set labeling and running experiments, even though the code was slightly changed at the end and a few pictures were provided in the paper, the final reward is "five works (non-work)"; well, sure enough, people still have to be strong to gain the initiative; Half of the 19 years have passed, and the editor only recently understood GAN and pytorch-based implementations in a deeper level, and reproduced the legendary CartoonGAN-pytorch in 2018.

(1) Basic principles of GAN

The reason why GAN works is the internal confrontation between generator G and discriminator D in GAN. G continuously improves its generation ability (in image conversion tasks, it is image conversion ability), and maps the source domain SRC data samples X For Y', try to hide D, make D mistakenly think that Y'is the sample Y of the target domain TAR.

At the macro level (on the entire data domain TAR or SRC), it is hoped that G will map the data distribution where X is located to the data distribution of Y', so that the data distribution of Y'and Y is similar.

(2) Mathematical expression

Generator G is generally set to be an Encoder-Decoder model. In the CV field, its role is often to output an image based on the input image (or noise); the role of the discriminator is to be a two-classifier, distinguishing True or False.

Generally, we x\overset{G}{\rightarrow}z=G(x)\overset{D}{\rightarrow}True /Falsehave: .

In classification tasks, our commonly used loss is cross-entropy loss; for binary classification, we are accustomed to using binary-entropy loss (BCE: Binary-entropy loss). It is defined as follows:

For G, we hope that the synthesized picture G(x)can be hidden from the detection of detector D, so D(G(x))the larger the response we hope (close to 1); namely-

L_{G}=log(1-D(G(x)))

The G(x)more similar, the D(G(x))larger, and the 1-D(G(x))smaller , the smaller the value after log (close to negative infinity).

For D, I hope that it can recognize which images are really from the TAR domain and which are counterfeit. Therefore, we hope that the D(y)larger the response value (close to 1), and the D(G(x))smaller the response value (close to 0); that is— —

L_{D}=-log(D(G(y))-log(1-D(G(x))).

(3) GAN update strategy combined with pytorch script

The content below comes from two good blogs that I have referenced. I think I should try to repeat it again to be more impressed.

1. Strategy 1: Update the discriminator D first, and then update the generator G.

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定义用于计算对抗损失的两个目标(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真标签,都是1
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  # 假标签,都是0

        #### Train D
        optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零

        ## Train D with real data (Y)
        real_imgs = imgs.to(device)
        pred_real = discriminator(real_imgs)               # 判别器对真数据的输出
        real_loss = adversarial_loss(pred_real, valid)     # 判别器对真实样本的损失
        
        ## Train with fake data (Y')
        z = torch.randn((imgs.shape[0], 100)).to(device)   # 噪声
        gen_imgs = generator(z)                            # 从噪声中生成假数据
        pred_gen = discriminator(gen_imgs)                 # 判别器对假数据的输出
        fake_loss = adversarial_loss(pred_gen, fake)       # 判别器对假样本的损失

        d_loss = (real_loss + fake_loss) / 2               # 两项损失相加取平均

        # 下面这行代码十分重要,将在正文着重讲解
        d_loss.backward(retain_graph=True)                 # 计算权重梯度;retain_graph 十分重要,显示声明保留计算图;否则计算图内存将会被释放
        optimizer_D.step()                                 # 判别器参数更新

        #### Train G
        g_loss = adversarial_loss(pred_gen, valid)         # 生成器的损失函数
        optimizer_G.zero_grad()                            # 生成器参数梯度归零
        g_loss.backward()                                  # 反向传播计算生成器的损失函数梯度
        optimizer_G.step()                                 # 生成器参数更新
    
        # end of the iteration
    # end of the epoch

Before the discussion, we made it clear:

  • The calculation diagram is used to record a calculation process. In this process, the square represents "operation", the prototype represents "variable", and a "variable" includes data and weight. Forward propagation along the calculation graph can get the result and each intermediate variable; back propagation can calculate the gradient of the variable of each node! Among them, a calculation graph is used to backpropagate BP once and only once . In other words, if a calculation graph still exists after the backpropagation, it can prevent the gradient update after the second backpropagation for this part of the calculation in the current iteration.
  • For the same calculation graph, multiple sets of data streams pass through in the same iteration. After the loss calculated by different data streams is superimposed, the back propagation is only calculated once, which is equivalent to the accumulation of gradients!

When training the discriminator-during calculation D(y), the calculation graph includes the entire forward process of D; during calculation D(G(x)), the calculation graph includes the entire forward process of G and D. Since d_loss includes real_loss =criterion( D(y), True_label) and fake_loss =criterion( D(G(x)), False_label ), when backpropagating, G and D are backpropagated. However, we noticed that only optimizer_D did the gradient update later. We know that when the optimizer in pytorch is initialized, it will "arrange" which modules it is responsible for updating the parameters of which module, and it will ignore other modules. Therefore, the optimizer_D.step() only updates the gradient of D once!

It can be seen that the gradients of D and G are calculated in reverse this time, but only the gradient of D is updated. Since G does not update the gradient, its calculation graph part is retained.

When training the generator-directly use the previously calculated D(G(x)) without repeated calculations. The calculation graph includes the entire forward process of D. When the loss is backpropagated, it is necessary to both D and G Do a reverse to find the gradient. This is why we have to declare to retain the calculation graph in loss_D.backward(), because the later generator will also use this part of the calculation graph of G when calculating the gradient, so this parameter is used to control the calculation graph not to be released. Of course, notice that the script we wrote here is: the fake data used when training D and the fake data used when training G are the same set of data; that is, if you use the fake data when training D and the fake data used when training G When the fake data is different (after being initialized and generated by G), there is no need, because the calculation graph of G will be generated again.

On the other hand, we see that before loss_G backpropagation, we must declare optimizer_G.zero_grad() to zero the gradient calculated during the last loss_D reverse direction ( otherwise it will be superimposed together, and considering that we train D first, It is because the update of G depends on D, so the gradient of G calculated when training D first has no basis and is of little value ). After that, we use optimizer_G.step() to achieve only the gradient update of G.

To sum up, in this strategy, we did two backpropagation on both D and G (the gradient was calculated twice)-the first propagation was to update the parameters of D, but the gradient of G had to be calculated additionally; The second propagation is to update the parameters of G, but the gradient of D has to be calculated additionally.

2. Strategy 2: Update generator G first, then update discriminator D.

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定义用于计算对抗损失的两个目标(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真标签,都是1
        fake  = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0

        real_imgs = Variable(imgs.type(Tensor))                     # 真实数据 y

        #### 训练生成器
        optimizer_G.zero_grad()                                     # 生成器参数梯度归零
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 随机噪声
        gen_imgs = generator(z) # 根据噪声生成虚假样本
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)   # 用真实的标签+假样本,计算        生成器损失
        g_loss.backward()                                           # 生成器梯度反向传播,反向传播经过了判别器,故此时判别器参数也有梯度
        optimizer_G.step()                                          # 生成器参数更新,判别器参数虽然有梯度,但是这一步不能更新判别器

        #### 训练判别器
        optimizer_D.zero_grad()                                     # 把生成器损失函数梯度反向传播时,顺带计算的判别器参数梯度清空
        real_loss=adversarial_loss(discriminator(real_imgs), valid) # 真样本+真标签:判别器损失
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假样本+假标签:判别器损失;注意这里的".detach()"的使用
        d_loss = (real_loss + fake_loss) / 2                        # 判别器总的损失函数
        d_loss.backward()                                           # 判别器损失回传
        optimizer_D.step()                                          # 判别器参数更新
        
        # end of the iteration
    # end of the epoch

Analysis found:

When training G, when calculating gen_imgs = G(z), a calculation graph covering the forward process of G is generated. When calculating g_loss = criterion(G(z), True), a calculation graph covering the forward process of D is generated, and g_loss is reversed. The gradients are calculated for both D and G during forward propagation; but we only use optimizer_G to update the gradient of G. After finishing, the calculation graphs of D and G are released.

When training D, when calculating D(y), a calculation graph that only includes D is generated, and when calculating D(G(z)), it passes through the calculation graph of D just generated. When sharing propagation, real_loss backpropagated and calculated the gradient of D once; then loss_fake wanted to backpropagate, but when it walked back to the input position of D, it found that there was no way forward, because The calculation graph for calculating its G is released. Therefore, we need to display and tell it to update the gradient to this point, which is achieved by "G(z).detach()". Detachment means that this data is "decoupled" from the calculation graph that generated it, that is, the gradient stops at the place where it is transmitted, and no longer continues to propagate.

To sum up, in this strategy, we did a backpropagation to G and backpropagation to D twice. And there is no need to specifically keep the calculation graph of G in the memory.

 

【to sum up】

The advantage of strategy 1 is that the noise is forwarded only once, but the disadvantage is that it needs to back-propagate both D and G twice, and the calculation graph (D+G) needs to be kept in memory.

The advantage of strategy 2 is: update G first, so that the calculation graph (D+G) of the forward propagation after the update can be safely destroyed without taking up too much memory; updating D later obviously needs to generate a new calculation graph again, but This time only D is included, which is smaller than Strategy 1. At the same time, this is the second forward propagation of D, and similarly, the second back propagation is required.

The former is to obtain the gradient of the back propagation of G one more time; the latter is to the forward propagation of D one more time. If D is more complicated, strategy 1 should be adopted; otherwise, strategy 2 should be adopted. Under normal circumstances, D is much simpler than G, so strategy 2 should be adopted mostly.

(The last sentence comes from the original text on Zhihu) But the second one is always weird to update the generator first, and then update the discriminator, because the update of the generator requires the discriminator to provide accurate loss and gradient, otherwise, would it be a blind update? ?

 

【reference】

1.  Pytorch: detach and retain_graph, and the principle analysis of GAN

2. Pytorch: detach 和 retain_graph

Guess you like

Origin blog.csdn.net/WinerChopin/article/details/95986694
Recommended