Keras 框架构建

import os
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam,SGD
import cfg
#from network import East
from network_densenet import East

from data_generator import gen
import tensorflow as tf
from keras import backend as K
from keras import models
from losses import quad_loss
#from tensorflow.keras.callbacks import TensorBoard
config = tf.ConfigProto()
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "1"
sess = tf.Session(config=config)
K.set_session(sess)

#模型载入与参数模块
east = East()               #返回一个Model类,Model(inputs=self.input_img, outputs=east_detect) 
east_network = east.east_network()   #1*1的卷积操作
east_network.summary()   #打印
if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path):  #cfg为参数文件
    print('加载模型成功')
    east_network.load_weights(cfg.saved_model_weights_file_path)

#载入模型且换参数模块
# models=models.load_model(cfg.load_model_file_path,custom_objects={'quad_loss': quad_loss})
# models.summary()
# east_network=models

#训练器配置模块
east_network.compile(loss=quad_loss, optimizer=Adam(lr=cfg.lr,    #配置训练模型
                                                    # clipvalue=cfg.clipvalue,
                                                    decay=cfg.decay))
#训练器训练初始化模块
east_network.fit_generator(generator=gen(),
                           steps_per_epoch=cfg.steps_per_epoch,
                           epochs=cfg.epoch_num,
                           validation_data=gen(is_val=True),
                           validation_steps=cfg.validation_steps,
                           verbose=1,
                           initial_epoch=cfg.initial_epoch,
                           callbacks=[
                               EarlyStopping(patience=cfg.patience, verbose=1),
                               ModelCheckpoint(filepath=cfg.model_weights_path, #每个周期保存一次
                                               save_best_only=True,
                                               save_weights_only=True,
                                               verbose=1)])
east_network.save(cfg.saved_model_file_path)
east_network.save_weights(cfg.saved_model_weights_file_path)

#生成器配置函数:generator=gen(),返回一个迭代器(图片,bacth*标签)

#训练生成器,如何读取训练图片并进行预处理
import os
import numpy as np
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import cfg
import imghdr

def gen(batch_size=cfg.batch_size, is_val=False):
    img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size
    x = np.zeros((batch_size, img_h, img_w, cfg.num_channels), dtype=np.float32)
    #import pdb; pdb.set_trace()
    pixel_num_h = img_h // cfg.pixel_size
    pixel_num_w = img_w // cfg.pixel_size
    y = np.zeros((batch_size, pixel_num_h, pixel_num_w, 7), dtype=np.float32)
    if is_val:
        with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val:  #txt文档
            f_list = f_val.readlines()
    else:
        with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train:
            f_list = f_train.readlines()
    while True:
        for i in range(batch_size):
            # random gen an image name
            random_img = np.random.choice(f_list)
            img_filename = str(random_img).strip().split(',')[0]
            # load img and img anno
            img_path = os.path.join(cfg.data_dir,
                                    cfg.train_image_dir_name,
                                    img_filename)
            try:    #处理图片异常时
                imghdr.what(img_path)
            except:
                print('88888888888888888888888888888888888')
                print('img_filename',img_path)
                continue
            img = image.load_img(img_path)
            img = image.img_to_array(img)                   #加载图片转化为数组
            x[i] = preprocess_input(img, mode='tf')         #预处理,归一化
            gt_file = os.path.join(cfg.data_dir,
                                   cfg.train_label_dir_name_gt,
                                   img_filename[:-4] + '_gt.npy')  #经过label,py文件处理过得gt.npy文件
            y[i] = np.load(gt_file)   #数组标签
        yield x, y

预测阶段:

import argparse
import time
import numpy as np
from PIL import Image, ImageDraw
import cfg_lei as cfg
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152  按照PCI_BUS_ID顺序从0开始排列GPU设备
os.environ["CUDA_VISIBLE_DEVICES"] = ""       #设置当前使用的GPU设备
import datetime
import tensorflow as tf
from keras import backend as K
import keras
from keras.models import Model
gpu_list = ""
gpu_opts = tf.GPUOptions(allow_growth = True, visible_device_list = gpu_list)
#设置CPU模式
config = tf.ConfigProto(
    allow_soft_placement=True,              #允许TF自动分配设备
    gpu_options=gpu_opts,
    intra_op_parallelism_threads=40,           #op内部并行计算的线程个数
    inter_op_parallelism_threads=40,           #op之间的并行计算
    device_count={'CPU': 40, 'GPU': 0})
