Getting started with pyTorch (4) - export the Minist model, C++ OpenCV DNN for recognition

learn better from others,

be the better one.

—— "Weika Zhixiang"

e214a2586127e1b41a1c29db8e3f9c19.jpeg

The length of this article is 2548 words , and it is expected to read for 8 minutes

foreword

The first three chapters introduce the training of pyTorch, and we have successfully saved the model. Today’s article is about using the DNN module of C++ OpenCV to reason about handwritten pictures.

5211dde6c655301afad658ae000adb46.png

achieve effect

d3be349de5811e45e58938ad1e4ec878.png

018f68be7cb451ce71d7a8a51d9d700f.png

The derived inference model uses the ResNet model with a training prediction rate of 99% in Minist. From the above two pictures, most of the digit recognition is no problem, but the number 7 in the two pictures is recognized as the number 1. This is not the problem to be solved in this article for the time being. Let's take a look at how to implement the derived model and reasoning.

Micro card Zhixiang

export model

5d8dfe45784573f1b00d0d65a5a6d67f.png

Since I don’t want to write a new network model, I changed the loaded training set and test set, network model, etc. in train.py to trainmodel.py. Then create a new traintoonnx.py file for exporting ONNX model files. Next, put the source code directly and talk about the key points

train.py

import torch
import time
import torch.optim as optim
import matplotlib.pyplot as plt
from pylab import mpl
import trainModel as tm


##训练轮数
epoch_times = 10


##设置初始预测率,用于判断高于当前预测率的保存模型
toppredicted = 0.0


##设置学习率
learnrate = 0.01 
##设置动量值,如果上一次的momentnum与本次梯度方向是相同的,梯度下降幅度会拉大,起到加速迭代的作用
momentnum = 0.5


##生成图用的数组
##预测值
predict_list = []
##训练轮次值
epoch_list = []
##loss值
loss_list = []


model = tm.Net(tm.train_name)
##加入判断是CPU训练还是GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)


##优化器 
optimizer = optim.SGD(model.parameters(), lr= learnrate, momentum= momentnum)
# optimizer = optim.NAdam(model.parameters(), lr= learnrate)


##训练函数
def train(epoch):
    running_loss = 0.0
    current_train = 0.0
    model.train()
    for batch_idx, data in enumerate(tm.train_dataloader, 0):
        inputs, target = data
        ##加入CPU和GPU选择
        inputs, target = inputs.to(device), target.to(device)


        optimizer.zero_grad()


        #前馈,反向传播,更新
        outputs = model(inputs)
        loss = model.criterion(outputs, target)
        loss.backward()
        optimizer.step()


        running_loss += loss.item()
        ##计算每300次打印一次学习效果
        if batch_idx % 300 == 299:
            current_train = current_train + 0.3
            current_epoch = epoch + 1 + current_train
            epoch_list.append(current_epoch)
            current_loss = running_loss / 300
            loss_list.append(current_loss)


            print('[%d, %5d] loss: %.3f' % (current_epoch, batch_idx + 1, current_loss))
            running_loss = 0.0




def test():
    correct = 0 
    total = 0
    model.eval()
    ##with这里标记是不再计算梯度
    with torch.no_grad():
        for data in tm.test_dataloader:
            inputs, labels = data
            ##加入CPU和GPU选择
            inputs, labels = inputs.to(device), labels.to(device)




            outputs = model(inputs)
            ##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
            _, predicted = torch.max(outputs.data, dim=1)


            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    currentpredicted = (100 * correct / total)
    ##用global声明toppredicted,用于在函数内部修改在函数外部声明的全局变量,否则报错
    global toppredicted
    ##当预测率大于原来的保存模型
    if currentpredicted > toppredicted:
        toppredicted = currentpredicted
        torch.save(model.state_dict(), tm.savemodel_name)
        print(tm.savemodel_name+" saved, currentpredicted:%d %%" % currentpredicted)


    predict_list.append(currentpredicted)    
    print('Accuracy on test set: %d %%' % currentpredicted)        


##开始训练
timestart = time.time()
for epoch in range(epoch_times):
    train(epoch)
    test()
