Notes d'apprentissage automatique - Traduction d'images avec Pix2Pix

1. Présentation de Pix2Pix

        La traduction d'images convient à une variété de tâches, de la simple amélioration et édition de photos à des tâches plus subtiles telles que les niveaux de gris au RVB. Par exemple, supposons que votre tâche consiste à augmenter l'image et que votre jeu de données est un ensemble d'images normales et de leurs homologues augmentées. Le but ici est d'apprendre un mappage efficace des images d'entrée à leurs homologues de sortie.

        Les auteurs de Pix2Pix s'appuient sur la méthode de base de calcul des mappages d'entrée-sortie et forment une fonction de perte supplémentaire pour améliorer ce mappage. Selon l' article Pix2Pix , leur méthode fonctionne bien sur une variété de tâches, y compris (mais sans s'y limiter) la synthèse de photos à partir de masques de segmentation.

        Démo 1 : Générer des chats à partir des bords

         Démo 2 : Génération de surfaces de construction

         En plus du GAN conditionnel, Pix2Pix mélange également la distance L1 (distance entre deux points) entre les images réelles et générées.

2. Générateur

        Pix2Pix utilise U-Net (ci-dessous) car il a des connexions de saut. Un U-Net est généralement caractérisé par son premier ensemble de couches de sous-échantillonnage, les couches de goulot d'étranglement, suivies de couches de suréchantillonnage. Le point clé à retenir ici est que les couches de sous-échantillonnage sont connectées aux couches de suréchantillonnage correspondantes, comme indiqué par les lignes en pointillés dans l'image ci-dessous.

        Code de référence 

# import the necessary packages
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras import Model
from tensorflow.keras import Input

class Pix2Pix(object):
	def __init__(self, imageHeight, imageWidth):
		# initialize the image height and width
		self.imageHeight = imageHeight
		self.imageWidth = imageWidth

	def generator(self):
		# initialize the input layer
		inputs = Input([self.imageHeight, self.imageWidth, 3])
  
		# down Layer 1 (d1) => final layer 1 (f1)
		d1 = Conv2D(32, (3, 3), activation="relu", padding="same")(
			inputs)
		d1 = Dropout(0.1)(d1)
		f1 = MaxPool2D((2, 2))(d1)

		# down Layer 2 (l2) => final layer 2 (f2)
		d2 = Conv2D(64, (3, 3), activation="relu", padding="same")(f1)
		f2 = MaxPool2D((2, 2))(d2)

		#  down Layer 3 (l3) => final layer 3 (f3)
		d3 = Conv2D(96, (3, 3), activation="relu", padding="same")(f2)
		f3 = MaxPool2D((2, 2))(d3)

		# down Layer 4 (l3) => final layer 4 (f4)
		d4 = Conv2D(96, (3, 3), activation="relu", padding="same")(f3)
		f4 = MaxPool2D((2, 2))(d4)

		# u-bend of the u-bet
		b5 = Conv2D(96, (3, 3), activation="relu", padding="same")(f4)
		b5 = Dropout(0.3)(b5)
		b5 = Conv2D(256, (3, 3), activation="relu", padding="same")(b5)

		# upsample Layer 6 (u6)
		u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2),
			padding="same")(b5)
		u6 = concatenate([u6, d4])
		u6 = Conv2D(128, (3, 3), activation="relu", padding="same")(
			u6)

		# upsample Layer 7 (u7)
		u7 = Conv2DTranspose(96, (2, 2), strides=(2, 2),
			padding="same")(u6)
		u7 = concatenate([u7, d3])
		u7 = Conv2D(128, (3, 3), activation="relu", padding="same")(
			u7)

		# upsample Layer 8 (u8)
		u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2),
			padding="same")(u7)
		u8 = concatenate([u8, d2])
		u8 = Conv2D(128, (3, 3), activation="relu", padding="same")(u8)

		# upsample Layer 9 (u9)
		u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2),
			padding="same")(u8)
		u9 = concatenate([u9, d1])
		u9 = Dropout(0.1)(u9)
		u9 = Conv2D(128, (3, 3), activation="relu", padding="same")(u9)

		# final conv2D layer
		outputLayer = Conv2D(3, (1, 1), activation="tanh")(u9)
	
		# create the generator model
		generator = Model(inputs, outputLayer)

		# return the generator
		return generator

3. Discriminateur

        Le discriminateur est un discriminateur Patch GAN. Un discriminateur GAN normal prend une image en entrée et sort une seule valeur de 0 (faux) ou 1 (vrai). Le discriminateur patch GAN analyse l'entrée en tant que patchs d'image locaux. Il évaluera si chaque patch de l'image est réel ou faux.

        Code de référence 

	def discriminator(self):
		# initialize input layer according to PatchGAN
		inputMask = Input(shape=[self.imageHeight, self.imageWidth, 3], 
			name="input_image"
		)
		targetImage = Input(
			shape=[self.imageHeight, self.imageWidth, 3], 
			name="target_image"
		)
  
		# concatenate the inputs
		x = concatenate([inputMask, targetImage])  

		# add four conv2D convolution layers
		x = Conv2D(64, 4, strides=2, padding="same")(x)  
		x = LeakyReLU()(x)
		x = Conv2D(128, 4, strides=2, padding="same")(x)
		x = LeakyReLU()(x)  
		x = Conv2D(256, 4, strides=2, padding="same")(x)
		x = LeakyReLU()(x)   
		x = Conv2D(512, 4, strides=1, padding="same")(x)  

		# add a batch-normalization layer => LeakyReLU => zeropad
		x = BatchNormalization()(x)
		x = LeakyReLU()(x)

		# final conv layer
		last = Conv2D(1, 3, strides=1)(x) 
  
		# create the discriminator model
		discriminator = Model(inputs=[inputMask, targetImage],
			outputs=last)

		# return the discriminator
		return discriminator

Quatrièmement, le processus de formation

        Dans Pix2Pix, Patch GAN recevra une paire d'images : masque de saisie et image générée et masque de saisie et image cible. C'est parce que la sortie dépend de l'entrée. Par conséquent, il est important de conserver l'image d'entrée dans le mixage ( comme illustré dans la figure ci- dessous, où le discriminateur prend deux entrées).

5. Code complet

ml_toolset/case 100 Traduction d'images avec Pix2Pix sur le GitHub principal bashendixie/ml_toolset Contribuez au développement de bashendixie/ml_toolset en créant un compte sur GitHub. https://github.com/bashendixie/ml_toolset/tree/main/%E6%A1 %88% E4%BE%8B100%20%E4%BD%BF%E7%94%A8Pix2Pix%E8%BF%9B%E8%A1%8C%E5%9B%BE%E5%83%8F%E7%BF %BB% E8%AF%91

6. Références connexes

Traduction d'image à image avec des réseaux contradictoires conditionnels https://phillipi.github.io/pix2pix/

Je suppose que tu aimes

Origine blog.csdn.net/bashendixie5/article/details/127177800
conseillé
Classement