sess = tf.Session(config=config)
K.set_session(sess)
sess.run(tf.global_variables_initializer())

from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
from keras.models import load_model
from label import point_inside_of_quad
from network import East
from preprocess import resize_image, resize_and_padding
from nms import nms
import datetime

def sigmoid(x):
    """`y = 1 / (1 + exp(-x))`"""
    return 1 / (1 + np.exp(-x))


def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s):
    geo /= [scale_ratio_w, scale_ratio_h]
    p_min = np.amin(geo, axis=0)
    p_max = np.amax(geo, axis=0)
    min_xy = p_min.astype(int)
    max_xy = p_max.astype(int) + 2
    sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy()
    for m in range(min_xy[1], max_xy[1]):
        for n in range(min_xy[0], max_xy[0]):
            if not point_inside_of_quad(n, m, geo, p_min, p_max):
                sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255
    sub_im = image.array_to_img(sub_im_arr, scale=False)
    sub_im.save(img_path + '_subim%d.jpg' % s)


def predict(east_detect, img_path, pixel_threshold, quiet=False):
    img = image.load_img(img_path)
    #d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
    #img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
    d_wight, d_height = cfg.max_predict_img_size, cfg.max_predict_img_size
    img = resize_and_padding(img, cfg.max_predict_img_size)    #按原始比例放缩
    img = image.img_to_array(img)
    img = preprocess_input(img, mode='tf')
    x = np.expand_dims(img, axis=0)
    #y = east_detect.predict(x)
    #get_output = K.function([east_detect.layers[0].input, K.learning_phase()], [east_detect.layers[-1].output])
    #y = get_output([x, 0])[0]
    net_inp = east_detect.layers[0].input
    net_out = east_detect.layers[-1].output
    y = sess.run(net_out, feed_dict={net_inp: x})

    y = np.squeeze(y, axis=0)
    y[:, :, :3] = sigmoid(y[:, :, :3])   #对top3进行激活(去负)
    #print('pielx',len(y[:, :, 0]))    #64 个像素的得分(是否含有字体)
    # print('head',y[:, :, 1])
    # print('end',y[:, :, 2])
    cond = np.greater_equal(y[:, :, 0], pixel_threshold)      #top0>=pixel_threshold:为包含字体,返回[[False False False,False False False]
    #print('cond',cond)
    activation_pixels = np.where(cond)        #每个像素和字体进行比较判断是否包含字体,返回True元素的所在行,返回Ture元素所在列
   # print('activation_pixelsactivation_pixels',activation_pixels[0])
   # print('activation_pixelsactivation_pixels',activation_pixels[1])
    start = datetime.datetime.now()
    quad_scores, quad_after_nms = nms(y, activation_pixels)     #NMS,side_vertex_pixel_threshold
    end = datetime.datetime.now()
    print('NMS', (end-start).total_seconds())
    with Image.open(img_path) as im:     #图上画框(img_path + '_act.jpg')和保存txt坐标(img_path[:-4] + '.txt')
        im_array = image.img_to_array(im.convert('RGB'))
        #d_wight, d_height = resize_image(im, cfg.max_predict_img_size)
        #scale_ratio_w = d_wight / im.width
        #scale_ratio_h = d_height / im.height
        #im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
        d_wight, d_height = cfg.max_predict_img_size, cfg.max_predict_img_size
        scale_ratio_w = d_wight / im.width
        scale_ratio_h = d_wight / im.width
        im = resize_and_padding(im, cfg.max_predict_img_size)
        quad_im = im.copy()

        start = datetime.datetime.now()
        draw = ImageDraw.Draw(im)
        for i, j in zip(activation_pixels[0], activation_pixels[1]):
            px = (j + 0.5) * cfg.pixel_size
            py = (i + 0.5) * cfg.pixel_size
            line_width, line_color = 1, 'red'
            if y[i, j, 1] >= cfg.side_vertex_pixel_threshold:   #头得分大于某阈值,才可能为头或尾
                if y[i, j, 2] < cfg.trunc_threshold:         #尾得分小于阈值,默认为开头
                    line_width, line_color = 2, 'yellow'
                elif y[i, j, 2] >= 1 - cfg.trunc_threshold:    #尾得分大于  1-阈值,默认为结尾
                    line_width, line_color = 2, 'green'
            draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
                       (px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
                       (px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
                       (px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
                       (px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)],
                      width=line_width, fill=line_color)
        im.save('/workdir/src/EastTrain/data/chenyu/data/pre_lei/' +img_path.split('/')[-1]+ '_act.jpg') #不经过NMS的实际像素得分图
        end = datetime.datetime.now()
        print('draw_act.jpg', (end-start).total_seconds())

        quad_draw = ImageDraw.Draw(quad_im)
        txt_items = []
        for score, geo, s in zip(quad_scores, quad_after_nms,
                                 range(len(quad_scores))):
            if np.amin(score) > 0:
                quad_draw.line([tuple(geo[0]),
                                tuple(geo[1]),
                                tuple(geo[2]),
                                tuple(geo[3]),
                                tuple(geo[0])], width=2, fill='red')
                if cfg.predict_cut_text_line:
                    cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array,
                                  img_path, s)
                rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
                rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
                txt_item = ','.join(map(str, rescaled_geo_list))
                txt_items.append(txt_item + '\n')
            elif not quiet:
                print('quad invalid with vertex num less than 4.')
        quad_im.save('/workdir/src/EastTrain/data/chenyu/data/pre_lei/' +img_path.split('/')[-1] + '_predict.jpg') #经过NMS的像素得分图
        if cfg.predict_write2txt and len(txt_items) > 0:
            with open('/workdir/src/EastTrain/data/chenyu/data/pre_lei/' +img_path.split('/')[-1] + '.txt', 'w') as f_txt:
                f_txt.writelines(txt_items)

