虚拟试穿--测试上衣代码详解

虚拟试穿

简介:本文梳理虚拟试穿算法框架结构。
算法仓库:https://github.com/beauthy/DeepFashion_Try_On


网络模型:latest_net_U.pth,latest_net_G1.pth,latest_net_G2.pth,latest_net_G.pth

测试效果:如图
在这里插入图片描述
测试效果:如视频

计算机视觉神经网络虚拟试穿测试


前言

本文将梳理算法实现过程原理。


提示:本文内容仅供学术研究与参考。

一、Try_On算法里面有什么?

0.环境; 1. 数据读取; 2. 数据模型:U-Net,G-Net; 3.损失函数; 4.调试常见的bug。

二、梳理步骤

1.环境

代码如下(示例):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述在这里插入图片描述
以上只等下次优化成requirements.txt再传上来。

2.读入数据

模型输入数据需要哪些呢?
在这里插入图片描述
测试数据集长什么样?
在这里插入图片描述
数据直观内容分析,我把模型需要的输入放一起,展示如下:
在这里插入图片描述
实际上,pose_关键点数据,和label_分割数据,是img_模特数据得到的(怎么生成关键点数据和人物分割数据的详细解读和代码,我再开一篇博客放上来);edge_数据就是待穿衣服color生成的。mask掩码数据是根据需要随机生成的。所以,完整的项目,的输入只需要模特和服装款式即可,也就是说可以实现给个人物和一件衣服就给实现换装。
再看,
在这里插入图片描述
看具体情况:通过photoshop的拾色器可以直观看到数据的值如下(label是灰度图):
在这里插入图片描述
背景的亮度L:0;面部的亮度L:9,左胳膊的亮度L:10,右胳膊的亮度L:8,上衣衣服位置的L:2。
把肢体图像分割出精确部分,用不同的亮度表示,到时候换衣服就有边界了。学姿势,纹理和褶皱等也有边界。

注:名字带mask的三张掩码图,黑色区域亮度0,白色区域亮度为100,它们没有实际意义,可用于增加噪声,让模型稳定性好一些(我是这样理解的,因为训练的时候中间结果也有损失函数的backforword).

ok,以上就是输入数据,理清了没?
下面分析怎么使用的,以及后面模型是怎么组合的。

3. 输入配置

怎么读取数据集,生成模型输入需要的数据?项目写了一个配置options,专用于对数据集目录信息,训练测试信息和超参数进行配置的文件。
在这里插入图片描述
test.py文件中:

 opt = TrainOptions().parse()

ctrl+鼠标左键点击TestOptions,找到opt对象的具体内容:

class TestOptions(BaseOptions):
    def initialize(self):
    BaseOptions.initialize(self)
    ......

ctrl+鼠标左键点击BaseOptions

class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self):
    	.....    

TestOptions类的initialize函数系重写,但还是调用了BaseOptions.initialize(self)的,所以BaseOptions.initialize的数据也包含了的。根据训练和测试所需要的数据不同,控制生成数据集和其他超参数。

4. 数据集处理详细

生成数据迭代器,同其他pytorch自制数据集相差不大。重点文件时aligned_dataset.py
在这里插入图片描述
具体一起来看看:
test.py文件中:

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

CreateDataLoader是封装数据迭代器的类:
点进去看一眼

def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt)
    return data_loader

具体内容在CustomDatasetDataLoader

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self.dataloader

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

发现还封装了一层,数据是:self.dataset = CreateDataset(opt):
发现还有函数封装,

def CreateDataset(opt):
    dataset = None
    from data.aligned_dataset import AlignedDataset
    dataset = AlignedDataset()

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset

注意dataset.initialize(opt),封装数据过程中动不动都在初始化。
继续AlignedDataset,ctrl+鼠标左键点击AlignedDataset

class AlignedDataset(BaseDataset):
    def initialize(self, opt):
    ......
    def __getitem__(self, index):
    ......

找到,def __getitem__(self, index):,看到它是不是很眼熟了,就是pytorch生成batch数据可迭代数据。这个类继承于BaseDataset,父类有transform等方法:

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass

def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.resize_or_crop == 'resize_and_crop':
        new_h = new_w = opt.loadSize            
    elif opt.resize_or_crop == 'scale_width_and_crop':
        new_w = opt.loadSize
        new_h = opt.loadSize * h // w

    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
    
    #flip = random.random() > 0.5
    flip = 0
    return {
    
    'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
    transform_list = []
    if 'resize' in opt.resize_or_crop:
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, method))   
    elif 'scale_width' in opt.resize_or_crop:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
        osize = [256,192]
        transform_list.append(transforms.Scale(osize, method))  
    if 'crop' in opt.resize_or_crop:
        transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

    if opt.resize_or_crop == 'none':
        base = float(2 ** opt.n_downsample_global)
        if opt.netG == 'local':
            base *= (2 ** opt.n_local_enhancers)
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def normalize():    
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size        
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img    
    w = target_width
    h = int(target_width * oh / ow)    
    return img.resize((w, h), method)

def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):        
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

既然找到数据了,就看看AlignedDataset的数据生成方式吧。
首先是初始化,

def initialize(self, opt):
	......

到这里还记得输入数据都有哪些吗?
待穿服装:color,轮廓edge;
模特:img,pose关键点,模特分割数据label;
掩码:两个黑背景掩码,一个白背景的掩码。

 dir_C = '_color'
 self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
 self.C_paths = sorted(make_dataset(self.dir_C))
 self.CR_paths = make_dataset(self.dir_C)
 dir_E = '_edge'
 self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
 self.E_paths = sorted(make_dataset(self.dir_E))
 self.ER_paths = make_dataset(self.dir_E)

dir_B =  '_img'
 self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
 self.B_paths = sorted(make_dataset(self.dir_B))
  self.BR_paths = sorted(make_dataset(self.dir_B))
  #  pose的关键点名称和img模特命名相差不大:pose_name = B_path.replace('.jpg', '_keypoints.json').replace('test_img', 'test_pose')               
 dir_A = '_label'
 self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
 self.A_paths = sorted(make_dataset(self.dir_A))
 self.AR_paths = make_dataset(self.dir_A)

发现初始化就是把大家的地址放到明面上了,有的还排好了序,def __getitem__(self, index):函数取数据就方便多了。
下面两个函数,就是具体生成路径字典或列表的:

def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    f = dir.split('/')[-1].split('_')[-1]
    print(dir, f)
    dirs = os.listdir(dir)
    for img in dirs:
        path = os.path.join(dir, img)
        # print(path)
        images.append(path)
    return images

def build_index(self, dirs):
    for k, dir in enumerate(dirs):
        name = dir.split('/')[-1]
        name = name.split('-')[0]

        # print(name)
        for k, d in enumerate(dirs[max(k - 20, 0):k + 20]):
            if name in d:
                if name not in self.diction.keys():
                    self.diction[name] = []
                    self.diction[name].append(d)
                else:
                    self.diction[name].append(d)

到这里了,就最后看一眼,生成的数据长什么样吧!

def __getitem__(self, index):
	......
	 if self.opt.isTrain:
	     input_dict = {
    
    'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor,
	                          'path': A_path, 'path_ref': AR_path,
	                          'edge': E_tensor, 'color': C_tensor, 'mask': M_tensor, 'colormask': MC_tensor,
	                          'pose': P_tensor, 'name': name
	                          }
	 else:
	     input_dict = {
    
    'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
	
	return input_dict

对,就是他input_dict,返回的字典。

注意:原作者共享的代码中,有这些后缀的文件夹,如图,我并没有相应后缀名文件,就配置训练模式,来测试数据了。
在这里插入图片描述

5.模型结构

在这里插入图片描述
继续看test.py文件,到模型板块

model = create_model(opt)

训练和测试使用的输入数据有区别的,训练的时候,除了将带穿衣服和模特的数据输入以外,还需要将穿好的结果输入,在最后面模型输出相比较,得出损失函数。

def create_model(opt):
    if opt.model == 'pix2pixHD':
        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
        if opt.isTrain:
            model = Pix2PixHDModel()
        else:
            model = InferenceModel()

    model.initialize(opt)
    if opt.verbose:
        print("model [%s] was created" % (model.name()))

    if opt.isTrain and len(opt.gpu_ids):
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)

    return model

测试model = InferenceModel()类实例也是继承的Pix2PixHDModel,重写了前向传播函数forward.

