One-click batch extraction of Stable Diffusion model mask

Have you ever thought that you can use an algorithm to extract the clothes of the models in the pictures in batches, and then change them through SD.

Is it too tiring to cut out one by one PS, you can use the algorithm to extract them in batches. This is simpler than the Segment Anything approach.

insert image description here

Article Directory

Mask batch extraction

import os

from tqdm import tqdm
from PIL import Image
import numpy as np

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu

from networks import U2NET

device = "cuda"

image_dir = "input_images"
result_dir = "output_images"
mask_dir = "output_mask"
checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
do_palette = True


def get_palette(num_cls):
    """Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i += 1
            lab >>= 3
    return palette


transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()

palette = get_palette(4)

images_list = sorted(os.listdir(image_dir))
pbar = tqdm(total=len(images_list))
for image_name in images_list:
    img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
    image_tensor = transform_rgb(img)
    image_tensor = torch.unsqueeze(image_tensor, 0)

    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()

    output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
    if do_palette:
        output_img.putpalette(palette)
    output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))

    pbar.update(1)

pbar.close()

from PIL import Image

dir_list = os.listdir(result_dir)

for n in dir_list:
    # 打开图片文件
    im = Image.open(result_dir + '/' + n)
    # 转换为RGB模式
    im = im.convert('RGB')
    # 获取像素矩阵
    pixels = im.load()
    # 遍历每个像素点
    for i in range(im.size[0]):
        for j in range(im.size[1]):
            # 判断当前像素是否为黑色
            if pixels[i, j] == (0, 0, 0):
                pass
            else:
                # 将黑色像素点转换为白色F
                pixels[i, j] = (255, 255, 255)
    # 保存修改后的图片

    im.save(os.path.join(mask_dir, str(n)[:-3] + "png"))

This part of the code is used input_imagesto cut out all the models below, and save the preprocessed pictures to output_imagesthe next.

insert image description here
Then it is processed into a black and white mask image by calculation.
insert image description here

SD dressup

Open in the SD img2imgpage Inpaint upload. Upload the original image and mask of the model.

insert image description here
Fill in the keyword to generate it and you can pull it.

insert image description here
It's that simple, if you are interested, you can make a batch processing script, and then select the pictures yourself.

Guess you like

Origin blog.csdn.net/qq_20288327/article/details/131579632