'''
def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False):
    img = image.load_img(img_path)
    #d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
    #scale_ratio_w = d_wight / img.width
    #scale_ratio_h = d_height / img.height
    #img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
    d_wight, d_height = cfg.max_predict_img_size, cfg.max_predict_img_size
    scale_ratio_w = d_wight / img.width
    scale_ratio_h = d_wight / img.width
    img = resize_and_padding(img, cfg.max_predict_img_size)
    img = image.img_to_array(img)
    img = preprocess_input(img, mode='tf')
    x = np.expand_dims(img, axis=0)
    y = east_detect.predict(x)

    y = np.squeeze(y, axis=0)
    y[:, :, :3] = sigmoid(y[:, :, :3])
    cond = np.greater_equal(y[:, :, 0], pixel_threshold)
    activation_pixels = np.where(cond)
    quad_scores, quad_after_nms = nms(y, activation_pixels)

    txt_items = []
    for score, geo in zip(quad_scores, quad_after_nms):
        if np.amin(score) > 0:
            rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
            rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
            txt_item = ','.join(map(str, rescaled_geo_list))
            txt_items.append(txt_item + '\n')
        elif not quiet:
            print('quad invalid with vertex num less than 4.')
    if cfg.predict_write2txt and len(txt_items) > 0:
        with open(txt_path, 'w') as f_txt:
            f_txt.writelines(txt_items)
'''

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-p',
                        #default='demo/001.png',
                        help='image path')
    parser.add_argument('--dir', '-d',
                        default='/workdir/src/EastTrain/data/chenyu/data/test/',
                        help='image dir')
    parser.add_argument('--threshold', '-t',
                        default=cfg.pixel_threshold,
                        help='pixel activation threshold')
    return parser.parse_args()



if __name__ == '__main__':
    args = parse_args()
    img_path = args.path            #demo/001.png
    threshold = float(args.threshold)

    imgdir = args.dir
    imgnames = os.listdir(imgdir)     #demo/
    print('imgnames',imgnames)

    #east = East()
    #east_detect = east.east_network()
    #

    from keras.utils.generic_utils import CustomObjectScope
