使用image-segmentation-keras的SegNet训练和预测自己的数据集

前面我的两篇博客分别介绍了语义分割FCNSegNet的算法重点知识及代码实现,最近在github上又fork了一个好资源https://github.com/divamgupta/image-segmentation-keras,这里分享一下。

该资源实现了FCN,UNet, SegNet, PSPNet网络,本篇以SegNet为例来说明下如何使用其来训练和预测自己的数据集。

值得一提的是,该资源是在最新的tensorflow2.x框架下实现的,不像很多其它资源还是基于tensorflow1.x搭配老版本的keras实现。这给电脑上配置的tensorflow2.x框架的用户带来了便利,不需要把精力放在不同版本的兼容上。

1.按照要求配置好版本

博主的环境配置如下:

cuda10.1, cuddn7.6

2. pip安装keras-segmentation库

pip install keras-segmentation

安装过程中也许你的已配置的相关库会和其要求的版本有冲突,需要结合各自的配置更改下。

3. 下载源代码(篇末也会附上本博客的修改后的项目资源),新建一个dataset文件夹,pycharm中的工程目录结构如下:

dataset中有4个文件夹,是我自己所做的数据集(train中是训练原图,trainannot中是训练用标注图,val是训练时的验证集原图,valannot是训练时验证集对应的标注图)

该数据集的制作方式及下载路径可见我的博客https://blog.csdn.net/jiugeshao/article/details/113836354

4. test_models.py被我修改为如下:

import numpy as np
import tempfile
import  os;

from keras_segmentation.models import all_models
from keras_segmentation.data_utils.data_loader import \
    verify_segmentation_dataset, image_segmentation_generator
from keras_segmentation.predict import predict_multiple, predict, evaluate

tr_im = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\train"
tr_an = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\trainannot"
te_im = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\val"
te_an = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\valannot"


def test_verify():
    verify_segmentation_dataset(tr_im, tr_an, 50)


def test_datag():
    g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an,
                                     batch_size=3,  n_classes=50,
                                     input_height=224, input_width=324,
                                     output_height=114, output_width=134,
                                     do_augment=False)

    x, y = next(g)
    assert x.shape[0] == 3
    assert y.shape[0] == 3
    assert y.shape[-1] == 50


# with augmentation
def test_datag2():
    g = image_segmentation_generator(images_path=tr_im, segs_path=tr_an,
                                     batch_size=3,  n_classes=50,
                                     input_height=224, input_width=324,
                                     output_height=114, output_width=134,
                                     do_augment=True)

    x, y = next(g)
    assert x.shape[0] == 3
    assert y.shape[0] == 3
    assert y.shape[-1] == 50


def test_model():
    model_name = "fcn_8"
    h = 288
    w = 512
    n_c = 2
    check_path = "D:\\model\\ckpt\\trainmodel"

    predict_multiple(
      inp_dir=te_im, checkpoints_path=check_path, out_dir="D:/tmp_batch")

    p = predict(inp="D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\val\\132.bmp", out_fname="D://tmp_one//out.jpg" , checkpoints_path=check_path)

后面会用该文件中的test_model()函数来完成训练模型对预测图片的测试。

predict_muiltiple()函数可以批量对图片进行测试

predict()函数可以指定某张图片进行测试

5.如下目录下新建一个testAll.py文件来完成对该框架的测试。

该框架中有很多有用的功能可供自己在后面深入开发中做借鉴,本篇主要目的是训练一个网络并用训练出来的模型去预测一张图片,但也提下里面好的功能:

(1)使用visualize_segmentation_dataset(tr_im, tr_an)函数可以可视化原图和对应的标注图

(2)采用预训练模型来初始化语义分割网络的编码器结构中的网络参数

(3)  支持增量学习(在前面最新一次的模型基础上继续学习,调整网络参数)

  (4)  如何设置每隔多次存一次模型,同时预测时能够判断最近的模型

