Image segmentation suite PaddleSeg comprehensive analysis (eight) prediction code interpretation

After the model is trained, the picture can be predicted, the model result can be visualized, and the segmentation effect can be viewed.

Run the command as follows:

python predict.py \
       --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml \
       --model_path output/iter_1000/model.pdparams \
       --image_path data/optic_disc_seg/JPEGImages/H0003.jpg \
       --save_dir output/result

First explain the meaning of the parameters of the above command,

-Config specifies the configuration file, which contains the name of the model.

--Model_path specifies the model path

--Image_path specifies the image path of the input prediction

-Save_dir specifies the path where the output prediction results are saved.

You can also use the following command to perform multi-scale rollover prediction.

--Aug_pred whether to enable enhanced prediction

--Scales scaling factor, the default is 1.0

--Flip_horizontal whether to turn on horizontal flip

--Flip_vertical whether to turn on vertical flip

Multi-scale flip prediction is based on ordinary prediction, multi-scale zooming, horizontal and vertical flipping operations on the input picture, to obtain multiple prediction results, and then add the multiple prediction results as the final output result. You can understand the workflow of the forecasting program through the following figure.

Let's interpret the code of predict.py below.

if __name__ == '__main__':
	#解析传入参数
    args = parse_args()
    #执行主体函数
    main(args)

Let's interpret the parse_args function to understand that the input parameters supported by the predict.py script are basically the same as those of val.py.

def parse_args():
    parser = argparse.ArgumentParser(description='Model prediction')

    # params of prediction
    # 配置文件路径
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    # 训练好的模型权重路径
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='The path of model for prediction',
        type=str,
        default=None)
    # 输入的预测图片路径
    parser.add_argument(
        '--image_path',
        dest='image_path',
        help=
        'The path of image, it can be a file or a directory including images',
        type=str,
        default=None)
    #输出的保存预测结果路径
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')

    # augment for prediction
    #是否使用多尺度和翻转增强的方式预测。这种方法会带来精度的提升,推荐使用
    parser.add_argument(
        '--aug_pred',
        dest='aug_pred',
        help='Whether to use mulit-scales and flip augment for prediction',
        action='store_true')
    # 指定缩放系数,1.0为保持尺寸不变,可以指定多个系数,用空格隔开。
    parser.add_argument(
        '--scales',
        dest='scales',
        nargs='+',
        help='Scales for augment',
        type=float,
        default=1.0)
    # 开启图片水平翻转
    parser.add_argument(
        '--flip_horizontal',
        dest='flip_horizontal',
        help='Whether to use flip horizontally augment',
        action='store_true')
    #开启图片垂直翻转
    parser.add_argument(
        '--flip_vertical',
        dest='flip_vertical',
        help='Whether to use flip vertically augment',
        action='store_true')

    # sliding window prediction
    #滑动窗口参数配置,是否开启滑动窗口
    parser.add_argument(
        '--is_slide',
        dest='is_slide',
        help='Whether to prediction by sliding window',
        action='store_true')
    # 滑动窗口尺寸
    parser.add_argument(
        '--crop_size',
        dest='crop_size',
        nargs=2,
        help=
        'The crop size of sliding window, the first is width and the second is height.',
        type=int,
        default=None)
    # 滑动窗口移动的步长,需要指定水平方向和垂直方向两个参数。
    parser.add_argument(
        '--stride',
        dest='stride',
        nargs=2,
        help=
        'The stride of sliding window, the first is width and the second is height.',
        type=int,
        default=None)

    return parser.parse_args()

The above is the analysis of the input parameters. In the main function, the predict function in the core/predict.py module is mainly used to predict the picture.

First look at the code summary of the predict function.

Then code interpretation of the predict function.

def predict(model,
            model_path,
            transforms,
            image_list,
            image_dir=None,
            save_dir='output',
            aug_pred=False,
            scales=1.0,
            flip_horizontal=True,
            flip_vertical=False,
            is_slide=False,
            stride=None,
            crop_size=None):
    #加载模型权重
    para_state_dict = paddle.load(model_path)
    model.set_dict(para_state_dict)
    #设置模型为评估模式
    model.eval()

    added_saved_dir = os.path.join(save_dir, 'added_prediction')
    pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')

    logger.info("Start to predict...")
    #设置进度条
    progbar_pred = progbar.Progbar(target=len(image_list), verbose=1)
    #遍历图片列表
    for i, im_path in enumerate(image_list):
    	#读取图像
        im = cv2.imread(im_path)
        #获取图像宽高
        ori_shape = im.shape[:2]
        #对图像进行转换
        im, _ = transforms(im)
        #新增一个维度
        im = im[np.newaxis, ...]
        #将ndarray数据转换为张量
        im = paddle.to_tensor(im)
		#是否开启多尺度翻转预测
        if aug_pred:
            #开启多尺度翻转预测,则对图片进行多尺度翻转预测
            pred = infer.aug_inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                scales=scales,
                flip_horizontal=flip_horizontal,
                flip_vertical=flip_vertical,
                is_slide=is_slide,
                stride=stride,
                crop_size=crop_size)
        else:
            #如果没有开启多尺度翻转预测,则对图片进行常规的推理预测操作。
            pred = infer.inference(
                model,
                im,
                ori_shape=ori_shape,
                transforms=transforms.transforms,
                is_slide=is_slide,
                stride=stride,
                crop_size=crop_size)
        #将返回数据去除多余的通道,并转为uint8类型,方便保存为图片
        pred = paddle.squeeze(pred)
        pred = pred.numpy().astype('uint8')
		
        #获取保存图片的名称
        # get the saved name
        if image_dir is not None:
            im_file = im_path.replace(image_dir, '')
        else:
            im_file = os.path.basename(im_path)
        if im_file[0] == '/':
            im_file = im_file[1:]
		#保存结果
        added_image = utils.visualize.visualize(im_path, pred, weight=0.6)
        added_image_path = os.path.join(added_saved_dir, im_file)
        mkdir(added_image_path)
        cv2.imwrite(added_image_path, added_image)

		# 保存伪色彩预测结果
        # save pseudo color prediction
        pred_mask = utils.visualize.get_pseudo_color_map(pred)
        pred_saved_path = os.path.join(pred_saved_dir,
                                       im_file.rsplit(".")[0] + ".png")
        mkdir(pred_saved_path)
        pred_mask.save(pred_saved_path)

        # pred_im = utils.visualize(im_path, pred, weight=0.0)
        # pred_saved_path = os.path.join(pred_saved_dir, im_file)
        # mkdir(pred_saved_path)
        # cv2.imwrite(pred_saved_path, pred_im)
		#进度条进度加1
        progbar_pred.update(i + 1)

In the above code, according to different input parameters, different inference functions are called, which have been introduced in the interpretation of the evaluation code in the previous section, and will not be repeated here.

The above is all the interpretation of the main code of the current version of PaddleSeg.

This series of articles will also be kept in sync with PaddleSeg version updates regularly. Due to my limited level, please understand if there are any errors.

PaddleSeg warehouse address: https://github.com/PaddlePaddle/PaddleSeg

Guess you like

Origin blog.csdn.net/txyugood/article/details/113487050