Generative confrontation network CycleGAN

1 Introduction

论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

Paper address: https://arxiv.org/abs/1703.10593

What is CycleGAN: CycleGAN is mainly used for the conversion between images, assuming there are two unpaired images X and Y, the algorithm is trained to learn an "automatic mutual conversion", no paired paired samples are required for training, only the source domain and images of the target domain. After training, the network can realize the migration of the image source domain to the target domain. CycleGAN is suitable for unpaired image-to-image translation, which solves the difficulty that the model needs paired data for training.

The difference with pix2pixGAN: Both can do image transformation, the pix2pix model must require paired data (paired data), and CycleGAN can also use unpaired data for training (unpaired data).

 2. Cycle-GAN network architecture

Related work:

BYDCGANCGANpix2pixGAN

CycleGAN is actually a one-way GAN of A→B plus a one-way GAN of B→A. The two GANs share two generators, and each has a discriminator, so there are two discriminators and two generators in total. A one-way GAN has two losses, so CycleGAN adds up to four losses in total.

Cyclic consistency loss: Because the network needs to ensure that the generated image must retain the characteristics of the original image, so if we use the generator GenratorA-B to generate a fake image, then we must be able to use another generator GenratorB-A to try to restore it to the original image. This process must satisfy cycle consistency.

Identity loss: It can be understood that the generator is responsible for image generation from domain x to domain y. If the image of domain y is input, the image of domain y should still be generated.

# 用狗的图像生成猫的图像
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from PIL import Image
import tqdm
import glob

