pix2pix code

Look at the model diagram:
insert image description here
First define the generator G. Unlike CGAN, pix2pix does not input noise, but uses dropout to increase randomness. Then the generator inputs x and outputs y are some pictures. Finally, according to the original text, G is a U-Net shape. In addition to upsampling and downsampling, the most important thing is the jump connection.

import torch
import torch.nn as nn
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):#(1,3,256,256)
        d1 = self.initial_down(x)#(1,64,128,128)
        d2 = self.down1(d1)#(1,128,64,64)
        d3 = self.down2(d2)#(1,256,32,32)
        d4 = self.down3(d3)#(1,512,16,16)
        d5 = self.down4(d4)#(1,512,8,8)
        d6 = self.down5(d5)#(1,512,4,4)
        d7 = self.down6(d6)#(1,512,2,2)
        bottleneck = self.bottleneck(d7)#(1,512,1,1)
        up1 = self.up1(bottleneck)#(1,512,2,2)
        up2 = self.up2(torch.cat([up1, d7], 1))#(1,512,4,4)
        up3 = self.up3(torch.cat([up2, d6], 1))#(1,512,8,8)
        up4 = self.up4(torch.cat([up3, d5], 1))#(1,512,16,16)
        up5 = self.up5(torch.cat([up4, d4], 1))#(1,256,32,32)
        up6 = self.up6(torch.cat([up5, d3], 1))#(1,128,64,64)
        up7 = self.up7(torch.cat([up6, d2], 1))#(1,64,128,128)
        return self.final_up(torch.cat([up7, d1], 1))#(1,3,256,256)
def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)
if __name__ == "__main__":
    test()

Here, a tensor with the size of the real data set is randomly generated for verification.
First use the first convolution: unlike the common convolution, the convolution kernel size is 4, the padd is reflect, no BN is added, and LeakyReLU is added.
The original size is (1,3,256,256) after the first convolution and becomes ((1,64,128,128))
insert image description here
and then after 6 downs, which is the continuous downsampling of the encoder.
insert image description here
Look at one, and the others are the same.
insert image description here
Inside the Block: specify down, leakyrelu and dropout. If down is specified, a convolution with a step size of 2 is used for downsampling. If not specified, a transposed convolution is used, followed by BN and leakyrelu. Finally, dropout is not used in the encoder.
insert image description here
Between the encoder and decoder is the bottleneck. is a convolution plus relu.
insert image description here
insert image description here
It should be noted that the image channel transformation of the encoder: different from ResNet.
insert image description here
In the decoder, upsampling is first performed to concat with the corresponding layer of the encoder, otherwise the size is different and cannot be concat.
insert image description here
By setting the parameter down to False, the transposed convolution is used, the activation function is set to relu, and the first three layers of the decoder use dropout. These are different from the encoder.
insert image description here
Finally, after a transposed convolution and tanh, the final output is the same size as the original image.
The above code implements:
insert image description here
followed by the discriminator: it is known from the original paper that patchGAN is used. It is also implemented by convolution in the code.

import torch
import torch.nn as nn
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = features[0]#64
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)#(1,6,256,256)
        x = self.initial(x)#(1,64,128,128)
        x = self.model(x)
        return x


def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x, y)#(1,1,30,30)
    print(model)
    print(preds.shape)
if __name__ == "__main__":
    test()

There are two inputs to the discriminator D, because the essence is still CGAN, so one input is the generated image, and the other input is the condition, which is x.
It is divided into three steps. First, concat the condition and the generated image together, then increase the number of channels through a convolution, and finally pass through the discriminator.
insert image description here
1: concat
2: The spliced ​​channels are expanded to 64, with a step size of 2.
insert image description here
3: Traversing features, there are four convolutions in layers, using the form of CONV+BN+LeakyReLU, and the final output channel is 1. Output size It is 30x30.
insert image description here
insert image description here
insert image description hereDiscriminator model:

Sequential(
  (0): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (1): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (2): CNNBlock(
    (conv): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (3): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
)

Then look at the train:

import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator_model import Generator
from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = Generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    train_dataset = MapDataset(root_dir=config.TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = MapDataset(root_dir=config.VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="/home/Projects/ZQB/a/PyTorch-GAN-master/implementations/pix2pix-pytorch/results")


if __name__ == "__main__":
    main()

1: Instantiate the discriminator, generator, set optimizer and loss function.
insert image description here
2: Pass in the pre-trained weights:
insert image description here
define the data set: the data set we use to color the sketch.
insert image description here
Dataset structure: We load the pictures under the train file according to the set dataset location.
insert image description here
When we go to dataset, we mainly look at how to load and process data in getitem.

import numpy as np
import config
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image


class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        t = np.unique(image)
        print(t)
        input_image = image[:, :600, :]#(512,600,3)
        target_image = image[:, 600:, :]#(512,424,3)

        augmentations = config.both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]#(256,256,3)
        target_image = augmentations["image0"]#(256,256,3)

        input_image = config.transform_only_input(image=input_image)["image"]#(3,256,256)
        target_image = config.transform_only_mask(image=target_image)["image"]#(3,256,256)

        return input_image, target_image


if __name__ == "__main__":
    dataset = MapDataset("data/train/")
    loader = DataLoader(dataset, batch_size=5)
    for x, y in loader:
        print(x.shape)
        save_image(x, "x.png")
        save_image(y, "y.png")
        import sys

        sys.exit()

First, according to the index, we find the corresponding picture and read it in:
insert image description here
Then divide the picture: because the input and target of the original picture are connected together.
insert image description here
insert image description here
Split the picture:
input:
insert image description here
target:
insert image description here
Then cut the two pictures to 256x256. Then transform the two pictures:
insert image description here
that is, when mydataset is executed, the output is the input and the paired target.
Back to the train: load the pictures through the trainloader for training.
insert image description here
Similarly, load the pictures in the val folder for val.
insert image description here
Then there is formal training:
insert image description here
pass the model, data, optimizer, and loss function to the train:
insert image description here
in the train_fn function, first add a progress bar, and then input both the input and target into cuda.
Train the discriminator:
input the real x, y into the discriminator, we want the output to be 1, input x to the fake y generated by the generator and the real x (as a condition) into the discriminator, we hope to output 0.
insert image description here
Training generator: real x and fake y are input into D, and we hope that D cannot be distinguished, that is, the output is 1. There is also an L1 loss, which is the loss between the real label and the fake generation. Then the two losses are added together as the generator loss.
insert image description here
Next we save the weights of D and G.
Then save the picture:
insert image description here

Guess you like

Origin blog.csdn.net/qq_43733107/article/details/130711045