tensorflow2.0入门实例五(模型训练)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_43162240/article/details/102664017

  做东西,最重要的就是动手了,所以这篇文章动手跑了一个fcn32s和fcn8s以及deeplab v3+的例子,这个例子的数据集选用自动驾驶相关竞赛的kitti数据集, FCN8s在训练过程中用tensorflow2.0自带的评估能达到91%精确率, deeplab v3+能达到97%的准确率。这篇文章适合入门级选手,在文章中不再讲述fcn的结构,直接百度就可以搜到。
  文章使用的是tensorflow2.0框架,该框架集成了keras,在模型的训练方面极其简洁,不像tf1.x那么复杂,综合其他深度学习框架,发现这个是最适合新手使用的一种。
  文章中用到的库函数,参数等均可在tensorflow2.0 api中查找到。
  文章的代码在github可以获取,地址:https://github.com/fengshilin/tf2.0-FCN

  文章的结构如下:

  1. 数据下载与分析
  2. 数据预处理(重点在label的预处理)
  3. 模型加载
  4. 模型建立(FCN与Deeplab)
  5. 模型训练与测试

模型训练

  tensorflow2.0集成了keras,使得模型训练变得更简洁
  这里需要引入前几步完成的dataset以及模型Mymodel,然后用compile编译并fit训练即可。

import os
import tensorflow as tf
import numpy as np
import scipy
import cv2

train_dataset = tf.data.Dataset.from_generator(
    train_generator, (tf.float32, tf.float32), (tf.TensorShape([None, None, None]), tf.TensorShape([None, None, None])))

train_dataset = train_dataset.shuffle(buffer_size=len(train_filenames))
train_dataset = train_dataset.batch(10)

model = MyModel(2)  # FCN模型
# model = DeepLabV3Plus(image_shape[0], image_shape[1], nclasses=2)  # deeplab模型

# 在得到模型后,配置compile的参数即可,需要optimizer与loss与评估器。learning_rate一般取0.001或者0.0001较好。
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
model.compile(
    optimizer=optimizer,
    loss=tf.compat.v2.nn.softmax_cross_entropy_with_logits,
    metrics=['accuracy']
)
# fit输入的参数为训练集dataset以及跑的轮数,可自己配置,用colab的gpu跑100轮需要1个小时左右。
model.fit(train_dataset, epochs=num_epochs)
# model.summary()

# 保存模型,以便以后在这个模型的基础上训练,或者用这个模型测试时,只需要这一个文件与测试数据集即可,不需要再定义模型。
model.save('weights/fcn8s_20191021.h5')

模型测试

  以上模型训练过程中会打印训练过程的准确率,FCN8s跑100轮,准确率能达到91%。下面讲解如何用数据来跑模型。

第一步

  生成测试集的dataset。

# 生成测试集的路径列表
test_dir = os.path.join("data", "test")+"/"
test_list_dir = os.listdir(test_dir)
test_list_dir.sort()
test_filenames = [test_dir + filename for filename in test_list_dir]

def test_generator():
    """测试集生成器"""
    for test_filename in test_filenames:
        image = handle_data(test_filename)

        yield image


def handle_data(train_filenames, train_label_filenames=None):
    """对数据做处理"""
    image = scipy.misc.imresize(
        scipy.misc.imread(train_filenames), image_shape)
    
    image_yuv = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
    image_yuv[:, :, 0] = cv2.equalizeHist(image_yuv[:, :, 0])
    image = cv2.cvtColor(image_yuv, cv2.COLOR_YUV2RGB)

    if train_label_filenames is not None:
        gt_image = scipy.misc.imresize(
            scipy.misc.imread(train_label_filenames), image_shape)
        
        background_color = np.array([255, 0, 0])
        gt_bg = np.all(gt_image == background_color, axis=2)
        gt_bg = gt_bg.reshape(*gt_bg.shape, 1)
        gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2)
    
        return np.array(image), gt_image
    else:
        return np.array(image)

# 测试数据集生成dataset,并设置批量为1
test_dataset = tf.data.Dataset.from_generator(
    test_generator, tf.float32, tf.TensorShape([None, None, None]))
test_dataset = test_dataset.batch(1)

第二步

  加载训练好的模型, 并预测结果

model = tf.keras.models.load_model('weights/fcn8s_20191021.h5')

for filename in test_filenames:
  ori_image = scipy.misc.imresize(
        scipy.misc.imread(filename), image_shape)
        
  #  image[np.newaxis,:,:,:]是给image增加一个维度1,将维度[h,w, c]转为[1, h, w, c]以作为模型的输入
  image = ori_image[np.newaxis,:,:,:].astype("float32")
  
  # 对输入做预测, 这里的输入维度是[batch, h, w, c], 输出维度为[batch, h, w, n_class]
  out = model.predict(image)
  
  # 写入图片,第一个参数是模型的预测值,第二个参数为输出的路径
  write_img(out, filename, ori_image)

第三步

  还记得之前对label做的操作吗?将每个像素点的像素值[0,0,255]映射为了形如[1,0]的数组,所以预测值也是[0.8,0.2]的概率值数组,需要将其设置为形如[255,255,255](白色,这个像素值数组可以自己设置,比如[0,0,0]表示黑色)的像素值。

def write_img(pred_images, filename, ori_image):
    
    pred = pred_images[0]

    COLORMAP = [[255,255,255],[0, 255, 0]]  # 第一个是背景的颜色,第二个是道路的颜色
    cm = np.array(COLORMAP).astype(np.uint8)

    pred = np.argmax(np.array(pred), axis=2)

    pred_val = cm[pred]
    
    overlap = cv2.addWeighted(ori_image, 0.7, pred_val, 0.3, 0)
    cv2.imwrite(os.path.join("gdrive", "My Drive", "data", "deeplab", filename.split("/")[-1]), overlap)
    print(os.path.join("gdrive", "My Drive", "data", "deeplab", filename.split("/")[-1])+"finished")

猜你喜欢

转载自blog.csdn.net/weixin_43162240/article/details/102664017