paddleocr 模型训练过程记录

要点:


一 文本检测模块

1.0.1 使用原预训练模型进行评估

python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy"

这是 PaddleOCR 评估模型的命令行代码,具体含义如下:

python tools/eval.py:运行 PaddleOCR 工具目录下的 eval.py 脚本。

-c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml:使用 ch_PP-OCRv3_det_cml.yml 配置文件中的模型参数进行模型评估。

-o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy":指定使用在训练过程中保存的预训练模型进行评估。其中 Global.pretrained_model 指定了预训练模型的路径,"./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy" 则表示预训练模型的存储位置

简而言之,此命令将通过指定的配置文件和预训练模型对 PaddleOCR 的目标检测模型进行评估,从而得到其在测试数据集上的性能指标。

1.1 基于PP-OCRv3检测预训练模型进行cml优化

python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy
  • python: 运行 Python 解释器。
  • tools/train.py: PaddleOCR 提供的模型训练的脚本
  • -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml: 指定了训练的任务名称和使用的配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml 是一个针对中文检测任务的训练配置文件,其中包括了模型的定义、优化器的设置、训练的数据集路径等。
  • -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy: 针对模型的预训练模型的路径设置,是一个可选参数。Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy 表示设置了预训练模型的路径为 ./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy

综上,执行该命令会从 configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml 中读取中文检测任务的训练配置,并在第一次训练时使用 ./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy 中的预训练权重参数初始化网络权重。

1.2.1 评估上一步优化的效果

python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det/best_accuracy"
  • 第一步跑的次数比较少,效果一般,官方说法是跑完大概从47.5% ->  65.2%

1.2 基于PP-OCRv3检测的学生模型进行fintune优化

python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/student

python tools/train.py:运行 PaddleOCR 工具目录下的 train.py 脚本。

-c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml:使用 ch_PP-OCRv3_det_student.yml 配置文件中的模型参数进行训练。这个配置文件设置了使用深度可分离卷积(depthwise separable convolution)替代常规卷积,从而减小模型大小,加速训练和推理速度

-o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/student:指定使用预训练模型进行模型优化。其中 Global.pretrained_model 指定了预训练模型的路径,"./pretrained_models/ch_PP-OCRv3_det_distill_train/student" 则表示预训练模型的存储位置

简而言之,这个命令将使用指定的配置和预训练模型对 PaddleOCR 检测模型进行训练,最终得到训练好的目标检测模型。在训练过程中,采用了深度可分离卷积和知识蒸馏等技术,来进一步提高模型的性能指标。

1.2.1 训练效果评估

python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_student/best_accuracy"
  • 3080Ti 800次跑了一天, 效果还可以 , 58% -> 87%

1.3  基于PP-OCRv3检测的教师模型进行fintune优化

首先需要从提供的预训练模型best_accuracy.pdparams中提取teacher参数,组合成适合dml训练的初始化模型,提取代码如下:

%cd /home/aistudio/PaddleOCR/pretrained_models/
# transform teacher params in best_accuracy.pdparams into teacher_dml.paramers
import paddle

# load pretrained model
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# print(all_params.keys())

# keep teacher params
t_params = {key[len("Teacher."):]: all_params[key] for key in all_params if "Teacher." in key}

# print(t_params.keys())

s_params = {"Student." + key: t_params[key] for key in t_params}
s2_params = {"Student2." + key: t_params[key] for key in t_params}
s_params = {**s_params, **s2_params}
# print(s_params.keys())

paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/teacher_dml.pdparams")

这段代码实现了一种知识蒸馏(Knowledge Distillation)的技术,即将一个较大的、精度较高的模型(即“Teacher”)的知识迁移到一个较小、精度略低的模型(即“Student”和“Student2”)中去。

1.3.1 执行训练

python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml
  • -c:指定配置文件的路径,configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml表示使用PP-OCRv3模型进行检测,使用dml(distillation multi-task learning)方式进行训练。
  • -o:指定超参数,Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml表示使用pretrained_models/ch_PP-OCRv3_det_distill_train目录下的teacher_dml模型作为预训练模型。

具体来说,该代码首先读取指定的配置文件,然后使用指定的超参数启动模型训练。在训练过程中,会自动调整学习率、保存模型、打印日志等操作。通过训练模型,可以得到一个可以用于文字检测的模型。

1.3.2 模型评估

python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_teacher/best_accuracy"
  • 200次跑了一天,      -> 85.9%

1.4 基于fintune好的学生和教师模型进行cml优化

需要从4.2和4.3训练得到的best_accuracy.pdparams中提取各自代表student和teacher的参数,组合成适合cml训练的初始化模型,提取代码如下:

%cd /home/aistudio/PaddleOCR/
# transform teacher params and student parameters into cml model
import paddle

all_params = paddle.load("./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# print(all_params.keys())

t_params = paddle.load("./output/ch_PP-OCR_v3_det_teacher/best_accuracy.pdparams")
# print(t_params.keys())

s_params = paddle.load("./output/ch_PP-OCR_v3_det_student/best_accuracy.pdparams")
# print(s_params.keys())

for key in all_params: 
    # teacher is OK
    if "Teacher." in key:
        new_key = key.replace("Teacher", "Student")
        #print("{} >> {}\n".format(key, new_key))
        assert all_params[key].shape == t_params[new_key].shape
        all_params[key] = t_params[new_key]

    if "Student." in key:
        new_key = key.replace("Student.", "")
        #print("{} >> {}\n".format(key, new_key))
        assert all_params[key].shape == s_params[new_key].shape
        all_params[key] = s_params[new_key]

    if "Student2." in key:
        new_key = key.replace("Student2.", "")
        print("{} >> {}\n".format(key, new_key))
        assert all_params[key].shape == s_params[new_key].shape
        all_params[key] = s_params[new_key]
        
paddle.save(all_params, "./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student.pdparams")

1.4.1 执行训练

python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student Global.save_model_dir=./output/ch_PP-OCR_v3_det_finetune/

1.4.2 执行评估

python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_finetune/best_accuracy"

1.5  模型推理

训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。

1.5.1 导出模型

%cd /home/aistudio/PaddleOCR
# 转化为推理模型
!python tools/export_model.py \
-c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
-o Global.pretrained_model=./output/ch_PP-OCR_v3_det_finetune/best_accuracy \
-o Global.save_inference_dir="./inference/det_ppocrv3"

1.5.2 推理预测

%cd /home/aistudio/PaddleOCR
# 推理预测
!python tools/infer/predict_det.py --image_dir="train_data/icdar2015/text_localization/test/1.jpg" --det_model_dir="./inference/det_ppocrv3/Student"

这段代码是用来进行文本检测的推理,其中 --image_dir 指定输入图片的路径,--det_model_dir 指定训练好的文本检测模型的路径。

具体来说,--image_dir 指定了待推理的单张图片路径,可以是一张图片或者是一个文件夹路径,会自动对其内部所有图片进行检测,--det_model_dir 指定了训练好的文本检测模型的路径。推理的过程会使用指定的模型对图片进行文本检测,输出检测结果。

二 针对文字识别的优化

2.0.1 下载预训练模型

下载需要的PP-OCRv3识别预训练模型,更多选择请自行下载其他的文字识别模型

%cd /home/aistudio/PaddleOCR
# 使用该指令下载需要的预训练模型
!wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
# 解压预训练模型文件
!tar -xf ./pretrained_models/ch_PP-OCRv3_rec_train.tar -C pretrained_models

2.0.2 原模型评估

python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy"
  • python: 启动Python解释器
  • tools/eval.py: 执行PaddleOCR中的eval.py脚本,该脚本是用于评估模型性能的
  • -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml: 指定使用的模型配置文件的路径,该模型是基于PP-OCRv3模型结构的知识蒸馏模型,其中rec表示文本识别模型,PP-OCRv3表示基础模型为PP-OCRv3,ch_PP-OCRv3_rec_distillation.yml表示该模型的配置文件名
  • -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy": 指定模型的预训练参数,即使用的权重文件,Global.pretrained_model表示全局参数pretrained_model"./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy"表示权重文件路径
  • 执行以上命令后,程序将加载模型配置和权重文件,对指定的测试集进行识别,并输出模型的性能指标。

2.1 识别模型优化

2.1.1 修正参数

  epoch_num: 100 # 训练epoch数
  save_model_dir: ./output/ch_PP-OCR_v3_rec
  save_epoch_step: 10
  eval_batch_step: [0, 100] # 评估间隔,每隔100step评估一次
  cal_metric_during_train: true
  pretrained_model: ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy  # 预训练模型路径
  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
  use_space_char: true  # 使用空格

  lr:
    name: Cosine # 修改学习率衰减策略为Cosine
    learning_rate: 0.0002 # 修改fine-tune的学习率
    warmup_epoch: 2 # 修改warmup轮数

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/ic15_data/ # 训练集图片路径
    ext_op_transform_idx: 1
    label_file_list:
    - ./train_data/ic15_data/rec_gt_train.txt # 训练集标签
    ratio_list:
    - 1.0
  loader:
    shuffle: true
    batch_size_per_card: 64
    drop_last: true
    num_workers: 4
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/ic15_data/ # 测试集图片路径
    label_file_list:
    - ./train_data/ic15_data/rec_gt_test.txt # 测试集标签
    ratio_list:
    - 1.0
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 64
    num_workers: 4

2.1.2 执行训练

python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

具体来说,train.py是一个命令行工具,可以用于在PaddlePaddle上训练OCR模型。该脚本采用Python语言编写,主要调用了PaddlePaddle的API实现训练过程。其输入参数包括配置文件路径-c),以及其他一些可选的参数-o)。在这个例子中,使用了配置文件ch_PP-OCRv3_rec_distillation.yml来指定训练的超参数、数据路径、网络结构等信息。

