pytorch实现GAN代码详解

  1. 设置超参数。
  2. (name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2),这句话设置了一个name和两个函数preprocess和d_input_func,前者用于预处理data,后者用于将data乘2。
  3. 定义函数get_distribution_sampler()通过输入均值和方差,返回正态分布的torch张量,维数是n维。
  4. 定义函数get_generator_input_sampler(),无输入,返回m×n维的torch张量。
  5. class Generator定义生成器,class Discriminator定义判别器。
  6. 定义函数extract(),用于将输入的v矩阵,转换为列表形式。
  7. 定义函数states(),返回d向量的平均值和标准差。
  8. 定义函数decorate_with_diffs(),第一行mean = torch.mean(data.data, 1, keepdim=True)用于返回按1维返回data.data的均值保存于mean;第二行mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0]),torch.mul(input, value, out=None)函数用于标量值value乘以input中的所有值,因此mean_broadcast里保存了一组与data维度相同的均值矩阵;第三行diffs = torch.pow(data - Variable(mean_broadcast), exponent),torch.pow(input, exponent, out=None)函数用于计算input的exponent次幂,因此可以判断这个函数用于预处理数据。
  9. 下面进入主函数,首先调用get_distribution_sampler()函数,生成1×n维正态分布的torch张量d_sampler,作为真实数据。
  10. 调用get_generator_input_sampler()函数,生成m×n维的伪造数据,m作为mini_batch。
  11. 通过三个超参数g_input_size,g_hidden_size,g_output_size构造生成器的神经网络模型G。
  12. 通过三个超参数d_input_func(d_input_size),d_hidden_size,d_output_size构造判别器的神经网络模型D,其中d_input_func()函数定义为x*2。
  13. 定义损失函数标准为nn.BCELoss(),生成器和判别器的优化算法都为Adam算法。
  14. 进入epoch迭代,分为两部分,首先为判别器的迭代优化,真实数据赋予标签1,backward()反向传递一次,伪造数据赋予标签0,反向传递一次。
  15. 之后迭代优化生成器,通过随机产生的gi_sampler数据经过G网络产生伪造的数据g_fake_data,g_fake_data经过D网络判别真假,由于G网络生成的数据要用于以假乱真,因此赋予的标签应为1,也就是说判别器越判别成假,则误差越大。产生的误差反向传播用于优化生成器。

代码来源:

https://github.com/devnag/pytorch-generative-adversarial-networks

猜你喜欢

转载自blog.csdn.net/u012759006/article/details/81136437