aprender mejor de los demás,
ser el mejor
—— "WeikaZhixiang"
La extensión de este artículo es de 2548 palabras , y se espera una lectura de 8 minutos.
prefacio
Los primeros tres capítulos introducen el entrenamiento de pyTorch, y hemos guardado con éxito el modelo.El artículo de hoy trata sobre el uso del módulo DNN de C++ OpenCV para razonar sobre imágenes escritas a mano.
lograr efecto
El modelo de inferencia derivado usa el modelo ResNet con una tasa de predicción de entrenamiento del 99 % en Minist. De las dos imágenes anteriores, la mayoría del reconocimiento de dígitos no es un problema, pero el número 7 en las dos imágenes se reconoce como el número 1. Este no es el problema a resolver en este artículo por el momento, veamos cómo implementar el modelo derivado y el razonamiento.
Micro tarjeta Zhixiang
modelo de exportación
Como no quiero escribir un nuevo modelo de red, cambié el conjunto de entrenamiento cargado y el conjunto de prueba, el modelo de red, etc. en train.py a trainmodel.py. Luego, cree un nuevo archivo traintoonnx.py para exportar archivos de modelo ONNX. A continuación, coloque el código fuente directamente y hable sobre los puntos clave.
tren.py
import torch
import time
import torch.optim as optim
import matplotlib.pyplot as plt
from pylab import mpl
import trainModel as tm
##训练轮数
epoch_times = 10
##设置初始预测率,用于判断高于当前预测率的保存模型
toppredicted = 0.0
##设置学习率
learnrate = 0.01
##设置动量值,如果上一次的momentnum与本次梯度方向是相同的,梯度下降幅度会拉大,起到加速迭代的作用
momentnum = 0.5
##生成图用的数组
##预测值
predict_list = []
##训练轮次值
epoch_list = []
##loss值
loss_list = []
model = tm.Net(tm.train_name)
##加入判断是CPU训练还是GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
##优化器
optimizer = optim.SGD(model.parameters(), lr= learnrate, momentum= momentnum)
# optimizer = optim.NAdam(model.parameters(), lr= learnrate)
##训练函数
def train(epoch):
running_loss = 0.0
current_train = 0.0
model.train()
for batch_idx, data in enumerate(tm.train_dataloader, 0):
inputs, target = data
##加入CPU和GPU选择
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad()
#前馈,反向传播,更新
outputs = model(inputs)
loss = model.criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
##计算每300次打印一次学习效果
if batch_idx % 300 == 299:
current_train = current_train + 0.3
current_epoch = epoch + 1 + current_train
epoch_list.append(current_epoch)
current_loss = running_loss / 300
loss_list.append(current_loss)
print('[%d, %5d] loss: %.3f' % (current_epoch, batch_idx + 1, current_loss))
running_loss = 0.0
def test():
correct = 0
total = 0
model.eval()
##with这里标记是不再计算梯度
with torch.no_grad():
for data in tm.test_dataloader:
inputs, labels = data
##加入CPU和GPU选择
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
_, predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
currentpredicted = (100 * correct / total)
##用global声明toppredicted,用于在函数内部修改在函数外部声明的全局变量,否则报错
global toppredicted
##当预测率大于原来的保存模型
if currentpredicted > toppredicted:
toppredicted = currentpredicted
torch.save(model.state_dict(), tm.savemodel_name)
print(tm.savemodel_name+" saved, currentpredicted:%d %%" % currentpredicted)
predict_list.append(currentpredicted)
print('Accuracy on test set: %d %%' % currentpredicted)
##开始训练
timestart = time.time()
for epoch in range(epoch_times):
train(epoch)
test()
timeend = time.time() - timestart
print("use time: {:.0f}m {:.0f}s".format(timeend // 60, timeend % 60))
##设置画布显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
##设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False
##创建画布
fig, (axloss, axpredict) = plt.subplots(nrows=1, ncols=2, figsize=(8,6))
#loss画布
axloss.plot(epoch_list, loss_list, label = 'loss', color='r')
##设置刻度
axloss.set_xticks(range(epoch_times)[::1])
axloss.set_xticklabels(range(epoch_times)[::1])
axloss.set_xlabel('训练轮数')
axloss.set_ylabel('数值')
axloss.set_title(tm.train_name+' 损失值')
#添加图例
axloss.legend(loc = 0)
#predict画布
axpredict.plot(range(epoch_times), predict_list, label = 'predict', color='g')
##设置刻度
axpredict.set_xticks(range(epoch_times)[::1])
axpredict.set_xticklabels(range(epoch_times)[::1])
# axpredict.set_yticks(range(100)[::5])
# axpredict.set_yticklabels(range(100)[::5])
axpredict.set_xlabel('训练轮数')
axpredict.set_ylabel('预测值')
axpredict.set_title(tm.train_name+' 预测值')
#添加图例
axpredict.legend(loc = 0)
#显示图像
plt.show()
trenmodelo.py
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from ModelLinearNet import LinearNet
from ModelConv2d import Conv2dNet
from ModelGoogleNet import GoogleNet
from ModelResNet import ResNet
batch_size = 64
##设置本次要训练用的模型
train_name = 'ResNet'
print("train_name:" + train_name)
##设置模型保存名称
savemodel_name = train_name + ".pt"
print("savemodel_name:" + savemodel_name)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
]) ##Normalize 里面两个值0.1307是均值mean, 0.3081是标准差std,计算好的直接用了
##训练数据集位置,如果不存在直接下载
train_dataset = datasets.MNIST(
root = '../datasets/mnist',
train = True,
download = True,
transform = transform
)
##读取训练数据集
train_dataloader = DataLoader(
dataset= train_dataset,
shuffle=True,
batch_size=batch_size
)
##测试数据集位置,如果不存在直接下载
test_dataset = datasets.MNIST(
root= '../datasets/mnist',
train= False,
download=True,
transform= transform
)
##读取测试数据集
test_dataloader = DataLoader(
dataset= test_dataset,
shuffle= True,
batch_size=batch_size
)
##设置选择训练模型,因为python用的是3.9,用不了match case语法
def switch(train_name):
if train_name == 'LinearNet':
return LinearNet()
elif train_name == 'Conv2dNet':
return Conv2dNet()
elif train_name == 'GoogleNet':
return GoogleNet()
elif train_name == 'ResNet':
return ResNet()
##定义训练模型
class Net(torch.nn.Module):
def __init__(self, train_name):
super(Net, self).__init__()
self.model = switch(train_name= train_name)
self.criterion = self.model.criterion
def forward(self, x):
x = self.model(x)
return x
trentoonnx.py
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import trainModel as tm
##获取输入参数
data = iter(tm.test_dataloader)
dummy_inputs, labels = next(data)
print(dummy_inputs.shape)
##加载模型
model = tm.Net(tm.train_name)
model.load_state_dict(torch.load(tm.savemodel_name))
print(model)
##加载的模型测试效果
outputs = model(dummy_inputs)
print(outputs)
##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
_, predicted = torch.max(outputs.data, dim=1)
print(_)
print(predicted)
outlabels = predicted.numpy().tolist()
print(outlabels)
##定义输出输出的参数名
input_name = ["input"]
output_name = ["output"]
onnx_name = tm.train_name + '.onnx'
torch.onnx.export(
model,
dummy_inputs,
onnx_name,
verbose=True,
input_names=input_name,
output_names=output_name
)
enfocar
01
Exportar después de cargar el modelo
Exporte el modelo Onnx, como se menciona en " Super Simple pyTorch Training->onnx Model->C++ OpenCV DNN Reasoning (with source code address) ", se exporta directamente después del entrenamiento, mientras que en traintoonnx.py se guarda el entrenamiento anterior. el modelo, aquí cargamos directamente el modelo para leer, y luego lo exportamos.
02
Al exportar el modelo ONNX y usar la inferencia de OpenCV, no se puede usar x.view
Esto es más crítico En nuestro modelo de entrenamiento original, x.view se usó en la propagación hacia adelante, como se muestra en la figura a continuación.
Se informó un error directamente al exportar ONNX para razonar en OpenCV, por lo que aquí debemos cambiarlo a x = x.flatten (1)
Micro tarjeta Zhixiang
Inferencia C++ OpenCV
Cuando se usa OpenCV DNN para la inferencia, no es tan simple como en " Super Simple pyTorch Training->onnx Model->C++ OpenCV DNN Reasoning (with source code address) ", porque es un reconocimiento de dígitos escrito a mano y la imagen durante el entrenamiento Minist Todas son muestras de 1X28X28, por lo que la imagen debe procesarse previamente antes de la inferencia. La siguiente es la idea de implementación.
# | tren de pensamiento |
---|---|
1 | Lea la imagen, procese en escala de grises, desenfoque gaussiano, binarización |
2 | Operaciones morfológicas, usando dilatación (evita encontrar contornos problemáticos) |
3 | Búsqueda de contornos, ordenar imágenes de captura de pantalla según el orden |
4 | La imagen ordenada se procesa y escala a (28X28) |
5 | Use el DNN para pasar la imagen procesada para la inferencia |
Código fuente de razonamiento C++
#pragma once
#include<iostream>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn/dnn.hpp>
using namespace cv;
using namespace std;
dnn::Net net;
//排序矩形
void SortRect(vector<Rect>& inputrects) {
for (int i = 0; i < inputrects.size(); ++i) {
for (int j = i; j < inputrects.size(); ++j) {
//说明顺序在上方,这里不用变
if (inputrects[i].y + inputrects[i].height < inputrects[i].y) {
}
//同一排
else if (inputrects[i].y <= inputrects[j].y + inputrects[j].height) {
if (inputrects[i].x > inputrects[j].x) {
swap(inputrects[i], inputrects[j]);
}
}
//下一排
else if (inputrects[i].y > inputrects[j].y + inputrects[j].height) {
swap(inputrects[i], inputrects[j]);
}
}
}
}
//处理DNN检测的MINIST图像,防止长方形图像直接转为28*28扁了
void DealInputMat(Mat& src, int row = 28, int col = 28, int tmppadding=5) {
int w = src.cols;
int h = src.rows;
//看图像的宽高对比,进行处理,先用padding填充黑色,保证图像接近正方形,这样缩放28*28比例不会失衡
if (w > h) {
int tmptopbottompadding = (w-h) / 2 + tmppadding;
copyMakeBorder(src, src, tmptopbottompadding, tmptopbottompadding, tmppadding, tmppadding,
BORDER_CONSTANT, Scalar(0));
}
else {
int tmpleftrightpadding = (h-w) / 2+ tmppadding;
copyMakeBorder(src, src, tmppadding, tmppadding, tmpleftrightpadding, tmpleftrightpadding,
BORDER_CONSTANT, Scalar(0));
}
resize(src, src, Size(row, col));
}
int main(int argc, char** argv) {
//定义onnx文件
string onnxfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/torchminist/ResNet.onnx";
//测试图片文件
string testfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/test5.png";
net = dnn::readNetFromONNX(onnxfile);
if (net.empty()) {
cout << "加载Onnx文件失败!" << endl;
return -1;
}
//读取图片,灰度,高斯模糊
Mat src = imread(testfile);
//备份源图
Mat backsrc;
src.copyTo(backsrc);
cvtColor(src, src, COLOR_BGR2GRAY);
GaussianBlur(src, src, Size(3, 3), 0.5, 0.5);
//二值化图片,注意用THRESH_BINARY_INV改为黑底白字,对应MINIST
threshold(src, src, 0, 255, THRESH_BINARY_INV | THRESH_OTSU);
//做彭账处理,防止手写的数字没有连起来,这里做了3次膨胀处理
Mat kernel = getStructuringElement(MORPH_RECT, Size(5, 5));
morphologyEx(src, src, MORPH_DILATE, kernel, Point(-1,-1), 3);
imshow("src", src);
vector<vector<Point>> contours;
vector<Vec4i> hierarchy;
vector<Rect> rects;
//查找轮廓
findContours(src, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_NONE);
for (int i = 0; i < contours.size(); ++i) {
RotatedRect rect = minAreaRect(contours[i]);
Rect outrect = rect.boundingRect();
//插入到矩形列表中
rects.push_back(outrect);
}
//按从左到右,从上到下排序
SortRect(rects);
//要输出的图像参数
for (int i = 0; i < rects.size(); ++i) {
Mat tmpsrc = src(rects[i]);
DealInputMat(tmpsrc);
//Mat inputBlob = dnn::blobFromImage(tmpsrc, 0.3081, Size(28, 28), Scalar(0.1307), false, false);
Mat inputBlob = dnn::blobFromImage(tmpsrc, 1, Size(28, 28), Scalar(), false, false);
//输入参数值
net.setInput(inputBlob, "input");
//预测结果
Mat output = net.forward("output");
//查找出结果中推理的最大值
Point maxLoc;
minMaxLoc(output, NULL, NULL, NULL, &maxLoc);
cout << "预测值:" << maxLoc.x << endl;
//画出截取图像位置,并显示识别的数字
rectangle(backsrc, rects[i], Scalar(255, 0, 255));
putText(backsrc, to_string(maxLoc.x), Point(rects[i].x, rects[i].y), FONT_HERSHEY_PLAIN, 5, Scalar(255, 0, 255), 1, -1);
}
imshow("backsrc", backsrc);
waitKey(0);
return 0;
}
enfocar
01
Utilice THRESH_BINARY_INV al binarizar
Todas las imágenes en el conjunto de entrenamiento Minist usan caracteres blancos sobre un fondo negro, por lo que debe usar THRESH_BINARY_INV para cambiarlas directamente a caracteres blancos sobre un fondo negro al binarizar.
02
Expansión de operaciones morfológicas
El uso de la expansión es principalmente para evitar que los números escritos a mano se desconecten, lo que da como resultado dos contornos cuando se busca el contorno.
Aquí, se usa una convolución 5X5, que se expande tres veces, y se usa la comparación entre la expansión y la no utilizada:
usar dilatación
No se utilizó dilatación, se identificó un contorno más
03
Clasificación de contorno
Si usa directamente la salida de contorno detectada, no hay problema para mostrar los números reconocidos en la imagen, pero habrá problemas con el orden de salida, como la imagen de arriba, los tres números 5, 6, 7, si encuentra directamente el contorno, presione Si el número de serie de los contornos está ordenado, el orden es 7, 5, 6
Si el texto sale en orden, obviamente escribí 567 a mano, pero el resultado es 756 si ingreso lo mismo, habrá problemas, así que aquí necesitamos ordenar los contornos encontrados, y el método de clasificación es de izquierda a derecha , ordena de arriba a abajo.
El método de clasificación de contorno
04
Escala la imagen a 28X28
Al igual que la imagen de arriba, especialmente el contorno buscado por el número 1, si se escala directamente a 28X28, la relación de la imagen estará desequilibrada, por lo que aquí es necesario procesar primero la imagen del contorno extraído.
A juzgar por el ancho y la altura, compense la diferencia. Por ejemplo, el número 1 en la imagen de arriba, el ancho es mucho peor que la altura, luego restamos el ancho de la altura actual y luego lo dividimos por 2 (dividir 2 es para llenar los lados izquierdo y derecho de manera uniforme), entonces que la relación esté cerca de un cuadrado, cuando la escala no esté desequilibrada. Rellene la función copyMakeBorder utilizada.
Para evitar que el número se adjunte directamente al borde después de escalar, llenamos un umbral alrededor del contorno extraído, lo llenamos todo con negro y finalmente lo escalamos. El efecto es más o menos el siguiente:
imagen de extracción de contorno
antes de procesar
imagen llena
después del tratamiento
05
Inferencia de DNN de OpenCV
Durante la inferencia, primero use blobFromImage para preprocesar la imagen y luego use DNN para la inferencia. El resultado final devuelto debe extraerse al valor máximo a través de minMaxLoc para juzgar el número de inferencia.
Después de los pasos anteriores, C++ OpenCV puede completar el reconocimiento de dígitos escritos a mano. Cuando se complete esta serie, el código fuente se colocará en GitHub.
encima
Maravillosa revisión del pasado.
Primeros pasos con pyTorch (3) - Capacitación de GoogleNet y ResNet