风格迁移--生成你想要的风格

风格迁移–生成你想要的风格

标签: pytorch


随着深度网络的流行,用AI作画也不再是问题,比如下面这一张:

output.png

你能看出来是手画的,还是自动生成的吗。

下面介绍一个风格迁移网络,能够帮你生成任意你想要的style。本文也会提供一个Starry_Night_Over_the_Rhone的style模型,大家可以自己后台回复style_transform获取代码和模型。

下面简单介绍一下风格迁移网络。

网络结构

网络结构.jpg

上图就是快速风格迁移网络的结构,左边虚线框里面是一个Encoder-Decoder结构,而右边整个就是一个训练好的vgg,主要用来做特征提取进而能够计算图片间的损失。

从图中可以看出输入是一个x,经过Image Transform Net会变为一个y^,而这个y^就是我们要的图片,也就是经过风格转换后的图片。比如我们输入一张东方明珠电视塔图片作为x,那么文章刚开始的那个图片就是作为y^,那这个y^是如何学习得到的呢。主要靠后面vgg网络做损失,然后指导前面的Image Transform Net学习。

下面介绍一下这个网络中最重要的损失函数,这个损失函数不同于之前的分类网络的损失,原来的分类网络一般就是一个交叉熵函数,但是这里的损失是一个预训练好的vgg,从图中可以看出,Loss Network有三个输入,分别是ys y^ yc,其中ys就是风格图片,在本次实验中我们选择的是:

Starry_Night_Over_the_Rhone.jpg

正是由于ys的缘故,所以我们的y^在风格上和它非常像。而yc其实就是x。将这三个输入到vgg里面,然后计算利用vgg强大的特征提取能力,把提取的特征做为损失,我们的目的是使得我们y^在内容上和yc相近,而风格上和ys更近,所以引出了两类损失,第一类是风格损失,第二类是内容损失,对应图中内容损失就直接对y^ yc的中间特征用mse计算即可,也就是右边下面的那个损失,而风格损失是上面的三个,是对ys y^的中间特征计算gram得到。

最后我们优化这两个损失就能保证我们的输出y^在风格和ys更近,而内容上和yc更近。

代码简析

这部分对代码做一个简单的分析,其中main.py是主函数,里面包含了两个主要方法train、stylize,其中train是用来训练模型的, 如果你有充分的数据集,你可以自己加载数据来进行训练,只需要修改Config里面data_root即可。

train里面主要就是加载数据,加载模型TransformerNet,而TransformerNet就是前面说的那个Image Transform Net,损失网络同样使用的是Vgg,在训练的过程中只更新TransformerNet的参数,因为Vgg是作为一个损失函数来用的,它直接使用一个ImageNet的预训练参数即可。

而stylize函数则提供了一个测试,当我们训练好了模型,就可以用这个函数来帮我们生成图片了,我们在Config里面指定一个content_path,这里我们可以假定是一个东方明珠,你可以用其他代替。在stylize里面做的事情就是把TransformerNet加载一下,注意要把训练好的模型给加载上去,然后一次前向传播即可。

如果你想直接用

如果你不太懂,而想尝下鲜,那么下载完代码后,直接输入python main.py即可,如果想换个图片,直接修改第40行的content_path = 'images.jpeg' # 需要进行分割迁移的图片即可。

欢迎大家关注我的微信公众号,未来上面会推送python 机器学习 算法学习 深度学习 论文阅读 以及偶尔的小鸡汤等内容。ようこそいらっしゃい!

搜索 coderwangson 关注

image

发布了260 篇原创文章 · 获赞 254 · 访问量 21万+

猜你喜欢

转载自blog.csdn.net/qq_28888837/article/details/103651649