CycleGan 画风迁移初探

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_30665603/article/details/80843461

论文题目:Unpaired Image-to-Image Translationusing Cycle-Consistent Adversarial Networks

论文链接:https://arxiv.org/abs/1703.10593

源起:

    在“使用图片描述方式解决翻译问题”中,遇到了对称结构“翻译”模型,其基本思想是

    利用非配对样本特征的相互转化来解决传统特征转化问题(包括翻译及本文将要提到的画风迁移)配对样本不足的问题(一些细节参见前文)。上一篇特征转化过程涉及的是两类特征(文本、图片)间的特征转化,较为精妙。其VAE的思想可以与本文GAN的思想进行比较。

    对于原本paired的方式大多是配对监督的基本思路,只不过是“配对”监督的生成。

    (数据配对“对齐”)

    unpaired sample,对于这种问题,一种思路就是利用cycle对抗(交互)的方式,

    GAN或者VAE的基本形式。这时的基本结构就会变成非监督编码解码器结构或者与之对应的GAN结构。“配对”非监督。

    (算法配对“对齐”)

      

问题提出:

       在现实生活中,有很多不同画质风格的游戏。比如辐射4,对于一些辐射老炮,更倾向于辐射新维加斯的感觉,所以想在“联邦”中“模拟”沙漠地形。

                    

所以就涉及到了画风迁移技术。这里我们以把橙子变成苹果为例子进行这种模型的说明。

模型:

       在VAE中中间隐状态的概率分布是整体分布生成能力的组成部分,其使用编码、解码之间的loss给与这种生成一种“有向”的引导,一定程度上来说隐状态就是一种“变化”的“新息”。站在隐状态的角度,从“引导”的模型化方式来讲,VAE对应回归;与之相对的,GAN对应分类。对于MNIST二者有下面生成过程图的区别:

       先看VAE:


                    

        再看GAN:

                                    

    VAE先具备了“型“(回归的有向指导就决定了其学习的方向),之后”填补“了”真“,是一个”清晰化“的过程,通过完善隐状态统计表达能力加强特征,是一种清晰化引导的结果。

    “初始点“相当于(0-9)特征的叠加,不断使得隐状态学到(0-9)的统计差异(抽象出类别的变化),是一种分离特征的思想,可以将其类比于均方对于方差(隐状态)及偏差(编码解码loss)的有统计意义的分离。

    GAN不具备“型“,但直接向其统计意义精确的分布进行收敛。以概率论中的特征函数为例,要使得一个分布去逼近另一个分布可以用对应特征函数的逼近来表达这个过程。这就是一种从抽象特征的角度直接逼近的统计方法。GAN的情形就非常类似于这种方法,其初始分布并不具备目标分布的”型“却可以去逼近。这也是GAN在一些有关距离的定义及变体中与泛函分析产生关系的原因,因为它的观点就是统计特征的函数逼近。(感兴趣可以参看Wasserstein GAN)。

下面简单地提一下GAN:

       我们直接看目标函数:

                    

        把D (discriminator) 及G (generator) 看成两个神经网络,对于G,类比VAE给定隐状态z,这个网络负责生成想要的统计特征。对于D,其负责对于统计特征(包含真实的统计特征及生成的统计特征)进行区分。目标函数的意思就是使得G生成的统计特征在于真实统计特征的对比下难以被D进行区分。(以DCGAN为例,G基本上是一个反卷积结构(解码结构),D基本上是一个卷积分类结构)

        现在看一下用GAN解决画风迁移问题。

        将画风类比为随机变量,GAN所要完成的任务是使得生成的画风是我们想要的画风。而且从需求角度,我们还要求保留原有画风的“边界“结构。如果仅仅从类似DCGAN的解码结构出发,生成有对应画风的”噪声“(不包含原画的边界信息等)十分简单。所以要保留边界信息还要借助编码解码结构。故使用的是条件GAN,见下面公式:

                    

        y为条件随机变量。原生的特征设计合并结构如下(这里原生指一般情况,在本文的模型中并没有使用这种结构):

                    

        下面简单提一下本文用到的条件GAN结构——带dropout的u-net。

