Principe GAN et méthode simple pour générer des images

Apprenons brièvement le GAN, principalement pour élargir l'ensemble de données. Actuellement, il y a trop peu de données disponibles et la quantité de données après amélioration des données (rotation, inversion, etc.) de 30 images dans une catégorie est loin d'être suffisante, donc J'ai essayé d'utiliser GAN. Générer des données, ajouter les données générées, puis les détecter et les classer. Je ne sais pas si cela donnera de bons résultats. Comme indiqué ci-dessous dans mon ensemble de données, je souhaite générer des fissures par lots et l'arrière-plan du circuit imprimé est complexe, donc je ne sais pas si cela peut être fait.
Insérer la description de l'image ici

Examen de certains articles et blogs : une simple augmentation de données à l'aide de GAN peut parfois améliorer les performances du classificateur, en particulier dans le cas d'ensembles de données très petits ou limités, mais les cas les plus prometteurs d'augmentation à l'aide de GAN semblent inclure l'apprentissage par transfert ou l'apprentissage en quelques étapes.

Principe du GAN

Concept GAN : Grâce à des jeux continus entre le réseau de génération G (Generator) et le réseau discriminateur D (Discriminator), G peut apprendre la répartition des données.

Insérer la description de l'image ici
La formation itérative alternative optimise en permanence les ajustements des paramètres pour que les mensonges deviennent réalité.

Formule mathématique :
Insérer la description de l'image ici
divisée en optimisation D et optimisation G :

Optimiser D :
Insérer la description de l'image ici
 Optimiser G :
Insérer la description de l'image ici
Lors de l'optimisation de D, c'est-à-dire du réseau discriminant, il n'y a en fait aucune génération de réseau. Le G(z) suivant est équivalent au faux échantillon obtenu. Optimisez le premier terme de la formule de D de sorte que lorsque le véritable échantillon x est entré, plus le résultat est grand, mieux c'est. Cela est compréhensible car le résultat de la prédiction du véritable échantillon doit être aussi proche de 1 que possible. Pour les faux échantillons, il est nécessaire d’optimiser le résultat le plus petit possible, c’est-à-dire que plus D(G(z)) est petit, mieux c’est, car son étiquette est 0. Mais le premier terme est plus grand et le deuxième terme est plus petit. Ce n'est pas une contradiction, alors changez le deuxième terme en 1-D(G(z)), de sorte que plus grand soit meilleur. La combinaison des deux est Plus grand le meilleur. Ensuite, lors de l'optimisation de G, il n'y a pas de véritable échantillon pour le moment, donc le premier élément est supprimé directement. Pour le moment, il n'y a que de faux échantillons, mais nous disons qu'à ce moment, nous espérons que l'étiquette du faux échantillon est 1, donc plus D(G(z)) est grand, mieux c'est, mais afin de l'unifier sous la forme de 1-D(G(z)) , alors il ne peut que minimiser 1-D(G(z)). Il n'y a pas de différence en substance, juste pour l'unification de la forme. Ensuite, ces deux modèles d'optimisation peuvent être combinés et écrits, et ils deviennent la fonction objectif maximale et minimale d'origine.

conseils de formation

  1. La fonction d'activation de la dernière couche utilise tanh (sauf BEGAN)

  2. En utilisant la fonction de perte de wassertein GAN,

  3. S'il existe des données d'étiquette, essayez d'utiliser des étiquettes. Certaines personnes ont suggéré que l'utilisation d'étiquettes inversées est très efficace. De plus, utilisez le lissage des étiquettes, le lissage des étiquettes unilatérales ou le lissage des étiquettes bilatérales.

  4. Utilisez la norme de mini-lot. Si vous n’utilisez pas la norme de lot, vous pouvez utiliser la norme d’instance ou la norme de poids.

  5. Pour éviter d'utiliser RELU et regrouper des couches et réduire la possibilité de dégradés clairsemés, vous pouvez utiliser la fonction d'activation de leakrelu

  6. L'optimiseur doit essayer de choisir ADAM. Ne définissez pas le taux d'apprentissage trop élevé. Le 1e-4 initial peut être utilisé comme référence. De plus, le taux d'apprentissage peut être continuellement réduit au fur et à mesure de la progression de la formation.

  7. Ajouter du bruit gaussien à la couche réseau de D équivaut à un bruit régulier

GAN grand public

