keras 版本PSEnet训练过程记录

1.由分步执行改成一个文件

训练文件

import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
session = tf.Session(config=config)
KTF.set_session(session)

import keras
from models.psenet import psenet

shape = (None,None,3)

inputs = keras.layers.Input(shape=shape)
output = psenet(inputs)
model  = keras.models.Model(inputs,output)
model.summary()

from keras.optimizers import Adam
from models.loss import build_loss
from models.metrics import build_iou,mean_iou
from keras.utils import multi_gpu_model


# parallel_model = multi_gpu_model(model,gpus=1)
parallel_model=model

adam = Adam(1e-3)


ious = build_iou([0,1],['bk','txt'])


parallel_model.compile(loss=build_loss,
              optimizer=adam,
              metrics=ious)


import config
from tool.generator import Generator

train_dir = config.MIWI_2018_TRAIN_LABEL_DIR
test_dir = config.MIWI_2018_TEST_LABEL_DIR
batch_size = 1
num_class =2
shape = (640,640)

gen_train = Generator(train_dir,batch_size = batch_size ,istraining=True,num_classes=num_class,mirror = False,reshape=shape)


gen_test = Generator(test_dir,batch_size = batch_size ,istraining=False,num_classes=num_class,
                     reshape=shape,mirror=False,scale=False,clip=False,trans_color=False)


from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
checkpoint = ModelCheckpoint(r'resent50-190422_BLINEAR-{epoch:02d}.hdf5',
                           save_weights_only=True)
tb = TensorBoard(log_dir='./logs')

print(gen_test.num_samples(),gen_train.num_samples())

res = parallel_model.fit_generator(gen_train,
                          steps_per_epoch =gen_train.num_samples()// batch_size,
                          epochs = 40,
                          validation_data=gen_test,
                          validation_steps =gen_test.num_samples()//batch_size,
                          verbose=1,
                          initial_epoch=0,
                          workers=4,
                          max_queue_size=16,
                          callbacks=[checkpoint,tb])


测试文件:

import keras
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.ConfigProto(device_count={'gpu':0})
# config.gpu_options.allow_growth=True
config.gpu_options.per_process_gpu_memory_fraction = 0.85
session = tf.Session(config=config)
KTF.set_session(session)

from models.psenet import psenet

shape = (None,None,3)

inputs = keras.layers.Input(shape=shape)
output = psenet(inputs)
model  = keras.models.Model(inputs,output)
model.summary()
model.load_weights('resent50-190219_BLINEAR-iou8604.hdf5')

import glob
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm

# dir = '/home/yang/Documents/data/ali/mtwi_2018_train/image_test'
dir = '/home/yang/Documents/data/ali/icpr_mtwi_task2/image_test'
saveimgdir="/home/yang/Documents/model/detect/PSENET-keras/imgs/result/img"
savetxtdir="/home/yang/Documents/model/detect/PSENET-keras/imgs/result/txt"

imagesfile = glob.glob(os.path.join(dir,'*.jpg'))
MIN_LEN = 640
MAX_LEN = 1024

from tool.utils import ufunc_4, scale_expand_kernels, fit_minarearectange, fit_boundingRect, save_MTWI_2108_resault

with tqdm.tqdm(total=len(imagesfile)) as bar:
    for i, j in enumerate(imagesfile):
        bar.update()
        try:
            images = cv2.imdecode(np.fromfile(j, dtype=np.uint8), -1)

            h, w = images.shape[0:2]



            if (w < h and w < MIN_LEN):
                h = MIN_LEN / w * h
                w = MIN_LEN
            elif (h <= w and h < MIN_LEN):
                w = MIN_LEN / h * w
                h = MIN_LEN

            w = min(w, MAX_LEN)
            h = min(h, MAX_LEN)

            w = int(w // 32 * 32)
            h = int(h // 32 * 32)

            #                 w = 640
            #                 h = 640
            scalex = images.shape[1] / w
            scaley = images.shape[0] / h

            images = cv2.resize(images, (w, h), cv2.INTER_AREA)
            images = np.reshape(images, (1, h, w, 3))

            res = model.predict(images)
            res1 = res[0]
            res1[res1 > 0.9] = 1
            res1[res1 <= 0.9] = 0
            newres1 = []
            for i in range(5):
                n = np.logical_and(res1[:, :, 5], res1[:, :, i]) * 255
                newres1.append(n)
            newres1.append(res1[:, :, 5] * 255)
            num_label, labelimage = scale_expand_kernels(newres1)
            rects = fit_minarearectange(num_label, labelimage)

            cv2.drawContours(images[0], np.array(rects) * 2, -1, (0, 0, 255), 2)

            base_name = '.'.join(os.path.basename(j).split('.')[:-1])
            # cv2.imwrite(os.path.join(saveimgdir, base_name + '.jpg'), images[0])

            save_MTWI_2108_resault(os.path.join(savetxtdir, base_name + '.txt'), np.array(rects) * 2, scalex, scaley)
        except Exception as e:
            print(j)
            continue

2.要先生成标签文件

执行tool文件夹下的gen_dataset.py,并修改npy文件保存类型为unit8,否则文件会很大。

npys = np.zeros((img.shape[0],img.shape[1],config.n),dtype='uint8')

猜你喜欢

转载自blog.csdn.net/u011489887/article/details/98985285