u-net示意图如下:

                        

        其为基本的卷积编码解码结构,加在卷积于反卷积对称位置、输入端对输出端的信息传递。这种“信息传递“用于在解码端更好地”复现“编码端图像信息细节。

        把input image对应待迁移样本,output对应迁移后的样本。这种情形就是条件GAN,z为退化分布的对应情形。可以理解为“复现“边界结构,要把迁移的画风信息加入到里面就要改变z。比较直观的思想是把z加入u-net的中间层(简单的tensor concatenate)。但文中使用的是另一种不改变网络结构并加入z的方法,即dropout。利用这种网络结构就可以直接做样本配对情形下的风格迁移,相应的内容可以参见pix2pix。(并不局限于画风迁移问题)

链接:https://arxiv.org/abs/1611.07004

        下面考虑的是把这种网络结构改变到非配对样本情形。这里的结构基本与“使用图片描述方式解决翻译问题”是类似的,下面直接给出模型图示:

                    

        图中的G、F可以看成上述条件GAN情形下的u-net结构(为了方便采用这种结构,具体结构以原文为准),从一种画风迁移到另一种画风的生成变换。然后在D利用对于生成后的图片特征与真实的图片特征进行打分,并得到loss,基本结构可参见DCGAN中D的结构,其中的细节是PatchGAN,其结构特征与要进行区分这一目的是统一的,区分“画风“——”纹理“而非边界,所以PatchGAN可以看成将图片”剪成“若干个小碎片,用同一部分权重区分对应小碎片间的画风区别做平均——CNN共用权重的思想。

        正是由于PatchGAN进行的是纹理上的判别,故要求要在模型结构上对于“边界“做”正则“,以保证迁移过程两端在边界上的一致性。(对应的部分在”使用图片描述方式解决翻译问题“中就是文本特征对于图片特征的均方一致收敛,保证特征转化的一致性,可以实现,agant的特征回传)

        由于画风迁移是一种特征,就可以直接把随机变量的一致性对应到可逆性(l1下)。((b)(c)两图)(在实现上要注意,可逆性的正则如果太强可能导致画风不变,所以可能涉及随着训练的进行在loss层面适当减小影响的调节,在下面的实现中提到了这个问题)

总的目标函数的构建如下图:

                    

                    

                    

下面尝试给出实现:(u-net部分借用junyanz https://github.com/junyanz

的实现,下面会给出对应链接)

       数据集链接:

https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip

       数据导出:

import torch
from torchvision import transforms
from PIL import Image
import os
from collections import defaultdict
from itertools import cycle

def walk_over_dir(rootDir):
    req_dict = defaultdict(list)
    def Test(rootDir):
        for root, dirs, files in os.walk(rootDir):
            root_key = root.split("\\")[-1].strip()
            for filespath in files:
                full_file_path = os.path.join(root,filespath)
                req_dict[root_key].append(full_file_path)
        return dict(req_dict.items())
    return Test(rootDir)

def global_transform_part(resize = 256):
    toTensor = transforms.ToTensor()
    r = transforms.Resize(size = resize)
    return transforms.Compose([r ,toTensor])

gp = global_transform_part()
# input img file
def transform_img_to_np_array(img_f, img_transform = gp, is_cuda = True):
    image = Image.open(img_f)
    image = image.convert("RGB")
    image = torch.tensor(img_transform(image), requires_grad = False)
    if is_cuda:
        image = image.cuda()
    return image

data_path_format = r"C:\tempCodingUsage\python\Pix2Pix\data\apple2orange\{}{}"

def data_loader(type = "train" ,batch_size = 1):
    assert type in ["train", "test"]
    apple_path = data_path_format.format(type, "A")
    orange_path = data_path_format.format(type, "B")

    apple_img_paths_dict = walk_over_dir(apple_path)
    orange_img_paths_dict = walk_over_dir(orange_path)

    def path_generator(img_dict):
        for k, v in img_dict.items():
            for apple_path in cycle(v):
                yield apple_path

    apple_generator = path_generator(apple_img_paths_dict)
    orange_generator = path_generator(orange_img_paths_dict)

    apple_list, orange_list = [], []
    while True:
        apple_list.append(apple_generator.__next__())
        orange_list.append(orange_generator.__next__())

        assert len(apple_list) == len(orange_list)
        if len(apple_list) == batch_size:
            # [batch, 3, 128, 128]
            apple_batch = torch.cat(list(map(lambda img_tensor: img_tensor.unsqueeze(0) ,map(transform_img_to_np_array, apple_list))),
                                    0)
            orange_batch = torch.cat(list(map(lambda img_tensor: img_tensor.unsqueeze(0) ,map(transform_img_to_np_array, orange_list))),
                                     0)

            input = {
                "apple": apple_batch,
                "orange": orange_batch
            }

            yield input
            apple_list, orange_list = [], []


