pytorch 7月23日学习---cgan代码学习2

一. cgan

cGAN网络拓扑结构

二. 训练过程

1. Train Generator

(1)定义valid 和fake,定义real_imgs和labels

valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)

fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

(2)

optimizer_G.zero_grad( )

(3)随机生成 z 与 gen_labels(0~9) 

z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

gen_labels =Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size) ))

(4)生成器生成图片

gen_imgs = generator(z, gen_labels)

(5)计算生成器loss

validity = discriminator(gen_imgs, gen_labels)

g_loss = adversarial_loss(validity, valid)

(6)更新G

  g_loss.backward()

  optimizer_G.step()

2. Train Discriminator

(1)

optimizer_D.zero_grad( )

(2)计算假图片的loss

validity_real = discriminator(real_imgs, labels)

d_real_loss = adversarial_loss(validity_real, valid)

(3)计算真图片的loss

validity_fake = discriminator(gen_imgs.detach(), gen_labels)

d_fake_loss = adversarial_loss(validity_fake, fake)

(4)计算总loss

d_loss = (d_real_loss + d_fake_loss) / 2

(5)更新D 

 d_loss.backward()

 optimizer_D.step()

3. 一些函数

1. numpy.random.randint(low, high, size)

low、high、size三个参数。默认high是None,如果只有low,那范围就是[0,low)。如果有high,范围就是[low,high)

>>> a=np.random.randint(0, 10, 64)

>>> a

array([3, 0, 2, 2, 9, 8, 7, 8, 0, 0, 9, 5, 2, 0, 4, 6, 4, 9, 8, 7, 0, 9,

       2, 6, 1, 3, 5, 3, 8, 5, 3, 9, 6, 6, 3, 7, 9, 6, 8, 4, 5, 2, 0, 0,

       4, 0, 1, 8, 1, 7, 0, 4, 3, 8, 5, 4, 4, 6, 8, 2, 2, 9, 3, 8])

2. numpy.random.normal(loc=0.0, scale=1.0, size=None)

loc:float

    此概率分布的均值(对应着整个分布的中心centre)

scale:float

    此概率分布的标准差(对应于分布的宽度,scale越大越矮胖,scale越小,越瘦高)

size:int or tuple of ints

输出的shape,默认为None,只输出一个值

三. 保存图片

def sample_image(n_row, batches_done):

    """Saves a grid of generated digits ranging from 0 to n_classes"""

    # Sample noise

    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))

    # Get labels ranging from 0 to n_classes for n rows

    labels = np.array([num for _ in range(n_row) for num in range(n_row)])

    labels = Variable(LongTensor(labels))

    gen_imgs = generator(z, labels)

    save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)

四. 结果展示

源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cgan

猜你喜欢

转载自blog.csdn.net/weixin_42445501/article/details/81172965