深度学习总结:GAN,原理,算法描述,pytoch实现

GAN的原理图:

在这里插入图片描述

GAN的原版算法描述:

在这里插入图片描述

pytorch实现

构建generator和discriminator:

G = nn.Sequential(                      # Generator
    nn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)

D = nn.Sequential(                      # Discriminator
    nn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

生成fake data:

    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)                    # fake painting from G (random ideas)

生成real data:


def artist_works():     # painting from the famous artist (real target)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
    paintings = torch.from_numpy(paintings).float()
    return paintings

artist_paintings = artist_works()           # real painting from artist

定义训练D的loss,定义训练G的loss, 实际就是forward pass:

这个loss就相当于把G和D连接起来了,形成通路了,这里实际上体现了pytorch动态图的思想。

    prob_artist0 = D(artist_paintings)          # D try to increase this prob
    prob_artist1 = D(G_paintings)               # D try to reduce this prob
    
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

优化过程, 实际就是backward pass:

为了实现fixedGtrainD,fixedDtrainG,我们设计优化器更新指定区域的参数:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

fixedG trainD:

opt_D.zero_grad():需要先初始化opt_D,避免前面的数据影响当前更新。
D_loss.backward(retain_graph=True) :计算整个graph梯度,retain_graph=True,需要保持计算图,啥意思?pytorch默认计算一次backward就释放当前graph,释放了就是你必须从头开始走forward pass ,而这里我们需要重新走一遍原图的D部分。
opt_D.step():根据梯度更新指定区域的参数

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)      # reusing computational graph
    opt_D.step()

fixedDtrainG:

G_loss.backward():G_loss = torch.mean(torch.log(1. - prob_artist1)),prob_artist1 = D(G_paintings) 可以看出我们需要重新走一遍D(G不需要走),这个是在原来graph上操作的,这就是为什么需要retain graph的原因

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

完整版训练过程:

def artist_works():     # painting from the famous artist (real target)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a-1)
    paintings = torch.from_numpy(paintings).float()
    return paintings

G = nn.Sequential(                      # Generator
    nn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)

D = nn.Sequential(                      # Discriminator
    nn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()   # something about continuous plotting

for step in range(10000):
    artist_paintings = artist_works()           # real painting from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideas
    G_paintings = G(G_ideas)                    # fake painting from G (random ideas)

    prob_artist0 = D(artist_paintings)          # D try to increase this prob
    prob_artist1 = D(G_paintings)               # D try to reduce this prob

    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
    G_loss = torch.mean(torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)      # reusing computational graph
    opt_D.step()

    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

猜你喜欢

转载自blog.csdn.net/weixin_40759186/article/details/87532637
今日推荐