Pytorch implementa la reparación de imágenes: codificador automático de contexto GAN+

Haga clic en " Aprendizaje automático y red antagónica generativa " arriba, siga la estrella

¡Obtenga productos secos de última generación interesantes y divertidos!


Autor:Hmrishav Bandyopadhyay

Compilación: cuenta pública ronghuaiyang AI Park

guía

Un artículo más clásico sobre restauración de imágenes.

¿Sabías que las fotos antiguas de la infancia en ese álbum de fotos polvoriento se pueden recuperar? ¡Sí, del tipo en el que todos se toman de la mano y disfrutan de la vida! ¿No me crees? Échale un vistazo:

La restauración de imágenes es un área activa de la investigación de la IA, y la IA ha podido producir mejores resultados de restauración que la mayoría de los artistas. En este artículo, discutimos la imagen en pintura usando redes neuronales, específicamente codificadores contextuales. Este artículo explica e implementa el trabajo de investigación sobre codificadores de contexto presentado en CVPR 2016.

codificador de contexto

Para comenzar a usar codificadores contextuales, debemos entender qué es un "autocodificador". Un autocodificador se compone estructuralmente de un codificador, un decodificador y un cuello de botella. El propósito de un codificador automático general es reducir el tamaño de una imagen ignorando el ruido en la imagen. Sin embargo, los codificadores automáticos no son específicos de las imágenes y también pueden extenderse a otros datos. Los codificadores automáticos tienen variantes específicas para realizar tareas específicas.

Estructura del codificador automático

Ahora que entendemos los codificadores automáticos, podemos comparar los codificadores contextuales con los codificadores automáticos. Un codificador contextual es una red neuronal convolucional que está entrenada para generar el contenido de una región de imagen arbitraria en función de su entorno: es decir, un codificador contextual toma datos alrededor de una región de imagen e intenta generar algo apropiado para esa región de imagen. Es como armar un rompecabezas cuando éramos niños, excepto que no necesitamos generar las piezas del rompecabezas.

Nuestro codificador de contexto aquí consta de un codificador, que captura el contexto de una imagen como una representación compacta de características latentes, y un decodificador, que utiliza esta representación para generar contenido de imagen faltante. Dado que necesitamos un gran conjunto de datos para entrenar una red neuronal, no podemos simplemente tratar de pintar imágenes problemáticas. Por lo tanto, segmentamos imágenes del conjunto de datos de imagen normal, creamos un problema de pintura interna y alimentamos la imagen a una red neuronal, creando contenido de imagen faltante en las regiones que segmentamos.

La advertencia aquí es que las imágenes alimentadas a la red neuronal tienen tantas partes faltantes que los métodos clásicos de pintura interna simplemente no funcionarán.

Uso de GAN

Las GAN o redes antagónicas generativas han demostrado ser extremadamente útiles para la generación de imágenes. El principio básico del funcionamiento de GAN es: un generador trata de "engañar" a un discriminador, y cierto discriminador trata de distinguir la imagen generada por el generador. En otras palabras, las dos redes intentan minimizar y maximizar la función de pérdida, respectivamente.

máscara de área

Una máscara de región es la parte de la imagen que enmascaramos para que podamos enviar las preguntas de pintura resultantes al modelo. Al enmascarar, establecemos los valores de píxeles de esa área de la imagen en 0. Hay tres métodos:

  1. Área central: para ocluir los datos de la imagen, la forma más fácil es poner a cero el parche cuadrado en el centro. Mientras la red aprende a arreglar, nos enfrentamos al problema de la generalización. La red no generaliza bien y solo puede aprender características de bajo nivel.

  2. Bloques aleatorios: para contrarrestar el problema del "bloqueo" de la red en los límites de las regiones enmascaradas, como en la máscara de la región central, el proceso de enmascaramiento es aleatorio. En lugar de elegir un único parche cuadrado como máscara, se establecen múltiples máscaras cuadradas superpuestas, ocupando hasta 1/4 de la imagen.

  3. Región aleatoria: sin embargo, el enmascaramiento de bloques aleatorios todavía tiene límites claros para que la red los capture. Para resolver este problema, se deben eliminar las formas arbitrarias de la imagen. Se pueden obtener formas arbitrarias del conjunto de datos PASCAL VOC 2012 y deformarse y colocarse como máscaras en ubicaciones de imagen arbitrarias.

De izquierda a derecha, a) máscara central, b) máscara de bloque aleatorio, c) máscara de región aleatoria