该命令的作用是在上述配置文件的基础上开始一个OCR模型的训练过程。在训练过程中,模型会不断地根据训练数据调整自己的参数,以提高在验证数据上的准确率。

2.1.3 模型评估

python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.checkpoints="./output/ch_PP-OCR_v3_rec/best_accuracy"
  • tools/eval.py: 要运行的Python脚本路径,这里是进行文本识别模型的评估。
  • -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml: 指定使用的配置文件,这里是PP-OCRv3文本识别模型的蒸馏配置文件。
  • -o Global.checkpoints="./output/ch_PP-OCR_v3_rec/best_accuracy": 指定了评估时加载的模型路径,这里是训练时保存的在验证集上最佳精度的模型
  • 跑了360次效果

2.2 模型推理

2.2.1 导出模型

训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。

python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" Global.save_inference_dir="./inference/rec_ppocrv3/"
  • -c:指定 OCR 模型配置文件路径
  • -o Global.pretrained_model:指定要导出的 OCR 模型路径
  • -o Global.save_inference_dir:指定导出的模型的保存路径

具体来说,该命令使用 configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml 配置文件中指定的模型配置,将模型导出到 ./inference/rec_ppocrv3/ 目录中,导出的模型使用的是 ./output/ch_PP-OCR_v3_rec/best_accuracy 中的最佳模型。

2.2.2 模型推理

python tools/infer/predict_rec.py --image_dir="train_data/ic15_data/test/1_crop_0.jpg" --rec_model_dir="./inference/rec_ppocrv3/Student"

这段代码使用了 PaddleOCR 工具包中的 predict_rec.py 脚本进行文本识别(Text Recognition)预测。具体来说,它会载入一张图片,使用指定的文本识别模型对图片中的文本进行识别,最后输出识别结果。

  • --image_dir:要进行识别的图片路径;
  • --rec_model_dir:文本识别模型路径。

三 结合检测和识别模型

3.1 将上面训练好的检测和识别模型进行串联测试

python3 tools/infer/predict_system.py --image_dir="./train_data/icdar2015/text_localization/test/142.jpg" --det_model_dir="./inference/det_ppocrv3/Student"  --rec_model_dir="./inference/rec_ppocrv3/Student"
  • tools/infer/predict_system.py:运行 OCR 模型推理的脚本
  • --image_dir="./train_data/icdar2015/text_localization/test/142.jpg":待识别图片所在的目录。
  • --det_model_dir="./inference/det_ppocrv3/Student"文本检测模型所在的目录。
  • --rec_model_dir="./inference/rec_ppocrv3/Student"文本识别模型所在的目录。

测试结果保存在./inference_results/目录下,可以用下面代码进行可视化


%cd /home/aistudio/PaddleOCR
# 显示结果
import matplotlib.pyplot as plt
from PIL import Image
img_path= "./inference_results/142.jpg"
img = Image.open(img_path)
plt.figure("test_img", figsize=(30,30))
plt.imshow(img)
plt.show()

3.2 后处理

如果需要获取key-value信息,可以基于启发式的规则,将识别结果与关键字库进行匹配;如果匹配上了,则取该字段为key, 后面一个字段为value

def postprocess(rec_res):
    keys = ["型号", "厂家", "版本号", "检定校准分类", "计量器具编号", "烟尘流量",
            "累积体积", "烟气温度", "动压", "静压", "时间", "试验台编号", "预测流速",
            "全压", "烟温", "流速", "工况流量", "标杆流量", "烟尘直读嘴", "烟尘采样嘴",
            "大气压", "计前温度", "计前压力", "干球温度", "湿球温度", "流量", "含湿量"]
    key_value = []
    if len(rec_res) > 1:
        for i in range(len(rec_res) - 1):
            rec_str, _ = rec_res[i]
            for key in keys:
                if rec_str in key:
                    key_value.append([rec_str, rec_res[i + 1][0]])
                    break
    return key_value
key_value = postprocess(filter_rec_res)

猜你喜欢

转载自blog.csdn.net/March_A/article/details/130442379