Machine Learning Notes - Image Translation with Pix2Pix

1. Overview of Pix2Pix

        Image translation is suitable for a variety of tasks, from simple photo enhancement and editing to more subtle tasks such as grayscale to RGB. For example, let's say your task is image augmentation and your dataset is a set of normal images and their augmented counterparts. The goal here is to learn an efficient mapping of input images to their output counterparts.

        The authors of Pix2Pix build on the base method of computing input-output mappings and train an additional loss function to enhance this mapping. According to the Pix2Pix paper , their method works well on a variety of tasks including (but not limited to) synthesizing photos from segmentation masks.

        Demo 1: Generating cats from edges

         Demo 2: Generating Building Surfaces

         On top of the conditional GAN, Pix2Pix also mixes the L1 distance (distance between two points) between real and generated images.

2. Generator

        Pix2Pix uses U-Net (below) because it has skip connections. A U-Net is typically characterized by its first set of downsampling layers, the bottleneck layers, followed by upsampling layers. The key point to remember here is that the downsampling layers are connected to the corresponding upsampling layers, as shown by the dashed lines in the image below.

        Reference Code 

# import the necessary packages
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras import Model
from tensorflow.keras import Input

class Pix2Pix(object):
	def __init__(self, imageHeight, imageWidth):
		# initialize the image height and width
		self.imageHeight = imageHeight
		self.imageWidth = imageWidth

	def generator(self):
		# initialize the input layer
		inputs = Input([self.imageHeight, self.imageWidth, 3])
  
		# down Layer 1 (d1) => final layer 1 (f1)
		d1 = Conv2D(32, (3, 3), activation="relu", padding="same")(
			inputs)
		d1 = Dropout(0.1)(d1)
		f1 = MaxPool2D((2, 2))(d1)

		# down Layer 2 (l2) => final layer 2 (f2)
		d2 = Conv2D(64, (3, 3), activation="relu", padding="same")(f1)
		f2 = MaxPool2D((2, 2))(d2)

		#  down Layer 3 (l3) => final layer 3 (f3)
		d3 = Conv2D(96, (3, 3), activation="relu", padding="same")(f2)
		f3 = MaxPool2D((2, 2))(d3)

		# down Layer 4 (l3) => final layer 4 (f4)
		d4 = Conv2D(96, (3, 3), activation="relu", padding="same")(f3)
		f4 = MaxPool2D((2, 2))(d4)

		# u-bend of the u-bet
		b5 = Conv2D(96, (3, 3), activation="relu", padding="same")(f4)
		b5 = Dropout(0.3)(b5)
		b5 = Conv2D(256, (3, 3), activation="relu", padding="same")(b5)

		# upsample Layer 6 (u6)
		u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2),
			padding="same")(b5)
		u6 = concatenate([u6, d4])
		u6 = Conv2D(128, (3, 3), activation="relu", padding="same")(
			u6)

		# upsample Layer 7 (u7)
		u7 = Conv2DTranspose(96, (2, 2), strides=(2, 2),
			padding="same")(u6)
		u7 = concatenate([u7, d3])
		u7 = Conv2D(128, (3, 3), activation="relu", padding="same")(
			u7)

		# upsample Layer 8 (u8)
		u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2),
			padding="same")(u7)
		u8 = concatenate([u8, d2])
		u8 = Conv2D(128, (3, 3), activation="relu", padding="same")(u8)

		# upsample Layer 9 (u9)
		u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2),
			padding="same")(u8)
		u9 = concatenate([u9, d1])
		u9 = Dropout(0.1)(u9)
		u9 = Conv2D(128, (3, 3), activation="relu", padding="same")(u9)

		# final conv2D layer
		outputLayer = Conv2D(3, (1, 1), activation="tanh")(u9)
	
		# create the generator model
		generator = Model(inputs, outputLayer)

		# return the generator
		return generator

3. Discriminator

        The discriminator is a Patch GAN discriminator. A normal GAN ​​discriminator takes an image as input and outputs a single value of 0 (false) or 1 (true). The patch GAN discriminator analyzes the input as local image patches. It will evaluate whether each patch in the image is real or fake.

        Reference Code 

	def discriminator(self):
		# initialize input layer according to PatchGAN
		inputMask = Input(shape=[self.imageHeight, self.imageWidth, 3], 
			name="input_image"
		)
		targetImage = Input(
			shape=[self.imageHeight, self.imageWidth, 3], 
			name="target_image"
		)
  
		# concatenate the inputs
		x = concatenate([inputMask, targetImage])  

		# add four conv2D convolution layers
		x = Conv2D(64, 4, strides=2, padding="same")(x)  
		x = LeakyReLU()(x)
		x = Conv2D(128, 4, strides=2, padding="same")(x)
		x = LeakyReLU()(x)  
		x = Conv2D(256, 4, strides=2, padding="same")(x)
		x = LeakyReLU()(x)   
		x = Conv2D(512, 4, strides=1, padding="same")(x)  

		# add a batch-normalization layer => LeakyReLU => zeropad
		x = BatchNormalization()(x)
		x = LeakyReLU()(x)

		# final conv layer
		last = Conv2D(1, 3, strides=1)(x) 
  
		# create the discriminator model
		discriminator = Model(inputs=[inputMask, targetImage],
			outputs=last)

		# return the discriminator
		return discriminator

Fourth, the training process

        In Pix2Pix, Patch GAN will receive a pair of images: input mask and generated image and input mask and target image. This is because the output depends on the input. Therefore, it is important to keep the input image in the mix ( as shown in the figure below, where the discriminator takes two inputs).

5. Complete code

ml_toolset/case 100 Image translation with Pix2Pix at main bashendixie/ml_toolset GitHub Contribute to bashendixie/ml_toolset development by creating an account on GitHub. https://github.com/bashendixie/ml_toolset/tree/main/%E6%A1 %88%E4%BE%8B100%20%E4%BD%BF%E7%94%A8Pix2Pix%E8%BF%9B%E8%A1%8C%E5%9B%BE%E5%83%8F%E7%BF %BB%E8%AF%91

6. Related References

Image-to-Image Translation with Conditional Adversarial Networkshttps://phillipi.github.io/pix2pix/

Guess you like

Origin blog.csdn.net/bashendixie5/article/details/127177800