dogs_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Dog/*.jpg') #获取数据集中的.jpg图片
cats_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Cat/*.jpg') #获取数据集中的.jpg图片
# print(cats_path[:3])
# print(dogs_path[:3])
cats_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Cat/*.jpg') #获取数据集中的.jpg图片
dogs_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Dog/*.jpg') #获取数据集中的.jpg图片

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((256, 256)),
                                transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间

# 定义数据读取
class SGANDataset(Dataset):
    def __init__(self, imgs_path): #初始化
        super(SGANDataset, self).__init__()
        self.imgs_path     = imgs_path #定义属性

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index): #对数据切片
        img_path        = self.imgs_path[index]

        # 从文件中读取图像
        pil_img         = Image.open(img_path)
        pil_img         = transform(pil_img)
        return pil_img

# 初始化训练集
dog_dataset = SGANDataset(dogs_path) #创建dataset
cat_dataset = SGANDataset(cats_path) #创建dataset

# 初始化测试集
dog_dataset_test = SGANDataset(dogs_path_test) #创建dataset
cat_dataset_test = SGANDataset(cats_path_test) #创建dataset

dog_dataloader = torch.utils.data.DataLoader(dog_dataset, batch_size=4, shuffle=True)
cat_dataloader = torch.utils.data.DataLoader(cat_dataset, batch_size=4, shuffle=True)

dog_dataloader_test = torch.utils.data.DataLoader(dog_dataset_test, batch_size=4)
cat_dataloader_test = torch.utils.data.DataLoader(cat_dataset_test, batch_size=4)

# cat_bath = next(iter(cat_dataloader)) #查看
# dog_bath = next(iter(dog_dataloader)) #查看
# print(dog_bath.shape) #torch.Size([4, 3, 256, 256])
# print(cat_bath.shape) #torch.Size([4, 3, 256, 256])

# 查看数据集
# plt.figure(figsize=(8, 12))
# for i, (dog, cat) in enumerate(zip(dog_bath[:3], cat_bath[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     dog = (dog.permute(1, 2, 0).numpy() + 1) / 2
#     cat = (cat.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('dog')
#     plt.imshow(dog)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('cat')
#     plt.imshow(cat)
# plt.show()


#定义下采样模块
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
        x = self.conv_relu(x)
        if is_bn:
            x = self.bn(x)
        return x

#定义上采样模块
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(inplace=True)
        )
        self.bn = nn.InstanceNorm2d(out_channels)

    def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
        x = self.upconv_relu(x)
        x = self.bn(x)
        if is_drop:
            x = F.dropout2d(x)
        return x

# 定义生成器,包含6个下采样层,6个上采样层
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)     #3,256,256 -- 64,128,128
        self.down2 = Downsample(64, 128)   #64,128,128 -- 128,64,64
        self.down3 = Downsample(128, 256)  #128,64,64 -- 256,32,32
        self.down4 = Downsample(256, 512)  #256,32,32 -- 512,16,16
        self.down5 = Downsample(512, 512)  #512,16,16 -- 512,8,8
        self.down6 = Downsample(512, 512)  #512,8,8 -- 512,4,4

        self.up1 = Upsample(512, 512)      #512,4,4 -- 512,8,8
        self.up2 = Upsample(1024, 512)     #1024,8,8 -- 512,16,16
        self.up3 = Upsample(1024, 256)     #1024,16,16 -- 256,32,32
        self.up4 = Upsample(512, 128)      #512,32,32 -- 128,64,64
        self.up5 = Upsample(256, 64)       #256,64,64 -- 64,128,128
        #128,128,128 -- 3,256,256
        self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)

        x6 = self.up1(x6, is_drop=True)
        x6 = torch.cat([x6, x5], dim=1)

        x6 = self.up2(x6, is_drop=True)
        x6 = torch.cat([x6, x4], dim=1)

        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x6, x3], dim=1)

        x6 = self.up4(x6)
        x6 = torch.cat([x6, x2], dim=1)

        x6 = self.up5(x6)
        x6 = torch.cat([x6, x1], dim=1)

        x6 = torch.tanh(self.last(x6))

        return x6

# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.last = nn.Conv2d(128, 1, 3)

    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x =torch.sigmoid(self.last(x))
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化两个生成器
gen_AB = Generator().to(device)
gen_BA = Generator().to(device)

# 初始化两个判别器
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 损失函数  1.gan loss  2.cycle consistance loss  3.identity loss
bce_loss = torch.nn.BCELoss()
l1_loss = torch.nn.L1Loss()

# 初始化优化器
# 对两个生成器同时进行优化, 使用itertools.chain对二者同时进行迭代
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), lr=2e-4, betas=(0.5, 0.999))

# 对两个判别器分别进行优化
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为gen_AB/gen_BA,test_input
    generate = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
    plt.figure(figsize=(10, 6))
    display_list = [test_input[0], generate[0]]
    title = ['Input image', 'Generate image']
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.title(title[i])
        plt.imshow((display_list[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./image/image_at_{}.png'.format(epoch))

test_batch = next(iter(dog_dataloader_test)) #batch_size,3,256,256
# 测试输入:选取test_batch中的第一张图片,并添加一个batch_size维度  3,256,256--1,3,256,256
test_input = torch.unsqueeze(test_batch[0], 0).to(device)

# cycleGAN训练
D_loss = []
G_loss = []
epochs = 50
for epoch in range(epochs):
    d_epoch_loss = 0
    g_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(dog_dataloader, cat_dataloader)): #取出真实的狗,猫图片
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        #--------------------begin--------------------#
        # 生成器训练
        gen_optimizer.zero_grad() #训练之前梯度清0
        # identity loss
        same_B = gen_AB(real_B) #真实的B经过生成器gen_AB还是要得到真实的B
        identity_B_loss = l1_loss(same_B, real_B)
        same_A = gen_AB(real_A) #真实的A经过生成器gen_BA还是要得到真实的A
        identity_A_loss = l1_loss(same_A, real_A)
        # 对抗损失 gan loss
        fake_B = gen_AB(real_A) #真实A通过生成器生成了B,此时生成器希望判别器将其判别为真
        D_pred_fake_B = dis_B(fake_B)
        gen_loss_AB = bce_loss(D_pred_fake_B, torch.ones_like(D_pred_fake_B, device=device))
        fake_A = gen_BA(real_B) #真实B通过生成器生成了A,此时生成器希望判别器将其判别为真
        D_pred_fake_A = dis_A(fake_A)
        gen_loss_BA = bce_loss(D_pred_fake_A, torch.ones_like(D_pred_fake_A, device=device))
        # 循环一致损失
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = l1_loss(recovered_A, real_A)

        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = l1_loss(recovered_B, real_B)

        # 生成器总的损失
        g_loss = identity_A_loss + identity_B_loss + gen_loss_AB + gen_loss_BA +cycle_loss_ABA + cycle_loss_BAB

        g_loss.backward()
        gen_optimizer.step()
        # --------------------end--------------------#

        # --------------------begin--------------------#
        # 判别器训练
        # dis_A训练
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A) #输入为真,期望判定为真
        dis_A_real_loss = bce_loss(dis_A_real_output, torch.ones_like(dis_A_real_output, device=device))

        dis_A_fake_output = dis_A(fake_A.detach())  #输入为假,期望判定为假,梯度截断
        dis_A_fake_loss = bce_loss(dis_A_fake_output, torch.zeros_like(dis_A_fake_output, device=device))

        dis_A_loss = dis_A_real_loss + dis_A_fake_loss #生成器A的总损失
        dis_A_loss.backward()
        dis_A_optimizer.step()

        # dis_B训练
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)  #输入为真,期望判定为真
        dis_B_real_loss = bce_loss(dis_B_real_output, torch.ones_like(dis_B_real_output, device=device))

        dis_B_fake_output = dis_B(fake_B.detach())  #输入为假,期望判定为假,梯度截断
        dis_B_fake_loss = bce_loss(dis_B_fake_output, torch.zeros_like(dis_B_fake_output, device=device))

        dis_B_loss = dis_B_real_loss + dis_B_fake_loss #生成器B的总损失
        dis_B_loss.backward()
        dis_B_optimizer.step()
        # --------------------end--------------------#

        with torch.no_grad():
            g_epoch_loss += g_loss.item() #将每一个批次的loss累加
            d_epoch_loss += (dis_A_loss + dis_B_loss).item()  # 将每一个批次的loss累加

    with torch.no_grad():
        g_epoch_loss /= (step + 1) #求得每一轮的平均loss
        d_epoch_loss /= (step + 1) #求得每一轮的平均loss
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('epoch:', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
        gen_img_plot(gen_AB, epoch, test_input)

Guess you like

Origin blog.csdn.net/m0_56247038/article/details/130395582