Pytorch はスタイル転送スタイル転送を実装します

  この記事では、スタイル転送を実装するための簡単なコードを示します。

1. 原理の紹介

  スタイル転送は、前回の記事で説明したディープ ドリーム アルゴリズムに似ており、何らかの最適化指標に従って勾配を計算し、入力画像のピクセルを逆最適化します。したがって、深い夢を学んだ後、鉄は熱いうちにこれをもう一度学びましたが、この記事は基本バージョンの実現に限定されており、この分野で開発された多くの進化バージョンについては説明しません。
  深層学習に基づくスタイル転送は、2015 年に Gatys によって初めて提案されました。その中心となる理論は、画像のスタイルを抽出するためにグラム行列 (グラム行列) を使用することです。グラム行列の計算方法は、画像の中間層の特徴(サイズは C H W)に独自の転置を使用して行列乗算を実行し、形状を C*(H W) にして、行列を計算します。のC C が得られる場合、この行列は元の行列の偏心共分散行列であり、ピクセルレベルの情報を取り除き、チャネル間の相関を表現します。Deep Dream の分析から、中間層の特徴の各チャネルは、画像の特徴の異なる次元を表現していることがわかりました。たとえば、あるものは尖った建物を表し、あるものは黒い縞模様を表します。2 つのチャネルの特徴をピクセル順に乗算して合計すると、2 つの特徴が同時に現れるか現れないかの程度を示します。たとえば、尖塔と黒い縞は常に同時に表示されますが、これは画像の特定の好み、つまりそのスタイルです。この分析から、このように定義されたスタイルは、画像内の地物がどこに現れるかには関係なく、2 つの地物が同時に現れるかどうかに関係していることがわかります。この定義は非常に賢明です。しかし、私はこの定義があまり完全ではないといつも感じており、これは単なる単純なスタイルであるべきであり、実際のスタイルはこれよりもはるかに複雑であるはずです。
  いずれにせよ、量から「スタイル」の量を抽出した後、それを操作してさまざまな花で遊ぶことができますが、最も一般的な方法は、ある画像から別の画像にスタイルを転送することです。アルゴリズムの原理を次の図に示します (Zhihu からの転載)。
ここに画像の説明を挿入

図 1. スタイル転送アルゴリズムの原理

  まず、スタイル イメージ、スタイル イメージ、コンテンツ イメージを選択し、次に初期シード イメージ (ノイズ イメージ、コンテンツ イメージ、または別のイメージ) を指定する必要があります。効果はさまざまです。次に、スタイル ピクチャとコンテンツ ピクチャを最初にネットワークに通過させます。スタイル ピクチャは、ネットワークの複数の中間層から特徴マップを抽出し、各層のスタイル マトリックス (グラム マトリックス) を計算します。レイヤーは比較的ベーシックなスタイル、ディープレイヤーは比較的ベーシックなスタイルですが、より高度なスタイルであり、どちらも便利です。(実際には、浅い層と深い層の間でグラム行列を相互計算すると便利だと思います。なぜなら、浅い特徴と深い特徴が同時に現れることがあるからです。これは、時間があるときに実験として残しておきます) 。次に、コンテンツ画像をネットワーク経由で渡し、後者の層の特定の層の特徴のみを抽出します。浅い特徴自体には、より多くのスタイルの意味が含まれており、スタイルを転送するときにそれらを保持する必要がないためです。最後に、ネットワークからの初期画像を繰り返しループし、転送されるたびに対応する各レイヤーのスタイルを計算し、それをスタイル マップと比較してスタイルの損失を取得し、対応するレイヤーのコンテンツを計算して、それと比較します。コンテンツの損失を取得するためのコンテンツ マップ。この論文では、著者は 3 番目の損失、つまり、生成された画像の滑らかさを示し、画像の歪みを防ぐ TV 損失を以下のコードに追加しました。包括的損失に従って入力画像の勾配を計算します。つまり、入力画像を繰り返し最適化して、スタイル画像と同じスタイル、コンテンツ画像と同じ内容を持ち、シード画像構造を保持した新しい画像を取得します(シード画像がランダムノイズでない場合)。
  以下で完全なコードを見てみましょう。

2. 完全なコード

import torch
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import numbers
import math
import cv2
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
import time
t0 = time.time()

model = models.vgg19(pretrained=True).cuda()
batch_size = 1

for params in model.parameters():
    params.requires_grad = False
model.eval()

mu = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(-1).unsqueeze(-1).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(-1).unsqueeze(-1).cuda()
unnormalize = lambda x: x*std + mu
normalize = lambda x: (x-mu)/std
transform_test = Compose([
    Resize((512,512)),
    ToTensor(),
])

content_img = Image.open('./data/tubingen.jpg')
image_size = content_img.size
content_img = transform_test(content_img).unsqueeze(0).cuda()

style_img = Image.open('./data/starry_night.jpg')
style_img = transform_test(style_img).unsqueeze(0).cuda()

var_img = content_img.clone()
#var_img = torch.rand_like(content_img)
var_img.requires_grad=True

class ShuntModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.module = model.features.cuda().eval()
        self.con_layers = [22]
        self.sty_layers = [1,6,11,20,29]
        for name, layer in self.module.named_children():
            if isinstance(layer, nn.MaxPool2d):
                self.module[int(name)] = nn.AvgPool2d(kernel_size = 2, stride = 2)

    def forward(self, tensor: torch.Tensor) -> dict:
        sty_feat_maps = []; con_feat_maps = [];
        x = normalize(tensor)
        for name, layer in self.module.named_children():
            x = layer(x);
            if int(name) in self.con_layers: con_feat_maps.append(x)
            if int(name) in self.sty_layers: sty_feat_maps.append(x)
        return {"Con_features": con_feat_maps, "Sty_features": sty_feat_maps}

model = ShuntModel(model)
sty_target = model(style_img)["Sty_features"]
con_target = model(content_img)["Con_features"]
gram_target = []
for i in range(len(sty_target)):
    b, c, h, w  = sty_target[i].size()
    tensor_ = sty_target[i].view(b * c, h * w)
    gram_i = torch.mm(tensor_, tensor_.t()).div(b*c*h*w)
    gram_target.append(gram_i)

optimizer = torch.optim.Adam([var_img], lr = 0.01, betas = (0.9,0.999), eps = 1e-8)
lam1 = 1e-3; lam2 = 1e7; lam3 = 5e-3
for itera in range(20001):
    optimizer.zero_grad()
    output = model(var_img)
    sty_output = output["Sty_features"]
    con_output = output["Con_features"]
    
    con_loss = torch.tensor([0]).cuda().float()
    for i in range(len(con_output)):
        con_loss = con_loss + F.mse_loss(con_output[i], con_target[i])
    
    sty_loss = torch.tensor([0]).cuda().float()
    for i in range(len(sty_output)):
        b, c, h, w  = sty_output[i].size()
        tensor_ = sty_output[i].view(b * c, h * w)
        gram_i = torch.mm(tensor_, tensor_.t()).div(b*c*h*w)
        sty_loss = sty_loss + F.mse_loss(gram_i, gram_target[i])
    
    b, c, h, w  = style_img.size()
    TV_loss = (torch.sum(torch.abs(style_img[:, :, :, :-1] - style_img[:, :, :, 1:])) +
                torch.sum(torch.abs(style_img[:, :, :-1, :] - style_img[:, :, 1:, :])))/(b*c*h*w)
    
    loss = con_loss * lam1 + sty_loss * lam2 + TV_loss * lam3
    loss.backward()
    var_img.data.clamp_(0, 1)
    optimizer.step()
    if itera%100==0:
        print('itera: %d, con_loss: %.4f, sty_loss: %.4f, TV_loss: %.4f'%(itera,
              con_loss.item()*lam1,sty_loss.item()*lam2,TV_loss.item()*lam3),'\n\t total loss:',loss.item())
        print('var_img mean:%.4f, std:%.4f'%(var_img.mean().item(),var_img.std().item()))
        print('time: %.2f seconds'%(time.time()-t0))
    if itera%1000==0:    
        save_img = var_img.clone()
        save_img = torch.clamp(save_img,0,1)
        save_img = save_img[0].permute(1,2,0).data.cpu().numpy()*255
        save_img = save_img[...,::-1].astype('uint8')  #注意cv2使用BGR顺序
        save_img = cv2.resize(save_img,image_size)
        cv2.imwrite('./data/output1/transfer%d.jpg'%itera,save_img)

3.効果

ここに画像の説明を挿入

図 2. 効果図
  通常は、元のイメージをシードとして使用することをお勧めします。

  この方法は時間がかかり、通常、イメージの生成に数分かかります。フォローアップの改善論文では、高速生成アルゴリズムや、効果を向上させるための領域ベースの移行 (空のスタイルの建物への転送の回避など) など、多くの改善が行われています。そして当分は学ぶつもりはありません。(ここ2年でAIの穴が掘られるのが早すぎて、知識量が多すぎるので、まずは大まかに覚えることしかできません。)
  この技術は、基本的な使い方をしていても、まだある程度は役に立ちます。この記事のコードのバージョンを使用すると、インターネットでいくつかのスタイルの写真を見つけて、それらを処理するだけで、かなりまともな画像を生成できます。以下は、私の女の赤ちゃんの写真を処理した結果の一部です。
ここに画像の説明を挿入

図 3. 効果図

4. 補足(2021.4.13)

  前述したように、複数の層の特徴を組み合わせてグラム行列を計算することができますが、この実装はそれほど難しくないので、今日試してみましたが、効果的でした。
ここに画像の説明を挿入

図 4. 統合グラム行列を使用したスタイルの抽出

  図 3 の右端の図と比較すると、統一グラム マトリックスを使用すると、スタイル マップ上のより詳細なスタイルが転送されることがわかります。たとえば、鼻の頭に黄色い点が現れたり、目の形が変化したりします。眉毛はスタイル マップに似ており、顔はペイントされています。色はスタイル マップに似ています。ただし、この方法はより多くのビデオ メモリ (メモリ) を使用し、計算速度が大幅に遅くなりますので、特定のニーズに応じて選択してください。

おすすめ

転載: blog.csdn.net/Brikie/article/details/115602714