class InferenceModel(Pix2PixHDModel):
    def forward(self, inp):
        label = inp
        return self.inference(label)

我们直接去看Pix2PixHDModel:

pix2pixHD可以实现高分辨率图像生成和图片的语义编辑。
对于一个生成对抗网络(GAN),学习的关键就是理解生成器、判别器和损失函数这三部分。
pix2pixHD的生成器和判别器都是多尺度的,损失函数由GAN loss、Feature matching loss和Content loss组成。

class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'
	def initialize(self, opt):
		BaseModel.initialize(self, opt)
		......
        with torch.no_grad():
            self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()
            self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval()
            self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval()
            self.G = networks.define_Refine(24, 3, self.gpu_ids).eval()		
......
	def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):
	......
   	

代码贴太多了,容易看不清重点。我就只贴些关键的,帮助理清楚整个网络的脉络。
Pix2PixHDModel继承了BaseModelBaseModel类有初始化函数,save_network函数,load_network函数。就没什么可以看的了。

from . import networks

networks.py文件中写了多种网络的具体结构,用类封装:
在这里插入图片描述
pix2pixHD模型类,初始化的时候,会对需要用到的网络进行初始化:
在这里插入图片描述
然后,在forward函数里面,是组合网络使用的方法和顺序,以及哪些地方需要计算损失来约束网络。

    def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):
        # Encode Inputs
        #ipdb.set_trace()
        input_label,masked_label,all_clothes_label= self.encode_input(label,clothes_mask,all_clothes_label)
        #ipdb.set_trace()
        arm1_mask=torch.FloatTensor((label.cpu().numpy()==11).astype(np.float)).cuda()
        arm2_mask=torch.FloatTensor((label.cpu().numpy()==13).astype(np.float)).cuda()
        pre_clothes_mask=torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
        clothes=clothes*pre_clothes_mask
        ......

forward函数的输入数据是:label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, mask.
我们输入的数据:input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}做了一点预处理的(加入高斯噪声,wash the label):

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

......
 for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
        	mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))
            mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))
            img_fore = data['image'] * mask_fore
            img_fore_wc = img_fore * mask_fore
            all_clothes_label = changearm(data['label'])

            ############## 模型向前传播 ######################
            losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(
                Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),
                Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()), 
                Variable(data['image'].cuda()),
                Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))

data里就是一次迭代获取dataset里的一个batch的数据。每个data都是input_dict样子,是字典。
mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int)),将datalabel对应的数据取出来处理(得到mask_clothes为区域分割数据label同尺寸的掩码图像,里面的值,原图等于4的是1,其他全0,就是把衣服区域取出来);label是什么还记得吗?看下图:
在这里插入图片描述
看:mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))(得到mask_fore 为区域分割数据label同尺寸的掩码图像,里面的值,原图大于1的是1,其他全0)
img_fore = data['image'] * mask_fore他们相乘,就是在抠图,去掉背景得到img_fore
all_clothes_label = changearm(data['label']):调用changearm函数(变胳膊区域):

    def changearm(old_label):
        label = old_label
        arm1 = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.int))
        arm2 = torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.int))
        noise = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.int))
        label = label * (1 - arm1) + arm1 * 4
        label = label * (1 - arm2) + arm2 * 4
        label = label * (1 - noise) + noise * 4
        return label

changearm函数将模特区域分割数据label的左右胳膊取出来( == 11, == 13),把整幅噪声也找出来,将他们的值变成4.就像下图,把手也和衣服化为一个区域了。
在这里插入图片描述

############## Forward Pass ######################
losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(
                Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),
                Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()), 
                Variable(data['image'].cuda()),
                Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))

所以,输入的量有,label–模特区域分割数据;edge–衣服轮廓;img_fore–去背景模特图像;mask_clothes–模特正穿着的衣服的掩码区域,color–衣服;all_clothes_label–模特区域分割label将胳膊手融入衣服的区域分割图像;image–模特;pose–模特关键点;mask_fore–模特区域;
对比看一看,模型中形参叫什么:

def forward(self,label,pre_clothes_mask,img_fore,
				clothes_mask,clothes,all_clothes_label,
				real_image,pose,grid,mask):
       

那就进入pix2pixHD模型的forward函数:

# Encode Inputs
input_label, masked_label, all_clothes_label = self.encode_input(label, clothes_mask, all_clothes_label)

对输入数据编码处理:

    def encode_input(self, label_map, clothes_mask, all_clothes_label):

        size = label_map.size()
        oneHot_size = (size[0], 14, size[2], size[3])
        input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)

        masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        masked_label = masked_label.scatter_(1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0)

        c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        c_label = c_label.scatter_(1, all_clothes_label.data.long().cuda(), 1.0)

        input_label = Variable(input_label)

        return input_label, masked_label, c_label

手动实现one_hot 时,关于scatter_()函数: scatter_()函数有三个参数 scatter_(dim, index, src)

  1. dim指的是在哪个维度进行索引
  2. index指的是:用来进行索引的tensor
  3. src指scatter的源元素,可以是一个标量也可以是一个张量。

一句话解释上面的scatter:
input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
既input_label.scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填进input_label中。

继续net_G1(conditional GAN):

 G1_in = torch.cat([pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape)], dim=1)
 arm_label = self.G1.refine(G1_in)
 arm_label = self.sigmoid(arm_label)
 CE_loss = self.cross_entropy2d(arm_label, (label * (1 - clothes_mask)).transpose(0, 1)[0].long()) * 10

在这里插入图片描述
直观了吧,看网络G1只有一个输出arm_label。在训练中,就是模特换好新衣服后的分割图,与网络输出做损失反向传输。(模特原来是长袖,后面要穿短袖,自然胳膊是重点。脖子似乎没有人关心高领和低领的问题,待改进)。测试时直接生成要的穿color款衣服的模特分割图。

armlabel_map = generate_discrete_label(arm_label.detach(), 14, False)
dis_label = generate_discrete_label(arm_label.detach(), 14)

生成离散标签函数的输入是G1网络的输出结果arm_label

def generate_discrete_label(inputs, label_nc, onehot=True, encode=True):
    pred_batch = []
    size = inputs.size()
    for input in inputs:
        input = input.view(1, label_nc, size[2], size[3])
        pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
        pred_batch.append(pred)

    pred_batch = np.array(pred_batch)
    pred_batch = torch.from_numpy(pred_batch)
    label_map = []
    for p in pred_batch:
        p = p.view(1, 256, 192)
        label_map.append(p)
    label_map = torch.stack(label_map, 0)
    if not onehot:
        return label_map.float().cuda()
    size = label_map.size()
    oneHot_size = (size[0], label_nc, size[2], size[3])
    input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
    input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)

    return input_label

上面要是看不明白出入输出变化,可以把结果或者结果的shape打印出来,对比着看。

继续net_G2:

G2_in = torch.cat([pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape)], 1)
fake_cl = self.G2.refine(G2_in)
fake_cl = self.sigmoid(fake_cl)
CE_loss += self.BCE(fake_cl, clothes_mask) * 10

在这里插入图片描述
G2的输入,是G1的输出+Pose+color+edge+noise组合输入,输出为模特穿上新衣后衣服的轮廓。训练的时候,是模特穿上新的衣服数据的衣服轮廓与G2的输出做损失,反向传播的。测试时,G2输出模特换上新衣的轮廓数据,此时,还没有图案和纹理的变化。
损失函数BCE参考:https://blog.csdn.net/qq_22210253/article/details/85222093

继续:

fake_cl_dis = torch.FloatTensor((fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
fake_cl_dis = morpho(fake_cl_dis, 1, True)
def morpho(mask, iter, bigger=True):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    new = []
    for i in range(len(mask)):
        tem = mask[i].cpu().detach().numpy().squeeze().reshape(256, 192, 1) * 255
        tem = tem.astype(np.uint8)
        if bigger:
            tem = cv2.dilate(tem, kernel, iterations=iter)
        else:
            tem = cv2.erode(tem, kernel, iterations=iter)
        tem = tem.astype(np.float64)
        tem = tem.reshape(1, 256, 192)
        new.append(tem.astype(np.float64) / 255.0)
    new = np.stack(new)
    new = torch.FloatTensor(new).cuda()
    return new

detach(): 神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,torch.tensor.detach()和torch.tensor.detach_()函数来切断一些分支的反向传播。

cv2.getStructuringElement( ) 返回指定形状和尺寸的结构元素。
函数的第一个参数表示内核的形状,有三种形状可以选择。
矩形:MORPH_RECT;
交叉形:MORPH_CROSS;
椭圆形:MORPH_ELLIPSE;
第二和第三个参数分别是内核的尺寸以及锚点的位置。一般在调用erode以及dilate函数之前,先定义一个Mat类型的变量来获得getStructuringElement函数的返回值: 对于锚点的位置,有默认值Point(-1,-1),表示锚点位于中心点。element形状唯一依赖锚点位置,其他情况下,锚点只是影响了形态学运算结果的偏移。

cv2.erode()腐蚀:将前景物体变小,理解成将图像断开裂缝变大(在图片上画上黑色印记,印记越来越大)
dst = cv.erode(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])

cv2.dilate()膨胀:将前景物体变大,理解成将图像断开裂缝变小(在图片上画上黑色印记,印记越来越小)
dst = cv2.dilate(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])

numpy.stack(arrays, axis=0)
沿着新轴连接数组的序列。
axis参数指定新轴在结果尺寸中的索引。例如,如果axis=0,它将是第一个维度,如果axis=-1,它将是最后一个维度。
参数: 数组:array_like的序列每个数组必须具有相同的形状。axis:int,可选输入数组沿其堆叠的结果数组中的轴。
返回: 堆叠:ndarray堆叠数组比输入数组多一个维。

new_arm1_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda()
new_arm2_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda()
fake_cl_dis = fake_cl_dis * (1 - new_arm1_mask) * (1 - new_arm2_mask)
fake_cl_dis *= mask_fore

arm1_occ = clothes_mask * new_arm1_mask
arm2_occ = clothes_mask * new_arm2_mask
bigger_arm1_occ = morpho(arm1_occ, 10)
bigger_arm2_occ = morpho(arm2_occ, 10)
arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask
arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask
armlabel_map *= (1 - new_arm1_mask)
armlabel_map *= (1 - new_arm2_mask)
armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11
armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13
armlabel_map *= (1 - fake_cl_dis)
dis_label = encode(armlabel_map, armlabel_map.shape)
fake_c, warped, warped_mask, warped_grid = self.Unet(clothes, fake_cl_dis, pre_clothes_mask, grid)
mask = fake_c[:, 3, :, :]
mask = self.sigmoid(mask) * fake_cl_dis
fake_c = self.tanh(fake_c[:, 0:3, :, :])
fake_c = fake_c * (1 - mask) + mask * warped
skin_color = self.ger_average_color((arm1_mask + arm2_mask - arm2_mask * arm1_mask),
                                    (arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image)
occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask + clothes_mask)) * (
            1 - bigger_arm2_occ * (arm2_mask + arm1_mask + clothes_mask))
img_hole_hand = img_fore * (1 - clothes_mask) * occlude * (1 - fake_cl_dis)
self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()

前面的G1,G2的网络我没有展开。后面会专门分析网络里面的组成,和输入输出等细节。在这里,我们溯源一下这个Unet:

def define_UnetMask(input_nc, gpu_ids=[]):
    netG = UnetMask(input_nc, output_nc=4)
    netG.cuda(gpu_ids[0])
    netG.apply(weights_init)
    return netG

Unet来源于UnetMask:

class UnetMask(nn.Module):
    def __init__(self, input_nc, output_nc=3):
        super(UnetMask, self).__init__()
        self.stn = STNNet()
        nl = nn.InstanceNorm2d
        self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
                                     nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))

        ......

    def forward(self, input, refer, mask, grid):
        input, warped_mask, rx, ry, cx, cy, grid = self.stn(input, torch.cat([mask, refer, input], 1), mask, grid)
        # print(input.shape)

        conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1))
        ......
        conv9 = self.conv9(torch.cat([conv1, up9], 1))
        return conv9, input, warped_mask, grid

UnetMask有一个特殊的网络层STNNet:

class STNNet(nn.Module):

    def __init__(self):
        super(STNNet, self).__init__()
        range = 0.9
        r1 = range
        r2 = range
        grid_size_h = 5
        grid_size_w = 5

        assert r1 < 1 and r2 < 1  # if >= 1, arctanh will cause error in BoundedGridLocNet
        target_control_points = torch.Tensor(list(itertools.product(
            np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)),
            np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)),
        )))
        Y, X = target_control_points.split(1, dim=1)
        target_control_points = torch.cat([X, Y], dim=1)
        self.target_control_points = target_control_points
        # self.get_row(target_control_points,5)
        GridLocNet = {
    
    
            'unbounded_stn': UnBoundedGridLocNet,
            'bounded_stn': BoundedGridLocNet,
        }['bounded_stn']
        self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points)

        self.tps = TPSGridGen(256, 192, target_control_points)

    def get_row(self, coor, num):
        for j in range(num):
            sum = 0
            buffer = 0
            flag = False
            max = -1
            for i in range(num - 1):
                differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2
                if not flag:
                    second_dif = 0
                    flag = True
                else:
                    second_dif = torch.abs(differ - buffer)

                buffer = differ
                sum += second_dif
            print(sum / num)

    def get_col(self, coor, num):
        for i in range(num):
            sum = 0
            buffer = 0
            flag = False
            max = -1
            for j in range(num - 1):
                differ = (coor[(j + 1) * num + i, :] - coor[j * num + i, :]) ** 2
                if not flag:
                    second_dif = 0
                    flag = True
                else:
                    second_dif = torch.abs(differ - buffer)

                buffer = differ
                sum += second_dif
            print(sum)

    def forward(self, x, reference, mask, grid_pic):
        batch_size = x.size(0)
        source_control_points, rx, ry, cx, cy = self.loc_net(reference)
        source_control_points = (source_control_points)
        # print('control points',source_control_points.shape)
        source_coordinate = self.tps(source_control_points)
        grid = source_coordinate.view(batch_size, 256, 192, 2)
        # print('grid size',grid.shape)
        transformed_x = grid_sample(x, grid, canvas=0)
        warped_mask = grid_sample(mask, grid, canvas=0)
        warped_gpic = grid_sample(grid_pic, grid, canvas=0)
        return transformed_x, warped_mask, rx, ry, cx, cy, warped_gpic

在这里插入图片描述
U_net不仅含有简单的神经网络层,还有STN网络层(spatial transform network,空间变换网络)。

前面还完成了step3的过程:
在这里插入图片描述
以上,G1,G2,Unet,step3完成后,才是G3网络。G3的输入:

G_in = torch.cat([img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape)], 1)
fake_image = self.G.refine(G_in.detach())
fake_image = self.tanh(fake_image)

在这里插入图片描述
返回所有输出结果:

return [self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real, loss_D_fake), fake_image,
                clothes, arm_label, L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid]

到这里,算是解释完了。


总结

在这里插入图片描述
总结,流程图最左侧三个灰蓝色模型,为基本输入color(服装),和img(模特)的预处理。本项目中,已经提供了它们三个的输出(未提供相关模型,不是重点),作为输入。
所以整个网络的关键输入就是:edge,color;pose,img,label.
整个网络就是G1+G2+Unet+G3构成;中间数据输入做了些些掩码和校正外,没有其他结构了。
小伙伴们,是不是弄清楚?

下载

最后,分享本blog需要用到的资源,前面贴了代码地址;训练好的上衣的G1,G2,Unet,G3,四个网络(latest_net_U.pth,latest_net_G1.pth,latest_net_G2.pth,latest_net_G.pth)下载地址:链接:https://pan.baidu.com/s/1uJVIUijvFjQz_geCBpizPQ
提取码:aqo8
测试数据地址:链接:https://pan.baidu.com/s/1jUaSii_-V965BUIFlbXjKg
提取码:ytkg

本文主要是自己巩固一下学习内容,下一步可以开始换其他裤子,裙子,大衣之类的了。

谨以此文与大家共勉!如果你觉得对你有用,请给点个赞

参考文献

感谢原作者:
[1] Yang, Han and Zhang, Ruimao and Guo, Xiaobao and Liu, Wei and Zuo, Wangmeng and Luo, Ping.Towards Photo-Realistic Virtual Try-On by Adaptively Generating-Preserving Image Content,IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),June,2020.

猜你喜欢

转载自blog.csdn.net/beauthy/article/details/113698320