Computer Vision: Style Transfer

style transfer

This section describes how to use convolutional neural networks to automatically apply styles from one image to another, known as style transfer (Gatys et al., 2016). Here we need two input images: one is the content image and the other is the style image. We will use a neural network to modify the content image so that it is stylistically close to the style image. For example, the content image in the picture is a landscape photo taken by the author in Mount Rainier National Park in the suburbs of Seattle, while the style image is an oil painting with the theme of autumn oak trees. The final output composite image has the oil paint strokes of the style image applied to make the overall color more vivid, while preserving the shape of the main body of the object in the content image.

insert image description here



method

First, we initialize the composite image, e.g. as the content image. This synthetic image is the only variable that needs to be updated during the style transfer process, that is, the model parameters that need to be iterated for style transfer. Then, we choose a pre-trained convolutional neural network to extract the features of the image, and the model parameters do not need to be updated during training. This deep convolutional neural network extracts image features step by step with multiple layers, and we can choose the output of some of the layers as content features or style features. The following figure is an example. The pre-trained neural network selected here contains 3 convolutional layers, in which the second layer outputs content features, and the first and third layers output style features.
insert image description here
Next, we calculate the loss function for style transfer by forward propagation (direction of solid arrow), and iterate the model parameters by backpropagation (direction of dotted arrow), that is, continuously update the synthetic image. The loss function commonly used in style transfer consists of 3 parts:

  1. The content loss makes the synthetic image and the content image close in content characteristics;

  2. The style loss makes the synthetic image and the style image close in style features;

  3. Total variational loss helps to reduce noise in synthetic images.

Finally, when the model training is over, we output the model parameters for style transfer, that is, the final composite image.

In the following, we will further understand the technical details of style transfer through code.

Read content and style images

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('rainier.jpg')
d2l.plt.imshow(content_img);

insert image description here

style_img = d2l.Image.open('autumn-oak.jpg')
d2l.plt.imshow(style_img);

insert image description here

preprocessing and postprocessing

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

The preprocess function takes an input image img and an image_shape parameter as input. It uses torchvision.transforms.Resize to resize the image to the specified size, uses torchvision.transforms.ToTensor to convert the image to a tensor, and uses torchvision.transforms.Normalize to normalize the tensor, where the mean and standard deviation Values ​​are defined by rgb_mean and rgb_std, respectively. Finally, it adds an extra dimension using unsqueeze(0) to create a batch dimension, then returns the preprocessed tensor.

The postprocess function accepts a tensor img as input. It removes the batch dimension by using index[0], moves the tensor to the same device as rgb_std using .to(rgb_std.device), and then denormalizes the tensor using torch.clamp and the * operation. The torch.clamp function is used to ensure that pixel values ​​are between 0 and 1, since normalization may cause some values ​​to fall outside this range. Finally, it converts the tensor back to a PIL image using torchvision.transforms.ToPILImage() and permutes the dimensions so that the channel is the last dimension using .permute(2, 0, 1) before returning the image.

Normalization is the process of converting data to values ​​within a specific range. In computer vision, it is common to normalize images using the mean and standard deviation. This is because each pixel in an image can be thought of as a set of numbers, each with a value between 0 and 255. Therefore, normalizing an image scales pixel values ​​between 0 and 1, making them easier to handle and compare. In this code, the torchvision.transforms.Normalize function is used to normalize the image, where rgb_mean and rgb_std define the mean and standard deviation values ​​used to normalize the image.

Denormalization is the process of transforming data from values ​​within the normalized range to the original values. In computer vision, denormalization is often used to convert model output back to the pixel value range of the original image. In this code, the torch.clamp and * operations in the postprocess function are used to denormalize the tensor to convert the pixel values ​​from the values ​​in the normalized range to the original image pixel values.

Extract image features

pretrained_net = torchvision.models.vgg19(pretrained=True)
print(pretrained_net)