Aquí, solo implementé el método de enmascaramiento del área central, porque esta es solo una guía para comenzar con la restauración de pinturas con IA. ¡Puedes probar otros métodos de enmascaramiento y dejarme saber el resultado en los comentarios!

estructura

A estas alturas, debería tener cierta comprensión del modelo. Veamos si tienes razón.

El modelo consta de una parte codificadora y una parte decodificadora, que construyen la parte contextual del codificador del modelo. Esta parte también actúa como generador generando datos y tratando de engañar al discriminador. El discriminador consta de una red convolucional y una función sigmoidea que finalmente da como salida un escalar.

pérdida

La función de pérdida del modelo se divide en 2 partes:

1. Pérdida de reconstrucción: La pérdida de reconstrucción es la función de pérdida L2. Ayuda a capturar la estructura general de las regiones faltantes y la coherencia con respecto a su contexto. Matemáticamente se expresa como:

La advertencia aquí es que usar solo la pérdida L2 desenfocará la imagen. Debido a que la imagen borrosa reduce el error de píxel promedio, la pérdida de L2 es mínima, pero esto no es lo que queremos.

2. Pérdida adversaria: trata de hacer que las predicciones "parezcan" reales (¡recuerde que el generador debe ser capaz de engañar al discriminador!), lo que nos ayuda a superar las imágenes borrosas que la pérdida de L2 puede hacer que obtengamos. Matemáticamente, podemos expresarlo como:

Aquí hay una observación interesante: la pérdida adversaria anima a que todo el resultado parezca real, no solo las partes que faltan. En otras palabras, la red contradictoria le da a toda la imagen un aspecto realista.

La función de pérdida global:

¡Construyamos este modelo!

Ahora, ya que tenemos clara la esencia principal de la red, comencemos a construir el modelo. Comenzaré configurando la estructura del modelo y luego pasaré a las funciones de entrenamiento y pérdida. El modelo se construye utilizando PyTorch.

Comencemos generando la red:

import torch
from torch import nn
class generator(nn.Module):

    #generator model
    def __init__(self):
        super(generator,self).__init__()
        

        self.t1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
            nn.LeakyReLU(0.2,in_place=True)
        )
        
        self.t2=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,in_place=True)
        )
        self.t3=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,in_place=True)
        )
        self.t4=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,in_place=True)
        )
        self.t5=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,in_place=True)
            
        )
        self.t6=nn.Sequential(
            nn.Conv2d(512,4000,kernel_size=(4,4))#bottleneck
            nn.BatchNorm2d(4000),
            nn.ReLU()
        )
        self.t7=nn.Sequential(
            nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
            )
        self.t8=nn.Sequential(
            nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
            )
        self.t9=nn.Sequential(
            nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
            )
        self.t10=nn.Sequential(
            nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=(4,4),stride=2,padding=1),
            nn.Tanh()
            )
                
    def forward(self,x):
     x=self.t1(x)
     x=self.t2(x)
     x=self.t3(x)
     x=self.t4(x)
     x=self.t5(x)
     x=self.t6(x)
     x=self.t7(x)
     x=self.t8(x)
     x=self.t9(x)
     x=self.t10(x)
     return x #output of generator
Modelos de Generadores para Redes

Ahora, la red discriminadora:

import torch
from torch import nn
class discriminator(nn.Module):

    #discriminator model
    def __init__(self):
        super(discriminator,self).__init__()
        
        self.t1=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
            nn.LeakyReLU(0.2,in_place=True)
        )
        
        self.t2=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,in_place=True)
        )
        
        self.t3=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,in_place=True)
        )
        self.t4=nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,in_place=True)
        )
        self.t5=nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=1,kernel_size=(4,4),stride=1,padding=0),
            nn.Sigmoid()
        )        
    
    def forward(self,x):
     x=self.t1(x)
     x=self.t2(x)
     x=self.t3(x)
     x=self.t4(x)
     x=self.t5(x)
     return x #output of discriminator
red discriminatoria

Ahora comencemos a entrenar la red. Establecemos el tamaño del lote en 64 y el número de épocas en 100. La tasa de aprendizaje se establece en 0,0002.

from model import generator, discriminator
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

from model import _netlocalD,_netG
import utils
epochs=100
Batch_Size=64
lr=0.0002
beta1=0.5
over=4
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot',  default='dataset/train', help='path to dataset')
opt = parser.parse_args()
try:
    os.makedirs("result/train/cropped")
    os.makedirs("result/train/real")
    os.makedirs("result/train/recon")
    os.makedirs("model")
