Style transfer StyleTransfer and Pytorch implementation

Style transfer and Pytorch implementation

Style transfer is the use of algorithms to learn the style of a picture, and then apply this style to another picture.

This article will introduce its principle and use Pytorch to implement it.

Insert picture description here

In convolution, the more specific the shallow features, the more abstract the deep features. From the perspective of style, the shallow features record information such as color and texture, while the deep features record more advanced information.

The main method is to randomly select a picture, and by optimizing content loss and style loss, change the picture so that its content is close to the content picture and the style is close to the style picture.

Content loss : directly calculate the Euclidean distance of the feature map ;

Style loss : Calculate the Euclidean distance of the Gram matrix of the feature map

The calculation method of the Gram matrix:

def get_gram_matrix(f_map):
    n, c, h, w = f_map.shape
    if n == 1:
        f_map = f_map.reshape(c, h * w)
        gram_matrix = torch.mm(f_map, f_map.t())
        return gram_matrix
    else:
        raise ValueError('批次应该为1,但是传入的不为1')

Reshape the feature map, combine the dimensions of width and height, and then calculate the matrix multiplication of its transpose with itself.

Migrate the pre-trained VGG19 model. And output feature maps of five different dimensions.

from torchvision.models import vgg19
from torch import nn
from torchvision.utils import save_image
import torch
import cv2


class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        a = vgg19(True)
        a = a.features
        self.layer1 = a[:4]
        self.layer2 = a[4:9]
        self.layer3 = a[9:18]
        self.layer4 = a[18:27]
        self.layer5 = a[27:36]

    def forward(self, input_):
        out1 = self.layer1(input_)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)
        return out1, out2, out3, out4, out5

Define the picture directly as a network parameter to train it. Training directly from the original content map here, white noise can also be used.

class GNet(nn.Module):
    def __init__(self, image):
        super(GNet, self).__init__()
        self.image_g = nn.Parameter(image.detach().clone())
        # self.image_g = nn.Parameter(torch.rand(image.shape))  # 也可以初始化一张白噪声训练 

    def forward(self):
        return self.image_g.clamp(0, 1)  # 为了限定数值范围。

Define the function to load the picture:

def load_image(path):
    image = cv2.imread(path)  # 打开图片
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换通道,因为opencv默认读取格式为BGR,转换为RGB格式
    image = torch.from_numpy(image).float() / 255  # 数值归一化操作
    image = image.permute(2, 0, 1).unsqueeze(0)  # 换轴,(H,W,C)转换为(C,H,W),并做升维处理。
    return image

Need to use pictures need to keep the shape consistent

First load the content image and style image , and then instantiate the VGG19 network and image . The image is trained directly from the original content image.

Instantiate the optimizer and loss function .

image_content = load_image('c.jpg').cuda()
image_style = load_image('s.jpg').cuda()
net = VGG19().cuda()
g_net = GNet(image_content).cuda()
optimizer = torch.optim.Adam(g_net.parameters())
loss_func = nn.MSELoss().cuda()

Calculate the output of the input VGG19 of the style picture, and get its Gram matrix .

s1, s2, s3, s4, s5 = net(image_style)
s1 = get_gram_matrix(s1).detach().clone()
s2 = get_gram_matrix(s2).detach().clone()
s3 = get_gram_matrix(s3).detach().clone()
s4 = get_gram_matrix(s4).detach().clone()
s5 = get_gram_matrix(s5).detach().clone()

Calculate the output of the content image input VGG19

c1, c2, c3, c4, c5 = net(image_content)
c1 = c1.detach().clone()
c2 = c2.detach().clone()
c3 = c3.detach().clone()
c4 = c4.detach().clone()
c5 = c5.detach().clone()

Train the picture.

i = 0
while True:
    """生成图片,计算损失"""
    image = g_net()
    out1, out2, out3, out4, out5 = net(image)

    """计算分格损失"""
    loss_s1 = loss_func(get_gram_matrix(out1), s1)
    loss_s2 = loss_func(get_gram_matrix(out2), s2)
    loss_s3 = loss_func(get_gram_matrix(out3), s3)
    loss_s4 = loss_func(get_gram_matrix(out4), s4)
    loss_s5 = loss_func(get_gram_matrix(out5), s5)
    loss_s = 0.1*loss_s1 + 0.1*loss_s2 + 0.6*loss_s3 + 0.1*loss_s4 + 0.1*loss_s5

    """计算内容损失"""
    loss_c1 = loss_func(out1, c1)
    loss_c2 = loss_func(out2, c2)
    loss_c3 = loss_func(out3, c3)
    loss_c4 = loss_func(out4, c4)
    loss_c5 = loss_func(out5, c5)
    loss_c = 0.05 * loss_c1 + 0.05 * loss_c2 + 0.15 * loss_c3 + 0.3 * loss_c4 + 0.45 * loss_c5

    """总损失"""
    loss = 0.5*loss_c + 0.5*loss_s

    """清空梯度、计算梯度、更新参数"""
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(i, loss.item(), loss_c.item(), loss_s.item())
    if i % 1000 == 0:
        save_image(image, f'{i}.jpg', padding=0, normalize=True, range=(0, 1))
    i += 1

Calculate the style loss and content loss separately, and then calculate the total loss to optimize the loss.

The effect can be achieved by basic iteration a thousand times.

The content picture is:

Insert picture description here

The effect of several pictures:

Style picture Generate picture
Insert picture description here
/ Insert picture description here
/ Insert picture description here
/ Insert picture description here
Insert picture description here Insert picture description here
/ Insert picture description here
Insert picture description here /> Insert picture description here
Insert picture description here Insert picture description here
/ Insert picture description here
Insert picture description here Insert picture description here
Insert picture description here Insert picture description here
Insert picture description here Insert picture description here
Insert picture description here Insert picture description here

Adjusting the different proportional coefficients of each loss can achieve different effects. You can try as appropriate.

Guess you like

Origin blog.csdn.net/weixin_48866452/article/details/109045157