The structure is:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
 (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

In order to extract the content features and style features of the image, we can select the output of some layers in the VGG network. Generally speaking, the closer to the input layer, the easier it is to extract the detailed information of the image ; on the contrary, the easier it is to extract the global information of the image . In order to avoid the synthetic image retaining too much detail of the content image, we choose the layer closer to the output of VGG, that is, the content layer, to output the content features of the image. We also select the output of different layers from VGG to match local and global styles, these layers are also called style layers . The VGG network uses 5 convolutional blocks. In the experiments, we choose the last convolutional layer of the fourth convolutional block as the content layer , and the first convolutional layer of each convolutional block as the style layer . The indices of these layers can be obtained by printing the pretrained_net instance.

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

When using the VGG layer to extract features, we only need to use all layers from the input layer to the content layer or style layer closest to the output layer. Next, build a new network net, which only retains all the layers of VGG that need to be used.

net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

The choice of these layers is determined by the two lists content_layers and style_layers. The content_layers list contains the indices of the layers used for content extraction, and the style_layers list contains the indices of the layers used for style extraction. The indexing of these layers is determined according to the structure of VGG19, since the features extracted in specific layers of this model have been proved to be very useful for content and style capture.

Given an input X, if we simply call the forward pass net(X), we can only get the output of the last layer. Since we also need the output of the intermediate layer, here we calculate layer by layer and keep the output of the content layer and style layer.

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

This code defines a function extract_features that takes an input image X, and a list of layer indices content_layers and style_layers for content extraction and style extraction.

The for loop in the function traverses the neural network net layer by layer and applies the input image X to obtain the output of each layer. If the current layer's index is in style_layers, add the output of that layer to the styles list. If the current layer's index is in content_layers, add the output of that layer to the contents list. Finally, the function returns two lists, contents and styles, containing the outputs of the layers used for content and style extraction, respectively.

The purpose of this function is to use the layers in the neural network net to extract content and style features of the input image. In particular, the contents list contains features representing the content of the input image, and the styles list contains features representing the style of the input image. These features will be used to calculate the loss function, which in turn is used to train the style transfer model.

Two functions are defined below: the get_contents function extracts content features from content images; the get_styles function extracts style features from style images. Because there is no need to change the model parameters of the pre-trained VGG during training, we can extract content features and style features before training begins. Since the synthetic image is the iterative model parameter required for style transfer, we can only extract the content features and style features of the synthetic image by calling the extract_features function during the training process.

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

loss function

Next we describe the loss function for style transfer. It consists of 3 parts: content loss, style loss and full variation loss.

content loss

Similar to the loss function in linear regression, the content loss measures the difference in content features between the synthetic image and the content image through a squared error function. The two inputs of the square error function are both the output of the content layer calculated by the extract_features function.

def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

style loss

The style loss is similar to the content loss, and also measures the difference in style between the synthesized image and the style image through the squared error function. In order to express the style output by the style layer, we first calculate the output of the style layer through the extract_features function. Suppose the number of samples for this output is 1 11 , the number of channels isccc , the height and width arehhh sumwww , we can transform this output into a matrixXXX , which hasccline c and wh whcolumn w h . This matrix can be viewed byccc pieces of lengthwh whw h vectorx 1 , . . . , xc x_1,...,x_cx1,...,xccombined. where vector xi x_ixirepresents channel iistylistic traits on i .

In the Gram matrix XXT ∈ R c × c XX^T \in R^{c\times c} of these vectorsXXTRin c × c , iii rowjjElements of column j xij x_{ij}xijThat is, the vector xi x_ixiand xj x_jxjinner product. it expresses channel iii and channeljjCorrelation of style features on j . We use such a Gram matrix to express the style output by the style layer. It should be noted that whenhw hwWhen the value of h w is large, the elements in the Gram matrix tend to have large values. In addition, the height and width of the Gram matrix are both the number of channelsccc . In order to make the style loss unaffected by the size of these values, the gram function defined below divides the Gram matrix by the number of elements in the matrix, iechw chwchw

def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

This code defines a function gram that takes a tensor X as input and returns the Gram matrix of X. The Gram matrix is ​​a matrix used to describe the correlation between features, and is usually used to calculate the style features of the input image.

Specifically, this function first takes the number of channels and number of pixels of the input tensor X and reshape it into a matrix of shape (num_channels, n), where num_channels is the number of channels and n is the number of pixels. The function then computes the matrix product X @ XT and divides the result by (num_channels * n) to normalize the values ​​of the Gram matrix. Finally, the function returns the normalized Gram matrix.

Naturally, the two Gram matrix inputs to the squared error function of the style loss are based on the style layer output of the synthesized image and the style image, respectively. It is assumed here that the Gram matrix gram_Y based on the style image has been pre-calculated

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

total variation loss

Sometimes, the synthesized image we learned has a lot of high-frequency noise, that is, there are extremely bright or extremely dark grainy pixels. A common denoising method is total variation denoising (total variation denoising): Suppose xi , j x_{i,j}xi,jIndicates coordinates ( i , j ) (i,j)(i,j ) , reduce the total variation loss
∑ i , j ∣ xi , j − xi + 1 , j ∣ + ∣ xi , j − xi , j + 1 ∣ \sum_{i,j} |x_{i ,j}-x_{i+1,j}|+|x_{i,j}-x_{i,j+1}|i,jxi,jxi+1,j+xi,jxi,j+1
can make adjacent pixel values ​​similar as much as possible.

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

total loss function

The loss function for style transfer is a weighted sum of content loss, style loss and total change loss. By tuning these weight hyperparameters, we can weigh the relative importance of synthetic images in preserving content, transferring style, and denoising.

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

Initialize composite image

In style transfer, the synthesized image is the only variable that needs to be updated during training. Therefore, we can define a simple model SynthesizedImage and treat the synthesized image as a model parameter. The forward propagation of the model only needs to return the model parameters.

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

Next, we define the get_inits function. This function creates a model instance of the composite image and initializes it to image X. Gram matrix styles_Y_gram of style images in each style layer will be pre-computed before training.

def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

training model

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

Now we train the model: first adjust the height and width of the content image and style image to 300 and 450 pixels respectively, and use the content image to initialize the composite image.

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

insert image description here

Guess you like

Origin blog.csdn.net/qq_51957239/article/details/131085254