图像修复实例解析(三)

本篇基于SIGGRAPH 2017 (ACM ToG)的 Globally and Locally Consistent Image Completion 

(CE中加入Global+Local两个判别器的改进), 

proj:http://hi.cs.waseda.ac.jp/~iizuka/projects/completion/

Github代码:

1)https://github.com/satoshiiizuka/siggraph2017_inpainting​github.com

2)https://github.com/shinseung428/GlobalLocalImageCompletion_TF

  其中第二个实现稍微不同于原论文。但是展示效果非常棒。第一个是官方代码。

因此,我这边主要以2)中的代码解析为例。先看看readme.

Tensorflow implementation of Globally and Locally Consistent Image Completion on celebA dataset.

因此数据集采用的是celebA自行下载即可。当然也可以自己准备数据集,后面时间充足的情况下,我准备利用亚洲人脸重新训练此模型。

What's different from the paper

  • smaller image input size (64x64)
  • smaller patch sizes
  • less number of training iteration (500,000 iterations in the paper)
  • Adam optimizer used instead of Adadelta

Requirements

  • Opencv 2.4
  • Tensorflow 1.4

Folder Setting

-data
  -img_align_celeba
    -img1.jpg
    -img2.jpg
    -...

Train

$ python train.py 

To continue training

$ python train.py --continue_training=True

Test

Download pretrained weights

$ python download.py
$ python test.py --img_path=./data/test/test_img.jpg

简单如上,dataset直接解压后放到指定目录,比如我的直接放到了

/home/gavin/Dataset/,那么训练的时候加上参数即可:
 python3 train.py --continue_training=True --data /home/gavin/Dataset/

由于原版是Python2 版本,可能有些写法需要修改,我这边小修小改已经改成了python3版本,

训练截图如下:

最后是test,代码如下:

import tensorflow as tf
import numpy as np
from config import *
from network import *


drawing = False # true if mouse is pressed
ix,iy = -1,-1
color = (255,255,255)
size = 10

def erase_img(args, img):

    # mouse callback function
    def erase_rect(event,x,y,flags,param):
        global ix,iy,drawing

        if event == cv2.EVENT_LBUTTONDOWN:
            drawing = True
            if drawing == True:
                # cv2.circle(img,(x,y),10,(255,255,255),-1)
                cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
                cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)
            
        elif event == cv2.EVENT_MOUSEMOVE:
            if drawing == True:
                # cv2.circle(img,(x,y),10,(255,255,255),-1)
                cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
                cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)
        elif event == cv2.EVENT_LBUTTONUP:
            drawing = False
            # cv2.circle(img,(x,y),10,(255,255,255),-1)
            cv2.rectangle(img,(x-size,y-size),(x+size,y+size),color,-1)
            cv2.rectangle(mask,(x-size,y-size),(x+size,y+size),color,-1)


    cv2.namedWindow('image')
    cv2.setMouseCallback('image',erase_rect)
    #cv2.namedWindow('mask')
    #cv2.setMouseCallback('mask',erase_rect)
    mask = np.zeros(img.shape)
    

    while(1):
        img_show = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        cv2.imshow('image',img_show)
        k = cv2.waitKey(1) & 0xFF
        if k == 27:#esc ord('q')
            break

    test_img = cv2.resize(img, (args.input_height, args.input_width))/127.5 - 1
    test_mask = cv2.resize(mask, (args.input_height, args.input_width))/255.0
    #fill mask region to 1
    test_img = (test_img * (1-test_mask)) + test_mask

    cv2.destroyAllWindows()
    return np.tile(test_img[np.newaxis,...], [args.batch_size,1,1,1]), np.tile(test_mask[np.newaxis,...], [args.batch_size,1,1,1])




def test(args, sess, model):
    #saver  
    saver = tf.train.Saver()        
    last_ckpt = tf.train.latest_checkpoint(args.checkpoints_path)
    saver.restore(sess, last_ckpt)
    ckpt_name = str(last_ckpt)
    print("Loaded model file from " + ckpt_name)
    
    img = cv2.imread(args.img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    orig_test = cv2.resize(img, (args.input_height, args.input_width))/127.5 - 1
    orig_test = np.tile(orig_test[np.newaxis,...],[args.batch_size,1,1,1])
    orig_test = orig_test.astype(np.float32)

    orig_w, orig_h = img.shape[0], img.shape[1]
    test_img, mask = erase_img(args, img)
    test_img = test_img.astype(np.float32)
    
    print("Testing ...")
    res_img = sess.run(model.test_res_imgs, feed_dict={model.single_orig:orig_test,
                                                       model.single_test:test_img,
                                                       model.single_mask:mask})


    orig = cv2.resize((orig_test[0]+1)/2, (orig_h//2, orig_w//2) )
    test = cv2.resize((test_img[0]+1)/2, (orig_h//2, orig_w//2))
    recon = cv2.resize((res_img[0]+1)/2, (orig_h//2, orig_w//2))

    res = np.hstack([orig,test,recon])
    res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)

    ''' 
    orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
    cv2.imshow("orig", orig)
    test = cv2.cvtColor(test, cv2.COLOR_BGR2RGB)
    cv2.imshow("test", test)
    recon = cv2.cvtColor(recon, cv2.COLOR_BGR2RGB)
    cv2.imshow("recon", recon)
    '''

    cv2.imshow("result", res)
    cv2.waitKey()
    cv2.destroyAllWindows()

    print("Done.")


def main(_):
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True
    
    with tf.Session(config=run_config) as sess:
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
        model = network(args)
        print('Start Testing...')
        test(args, sess, model)

main(args)

效果展示:

运行后,出现原图,然后鼠标点击可以涂mask,最后按住Esc完成操作,程序自动进行修复。

Alt text 

猜你喜欢

转载自blog.csdn.net/Gavinmiaoc/article/details/83826750