数据集 MNIST
1.对抗损失loss 使用torch.nn.BCEloss()
#Loss adversarial_loss = nn.BCELoss() #求一个二分类的交叉熵
为什么使用交叉熵函数?
最大每个真实抽样xi的likelihood <--> 最小化KL散度 <--> 交叉熵作为目标函数
2.训练时的对抗判别
#-------------------------- train generator---------------------------------------------------- optimizer_G.zero_grad() # sample noise z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) #正态分布,均值0,标准差1,size(n,100) # generate fake imgs gen_imgs = generator(z) # loss measure -> discriminator's ability g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() optimizer_G.step() # ---------------------------train discriminator----------------------------------------------- optimizer_D.zero_grad() # loss measure -> generator's ability real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 这里用detach,目的是将variable参数分隔,不参与梯度更新 d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step()
值得注意的一点是,
训练生成器时,也是用D计算loss,但目的是使生成的和1更靠近;
训练判别器时,用D计算loss,real_img和1靠近,fake和0靠近 。