Réalisation d'une série de lectures et d'analyses sur les modèles GAN traditionnels :

  1. StyleGAN : Convient à la synthèse faciale, l'effet est très bon, mais je ne l'ai pas vu utilisé pour d'autres générations de données, donc je n'ai pas l'intention de l'utiliser ;
  2. CycleGAN et pix2pix GAN : Convient principalement au transfert de style. Je prévois uniquement de générer des images de haute qualité, et je suis trop paresseux pour séparer l'arrière-plan et l'objet pour l'appairage. Cela n'a rien à voir avec la transformation de style, donc je n'ai pas l'intention de le faire utilisez-le. Si l'effet ultérieur n'est pas bon, j'envisagerai d'utiliser cette méthode ;
  3. SMOTE : Certaines personnes utilisent l'algorithme smote pour effectuer une amplification des données sur des échantillons de classes minoritaires, mais elles n'utilisent que l'algorithme K voisin le plus proche pour examiner la distribution des données. Enfin, elles utilisent le GAN pour l'expansion des données. Dans des applications pratiques, l'ensemble de données peut Je suis également confronté à non seulement des problèmes. Le problème est que les données sont déséquilibrées et que la taille de l'échantillon de chaque type est petite, ce qui rend difficile la formation efficace du modèle, je n'ai donc pas l'intention de l'essayer.
  4. SGAN (Semi-Supervised GAN) : En GAN semi-supervisé, le Discriminateur dispose de N+1 sorties, où N est le nombre de classifications de l'échantillon. Cette méthode peut former un classificateur plus efficace et également générer des échantillons de meilleure qualité. Dans l'expérience MNIST, nous pouvons voir si SGAN peut générer de meilleurs échantillons que le GAN général, comme le montre la figure 3, SGAN (à gauche) La sortie de est plus claire que celui du GAN (à droite), qui semble correct pour différentes initialisations et architectures de réseau, mais il est difficile d'effectuer une évaluation systématique de la qualité des échantillons pour différents hyperparamètres. Je vais essayer
    Insérer la description de l'image ici
  5. DCGAN : Pas très efficace. . . Il est également possible que mon ensemble de données provienne d'un circuit imprimé et que l'image générée soit très floue. Peut-être qu'une simple image d'arrière-plan peut être utilisée. L'image ci-dessous est l'image que j'ai générée. . .
    Insérer la description de l'image ici

Implémentation partielle du code GAN (mnist à titre d'exemple 28x28)

Réseau Discriminateur

#定义判别器  #####Discriminator######使用多层网络来作为判别器
 
#将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis=nn.Sequential(
            nn.Linear(784,256),#输入特征数为784,输出为256
            nn.LeakyReLU(0.2),#进行非线性映射
            nn.Linear(256,256),#进行一个线性映射
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()
            # sigmoid可以班实数映射到【0,1】,作为概率值,
            # 多分类用softmax函数
        )
    def forward(self, x):
        x=self.dis(x)
        return x

Réseau Génératif

####### 定义生成器 Generator #####
#输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
# 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
# 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布
# 能够在-1~1之间。
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen=nn.Sequential(
            nn.Linear(100,256),#用线性变换将输入映射到256维
            nn.ReLU(True),#relu激活
            nn.Linear(256,256),#线性变换
            nn.ReLU(True),#relu激活
            nn.Linear(256,784),#线性变换
            nn.Tanh()#Tanh激活使得生成数据分布在【-1,1】之间
        )
 
    def forward(self, x):
        x=self.gen(x)
        return x


#创建对象
D=discriminator()
G=generator()
if torch.cuda.is_available():
    D=D.cuda()
    G=G.cuda()

Train discriminateur

La formation du discriminateur se compose de deux parties. La première partie consiste à juger les images réelles comme vraies et la deuxième partie consiste à juger les fausses images comme fausses. Dans ces deux processus, les paramètres du générateur ne participent pas à la mise à jour.

criterion = nn.BCELoss()  #定义loss的度量方式是单目标二分类交叉熵函数
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

Entrée dans la formation
Puisque la formation du générateur nécessite la sortie du discriminateur, le discriminateur doit être formé en premier.

img = img.view(num_img, -1)  # 将图片展开乘28x28=784
real_img = Variable(img).cuda()  # 将tensor变成Variable放入计算图中
real_label = Variable(torch.ones(num_img)).cuda()  # 定义真实label为1
fake_label = Variable(torch.zeros(num_img)).cuda()  # 定义假的label为0