同时也建议大家看看作者的博客https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html

testAll.py中代码如下:

import test_models
from keras_segmentation.train import train
from keras_segmentation.data_utils.visualize_dataset import visualize_segmentation_dataset
from keras_segmentation.pretrained import pspnet_50_ADE_20K , pspnet_101_cityscapes, pspnet_101_voc12
tr_im = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\train"
tr_an = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\trainannot"
te_im = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\val"
te_an = "D:\\mycode\\0-Object_Segmentation\\image-segmentation-keras-master\\dataset\\valannot"
#visualize_segmentation_dataset(tr_im,tr_an,8 )
# model = pspnet_50_ADE_20K() # load the pretrained model trained on ADE20k dataset
# model = pspnet_101_cityscapes() # load the pretrained model trained on Cityscapes dataset
# model = pspnet_101_voc12() # load the pretrained model trained on Pascal VOC 2012 dataset
# # load any of the 3 pretrained models
# out = model.predict_segmentation(
#     inp="input_image.jpg",
#     out_fname="out.png"
#)
check_path = "D:\\model\\ckpt\\trainmodel"
train("vgg_segnet",tr_im ,tr_an,288,512,8,verify_dataset=True,checkpoints_path=check_path,epochs=5,
          batch_size=4,
          validate=False,
          val_images=te_im,
          val_annotations=te_im,
          val_batch_size=2,
          steps_per_epoch=100,
          val_steps_per_epoch=100,
)
test_models.test_model()

该py文件在训练结束后,会预测一次,训练所用的数据集路径设定好,要保存的模型所在文件夹路径设定好,同时注意下这里:

原keras_segmentation中vgg16.py会加载预训练模型,但该模型由于github限速原因,不定能下载下来,所以我做了修改,指定了一个路径,当然这不是一个好的做法,

可以将此路径在外面调用接口中开放出来。该预训练模型"vgg16_weights_th_dim_ordering_th_kernels_notop.h5",也在篇末的资源路径里附上,有需要的可下载。

6.运行testAll.py文件,进行模型训练和预测

保存的模型路径下可看到每次迭代完成后所保存的模型

7.批量图片预测结果如下:

一一对应的原图如下:

8.再仔细分析下所预测的单张图的效果

                                                   原图

                                           对应的人工标注图

                                            该篇预测结果图

     博客https://blog.csdn.net/jiugeshao/article/details/113836354预测结果

可使用如下代码查看标注图的显示效果:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:Icecream.Shao
from skimage import io,data,color
import cv2 as cv
import numpy as np
 
#img_name='Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/validation/gt/img_000495_bad_gt.png'
#img=io.imread(img_name,as_grey=False)
img_name='logs/pred_8.png'
img=io.imread(img_name)
img_gray=color.rgb2gray(img)
rows,cols=img_gray.shape
for i in range(rows):
    for j in range(cols):
        if (img_gray[i,j]<=0.5):
            img_gray[i,j]=0
        else:
            img_gray[i,j]=1
 
io.imshow(img_gray)
io.show()
 
cv.imshow("original", img)
cv.waitKey(0)
 
ret, binary = cv.threshold(img, 0, 255, cv.THRESH_BINARY | cv.THRESH_OTSU)#大律法,全局自适应阈值 参数0可改为任意数字但不起作用
print("阈值:%s" % ret)
cv.imshow("OTSU", binary)
cv.waitKey(0)
 
# ret, binary = cv.threshold(gray, 150, 255, cv.THRESH_BINARY)# 自定义阈值为150,大于150的是白色 小于的是黑色
# print("阈值:%s" % ret)
# cv.imshow("自定义", binary)

对应的完整工程资源见如下链接

链接:https://pan.baidu.com/s/1FBybwLwJtNti_ZJbwXwWiQ 
提取码:oaaa 
 

Guess you like

Origin blog.csdn.net/jiugeshao/article/details/115023441