版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qian1996/article/details/85083438
首先定义两个损失函数:
内容损失:
class Content_loss(torch.nn.Module):
# weight权重 控制风格的影响程度 target经过卷积获取到的输入图像的内容。
def __init__(self, weight, target):
super(Content_loss, self).__init__()
self.weight = weight
# detach对提取的内容进行锁定,不进行梯度
self.target = target.detach() * weight
# 用均方误差作为损失函数
self.loss_fn = torch.nn.MSELoss()
# 计算图像与内容之间的损失值
def forward(self, input):
self.loss = self.loss_fn(input * self.weight, self.target)
return input
# 计算损失值向后传播
def backward(self):
self.loss.backward(retain_graph=True)
return self.loss
风格损失:
'''
图像风格损失
'''
class Style_loss(torch.nn.Module):
# weight权重 控制风格的影响程度 target经过卷积获取到的输入图像的内容。
def __init__(self, weight, target):
super(Style_loss, self).__init__()
self.weight = weight
self.target = target.detach() * weight
# 用均方误差作为损失函数
self.loss_fn = torch.nn.MSELoss()
self.gram = Gram_matrix()
# 计算图像与内容之间的损失值
def forward(self, input):
self.Gram = self.gram(input.clone())
self.Gram.mul_(self.weight)
self.loss = self.loss_fn(self.Gram, self.target)
return input
# 计算损失值向后传播
def backward(self):
self.loss.backward(retain_graph=True)
return self.loss
'''用这个类定义的实例参与风格损失的计算
格拉姆矩阵
卷积-》图像风格(由数字组成) 相当于进行内积运算
放大图片风格在进行损失计算,能对合成的图片产生更大的影响
'''
class Gram_matrix(torch.nn.Module):
def forward(self, input):
a, b, c, d = input.size()
print("a",a,"b",b,"c",c,"d",d)
# 转为(ab行 cd列)
feature = input.view(a * b, c * d)
# 内积运算
gram = torch.mm(feature, feature.t())
# 除以abcd
return gram.div(a * b * c * d)
搭建网络模型
'''图像风格迁移模型'''
new_model = torch.nn.Sequential()
# 深层复制,改变任意一个都不变
# 浅层复制 改变原来的cnn model会变
model = copy.deepcopy(cnn)
gram = Gram_matrix()
if (use_gpu):
new_model = new_model.cuda()
gram = gram.cuda()
index = 1
# 仅用到迁移模型提取特征的前八层
for layer in list(model)[:8]:
# 实例检测函数检测
if isinstance(layer, torch.nn.Conv2d):
name = "Conv_" + str(index)
# 向空模型中加入指定的层次模块,得到自定义模型
new_model.add_module(name, layer)
if name in content_layer:
target = new_model(content_img).clone()
content_loss = Content_loss(content_weight, target)
new_model.add_module("content_loss_" + str(index), content_loss)
content_losses.append(content_loss)
if name in style_layer:
target = new_model(style_img).clone()
target = gram(target)
style_loss = Style_loss(style_weight, target)
new_model.add_module("style_loss_" + str(index), style_loss)
style_losses.append(style_loss)
if isinstance(layer, torch.nn.ReLU):
name = "ReLU_" + str(index)
new_model.add_module(name, layer)
index = index + 1
if isinstance(layer, torch.nn.MaxPool2d):
name = "MaxPool2d_" + str(index)
new_model.add_module(name, layer)
输出自定义的网络结构
Sequential(
(Conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_1): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_1): ReLU(inplace)
(Conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_2): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_2): ReLU(inplace)
(MaxPool2d_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(Conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(content_loss_3): Content_loss(
(loss_fn): MSELoss()
)
(style_loss_3): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
(ReLU_3): ReLU(inplace)
(Conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(style_loss_4): Style_loss(
(loss_fn): MSELoss()
(gram): Gram_matrix()
)
)
训练
optimizer = torch.optim.LBFGS([parameter])
epoch_n = 300
epoch = [0]
while epoch[0] <= epoch_n:
def closure():
optimizer.zero_grad()
style_score = 0
content_score = 0
parameter.data.clamp_(0, 1)
new_model(parameter)
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
epoch[0] += 1
if epoch[0] % 50 == 0:
print("Epoch:{} StyleLoss :{:4f} Content Loss:{:4f}".format(
epoch[0], style_score.data[0], content_score.data[0]))
img_cs = new_model(parameter)
plt.figure("Img_cs")
plt.imshow(img_cs)
plt.show()
return style_score + content_score
optimizer.step(closure)
结果