Python uses AI photo2cartoon to make your cartoon avatar

Python uses AI photo2cartoon to make your cartoon avatar

git clone https://github.com/minivision-ai/photo2cartoon.git
cd ./photo2cartoon
python test.py --photo_path images/photo_test.jpg --save_path images/cartoon_result.png

1. Rendering

The official renderings are as follows:

insert image description here
Effect picture 1 is as follows:
insert image description here
Effect picture 2 is as follows:

insert image description here
The effect picture 3 is as follows:

insert image description here

2. Principle

The goal of portrait toon rendering is to convert photorealistic images into cartoon-like non-photorealistic images while maintaining original image ID information and texture details.

However, the task of image cartoonization faces some difficulties:

  • Cartoon images tend to have sharp edges, smooth color blocks, and simplified textures, which are quite different from other artistic styles. Cartoon images generated using traditional image processing techniques cannot adaptively handle complex lighting and textures, and the effect is poor; methods based on style transfer cannot accurately outline details.
  • Data acquisition is difficult. It takes a lot of time and cost to draw a cartoon with an exquisite and uniform style, and the face and facial features of the converted cartoon are different from the original photo, so it does not constitute pixel-level paired data, and it is difficult to use paired data The Image Translation (Paired Image Translation) method.
  • It is easy to lose identity information after the photos are cartoonized. The cycle consistency loss (Cycle Loss) in the Unpaired Image Translation method based on unpaired data cannot effectively constrain the id of the input and output.

The research team of Xiaoshi Technology proposed a cartoon model based on generative confrontation network, which can obtain beautiful results with only a small amount of unpaired training data. The cartoon style rendering network is the core of the solution, which mainly consists of three parts: feature extraction, feature fusion and feature reconstruction.

3. Source code

For source code and sample file models, see resources: https://download.csdn.net/download/qq_40985985/87739184

insert image description here

# 使用预训练的模型生成漫画头像
# python test.py --photo_path images/ml.jpg --save_path images/cartoon_ml_result.png

import argparse
import os

import cv2
import numpy as np
import torch

from models import ResnetGenerator
from utils import Preprocess

parser = argparse.ArgumentParser()
parser.add_argument('--photo_path', type=str, default='images/photo_test.jpg', help='input photo path')
parser.add_argument('--save_path', type=str, default='images/photo_test_cartoon.jpg', help='cartoon save path')
args = parser.parse_args()

os.makedirs(os.path.dirname(args.save_path), exist_ok=True)


class Photo2Cartoon:
    def __init__(self):
        self.pre = Preprocess()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net = ResnetGenerator(ngf=32, img_size=256, light=True).to(self.device)

        assert os.path.exists(
            './models/photo2cartoon_weights.pt'), "[Step1: load weights] Can not find 'photo2cartoon_weights.pt' in folder 'models!!!'"
        params = torch.load('./models/photo2cartoon_weights.pt', map_location=self.device)
        self.net.load_state_dict(params['genA2B'])
        print('[Step1: load weights] success!')

    def inference(self, img):
        # face alignment and segmentation
        face_rgba = self.pre.process(img)
        if face_rgba is None:
            print('[Step2: face detect] can not detect face!!!')
            return None

        print('[Step2: face detect] success!')
        face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
        face = face_rgba[:, :, :3].copy()
        mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
        face = (face * mask + (1 - mask) * 255) / 127.5 - 1

        face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
        face = torch.from_numpy(face).to(self.device)

        # inference
        with torch.no_grad():
            cartoon = self.net(face)[0][0]

        # post-process
        cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
        cartoon = (cartoon + 1) * 127.5
        cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
        cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
        print('[Step3: photo to cartoon] success!')
        return cartoon


if __name__ == '__main__':
    img = cv2.cvtColor(cv2.imread(args.photo_path), cv2.COLOR_BGR2RGB)
    c2p = Photo2Cartoon()
    cartoon = c2p.inference(img)
    if cartoon is not None:
        cv2.imwrite(args.save_path, cartoon)
        print('Cartoon portrait has been saved successfully!')
        origin = cv2.resize(cv2.imread(args.photo_path), (256, 256))
        res = cv2.imread(args.save_path)
        print(origin.shape, res.shape)
        cv2.imshow("origin VS cartoon", np.hstack([origin, res]))
        cv2.waitKey(0)

reference

Guess you like

Origin blog.csdn.net/qq_40985985/article/details/130366039