Generative Adversarial Nets (GAN)

Generative Adversarial Nets

The article, to lead the ideological confrontation of learning, which is more valuable theoretical proof that they rarely Watch crucial.

aims

GAN, translation generated against the network, the purpose is to train a network to fit the distribution of data, the previous method, similar to the Gaussian kernel, Parzen window and so can be used to estimate (though not very familiar).

GAN has two networks, one is G (z) generation networks, and D (x) determines a network, where \ (Z \) subject to a random distribution, and \ (X \) is the original data, \ (Z \) obey a random distribution is a very important point, assuming \ (\ Hat {X} = G (X) \) , then:
\ [P (\ Hat {X}) = \ int P (Z) the I (G (Z ) = \ hat {x})
\ mathrm {d} z \] wherein \ (the I \) denotes the indicator function, which means that the network \ (G \) is a profile, and we want, this distribution is capable of take as much as possible to fit the original data \ (x \) distribution.

frame

Here Insert Picture Description
GAN need to train the above two networks, output D is a scalar a from 0 to 1, which means that the x input whether the real data (the real one), so the loss of function (V (D, G) part) :

Here Insert Picture Description
In practice, the fixed network to update network G D, then updating the network D G fixed network, iteratively:
Here Insert Picture Description

theory

As for why you can do it, the author gives proof of refining.

Here Insert Picture Description

Here Insert Picture Description
The above only puzzling point to prove that \ (p_z \ rightarrow p_g \) changes, I began to feel this change is the use of yuan, but seen from the other blog, it seems that with the knowledge of the derivative measure theory Finally, use a variation of knowledge.

Here Insert Picture Description
Wherein:
Here Insert Picture Description
the proof idea is, when (p_g = p_ {data} \ ) \ time, \ (C (G) = - \ log. 4 \) , it suffices to show that the minimum value, and only then \ ( p_g = p_ {data} \) when it was established proved over, and to prove it, the authors scrape together a JSD, but its just to meet our requirements (KL divergence can actually just Gibb inequality ).

Numerical experiments

Experiment (simulation code written in others), our goal is to give a natural z, G can be given some of the figures in the MNIST dataset.

Network without using convolution layers:
Here Insert Picture Description
network layer with a convolution, but no matter \ (z \) how they change, the results are the same, feeling a little strange, but in fact, if \ (G \) has generated things are an allegory He said to be 1, that indeed able to fool \ (D \) , the question of what it considered? Contrary to ah ...
Here Insert Picture Description

Code

Code One thing to note is that with BCELoss, but G network to update the time, not the incoming fake_label, but real_label, because G requires fool D, I do not know what to say, should understand.


import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt


class Generator(nn.Module):
    def __init__(self, input_size):
        super(Generator, self).__init__()
        self.dense = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784)
        )

    def forward(self, x):
        out = self.dense(x)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dense = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.dense(x)
        return out



class Train:
    def __init__(self, trainset, batch_size, z_size=100, criterion=nn.BCELoss(), lr=1e-3):
        self.generator = Generator(z_size)
        self.discriminator = Discriminator()
        self.opt1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=0.9)
        self.opt2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=0.9)
        self.trainset = trainset
        self.batch_size = batch_size
        self.real_label = torch.ones(batch_size)
        self.fake_label = torch.zeros(batch_size)
        self.criterion = criterion
        self.z_size = z_size


    def train(self, epoch_size, path):
        running_loss1 = 0.0
        running_loss2 = 0.0
        for epoch in range(epoch_size):
            for i, data in enumerate(self.trainset, 0):
                try:
                    real_img, _ = data

                    out1 = self.discriminator(real_img)
                    real_loss = self.criterion(out1, self.real_label)

                    z = torch.randn(self.batch_size, self.z_size)
                    fake_img = self.generator(z)
                    out2 = self.discriminator(fake_img)
                    fake_loss = self.criterion(out2, self.fake_label)

                    loss = real_loss + fake_loss
                    self.opt2.zero_grad()
                    loss.backward()
                    self.opt2.step()

                    z = torch.randn(self.batch_size, self.z_size)
                    fake_img = self.generator(z)
                    out2 = self.discriminator(fake_img)
                    fake_loss = self.criterion(out2, self.real_label) #real_label!!!!

                    self.opt1.zero_grad()
                    fake_loss.backward()
                    self.opt1.step()

                    running_loss1 += fake_loss
                    running_loss2 += real_loss
                    if i % 10 == 9:
                        print("[epoch:{}    loss1: {:.7f}   loss2: {:.7f}]".format(
                            epoch,
                            running_loss1 / 10,
                            running_loss2 / 10
                        ))
                        running_loss1 = 0.0
                        running_loss2 = 0.0
                except ValueError as err:
                    print(err)  #最后一批的数据可能不是batch_size
                    continue
        torch.save(self.generator.state_dict(), path)

    def loading(self, path):
        self.generator.load_state_dict(torch.load(path))
        self.generator.eval()