#   mobilenet有许多自定义的层,包括relu6,为了转换方便,可在load_model之前添加如下内容
#     with CustomObjectScope({'relu6': keras.applications.mobilenet.relu6, 'DepthwiseConv2D': keras.applications.mobilenet.DepthwiseConv2D}):
#
#         east_detect = load_model(cfg.saved_model_file_path, compile=False)
    east_detect = load_model(cfg.saved_model_file_path, compile=False)
  #  east_detect.load_weights(cfg.saved_model_weights_file_path_change)
    print(cfg.saved_model_file_path)

   # Model(inputs=, outputs=east_detect).summary()


    # t1 = datetime.datetime.now()
    if img_path:
        for xxxx in range(20):
            predict(east_detect, img_path, threshold)
    else:
        for name in imgnames:     #指定预测图片所在文件夹
            start = datetime.datetime.now()
            predict(east_detect, os.path.join(imgdir, name), threshold)    #pixel_threshold
            end = datetime.datetime.now()
            dis=(end-start).total_seconds()
            print(dis)
            with open('/workdir/src/EastTrain/data/chenyu/data/pre_lei/record.txt','a') as f:
                f.write(name)
                f.write('\n'+str(dis))




    # t2 = datetime.datetime.now()
  #  print("%d s, %d ms" % (int((t2-t1).seconds), int((t2-t1).microseconds / 1000)))

几个金字塔结构的上采样(conacat)融合部分,利用了递归

# coding=utf-8
from keras import Input, Model
#from keras.applications.vgg16 import VGG16
#from keras.applications.densenet import DenseNet121
from keras.layers import Concatenate, Conv2D, UpSampling2D, BatchNormalization
from shufftle import basemodel
import cfg_shuffle as cfg

"""
input_shape=(img.height, img.width, 3), height and width must scaled by 32.
So images's height and width need to be pre-processed to the nearest num that
scaled by 32.And the annotations xy need to be scaled by the same ratio 
as height and width respectively.
"""


class East:

    def __init__(self):
        densenet_layers = ['stage1','stage5','stage13']
        self.input_img = Input(name='input_img',
                               shape=(cfg.max_train_img_size, cfg.max_train_img_size, cfg.num_channels),
                               dtype='float32')

        densenet = basemodel(input_tensor=self.input_img)

        #if cfg.locked_layers:
        #    # locked first two conv layers
        #    locked_layers = [vgg16.get_layer('block1_conv1'),
        #                     vgg16.get_layer('block1_conv2')]
        #    for layer in locked_layers:
        #        layer.trainable = False
        self.f = [densenet.get_layer(densenet_layers[i-1]).output
                  for i in cfg.feature_layers_range]
        self.f.insert(0, None)
        #self.diff = cfg.feature_layers_range[0] - cfg.feature_layers_num

    def g(self, i):
        # i+diff in cfg.feature_layers_range
        #assert i + self.diff in cfg.feature_layers_range, \
        #    ('i=%d+diff=%d not in ' % (i, self.diff)) + \
        #    str(cfg.feature_layers_range)
        if i == cfg.feature_layers_num:
            bn = BatchNormalization()(self.h(i))
            return Conv2D(32, 3, activation='relu', padding='same')(bn)
        else:
            return UpSampling2D((2, 2))(self.h(i))

    def h(self, i):
        # i+diff in cfg.feature_layers_range
        #assert i + self.diff in cfg.feature_layers_range, \
        #    ('i=%d+diff=%d not in ' % (i, self.diff)) + \
        #    str(cfg.feature_layers_range)
        if i == 1:
            return self.f[i]
        else:
            print("=====", len(self.f), i)
            concat = Concatenate(axis=-1)([self.g(i - 1), self.f[i]])
            bn1 = BatchNormalization()(concat)
            conv_1 = Conv2D(128 // 2 ** (i - 2), 1,
                            activation='relu', padding='same',)(bn1)
            bn2 = BatchNormalization()(conv_1)
            conv_3 = Conv2D(128 // 2 ** (i - 2), 3,
                            activation='relu', padding='same',)(bn2)
            return conv_3

    def east_network(self):
        inside_score = Conv2D(1, 1, padding='same', name='inside_score'
                              )(self.g(cfg.feature_layers_num))            #是否包含字体
        side_v_code = Conv2D(2, 1, padding='same', name='side_vertex_code'
                             )(self.g(cfg.feature_layers_num))              #字头,字尾
        side_v_coord = Conv2D(4, 1, padding='same', name='side_vertex_coord'
                              )(self.g(cfg.feature_layers_num))            #两点坐标
        east_detect = Concatenate(axis=-1,
                                  name='east_detect')([inside_score,
                                                       side_v_code,
                                                       side_v_coord])
        return Model(inputs=self.input_img, outputs=east_detect)


if __name__ == '__main__':
    east = East()
    east_network = east.east_network()
    east_network.summary()

猜你喜欢

转载自blog.csdn.net/weixin_38740463/article/details/89841923