1. Présentation de Pix2Pix
La traduction d'images convient à une variété de tâches, de la simple amélioration et édition de photos à des tâches plus subtiles telles que les niveaux de gris au RVB. Par exemple, supposons que votre tâche consiste à augmenter l'image et que votre jeu de données est un ensemble d'images normales et de leurs homologues augmentées. Le but ici est d'apprendre un mappage efficace des images d'entrée à leurs homologues de sortie.
Les auteurs de Pix2Pix s'appuient sur la méthode de base de calcul des mappages d'entrée-sortie et forment une fonction de perte supplémentaire pour améliorer ce mappage. Selon l' article Pix2Pix , leur méthode fonctionne bien sur une variété de tâches, y compris (mais sans s'y limiter) la synthèse de photos à partir de masques de segmentation.
Démo 1 : Générer des chats à partir des bords
Démo 2 : Génération de surfaces de construction
En plus du GAN conditionnel, Pix2Pix mélange également la distance L1 (distance entre deux points) entre les images réelles et générées.
2. Générateur
Pix2Pix utilise U-Net (ci-dessous) car il a des connexions de saut. Un U-Net est généralement caractérisé par son premier ensemble de couches de sous-échantillonnage, les couches de goulot d'étranglement, suivies de couches de suréchantillonnage. Le point clé à retenir ici est que les couches de sous-échantillonnage sont connectées aux couches de suréchantillonnage correspondantes, comme indiqué par les lignes en pointillés dans l'image ci-dessous.
Code de référence
# import the necessary packages
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras import Model
from tensorflow.keras import Input
class Pix2Pix(object):
def __init__(self, imageHeight, imageWidth):
# initialize the image height and width
self.imageHeight = imageHeight
self.imageWidth = imageWidth
def generator(self):
# initialize the input layer
inputs = Input([self.imageHeight, self.imageWidth, 3])
# down Layer 1 (d1) => final layer 1 (f1)
d1 = Conv2D(32, (3, 3), activation="relu", padding="same")(
inputs)
d1 = Dropout(0.1)(d1)
f1 = MaxPool2D((2, 2))(d1)
# down Layer 2 (l2) => final layer 2 (f2)
d2 = Conv2D(64, (3, 3), activation="relu", padding="same")(f1)
f2 = MaxPool2D((2, 2))(d2)
# down Layer 3 (l3) => final layer 3 (f3)
d3 = Conv2D(96, (3, 3), activation="relu", padding="same")(f2)
f3 = MaxPool2D((2, 2))(d3)
# down Layer 4 (l3) => final layer 4 (f4)
d4 = Conv2D(96, (3, 3), activation="relu", padding="same")(f3)
f4 = MaxPool2D((2, 2))(d4)
# u-bend of the u-bet
b5 = Conv2D(96, (3, 3), activation="relu", padding="same")(f4)
b5 = Dropout(0.3)(b5)
b5 = Conv2D(256, (3, 3), activation="relu", padding="same")(b5)
# upsample Layer 6 (u6)
u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2),
padding="same")(b5)
u6 = concatenate([u6, d4])
u6 = Conv2D(128, (3, 3), activation="relu", padding="same")(
u6)
# upsample Layer 7 (u7)
u7 = Conv2DTranspose(96, (2, 2), strides=(2, 2),
padding="same")(u6)
u7 = concatenate([u7, d3])
u7 = Conv2D(128, (3, 3), activation="relu", padding="same")(
u7)
# upsample Layer 8 (u8)
u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2),
padding="same")(u7)
u8 = concatenate([u8, d2])
u8 = Conv2D(128, (3, 3), activation="relu", padding="same")(u8)
# upsample Layer 9 (u9)
u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2),
padding="same")(u8)
u9 = concatenate([u9, d1])
u9 = Dropout(0.1)(u9)
u9 = Conv2D(128, (3, 3), activation="relu", padding="same")(u9)
# final conv2D layer
outputLayer = Conv2D(3, (1, 1), activation="tanh")(u9)
# create the generator model
generator = Model(inputs, outputLayer)
# return the generator
return generator
3. Discriminateur
Le discriminateur est un discriminateur Patch GAN. Un discriminateur GAN normal prend une image en entrée et sort une seule valeur de 0 (faux) ou 1 (vrai). Le discriminateur patch GAN analyse l'entrée en tant que patchs d'image locaux. Il évaluera si chaque patch de l'image est réel ou faux.
Code de référence
def discriminator(self):
# initialize input layer according to PatchGAN
inputMask = Input(shape=[self.imageHeight, self.imageWidth, 3],
name="input_image"
)
targetImage = Input(
shape=[self.imageHeight, self.imageWidth, 3],
name="target_image"
)
# concatenate the inputs
x = concatenate([inputMask, targetImage])
# add four conv2D convolution layers
x = Conv2D(64, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(256, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(512, 4, strides=1, padding="same")(x)
# add a batch-normalization layer => LeakyReLU => zeropad
x = BatchNormalization()(x)
x = LeakyReLU()(x)
# final conv layer
last = Conv2D(1, 3, strides=1)(x)
# create the discriminator model
discriminator = Model(inputs=[inputMask, targetImage],
outputs=last)
# return the discriminator
return discriminator
Quatrièmement, le processus de formation
Dans Pix2Pix, Patch GAN recevra une paire d'images : masque de saisie et image générée et masque de saisie et image cible. C'est parce que la sortie dépend de l'entrée. Par conséquent, il est important de conserver l'image d'entrée dans le mixage ( comme illustré dans la figure ci- dessous, où le discriminateur prend deux entrées).