timeend = time.time() - timestart
print("use time: {:.0f}m {:.0f}s".format(timeend // 60, timeend % 60))






##设置画布显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
##设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False


##创建画布
fig, (axloss, axpredict) = plt.subplots(nrows=1, ncols=2, figsize=(8,6))


#loss画布
axloss.plot(epoch_list, loss_list, label = 'loss', color='r')
##设置刻度
axloss.set_xticks(range(epoch_times)[::1])
axloss.set_xticklabels(range(epoch_times)[::1])


axloss.set_xlabel('训练轮数')
axloss.set_ylabel('数值')
axloss.set_title(tm.train_name+' 损失值')
#添加图例
axloss.legend(loc = 0)


#predict画布
axpredict.plot(range(epoch_times), predict_list, label = 'predict', color='g')
##设置刻度
axpredict.set_xticks(range(epoch_times)[::1])
axpredict.set_xticklabels(range(epoch_times)[::1])
# axpredict.set_yticks(range(100)[::5])
# axpredict.set_yticklabels(range(100)[::5])


axpredict.set_xlabel('训练轮数')
axpredict.set_ylabel('预测值')
axpredict.set_title(tm.train_name+' 预测值')
#添加图例
axpredict.legend(loc = 0)


#显示图像
plt.show()

trainmodel.py

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from ModelLinearNet import LinearNet
from ModelConv2d import Conv2dNet
from ModelGoogleNet import GoogleNet
from ModelResNet import ResNet




batch_size = 64
##设置本次要训练用的模型
train_name = 'ResNet'
print("train_name:" + train_name)
##设置模型保存名称
savemodel_name = train_name + ".pt"
print("savemodel_name:" + savemodel_name)




transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
]) ##Normalize 里面两个值0.1307是均值mean, 0.3081是标准差std,计算好的直接用了


##训练数据集位置,如果不存在直接下载
train_dataset = datasets.MNIST(
    root = '../datasets/mnist', 
    train = True,
    download = True,
    transform = transform
)
##读取训练数据集
train_dataloader = DataLoader(
    dataset= train_dataset,
    shuffle=True,
    batch_size=batch_size
)
##测试数据集位置,如果不存在直接下载
test_dataset = datasets.MNIST(
    root= '../datasets/mnist',
    train= False,
    download=True,
    transform= transform
)
##读取测试数据集
test_dataloader = DataLoader(
    dataset= test_dataset,
    shuffle= True,
    batch_size=batch_size
)




##设置选择训练模型,因为python用的是3.9,用不了match case语法
def switch(train_name):
    if train_name == 'LinearNet':
        return LinearNet()
    elif train_name == 'Conv2dNet':
        return Conv2dNet()
    elif train_name == 'GoogleNet':
        return GoogleNet()
    elif train_name == 'ResNet':
        return ResNet()




##定义训练模型
class Net(torch.nn.Module):
    def __init__(self, train_name):
        super(Net, self).__init__()
        self.model = switch(train_name= train_name)
        self.criterion = self.model.criterion


    def forward(self, x):
        x = self.model(x)
        return x

traintoonnx.py

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import trainModel as tm




##获取输入参数
data = iter(tm.test_dataloader)
dummy_inputs, labels = next(data)
print(dummy_inputs.shape)


##加载模型
model = tm.Net(tm.train_name)
model.load_state_dict(torch.load(tm.savemodel_name))
print(model)


##加载的模型测试效果
outputs = model(dummy_inputs)
print(outputs)
##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
_, predicted = torch.max(outputs.data, dim=1)
print(_)
print(predicted)
outlabels = predicted.numpy().tolist()
print(outlabels)


##定义输出输出的参数名
input_name = ["input"]
output_name = ["output"]


onnx_name = tm.train_name + '.onnx'


torch.onnx.export(
    model,
    dummy_inputs,
    onnx_name,
    verbose=True,
    input_names=input_name,
    output_names=output_name
)

focus

01

Export after loading the model

Export the Onnx model, as mentioned in " Super Simple pyTorch Training->onnx Model->C++ OpenCV DNN Reasoning (with source code address) ", it is directly exported after training, while in traintoonnx.py it is The previous training has saved the model, here we directly load the model to read, and then export it.

aab374a59dcfb82bbc5df79fcc08a3a3.png

02

When exporting ONNX model and using OpenCV inference, x.view cannot be used

This is more critical. In our original training model, x.view was used in the forward propagation, as shown in the figure below.

607c9295419d119e2becd1d738fa006f.png

An error was reported directly when exporting ONNX for reasoning in OpenCV, so here we need to change it to x = x.flatten(1)

407f29007636193b7bc5627047c7920c.png

