GAN_pytorch

数据集 MNIST

1.对抗损失loss 使用torch.nn.BCEloss()

#Loss
adversarial_loss = nn.BCELoss()  #求一个二分类的交叉熵

为什么使用交叉熵函数?

最大每个真实抽样xi的likelihood <--> 最小化KL散度 <--> 交叉熵作为目标函数

2.训练时的对抗判别

   #-------------------------- train generator----------------------------------------------------
        optimizer_G.zero_grad()

        # sample noise
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) #正态分布,均值0,标准差1,size(n,100)

        # generate fake imgs
        gen_imgs = generator(z)

        # loss measure -> discriminator's ability
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)  

        g_loss.backward()
        optimizer_G.step()

        # ---------------------------train discriminator-----------------------------------------------
        optimizer_D.zero_grad()

        # loss measure -> generator's ability
        real_loss = adversarial_loss(discriminator(real_imgs), valid)  
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 这里用detach,目的是将variable参数分隔,不参与梯度更新

        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

值得注意的一点是,

训练生成器时,也是用D计算loss,但目的是使生成的和1更靠近;

训练判别器时,用D计算loss,real_img和1靠近,fake和0靠近 。

猜你喜欢

转载自www.cnblogs.com/Wiikk/p/12815520.html
GAN