pytorch实现图像风格迁移

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qian1996/article/details/85083438

首先定义两个损失函数:

内容损失:

class Content_loss(torch.nn.Module):
    # weight权重 控制风格的影响程度     target经过卷积获取到的输入图像的内容。
    def __init__(self, weight, target):
        super(Content_loss, self).__init__()
        self.weight = weight
        # detach对提取的内容进行锁定,不进行梯度
        self.target = target.detach() * weight
        # 用均方误差作为损失函数
        self.loss_fn = torch.nn.MSELoss()

    # 计算图像与内容之间的损失值
    def forward(self, input):
        self.loss = self.loss_fn(input * self.weight, self.target)
        return input

    # 计算损失值向后传播
    def backward(self):
        self.loss.backward(retain_graph=True)
        return self.loss

风格损失:

'''
图像风格损失
'''


class Style_loss(torch.nn.Module):
    # weight权重 控制风格的影响程度     target经过卷积获取到的输入图像的内容。
    def __init__(self, weight, target):
        super(Style_loss, self).__init__()
        self.weight = weight
        self.target = target.detach() * weight
        # 用均方误差作为损失函数
        self.loss_fn = torch.nn.MSELoss()
        self.gram = Gram_matrix()

    # 计算图像与内容之间的损失值
    def forward(self, input):
        self.Gram = self.gram(input.clone())
        self.Gram.mul_(self.weight)
        self.loss = self.loss_fn(self.Gram, self.target)
        return input

    # 计算损失值向后传播
    def backward(self):
        self.loss.backward(retain_graph=True)
        return self.loss


'''用这个类定义的实例参与风格损失的计算
格拉姆矩阵
卷积-》图像风格(由数字组成)  相当于进行内积运算
放大图片风格在进行损失计算,能对合成的图片产生更大的影响
'''


class Gram_matrix(torch.nn.Module):
    def forward(self, input):
        a, b, c, d = input.size()
        print("a",a,"b",b,"c",c,"d",d)
        # 转为(ab行 cd列)
        feature = input.view(a * b, c * d)
        # 内积运算
        gram = torch.mm(feature, feature.t())
        # 除以abcd
        return gram.div(a * b * c * d)

搭建网络模型

'''图像风格迁移模型'''
new_model = torch.nn.Sequential()
# 深层复制,改变任意一个都不变
# 浅层复制 改变原来的cnn model会变
model = copy.deepcopy(cnn)
gram = Gram_matrix()

if (use_gpu):
    new_model = new_model.cuda()
    gram = gram.cuda()

index = 1
# 仅用到迁移模型提取特征的前八层
for layer in list(model)[:8]:
    # 实例检测函数检测
    if isinstance(layer, torch.nn.Conv2d):
        name = "Conv_" + str(index)
        # 向空模型中加入指定的层次模块,得到自定义模型
        new_model.add_module(name, layer)
        if name in content_layer:
            target = new_model(content_img).clone()
            content_loss = Content_loss(content_weight, target)
            new_model.add_module("content_loss_" + str(index), content_loss)
            content_losses.append(content_loss)
        if name in style_layer:
            target = new_model(style_img).clone()
            target = gram(target)
            style_loss = Style_loss(style_weight, target)
            new_model.add_module("style_loss_" + str(index), style_loss)
            style_losses.append(style_loss)
    if isinstance(layer, torch.nn.ReLU):
        name = "ReLU_" + str(index)
        new_model.add_module(name, layer)
        index = index + 1
    if isinstance(layer, torch.nn.MaxPool2d):
        name = "MaxPool2d_" + str(index)
        new_model.add_module(name, layer)

输出自定义的网络结构

Sequential(
  (Conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_1): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_1): ReLU(inplace)
  (Conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_2): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_2): ReLU(inplace)
  (MaxPool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (content_loss_3): Content_loss(
    (loss_fn): MSELoss()
  )
  (style_loss_3): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
  (ReLU_3): ReLU(inplace)
  (Conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (style_loss_4): Style_loss(
    (loss_fn): MSELoss()
    (gram): Gram_matrix()
  )
)

训练

optimizer = torch.optim.LBFGS([parameter])

epoch_n = 300
epoch = [0]
while epoch[0] <= epoch_n:
    def closure():
        optimizer.zero_grad()
        style_score = 0
        content_score = 0
        parameter.data.clamp_(0, 1)
        new_model(parameter)
        for sl in style_losses:
            style_score += sl.backward()
        for cl in content_losses:
            content_score += cl.backward()

        epoch[0] += 1
        if epoch[0] % 50 == 0:
            print("Epoch:{} StyleLoss :{:4f} Content Loss:{:4f}".format(
                epoch[0], style_score.data[0], content_score.data[0]))
            img_cs = new_model(parameter)
            plt.figure("Img_cs")
            plt.imshow(img_cs)
            plt.show()
        return style_score + content_score


    optimizer.step(closure)

结果

猜你喜欢

转载自blog.csdn.net/qian1996/article/details/85083438