代码原址:https://github.com/hellochick/ICNet-tensorflow
参考:
一、修改ICNet参数
类别4类(其中0为背景)
- inference.py train.py
- 修改
num_classes(cityscapes_class)
为对应类别4 - 修改
ignore_label
为0(背景为0)
- 修改
- icnet_cityscapes.prototxt icnet_cityscapes_bnnomerge.prototxt
- 修改
num_output
为对应类别4
- 修改
- tools.py (与数据集中的类别颜色无关,与输出颜色有关)
- 修改颜色
label_colours
label_colours = [[0,0,0], [128, 0, 0], [0, 128, 0], [128, 128, 0]] # 0 = background, 1 = dangerous1, 2 = dangerous2, 3 = warning
- 修改颜色
- network.py
- 在加载预训练模型时,由于最后一层分类层输出个数与预训练模型不同,所以加载时候要将最后一层删除
def load(self, data_path, session, ignore_missing=True):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
session: The current TensorFlow session
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
data_dict = np.load(data_path, encoding='latin1').item()
for op_name in data_dict:
with tf.variable_scope(op_name, reuse=True):
for param_name, data in data_dict[op_name].items():
try:
if 'bn' in op_name:
param_name = BN_param_map[param_name]
var = tf.get_variable(param_name)
# 新增
if 'conv6_cls' not in var.name:
session.run(var.assign(data))
except ValueError:
if not ignore_missing:
raise
- inference.py
- 选择使用的checkpoint
if args.model == 'others': ckpt = tf.train.get_checkpoint_state(model_path) print("ckpt: ", ckpt) if ckpt and ckpt.model_checkpoint_path: loader = tf.train.Saver(var_list=tf.global_variables()) load_step = int(os.path.basename(ckpt.all_model_checkpoint_paths[3]).split('-')[1]) load(loader, sess, ckpt.all_model_checkpoint_paths[3]) # 选择保存的五个checkpoint的第4个 else: print('No checkpoint file found.')
二、制作数据集
之前输出一直得到全黑图像,我修改了两个地方,具体不知道是哪个起了作用
1、label_img单通道
- 采用labelme制作数据集,得到.json文件
- .json文件应该要转换成单通道的image
labelme_json_to_dataset <文件名>.json
- 得到四个文件 *.png, info.yaml , label.png, label_viz.png,其中
label.png
即是我们要的label_data。 - 但是labelme输出的不是单通道的(不仅有类别信息还有颜色信息),修改
/Users/xxxx/.pyenv/versions/anaconda3-5.0.1/envs/labelme/lib/python3.6/site-packages/labelme/utils/_io.py
def lblsave(filename, lbl): if osp.splitext(filename)[1] != '.png': filename += '.png' # Assume label ranses [-1, 254] for int32, # and [0, 255] for uint8 as VOC. if lbl.min() >= -1 and lbl.max() < 255: lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='L') #P # colormap = label_colormap(255) # lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten()) lbl_pil.save(filename) else: logger.warn( '[%s] Cannot save the pixel-wise class label as PNG, ' 'so please use the npy file.' % filename )
- 得到一张黑色的图像,可以用matlab将其打开,为一个包含0~4矩阵
- .json文件应该要转换成单通道的image
2、训练图片
- 训练的原始图片使用由
.json
得到的img.png
而不是原始的1.jpg