except OSError:
    pass

transform = transforms.Compose([transforms.Scale(128),
                                transforms.CenterCrop(128),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = dset.ImageFolder(root=opt.dataroot, transform=transform )
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=Batch_Size,
                                         shuffle=True, num_workers=2)

ngpu = int(opt.ngpu)

wtl2 = 0.999

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


resume_epoch=0

netG = generator()
netG.apply(weights_init)


netD = discriminator()
netD.apply(weights_init)

criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

input_real = torch.FloatTensor(Batch_Size, 3, 128, 128)
input_cropped = torch.FloatTensor(Batch_Size, 3, 128, 128)
label = torch.FloatTensor(Batch_Size)
real_label = 1
fake_label = 0

real_center = torch.FloatTensor(Batch_Size, 3, 64,64)


netD.cuda()
netG.cuda()
criterion.cuda()
criterionMSE.cuda()
input_real, input_cropped,label = input_real.cuda(),input_cropped.cuda(), label.cuda()
real_center = real_center.cuda()


input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)


real_center = Variable(real_center)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

for epoch in range(resume_epoch,epochs):
    for i, data in enumerate(dataloader, 0):
        real_cpu, _ = data
        real_center_cpu = real_cpu[:,:,int(128/4):int(128/4)+int(128/2),int(128/4):int(128/4)+int(128/2)]
        batch_size = real_cpu.size(0)
        with torch.no_grad():
            input_real.resize_(real_cpu.size()).copy_(real_cpu)
            input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
            real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
            input_cropped[:,0,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*117.0/255.0 - 1.0
            input_cropped[:,1,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*104.0/255.0 - 1.0
            input_cropped[:,2,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*123.0/255.0 - 1.0

        #start the discriminator by training with real data---
        netD.zero_grad()
        with torch.no_grad():
            label.resize_(batch_size).fill_(real_label)

        output = netD(real_center)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        # train the discriminator with fake data---
        fake = netG(input_cropped)
        label.data.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()


        #train the generator now---
        netG.zero_grad()
        label.data.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG_D = criterion(output, label)

        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(wtl2*10)
        wtl2Matrix.data[:,:,int(over):int(128/2 - over),int(over):int(128/2 - over)] = wtl2

        errG_l2 = (fake-real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix
        errG_l2 = errG_l2.mean()

        errG = (1-wtl2) * errG_D + wtl2 * errG_l2

        errG.backward()

        D_G_z2 = output.data.mean()
        optimizerG.step()

        print('[%d / %d][%d / %d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
              % (epoch, epochs, i, len(dataloader),
                 errD.data, errG_D.data,errG_l2.data, D_x,D_G_z1, ))

        if i % 100 == 0:

            vutils.save_image(real_cpu,
                    'result/train/real/real_samples_epoch_%03d.png' % (epoch))
            vutils.save_image(input_cropped.data,
                    'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
            recon_image = input_cropped.clone()
            recon_image.data[:,:,int(128/4):int(128/4+128/2),int(128/4):int(128/4+128/2)] = fake.data
            vutils.save_image(recon_image.data,
                    'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
Módulos de formación para la formación de generadores y discriminadores

resultado

Veamos qué puede construir nuestro modelo. Imagen (ruido) en la época 0:

En la época 100:

Veamos qué es la entrada al modelo:

-FIN-

Texto original en inglés: https://towardsdatascience.com/inpainting-with-ai-get-back-your-images-pytorch-a68f689128e5

Supongo que te gustará:

¡Más de 100 artículos! ¡Resumen de las ponencias GAN más completas en CVPR 2020!

Descarga adjunta | Versión en chino "Python Advanced"

Descarga adjunta | Versión china clásica "Think Python"

Descarga adjunta | "Tutorial Práctico de Entrenamiento del Modelo Pytorch"

Descarga adjunta | El último Li Mu 2020 "Aprendizaje profundo a mano"

Descarga adjunta | Versión en chino de "Aprendizaje automático explicable"

Descarga adjunta | "Algoritmos de aprendizaje profundo de TensorFlow 2.0 en la práctica"

Descarga adjunta | ¡Más de 100 artículos! ¡Resumen de las ponencias GAN más completas en CVPR 2020!

Descarga adjunta | "Métodos Matemáticos en Visión por Computador" compartir

Supongo que te gusta

Origin blog.csdn.net/lgzlgz3102/article/details/114465510#comments_26977746
Recomendado
Clasificación