Look at the model diagram:
First define the generator G. Unlike CGAN, pix2pix does not input noise, but uses dropout to increase randomness. Then the generator inputs x and outputs y are some pictures. Finally, according to the original text, G is a U-Net shape. In addition to upsampling and downsampling, the most important thing is the jump connection.
import torch
import torch.nn as nn
class Block(nn.Module):
def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
super(Block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
if down
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(0.5)
self.down = down
def forward(self, x):
x = self.conv(x)
return self.dropout(x) if self.use_dropout else x
class Generator(nn.Module):
def __init__(self, in_channels=3, features=64):
super().__init__()
self.initial_down = nn.Sequential(
nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
nn.LeakyReLU(0.2),
)
self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
self.down2 = Block(
features * 2, features * 4, down=True, act="leaky", use_dropout=False
)
self.down3 = Block(
features * 4, features * 8, down=True, act="leaky", use_dropout=False
)
self.down4 = Block(
features * 8, features * 8, down=True, act="leaky", use_dropout=False
)
self.down5 = Block(
features * 8, features * 8, down=True, act="leaky", use_dropout=False
)
self.down6 = Block(
features * 8, features * 8, down=True, act="leaky", use_dropout=False
)
self.bottleneck = nn.Sequential(
nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
)
self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
self.up2 = Block(
features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
)
self.up3 = Block(
features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
)
self.up4 = Block(
features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
)
self.up5 = Block(
features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
)
self.up6 = Block(
features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
)
self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
self.final_up = nn.Sequential(
nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x):#(1,3,256,256)
d1 = self.initial_down(x)#(1,64,128,128)
d2 = self.down1(d1)#(1,128,64,64)
d3 = self.down2(d2)#(1,256,32,32)
d4 = self.down3(d3)#(1,512,16,16)
d5 = self.down4(d4)#(1,512,8,8)
d6 = self.down5(d5)#(1,512,4,4)
d7 = self.down6(d6)#(1,512,2,2)
bottleneck = self.bottleneck(d7)#(1,512,1,1)
up1 = self.up1(bottleneck)#(1,512,2,2)
up2 = self.up2(torch.cat([up1, d7], 1))#(1,512,4,4)
up3 = self.up3(torch.cat([up2, d6], 1))#(1,512,8,8)
up4 = self.up4(torch.cat([up3, d5], 1))#(1,512,16,16)
up5 = self.up5(torch.cat([up4, d4], 1))#(1,256,32,32)
up6 = self.up6(torch.cat([up5, d3], 1))#(1,128,64,64)
up7 = self.up7(torch.cat([up6, d2], 1))#(1,64,128,128)
return self.final_up(torch.cat([up7, d1], 1))#(1,3,256,256)
def test():
x = torch.randn((1, 3, 256, 256))
model = Generator(in_channels=3, features=64)
preds = model(x)
print(preds.shape)
if __name__ == "__main__":
test()
Here, a tensor with the size of the real data set is randomly generated for verification.
First use the first convolution: unlike the common convolution, the convolution kernel size is 4, the padd is reflect, no BN is added, and LeakyReLU is added.
The original size is (1,3,256,256) after the first convolution and becomes ((1,64,128,128))
and then after 6 downs, which is the continuous downsampling of the encoder.
Look at one, and the others are the same.
Inside the Block: specify down, leakyrelu and dropout. If down is specified, a convolution with a step size of 2 is used for downsampling. If not specified, a transposed convolution is used, followed by BN and leakyrelu. Finally, dropout is not used in the encoder.
Between the encoder and decoder is the bottleneck. is a convolution plus relu.
It should be noted that the image channel transformation of the encoder: different from ResNet.
In the decoder, upsampling is first performed to concat with the corresponding layer of the encoder, otherwise the size is different and cannot be concat.
By setting the parameter down to False, the transposed convolution is used, the activation function is set to relu, and the first three layers of the decoder use dropout. These are different from the encoder.
Finally, after a transposed convolution and tanh, the final output is the same size as the original image.
The above code implements:
followed by the discriminator: it is known from the original paper that patchGAN is used. It is also implemented by convolution in the code.
import torch
import torch.nn as nn
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(CNNBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.conv(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super().__init__()
self.initial = nn.Sequential(
nn.Conv2d(
in_channels * 2,
features[0],
kernel_size=4,
stride=2,
padding=1,
padding_mode="reflect",
),
nn.LeakyReLU(0.2),
)
layers = []
in_channels = features[0]#64
for feature in features[1:]:
layers.append(
CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
)
in_channels = feature
layers.append(
nn.Conv2d(
in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
),
)
self.model = nn.Sequential(*layers)
def forward(self, x, y):
x = torch.cat([x, y], dim=1)#(1,6,256,256)
x = self.initial(x)#(1,64,128,128)
x = self.model(x)
return x
def test():
x = torch.randn((1, 3, 256, 256))
y = torch.randn((1, 3, 256, 256))
model = Discriminator(in_channels=3)
preds = model(x, y)#(1,1,30,30)
print(model)
print(preds.shape)
if __name__ == "__main__":
test()
There are two inputs to the discriminator D, because the essence is still CGAN, so one input is the generated image, and the other input is the condition, which is x.
It is divided into three steps. First, concat the condition and the generated image together, then increase the number of channels through a convolution, and finally pass through the discriminator.
1: concat
2: The spliced channels are expanded to 64, with a step size of 2.
3: Traversing features, there are four convolutions in layers, using the form of CONV+BN+LeakyReLU, and the final output channel is 1. Output size It is 30x30.
Discriminator model:
Sequential(
(0): CNNBlock(
(conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
)
(1): CNNBlock(
(conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
)
(2): CNNBlock(
(conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2)
)
)
(3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
)
Then look at the train:
import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator_model import Generator
from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
torch.backends.cudnn.benchmark = True
def train_fn(
disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
loop = tqdm(loader, leave=True)
for idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y = y.to(config.DEVICE)
# Train Discriminator
with torch.cuda.amp.autocast():
y_fake = gen(x)
D_real = disc(x, y)
D_real_loss = bce(D_real, torch.ones_like(D_real))
D_fake = disc(x, y_fake.detach())
D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
D_loss = (D_real_loss + D_fake_loss) / 2
disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train generator
with torch.cuda.amp.autocast():
D_fake = disc(x, y_fake)
G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
G_loss = G_fake_loss + L1
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 10 == 0:
loop.set_postfix(
D_real=torch.sigmoid(D_real).mean().item(),
D_fake=torch.sigmoid(D_fake).mean().item(),
)
def main():
disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=3, features=64).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
)
train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_dataset = MapDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
for epoch in range(config.NUM_EPOCHS):
train_fn(
disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
)
if config.SAVE_MODEL and epoch % 5 == 0:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
save_some_examples(gen, val_loader, epoch, folder="/home/Projects/ZQB/a/PyTorch-GAN-master/implementations/pix2pix-pytorch/results")
if __name__ == "__main__":
main()
1: Instantiate the discriminator, generator, set optimizer and loss function.
2: Pass in the pre-trained weights:
define the data set: the data set we use to color the sketch.
Dataset structure: We load the pictures under the train file according to the set dataset location.
When we go to dataset, we mainly look at how to load and process data in getitem.
import numpy as np
import config
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
class MapDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.list_files = os.listdir(self.root_dir)
def __len__(self):
return len(self.list_files)
def __getitem__(self, index):
img_file = self.list_files[index]
img_path = os.path.join(self.root_dir, img_file)
image = np.array(Image.open(img_path))
t = np.unique(image)
print(t)
input_image = image[:, :600, :]#(512,600,3)
target_image = image[:, 600:, :]#(512,424,3)
augmentations = config.both_transform(image=input_image, image0=target_image)
input_image = augmentations["image"]#(256,256,3)
target_image = augmentations["image0"]#(256,256,3)
input_image = config.transform_only_input(image=input_image)["image"]#(3,256,256)
target_image = config.transform_only_mask(image=target_image)["image"]#(3,256,256)
return input_image, target_image
if __name__ == "__main__":
dataset = MapDataset("data/train/")
loader = DataLoader(dataset, batch_size=5)
for x, y in loader:
print(x.shape)
save_image(x, "x.png")
save_image(y, "y.png")
import sys
sys.exit()
First, according to the index, we find the corresponding picture and read it in:
Then divide the picture: because the input and target of the original picture are connected together.
Split the picture:
input:
target:
Then cut the two pictures to 256x256. Then transform the two pictures:
that is, when mydataset is executed, the output is the input and the paired target.
Back to the train: load the pictures through the trainloader for training.
Similarly, load the pictures in the val folder for val.
Then there is formal training:
pass the model, data, optimizer, and loss function to the train:
in the train_fn function, first add a progress bar, and then input both the input and target into cuda.
Train the discriminator:
input the real x, y into the discriminator, we want the output to be 1, input x to the fake y generated by the generator and the real x (as a condition) into the discriminator, we hope to output 0.
Training generator: real x and fake y are input into D, and we hope that D cannot be distinguished, that is, the output is 1. There is also an L1 loss, which is the loss between the real label and the fake generation. Then the two losses are added together as the generator loss.
Next we save the weights of D and G.
Then save the picture: