[Paper] Image text erasure based on GANs - 2010.EraseNet: End-to-End Text Removal in the Wild (open source)

pytorch official code: https://github.com/lcy0604/EraseNet
paper: 2010.EraseNet: End-to-End Text Removal in the Wild network disk extraction code: 0719

1. Image text removal effect

Figure 10 Removal of SCUT-EnsText real dataset

The first column is the original image with text, the second column is the removed label, and the remaining columns are different algorithm removal effects (pix2pix, scenenetextEraser, EnsNet, EraseNet in this article)
insert image description here

Figure 11 合成的Comparison of text and image removal effects in datasets

insert image description here

Figure 12 Comparing the removal effect with the inpanting method

insert image description here

2. Method overview

The model designs a two-stage coarse-to-fine (ha two-stage · coarse-to-refine generator network) 生成器network and a local global discriminator 鉴别器network (a local-global discriminator network.). (In this paper, the author improves SN-GAN and proposes an local-global SN-Patch-GAN architecture named
An extra 语义分割网络head integrated with the whole algorithm for perceiving text regions.

At the same time, with the help of external pre-trained VGG-16network extraction features, it is used to supervise the differences in the high-level semantics (discrepancies of high-level semantics) between the generated fake samples and the ground-truths. Figure 8
insert image description here
insert image description here
Discrimination architecture
insert image description here

Figure 9 Comparison of the effects of different algorithms

insert image description here

training details

单个NVIDIA 2080TI GPU, batch size =4
use

data set

SCUT-EnsText: Proposed and collected by South China University of Technology, see
Synthetic data for text localization in natural images proposed by the header code library in 2016 to synthesize data sets

3. Local 自己数据集Experimental Results

shopping map conversion

inference code


# -*- coding: utf-8 -*-
# @Time : 2023/7/6 20:36
# @Author : XyZeng
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image

from data.dataloader import ErasingData,ImageTransform
from models.sa_gan import STRnet2




parser = argparse.ArgumentParser()
parser.add_argument('--numOfWorkers', type=int, default=0,
                    help='workers for dataloader')
parser.add_argument('--modelsSavePath', type=str, default='',
                    help='path for saving models')
parser.add_argument('--logPath', type=str,
                    default='')
parser.add_argument('--batchSize', type=int, default=16)
parser.add_argument('--loadSize', type=int, default=512,
                    help='image loading size')
parser.add_argument('--dataRoot', type=str,
                    default='./')
parser.add_argument('--pretrained',type=str, default='./model.pth', help='pretrained models for finetuning')
parser.add_argument('--savePath', type=str, default='./output')
args = parser.parse_args()

cuda = torch.cuda.is_available()
if cuda:
    print('Cuda is available!')
    cudnn.benchmark = True


def visual(image):
    im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()
    Image.fromarray(im[0].astype(np.uint8)).show()

batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
dataRoot = args.dataRoot
savePath = args.savePath

# result_with_mask = savePath + 'WithMaskOutput/'
# result_straight = savePath + 'StrOuput/'

#import pdb;pdb.set_trace()
import torch.nn.functional as F
# if not os.path.exists(savePath):
os.makedirs(savePath,exist_ok=True)
# os.makedirs(result_with_mask,exist_ok=True)
# os.makedirs(savePath,exist_ok=True)




netG = STRnet2(3)
netG.load_state_dict(torch.load(args.pretrained))

#
if cuda:
    netG = netG.cuda()

for param in netG.parameters():
    param.requires_grad = False

print('OK!')

import time
start = time.time()
netG.eval()

ImgTrans=ImageTransform(args.loadSize)


def get_img_tensor(path):
    img = Image.open(path)
    # mask = Image.open(path)
    ori_img_size = img.size
    ### for data augmentation
    # Use Image.Resampling.NEAREST (0), Image.Resampling.LANCZOS (1), Image.Resampling.BILINEAR (2), Image.Resampling.BICUBIC (3), Image.Resampling.BOX (4) o
    img=img.convert('RGB').resize((args.loadSize,args.loadSize) ,2)
    inputImage = ImgTrans(img).unsqueeze(0)
    # mask = ImgTrans(mask.convert('RGB'))
    # inputImage = F.interpolate(inputImage, size=(512,512), mode='bilinear')  # Adjust size to 115
    print('inputImage',inputImage.size())
    print('ori_img_size',ori_img_size)
    return inputImage,ori_img_size


def torch_img_2_pil(gpu_tensor):

    # 将数值范围从[0, 1]映射到[0, 255]
    cpu_tensor = gpu_tensor.cpu().mul(255).add(0.5).clamp(0, 255)
    # 将CPU上的Tensor转换为NumPy数组
    numpy_array = cpu_tensor.squeeze().permute(1, 2, 0).to(torch.uint8).numpy()
    # 将NumPy数组转换为PIL图像
    pil_image = Image.fromarray(numpy_array)
    return pil_image


if __name__ == '__main__':

    # path=r'F:\code\23_0613_image_translate\EraseNet-master\example\all_images\118.jpg'
    # print('path',path)


    # inpur_dir=r'example\all_images'  # 改为'./你需要转换的图片目录'
    inpur_dir=r'F:\code\23_0613_image_translate\inpainting_datasets'  # 改为'./你需要转换的图片目录'
    for name in os.listdir(inpur_dir):
        path=os.path.join(inpur_dir,name)
        imgs,ori_img_size=get_img_tensor(path)

        if cuda:
            imgs = imgs.cuda()
            # masks = masks.cuda()
        '''
        看论文喝源码能发现5个输出的对应
        '''
        out1, out2, out3, g_images,mm = netG(imgs)
        g_image = g_images.data.cpu()
        mm = mm.data.cpu()

        # save_image(g_image_with_mask, result_with_mask+path[0])
        dir,name=os.path.split(path)
        out_path=os.path.join(savePath,name)
        mask_path= os.path.join(savePath,name+'_mask.png')

        # 将CPU上的Tensor转换为NumPy数组

        '''
        还原图片合适与大小
        '''
        numpy_array = g_image.numpy()

        # 将NumPy数组转换为PIL图像

        # visual(g_image)
        # pil_image = Image.fromarray(numpy_array.squeeze(0).transpose((1, 2, 0)).astype('uint8'))
        pil_image=torch_img_2_pil(g_image)
        result_img = pil_image.resize(ori_img_size)
        result_img.save(out_path)
        #src=Image.open(path)
        #src.save(mask_path)
        save_image(g_image, out_path)
        save_image(mm,mask_path)
        # print(out_path,mask_path)
        # break






Guess you like

Origin blog.csdn.net/imwaters/article/details/131812511