"""
加了点卷积
"""
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt


class Generator(nn.Module):
    def __init__(self, input_size):
        super(Generator, self).__init__()
        self.dense = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784)
        )

    def forward(self, x):
        out = self.dense(x)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 3, 2),  # 1x28x28 --> 32x10x10
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32 x 10 x 10 --> 32x5x5
            nn.Conv2d(32, 64, 3, 1, 1),  # 32x5x5-->32x5x5
            nn.ReLU()
        )
        self.dense = nn.Sequential(
            nn.Linear(1600, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), 1, 28, 28)
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        out = self.dense(x)
        return out



class Train:
    def __init__(self, trainset, batch_size, z_size=100, criterion=nn.BCELoss(), lr=1e-3):
        self.generator = Generator(z_size)
        self.discriminator = Discriminator()
        self.opt1 = torch.optim.SGD(self.generator.parameters(), lr=lr, momentum=0.9)
        self.opt2 = torch.optim.SGD(self.discriminator.parameters(), lr=lr, momentum=0.9)
        self.trainset = trainset
        self.batch_size = batch_size
        self.real_label = torch.ones(batch_size)
        self.fake_label = torch.zeros(batch_size)
        self.criterion = criterion
        self.z_size = z_size


    def train(self, epoch_size, path):
        running_loss1 = 0.0
        running_loss2 = 0.0
        for epoch in range(epoch_size):
            for i, data in enumerate(self.trainset, 0):
                try:
                    real_img, _ = data

                    out1 = self.discriminator(real_img)
                    real_loss = self.criterion(out1, self.real_label)

                    z = torch.randn(self.batch_size, self.z_size)
                    fake_img = self.generator(z)
                    out2 = self.discriminator(fake_img)
                    fake_loss = self.criterion(out2, self.fake_label)

                    loss = real_loss + fake_loss
                    self.opt2.zero_grad()
                    loss.backward()
                    self.opt2.step()

                    z = torch.randn(self.batch_size, self.z_size)
                    fake_img = self.generator(z)
                    out2 = self.discriminator(fake_img)
                    fake_loss = self.criterion(out2, self.real_label) #real_label!!!!

                    self.opt1.zero_grad()
                    fake_loss.backward()
                    self.opt1.step()

                    running_loss1 += fake_loss
                    running_loss2 += real_loss
                    if i % 10 == 9:
                        print("[epoch:{}    loss1: {:.7f}   loss2: {:.7f}]".format(
                            epoch,
                            running_loss1 / 10,
                            running_loss2 / 10
                        ))
                        running_loss1 = 0.0
                        running_loss2 = 0.0
                except ValueError as err:
                    print(err)  #最后一批的数据可能不是batch_size
                    continue
        torch.save(self.generator.state_dict(), path)

    def loading(self, path):
        self.generator.load_state_dict(torch.load(path))
        self.generator.eval()

Guess you like

Origin www.cnblogs.com/MTandHJ/p/11332262.html