# 计算真实图片的损失
real_out = D(real_img)  # 将真实图片放入判别器中
d_loss_real = criterion(real_out, real_label)  # 得到真实图片的loss
real_scores = real_out  # 得到真实图片的判别值,输出的值越接近1越好

# 计算假图片的损失
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 随机生成一些噪声
fake_img = G(z)  # 放入生成网络生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片
d_loss_fake = criterion(fake_out, fake_label)  # 得到假的图片的loss
fake_scores = fake_out  # 假的图片放入判别器越接近0越好

# 损失函数和优化
d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来
d_optimizer.zero_grad()  # 归0梯度
d_loss.backward()  # 反向传播
d_optimizer.step()  # 更新参数

Train Génératif

Principe : Le but est d'espérer que la fausse image générée sera jugée comme une image réelle par le discriminateur et que
D sera corrigé. Le résultat du passage de la fausse image dans le discriminateur correspond à l'étiquette réelle. Les paramètres mis à jour par rétropropagation sont les paramètres du réseau de génération. De cette façon, nous pouvons faire en sorte que le discriminateur juge les fausses images générées comme réelles en suivant les paramètres du réseau de nouvelle génération, réalisant ainsi le rôle de confrontation générative.

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()  # 得到随机噪声
fake_img = G(z)  # 生成假的图片
output = D(fake_img)  # 经过判别器得到结果
g_loss = criterion(output, real_label)  # 得到假的图片与真实图片label的loss
 
# bp and optimize
g_optimizer.zero_grad()  # 归0梯度
g_loss.backward()  # 反向传播
g_optimizer.step()  # 更新生成网络的参数

Code complet (simple)

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('./img'):
    os.mkdir('./img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 128
num_epoch = 100
z_dimension = 100

# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)


# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid())

    def forward(self, x):
        x = self.dis(x)
        return x


# Generator
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh())

    def forward(self, x):
        x = self.gen(x)
        return x


D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        real_img = Variable(img).cuda()
        real_label = Variable(torch.ones(num_img)).cuda()
        fake_label = Variable(torch.zeros(num_img)).cuda()

        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

La version de pytorch utilisée cette fois est 0.5 ou supérieure, elle est donc écrite comme train_loss+=loss.item(). S'il s'agit de la version 0.3, remplacez-la par train_loss+=loss.data[0], sinon une erreur sera signalée.

Résultats expérimentaux:

L'image de gauche est la 100ème image générée et l'image de droite est l'image réelle.
Insérer la description de l'image ici
résultats

Code complet (avec convolution)

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from matplotlib import pyplot as plt
import os

if not os.path.exists('./cnn_img'):
    os.mkdir('./cnn_img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 128
num_epoch = 100
z_dimension = 100    #噪声维度

# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)


# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),    # batch,32,28,28
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)     # batch,32,14,14
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)  # batch, 64, 7, 7
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# Generator
class generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(generator, self).__init__()
        # 1.第一层线性变换
        self.fc = nn.Linear(input_size, num_feature)  # batch, 3136=1x56x56
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.conv1_g = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.conv2_g = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.conv3_g = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.conv1_g(x)
        x = self.conv2_g(x)
        x = self.conv3_g(x)
        return x


D = discriminator()
G = generator(z_dimension, 3136)
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001)

# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        #print(img.shape)             # inputs:img=[128,1,28,28]
        # =================train discriminator
        #img = img.view(num_img, -1)        # img.shape: [128, 784]
        real_img = Variable(img).cuda()
        #print(real_img.shape)
        real_label = Variable(torch.ones(num_img)).cuda()
        fake_label = Variable(torch.zeros(num_img)).cuda()

        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './cnn_img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './cnn_img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './cnn_img/generator.pth')
torch.save(D.state_dict(), './cnn_img/discriminator.pth')

Résultats expérimentaux

Insérer la description de l'image ici
Le côté gauche est une image réelle et le côté droit est une image générée à l'époque 100. Je l'ai exécuté sur mon ordinateur NVIDIA-MX450 sans aucun problème et l'effet est toujours bon.

Référence

1. Compréhension simple et expérience du réseau neuronal contradictoire génératif GAN
2. Notes papier sur le GAN et l'amplification de petits échantillons
3. Analyse du principe du GAN (réseau neuronal contradictoire génératif)
4. Implémentation de Pytorch consistant à utiliser le GAN pour générer un ensemble de données MNIST

Guess you like

Origin blog.csdn.net/qq_42740834/article/details/123692837