if __name__ == "__main__":
    pass

模型实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from data_preprocess.preprocess_256 import data_loader
from torch import optim
import pause
import numpy as np

# use code from u-net implementation by
# junyanz https://github.com/junyanz
# download https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
from networks import define_G

# 128 64 32 16 8 4 2 2 4 8 16 32 64 128
# 3 64 128 256 512 1024 1024 512 256 128 64 3
params = [3, 3, 64, "unet_256", "instance", True, "normal",
          [0]]

ato_g = define_G(*params)
ato_g.bce_loss = nn.BCELoss()
ota_g = define_G(*params)
ota_g.bce_loss = nn.BCELoss()

class Schedule(nn.Module):
    def __init__(self, lambda_val = 10.0, N = 8,
                 ):
        super(Schedule, self).__init__()
        self.aa_loss = nn.L1Loss().cuda()
        self.oo_loss = nn.L1Loss().cuda()
        self.lambda_val = lambda_val

        self.identity_loss = nn.MSELoss()

        # generator params
        generator_channels = [3, 64, 128, 256, 512, 1024, 512, 256, 128, 64,
                              3]

        # discrimination part
        self.discriminator_seq_0 = \
            nn.Sequential(
                nn.Conv2d(generator_channels[0], generator_channels[1],
                          kernel_size = N, stride=N),
                nn.BatchNorm2d(num_features=generator_channels[1]),
                nn.LeakyReLU(),

                nn.Conv2d(generator_channels[1], generator_channels[2],
                          kernel_size=N, padding=1, stride=N),
                nn.BatchNorm2d(num_features=generator_channels[2]),
                nn.LeakyReLU(),
            ).cuda()

        flatten_conv_num = 2048
        self.final_linear_0 = nn.Linear(flatten_conv_num, 1).cuda()
        self.bce_loss_0 = nn.BCELoss().cuda()

        self.discriminator_seq_1 = \
            nn.Sequential(
                nn.Conv2d(generator_channels[0], generator_channels[1],
                          kernel_size = N, stride=N),
                nn.BatchNorm2d(num_features=generator_channels[1]),
                nn.LeakyReLU(),

                nn.Conv2d(generator_channels[1], generator_channels[2],
                          kernel_size=N, padding=1, stride=N),
                nn.BatchNorm2d(num_features=generator_channels[2]),
                nn.LeakyReLU(),
            ).cuda()

        flatten_conv_num = 2048
        self.final_linear_1 = nn.Linear(flatten_conv_num, 1).cuda()
        self.bce_loss_1 = nn.BCELoss().cuda()

    # orange_input [batch, 3, 128, 128] apple_input [batch, 3, 128, 128]
    def construct_l1_penalize(self, ato_transformation_0, ota_transformation_1,
                              apple_input, orange_input):
        ota_transformation_0 = ota_g(ato_transformation_0)
        ato_transformation_1 = ato_g(ota_transformation_1)

        return self.aa_loss(ota_transformation_0, apple_input) + self.oo_loss(ato_transformation_1, orange_input)

    def forward(self, input, step):
        '''
        input {"orange": [batch, 3, 128, 128],
                "apple": [batch, 3, 128, 128]
        } -> transform
        :param input:
        :return:
        '''
        orange_input = input["orange"]
        apple_input = input["apple"]

        # discrimination loss construct
        ato_transformation_0 = ato_g(apple_input)
        ota_identity_loss = self.identity_loss(ota_g(apple_input), apple_input)

        ota_transformation_1 = ota_g(orange_input)
        ato_identity_loss = self.identity_loss(ato_g(orange_input), orange_input)

        # ato discrimination loss
        ota_get_ato_sig = self.get_forward(ato_transformation_0, self.discriminator_seq_0, self.final_linear_0)

        true_part = torch.ones(ota_get_ato_sig.size()).cuda()
        false_part = torch.zeros(ota_get_ato_sig.size()).cuda()
        ato_gen_loss = ota_g.bce_loss(ota_get_ato_sig, true_part) + ota_g.bce_loss(ota_get_ato_sig, false_part)

        ota_get_o_sig = self.get_forward(orange_input, self.discriminator_seq_0, self.final_linear_0)

        true_part = torch.ones(ota_get_ato_sig.size()).cuda()
        ato_true_loss = ota_g.bce_loss(ota_get_o_sig, true_part)
        ato_loss = ato_gen_loss + ato_true_loss

        # ota discrimination loss
        ato_get_ota_sig = self.get_forward(ota_transformation_1, self.discriminator_seq_1, self.final_linear_1)

        true_part = torch.ones(ato_get_ota_sig.size()).cuda()
        false_part = torch.zeros(ato_get_ota_sig.size()).cuda()
        ota_gen_loss = ato_g.bce_loss(ato_get_ota_sig, true_part) + ota_g.bce_loss(ato_get_ota_sig, false_part)

        ato_get_a_sig = self.get_forward(apple_input, self.discriminator_seq_1, self.final_linear_1)

        true_part = torch.ones(ato_get_a_sig.size()).cuda()

        ota_true_loss = ato_g.bce_loss(ato_get_a_sig, true_part)

        ota_loss = ota_gen_loss + ota_true_loss

        l1_loss = self.construct_l1_penalize(ato_transformation_0, ota_transformation_1,
                                             apple_input, orange_input)

        d_loss = ato_loss + ota_loss
        g_loss = self.lambda_val * l1_loss + (ato_identity_loss + ota_identity_loss) * 1.0

        # adjust lambda_val in training step, prevent overfitting
        def step_adjust(step, ratio = 100):
            return torch.tensor(np.exp(-1 * (step) / ratio)).cuda()
        lambda_mul = step_adjust(step)
        g_loss = lambda_mul * g_loss

        return d_loss, g_loss, lambda_mul

    def get_forward(self, img_input, discriminator_seq, final_linear):
        conv_output = discriminator_seq(img_input)
        batch_size = conv_output.size(0)
        flatten_feature = conv_output.view(batch_size, -1)

        return F.sigmoid(final_linear(flatten_feature))

