1. Overview of Pix2Pix
Image translation is suitable for a variety of tasks, from simple photo enhancement and editing to more subtle tasks such as grayscale to RGB. For example, let's say your task is image augmentation and your dataset is a set of normal images and their augmented counterparts. The goal here is to learn an efficient mapping of input images to their output counterparts.
The authors of Pix2Pix build on the base method of computing input-output mappings and train an additional loss function to enhance this mapping. According to the Pix2Pix paper , their method works well on a variety of tasks including (but not limited to) synthesizing photos from segmentation masks.
Demo 1: Generating cats from edges
Demo 2: Generating Building Surfaces
On top of the conditional GAN, Pix2Pix also mixes the L1 distance (distance between two points) between real and generated images.
2. Generator
Pix2Pix uses U-Net (below) because it has skip connections. A U-Net is typically characterized by its first set of downsampling layers, the bottleneck layers, followed by upsampling layers. The key point to remember here is that the downsampling layers are connected to the corresponding upsampling layers, as shown by the dashed lines in the image below.
Reference Code
# import the necessary packages
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras import Model
from tensorflow.keras import Input
class Pix2Pix(object):
def __init__(self, imageHeight, imageWidth):
# initialize the image height and width
self.imageHeight = imageHeight
self.imageWidth = imageWidth
def generator(self):
# initialize the input layer
inputs = Input([self.imageHeight, self.imageWidth, 3])
# down Layer 1 (d1) => final layer 1 (f1)
d1 = Conv2D(32, (3, 3), activation="relu", padding="same")(
inputs)
d1 = Dropout(0.1)(d1)
f1 = MaxPool2D((2, 2))(d1)
# down Layer 2 (l2) => final layer 2 (f2)
d2 = Conv2D(64, (3, 3), activation="relu", padding="same")(f1)
f2 = MaxPool2D((2, 2))(d2)
# down Layer 3 (l3) => final layer 3 (f3)
d3 = Conv2D(96, (3, 3), activation="relu", padding="same")(f2)
f3 = MaxPool2D((2, 2))(d3)
# down Layer 4 (l3) => final layer 4 (f4)
d4 = Conv2D(96, (3, 3), activation="relu", padding="same")(f3)
f4 = MaxPool2D((2, 2))(d4)
# u-bend of the u-bet
b5 = Conv2D(96, (3, 3), activation="relu", padding="same")(f4)
b5 = Dropout(0.3)(b5)
b5 = Conv2D(256, (3, 3), activation="relu", padding="same")(b5)
# upsample Layer 6 (u6)
u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2),
padding="same")(b5)
u6 = concatenate([u6, d4])
u6 = Conv2D(128, (3, 3), activation="relu", padding="same")(
u6)
# upsample Layer 7 (u7)
u7 = Conv2DTranspose(96, (2, 2), strides=(2, 2),
padding="same")(u6)
u7 = concatenate([u7, d3])
u7 = Conv2D(128, (3, 3), activation="relu", padding="same")(
u7)
# upsample Layer 8 (u8)
u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2),
padding="same")(u7)
u8 = concatenate([u8, d2])
u8 = Conv2D(128, (3, 3), activation="relu", padding="same")(u8)
# upsample Layer 9 (u9)
u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2),
padding="same")(u8)
u9 = concatenate([u9, d1])
u9 = Dropout(0.1)(u9)
u9 = Conv2D(128, (3, 3), activation="relu", padding="same")(u9)
# final conv2D layer
outputLayer = Conv2D(3, (1, 1), activation="tanh")(u9)
# create the generator model
generator = Model(inputs, outputLayer)
# return the generator
return generator
3. Discriminator
The discriminator is a Patch GAN discriminator. A normal GAN discriminator takes an image as input and outputs a single value of 0 (false) or 1 (true). The patch GAN discriminator analyzes the input as local image patches. It will evaluate whether each patch in the image is real or fake.
Reference Code
def discriminator(self):
# initialize input layer according to PatchGAN
inputMask = Input(shape=[self.imageHeight, self.imageWidth, 3],
name="input_image"
)
targetImage = Input(
shape=[self.imageHeight, self.imageWidth, 3],
name="target_image"
)
# concatenate the inputs
x = concatenate([inputMask, targetImage])
# add four conv2D convolution layers
x = Conv2D(64, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(256, 4, strides=2, padding="same")(x)
x = LeakyReLU()(x)
x = Conv2D(512, 4, strides=1, padding="same")(x)
# add a batch-normalization layer => LeakyReLU => zeropad
x = BatchNormalization()(x)
x = LeakyReLU()(x)
# final conv layer
last = Conv2D(1, 3, strides=1)(x)
# create the discriminator model
discriminator = Model(inputs=[inputMask, targetImage],
outputs=last)
# return the discriminator
return discriminator
Fourth, the training process
In Pix2Pix, Patch GAN will receive a pair of images: input mask and generated image and input mask and target image. This is because the output depends on the input. Therefore, it is important to keep the input image in the mix ( as shown in the figure below, where the discriminator takes two inputs).
5. Complete code
6. Related References
Image-to-Image Translation with Conditional Adversarial Networkshttps://phillipi.github.io/pix2pix/