ICNet-tensorflow 训练自己的数据集

代码原址:https://github.com/hellochick/ICNet-tensorflow
参考:

  1. PSPNet-tensorflow实现并训练数据
  2. 【图像语义分割】Label data的标注–Labelme(python)

一、修改ICNet参数

类别4类(其中0为背景)

  • inference.py train.py
    • 修改 num_classes(cityscapes_class) 为对应类别4
    • 修改 ignore_label 为0(背景为0)
  • icnet_cityscapes.prototxt icnet_cityscapes_bnnomerge.prototxt
    • 修改 num_output 为对应类别4
  • tools.py (与数据集中的类别颜色无关,与输出颜色有关)
    • 修改颜色 label_colours
    label_colours = [[0,0,0], [128, 0, 0], [0, 128, 0], [128, 128, 0]]
    # 0 = background, 1 = dangerous1, 2 = dangerous2, 3 = warning
    
  • network.py
    • 在加载预训练模型时,由于最后一层分类层输出个数与预训练模型不同,所以加载时候要将最后一层删除
def load(self, data_path, session, ignore_missing=True):
    '''Load network weights.
    data_path: The path to the numpy-serialized network weights
    session: The current TensorFlow session
    ignore_missing: If true, serialized weights for missing layers are ignored.
    '''
    data_dict = np.load(data_path, encoding='latin1').item()
    for op_name in data_dict:
        with tf.variable_scope(op_name, reuse=True):
            for param_name, data in data_dict[op_name].items():
               try:
                    if 'bn' in op_name:
                        param_name = BN_param_map[param_name]

                    var = tf.get_variable(param_name)
                    # 新增
                    if 'conv6_cls' not in var.name:
                       session.run(var.assign(data))
               except ValueError:
                    if not ignore_missing:
                        raise
  • inference.py
    • 选择使用的checkpoint
    if args.model == 'others':
        ckpt = tf.train.get_checkpoint_state(model_path)
        print("ckpt: ", ckpt)
        if ckpt and ckpt.model_checkpoint_path:
            loader = tf.train.Saver(var_list=tf.global_variables())
            load_step = int(os.path.basename(ckpt.all_model_checkpoint_paths[3]).split('-')[1])
            load(loader, sess, ckpt.all_model_checkpoint_paths[3])
            # 选择保存的五个checkpoint的第4个
        else:
            print('No checkpoint file found.')
    

二、制作数据集

之前输出一直得到全黑图像,我修改了两个地方,具体不知道是哪个起了作用
1、label_img单通道
  • 采用labelme制作数据集,得到.json文件
    • .json文件应该要转换成单通道的image labelme_json_to_dataset <文件名>.json
    • 得到四个文件 *.png, info.yaml , label.png, label_viz.png,其中label.png即是我们要的label_data。
    • 但是labelme输出的不是单通道的(不仅有类别信息还有颜色信息),修改 /Users/xxxx/.pyenv/versions/anaconda3-5.0.1/envs/labelme/lib/python3.6/site-packages/labelme/utils/_io.py
    def lblsave(filename, lbl):
        if osp.splitext(filename)[1] != '.png':
            filename += '.png'
        # Assume label ranses [-1, 254] for int32,
        # and [0, 255] for uint8 as VOC.
        if lbl.min() >= -1 and lbl.max() < 255:
            lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='L') #P
            # colormap = label_colormap(255)
            # lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
            lbl_pil.save(filename)
        else:
            logger.warn(
                '[%s] Cannot save the pixel-wise class label as PNG, '
                'so please use the npy file.' % filename
            )
    
    • 得到一张黑色的图像,可以用matlab将其打开,为一个包含0~4矩阵
2、训练图片
  • 训练的原始图片使用由 .json 得到的 img.png 而不是原始的1.jpg

猜你喜欢

转载自blog.csdn.net/manga_/article/details/82808077