Micro card Zhixiang

C++ OpenCV inference

When using OpenCV DNN for inference, it is not as simple as in " Super Simple pyTorch Training->onnx Model->C++ OpenCV DNN Reasoning (with source code address) ", because it is handwritten digit recognition, and the image during Minist training They are all 1X28X28 samples, so the image needs to be preprocessed before inference. The following is the implementation idea.

# train of thought
1 Read the image, do grayscale processing, Gaussian blur, binarization
2 Morphological operations, using dilation (prevents problematic contour finding)
3 Contour lookup, sort screenshot images according to order
4 The sorted image is processed and scaled to (28X28)
5 Use the DNN to pass in the processed image for inference

C++ reasoning source code

#pragma once
#include<iostream>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn/dnn.hpp>


using namespace cv;
using namespace std;


dnn::Net net;


//排序矩形
void SortRect(vector<Rect>& inputrects) {
  for (int i = 0; i < inputrects.size(); ++i) {
    for (int j = i; j < inputrects.size(); ++j) {
      //说明顺序在上方,这里不用变
      if (inputrects[i].y + inputrects[i].height < inputrects[i].y) {


      }
      //同一排
      else if (inputrects[i].y <= inputrects[j].y + inputrects[j].height) {
        if (inputrects[i].x > inputrects[j].x) {
          swap(inputrects[i], inputrects[j]);
        }
      }
      //下一排
      else if (inputrects[i].y > inputrects[j].y + inputrects[j].height) {
        swap(inputrects[i], inputrects[j]);
      }
    }
  }
}


//处理DNN检测的MINIST图像,防止长方形图像直接转为28*28扁了
void DealInputMat(Mat& src, int row = 28, int col = 28, int tmppadding=5) {
  int w = src.cols;
  int h = src.rows;
  //看图像的宽高对比,进行处理,先用padding填充黑色,保证图像接近正方形,这样缩放28*28比例不会失衡
  if (w > h) {
    int tmptopbottompadding = (w-h) / 2 + tmppadding;
    copyMakeBorder(src, src, tmptopbottompadding, tmptopbottompadding, tmppadding, tmppadding,
      BORDER_CONSTANT, Scalar(0));
  }
  else {
    int tmpleftrightpadding = (h-w) / 2+ tmppadding;
    copyMakeBorder(src, src, tmppadding, tmppadding, tmpleftrightpadding, tmpleftrightpadding,
      BORDER_CONSTANT, Scalar(0));


  }
  resize(src, src, Size(row, col));
}


int main(int argc, char** argv) {
  //定义onnx文件
  string onnxfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/torchminist/ResNet.onnx";


  //测试图片文件
  string testfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/test5.png";


  net = dnn::readNetFromONNX(onnxfile);
  if (net.empty()) {
    cout << "加载Onnx文件失败!" << endl;
    return -1;
  }


  //读取图片,灰度,高斯模糊
  Mat src = imread(testfile);
  //备份源图
  Mat backsrc;
  src.copyTo(backsrc);
  cvtColor(src, src, COLOR_BGR2GRAY);
  GaussianBlur(src, src, Size(3, 3), 0.5, 0.5);
  //二值化图片,注意用THRESH_BINARY_INV改为黑底白字,对应MINIST
  threshold(src, src, 0, 255, THRESH_BINARY_INV | THRESH_OTSU);


  //做彭账处理,防止手写的数字没有连起来,这里做了3次膨胀处理
  Mat kernel = getStructuringElement(MORPH_RECT, Size(5, 5));
  morphologyEx(src, src, MORPH_DILATE, kernel, Point(-1,-1), 3);
  imshow("src", src);


  vector<vector<Point>> contours;
  vector<Vec4i> hierarchy;
  vector<Rect> rects;


  //查找轮廓
  findContours(src, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_NONE);
  for (int i = 0; i < contours.size(); ++i) {
    RotatedRect rect = minAreaRect(contours[i]);
    Rect outrect = rect.boundingRect();
    //插入到矩形列表中
    rects.push_back(outrect);
  }


  //按从左到右,从上到下排序
  SortRect(rects);
  //要输出的图像参数
  for (int i = 0; i < rects.size(); ++i) {
    Mat tmpsrc = src(rects[i]);
    DealInputMat(tmpsrc);


    //Mat inputBlob = dnn::blobFromImage(tmpsrc, 0.3081, Size(28, 28), Scalar(0.1307), false, false);
    Mat inputBlob = dnn::blobFromImage(tmpsrc, 1, Size(28, 28), Scalar(), false, false);


    //输入参数值
    net.setInput(inputBlob, "input");
    //预测结果 
    Mat output = net.forward("output");


    //查找出结果中推理的最大值
    Point maxLoc;
    minMaxLoc(output, NULL, NULL, NULL, &maxLoc);


    cout << "预测值:" << maxLoc.x << endl;


    //画出截取图像位置,并显示识别的数字
    rectangle(backsrc, rects[i], Scalar(255, 0, 255));
    putText(backsrc, to_string(maxLoc.x), Point(rects[i].x, rects[i].y), FONT_HERSHEY_PLAIN, 5, Scalar(255, 0, 255), 1, -1);


  }


  imshow("backsrc", backsrc);




  waitKey(0);
  return 0;
}

