Python uses AI animegan2-pytorch to make your own comic avatar/landscape picture

Python uses AI animegan2-pytorch to make your comic avatar

git clone https://github.com/bryandlee/animegan2-pytorch
cd ./animegan2-pytorch
python test.py --photo_path images/photo_test.jpg --save_path images/animegan2_result.png

1. Rendering

The official renderings are as follows:

The renderings of the v2 512 model are as follows:
insert image description here

The renderings of the v1 512 model are as follows:
insert image description here

The rendering effect of v1 is not very good as follows:
insert image description here

The rendering rece is as follows.
The characters will have a sickly beauty, too white, and the landscape effect will be better; the
renderings of the characters and photo2cartoon are a bit similar;
insert image description here

insert image description here

The effect picture of the paprika model is as follows.
The texture traces of the characters are too obvious, which is more suitable for the landscape
. The effect of the next picture of Minglan is not bad, and different models will have slight differences in different images;
insert image description here
insert image description here

insert image description here

Origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll The landscape effect comparison chart is as follows:

insert image description here
insert image description here

Origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll The character effect comparison chart is as follows:
insert image description here
insert image description here

2. Principle

The goal of Portrait/Landscape Toon Style Rendering is to convert photorealistic images into cartoon-like non-photorealistic images while maintaining original image ID information and texture details.

3. Source code

For source code and example file models, see resources: https://download.csdn.net/download/qq_40985985/87739198

# animegan2-pytroch 生成漫画头像或者风景图
# python test.py --checkpoint weights/face_paint_512_v2.pt --input_dir samples/faces/ --device cpu --output_dir samples/resv2
# model loaded: weights/face_paint_512_v2.pt

import os
import argparse

from PIL import Image
import numpy as np

import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

from model import Generator


torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def load_image(image_path, x32=False):
    img = Image.open(image_path).convert("RGB")

    if x32:
        def to_32s(x):
            return 256 if x < 256 else x - x % 32
        w, h = img.size
        img = img.resize((to_32s(w), to_32s(h)))

    return img


def test(args):
    device = args.device
    
    net = Generator()
    net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
    net.to(device).eval()
    print(f"model loaded: {
      
      args.checkpoint}")
    
    os.makedirs(args.output_dir, exist_ok=True)

    for image_name in sorted(os.listdir(args.input_dir)):
        if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
            continue
            
        image = load_image(os.path.join(args.input_dir, image_name), args.x32)

        with torch.no_grad():
            image = to_tensor(image).unsqueeze(0) * 2 - 1
            out = net(image.to(device), args.upsample_align).cpu()
            out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
            out = to_pil_image(out)

        out.save(os.path.join(args.output_dir, image_name))
        print(f"image saved: {
      
      image_name}")


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='./weights/paprika.pt',
    )
    parser.add_argument(
        '--input_dir', 
        type=str, 
        default='./samples/inputs',
    )
    parser.add_argument(
        '--output_dir', 
        type=str, 
        default='./samples/results',
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0',
    )
    parser.add_argument(
        '--upsample_align',
        type=bool,
        default=False,
        help="Align corners in decoder upsampling layers"
    )
    parser.add_argument(
        '--x32',
        action="store_true",
        help="Resize images to multiple of 32"
    )
    args = parser.parse_args()
    
    test(args)

# 原图VS效果图绘制
# python plot_sample.py

# 获取输入路径的所有图像
import cv2
import imutils
import numpy as np
from imutils import paths

imagePaths = sorted(list(paths.list_images("samples")))

list = [x for x in imagePaths if x.find('inputs') > 0]
print(list)

resv1 = [x for x in imagePaths if x.find("resv1") > 0]
resv2 = [x for x in imagePaths if x.find("resv2") > 0]
cele = [x for x in imagePaths if x.find("cele") > 0]
pap = [x for x in imagePaths if x.find("paprika") > 0]

img = None
for i in list:
    if (i.find("ml2.jpg") < 0): continue
    img = None
    for j in resv1:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
                       imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
                       img)
            # cv2.waitKey(0)
    for j in resv2:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in pap:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in cele:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll',
                       img)
            cv2.waitKey(0)

reference

Guess you like

Origin blog.csdn.net/qq_40985985/article/details/130379775