Aprendizaje semi-supervisado con redes generativas de confrontación (小白 学 GAN 七)

Enlace original: https://arxiv.org/pdf/1606.01583.pdf

Introducción

La idea central: discriminador al determinar la autenticidad de los datos, mientras que el proceso de clasificación de datos, es decir, la integración de la tarea de aprendizaje supervisada por GAN original al discriminador en .

       

Como se muestra en la figura anterior, el discriminador es en realidad un modelo de fusión de múltiples tareas, que no solo completa el juicio sobre la autenticidad de los datos, sino que también los clasifica.

estructura basica

El generador y el discriminador de este artículo se tomaron prestados de DCGAN , pero este artículo no utiliza la deconvolución sino una capa de muestreo superior.

PÉRDIDA

Debido a que la estructura del modelo cambia, la salida del discriminador son dos cosas, por lo que también se calculan dos partes al calcular la PÉRDIDA. El cálculo de LOSS de la parte de autenticidad de los datos es similar a otros GAN, por lo que no lo repetiré aquí. En cuanto al cálculo de la PÉRDIDA de la categoría de datos, se divide en dos casos para discutir: el primer tipo de datos reales se ingresa al discriminador, la etiqueta de datos reales se usa para calcular la PÉRDIDA de la categoría de datos de salida; el segundo tipo es cuando los datos de entrada son datos generados, la etiqueta Es un muestreo aleatorio del número total de categorías.

\ min_G \ max_D V (D, G) = \ mathbb {E} _ {x \ sim p_ {data} (x)} [logD_ {r} (x) + D_c (\ hat y = y | x)] \ \ \ \ \ \ + \ mathbb {E} _ {z \ sim p_ {z} (z)} [(1-logD_ {r} (G (z))) + D_c (\ hat y = y '| G (z))]

Semi supervisado

En el aprendizaje supervisado, se conocen las etiquetas de todos los datos, mientras que se desconocen las etiquetas de parte de los datos en el aprendizaje semi-supervisado. Bajo esta arquitectura, la etiqueta de los datos generados por el generador es simplemente desconocida, pero esta es otra situación embarazosa. Como no hay etiqueta, ¿cómo calcular la PÉRDIDA después de la clasificación? El autor agregó otra categoría a la categoría original. , Utilice esta clase como etiqueta para los datos generados. Esto es muy inteligente. Cuando los datos generados comienzan a ajustarse gradualmente a los datos originales, el discriminador tiene que profundizar más en el aprendizaje de los detalles de los datos para distinguir los datos generados y eliminar la interferencia, y su capacidad discriminativa también se obtiene en el proceso. Una mejora adicional.

La figura anterior son los resultados de la prueba del autor sobre el conjunto de datos minist Se puede ver que cuando el tamaño de la muestra es pequeño, el uso de esta estrategia de aprendizaje puede obtener mejores resultados de clasificación.

Resultados de código y práctica

Enlace de referencia: https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/sgan/sgan.py

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


#测试判别器的分类能力
#将预测的最后一维去掉


acc = 0
for imgs, labels in dataloader:
    real_imgs = Variable(imgs.type(FloatTensor))
    _,predict = discriminator(real_imgs)
    predict = predict[:,:-1].cpu().detach().numpy()
    d_acc = np.mean(np.argmax(predict, axis=1) == labels.numpy())
    acc += d_acc
print("discriminator's acc:" , acc/len(dataloader))

resultados de la prueba minist

Cuando el entrenamiento finaliza después de 400 iteraciones epoch, la precisión del discriminador es de aproximadamente el 50%. Cuando se prueba el rendimiento del discriminador, la salida del vector de etiqueta del discriminador se elimina de la última dimensión, es decir, los datos generados ya no se clasifican en una categoría, luego la discriminación El detector puede considerarse un clasificador ordinario, por lo que no se dividirá en la undécima categoría de datos generados. Sin embargo, sucedió algo triste. Después de quitar la undécima dimensión, la precisión de la medición era casi cero. Es decir, este clasificador básicamente solo puede distinguir si los datos son verdaderos o no. Por eso, la tasa de precisión es cuando se termina el entrenamiento. La razón por la que está cerca del 50% es porque básicamente el discriminador en las diez categorías reales está "adivinando". Se puede ver que es difícil entrenar con éxito, pero aquí no se utilizan técnicas como WGAN, por lo que es difícil entrenar para converger.

En resumen, el entrenamiento semi-supervisado mencionado en el artículo es todavía inmaduro .

 

 

Supongo que te gusta

Origin blog.csdn.net/fan1102958151/article/details/106451799
Recomendado
Clasificación