def train():
    train_generator = data_loader(type="train", batch_size=4)
    schedule = Schedule().cuda()
    schedule.train()

    g_optimizer = optim.Adam(list(ato_g.parameters()) + list(ota_g.parameters()),
                             lr = 0.001)
    d_optimizer = optim.Adam(schedule.parameters(), lr = 0.001)

    step = 0
    d_loss_list = []
    g_loss_list = []

    while True:
        input = train_generator.__next__()
        d_loss, g_loss, lambda_mul = schedule(input, step)

        d_loss_list.append(d_loss.cpu().data.numpy())
        g_loss_list.append(g_loss.cpu().data.numpy())

        print("d_loss :{}".format(np.mean(d_loss_list)) + " g_loss :{}".format(np.mean(g_loss_list)) \
              + " lambda_mul :{}".format(lambda_mul))

        if step % 100 == 0:
            with open(r"C:\tempCodingUsage\python\CycleGan\cycgan_drop.pkl", "wb") as f:
                torch.save({
                    "schedule": schedule,
                    "ato_g": ato_g,
                    "ota_g": ota_g
                }, f)
            d_loss_list = []
            g_loss_list = []
            print("call save" + "-" * 1000)

        for param in schedule.parameters():
            param.requires_grad = False

        g_optimizer.zero_grad()
        g_loss.backward(retain_graph = True)
        g_optimizer.step()

        for param in ato_g.parameters():
            param.requires_grad = False
        for param in ota_g.parameters():
            param.requires_grad = False

        for param in schedule.parameters():
            param.requires_grad = True

        d_optimizer.zero_grad()
        d_loss.backward(retain_graph = True)
        d_optimizer.step()

        for param in ato_g.parameters():
            param.requires_grad = True
        for param in ota_g.parameters():
            param.requires_grad = True

        step += 1

if __name__ == "__main__":
    train()

    networks.py 给出的代码风格,有些类似于在 javascript 中对函数对象使用工厂模式。

    下面给一个某一训练步类似橙子到苹果的转化图片:

                         



















猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/80843461
今日推荐