focus

01

Use THRESH_BINARY_INV when binarizing

The pictures in the Minist training set all use white characters on a black background, so you need to use THRESH_BINARY_INV to directly change them to white characters on a black background when binarizing.

c44c2cf7a5228bd82f2375231191903e.png

3c32797b55465ed099a9e1541f1952bc.png

02

Morphological Operation Expansion

The use of expansion is mainly to prevent the handwritten numbers from being disconnected, resulting in two contours when the contour is searched

1961252682a527ea8fe4c38de2e44ca0.png

Here, a 5X5 convolution is used, which is expanded three times, and the comparison between the expansion and the unused one is used:

5d869623c4d7a517bdca2f6faada7631.png

use dilation

31fcf740fdf7b060ff874a2c19289ce0.png

No dilation was used, one more contour was identified

03

Contour sorting

If you directly use the detected contour output, it is no problem to display the recognized numbers in the picture, but there will be problems with the output order, like the above picture, the three numbers 5, 6, 7, if you directly find the contour, press If the serial number of the contours is sorted, the order is 7, 5, 6

a32dcfa5c1a608e0bec9c9fc4fd17f3f.png

If the text is output in order, obviously I wrote 567 by hand, but the result is 756 if I input the same, there will be problems, so here we need to sort the found contours, and the sorting method is from left to right, order from top to bottom.

63b35f401e77bdbc0b102e0f4b8094b7.png

The method of contour sorting

a6141bba2bccca4803bef10e642b0158.png

04

Scale the picture to 28X28

336f2ccc28ee99dbc1e418a91b4060d0.png

Like the picture above, especially the contour searched by the number 1, if it is directly scaled to 28X28, the ratio of the image will be unbalanced, so here it is necessary to process the extracted contour image first.

Judging the width and height, make up the difference. For example, the number 1 in the picture above, the width is much worse than the height, then we subtract the width from the current height, and then divide by 2 (dividing 2 is to fill the left and right sides evenly), so that the ratio is close to a square, when scaling Not out of balance. Populate the function copyMakeBorder used.

To prevent the number from being directly attached to the edge after scaling, we fill a threshold around the extracted contour, fill it all with black, and finally scale it. The effect is roughly as follows:

a87023c2dd77d7b1a98375b3ef235aca.png

contour extraction image

before processing

bc552315b8802fb7a1c3a303755fdddd.png

filled image

after treatment

4743bd17cf594c2d68721f2c8d6cffbf.png

05

OpenCV DNN inference

During inference, first use blobFromImage to preprocess the image, and then use DNN for inference. The final returned result needs to be extracted to the maximum value through minMaxLoc to judge the number of inference.

61f2eba5d1ed3d7ee3e745c9727a5218.png

After the above steps, C++ OpenCV can complete the handwritten digit recognition. When this series is completed, the source code will be put into GitHub.

over

31f153b8f03a6fb965066c371322e87d.png

622518075c46099be7addbc22a24c058.png

Wonderful review of the past

 

9ca48bb45e643ac6682893af624490f0.jpeg

Getting started with pyTorch (3) - GoogleNet and ResNet training

 

 

43ca244b4ffafd65a8c83794e0457932.jpeg

Getting Started with pyTorch (2) - Common Network Layer Functions and Convolutional Neural Network Training

 

 

7656bad80f2480490bf9188ce14fe093.jpeg

Getting started with pyTorch (1) - Minist handwritten data recognition training fully connected network

 

Guess you like

Origin blog.csdn.net/Vaccae/article/details/128379461