ann_train.h源码定义一个 AnnTrain类,该类继承自ITrain类(在train.h文件中):
#include "easypr/train/train.h"
#include "easypr/util/kv.h"
#include <memory>
namespace easypr {
class AnnTrain : public ITrain {
public:
explicit AnnTrain(const char* chars_folder, const char* xml);//构造函数,参见文末1
virtual void train();//训练函数,文末2
virtual void test();//调用identify, identifyChinese函数进行预测;
std::pair<std::string, std::string> identifyChinese(cv::Mat input);//识别中文字符
std::pair<std::string, std::string> identify(cv::Mat input);//输入待识别的单个字符图像,进行预测字符,参见文末4
private:
virtual cv::Ptr<cv::ml::TrainData> tdata();
cv::Ptr<cv::ml::TrainData> sdata(size_t number_for_count = 100);
cv::Ptr<cv::ml::ANN_MLP> ann_;
const char* ann_xml_;
const char* chars_folder_;
std::shared_ptr<Kv> kv_;
int type;
};
}
文末
1、
AnnTrain::AnnTrain(const char* chars_folder, const char* xml)
: chars_folder_(chars_folder), ann_xml_(xml) {
ann_ = cv::ml::ANN_MLP::create();//创建一个多层感知器对象
// type=0, all characters
// type=1, only chinese
type = 0;
kv_ = std::shared_ptr<Kv>(new Kv);
kv_->load("resources/text/province_mapping");
}
AnnTrain::AnnTrain(const char* chars_folder, const char* xml) : chars_folder_(chars_folder), ann_xml_(xml)
通过使用构造函数初始化列表,显式的初始化类的成员。
2、可以看到下列代码中有一些参数如:kCharsTotalNumber、kAnnInput、kNeurons等并没有在该头文件中定义,其实这些变量统一定义在config.h头文件中,让我们来详细了解一下config.h文件中都定义了哪些变量,参见文末3
void AnnTrain::train() {
int classNumber = 0;
cv::Mat layers;
int input_number = 0;
int hidden_number = 0;
int output_number = 0;
if (type == 0) {//type == 0表示识别所有中英文字符
classNumber = kCharsTotalNumber;// 65个中英文字符
input_number = kAnnInput; //120,输入层有120个神经元,表示一个样本是由120个特征值组成的特征向量
hidden_number = kNeurons; //隐藏层,40神经元
output_number = classNumber; //输出神经元等于需要分类的个数,即65个
}
else if (type == 1) {//只识别中文字符
classNumber = kChineseNumber; //31个中文字符
input_number = kAnnInput;
hidden_number = kNeurons;
output_number = classNumber;
}
int N = input_number;
int m = output_number;
int first_hidden_neurons = int(std::sqrt((m + 2) * N) + 2 * std::sqrt(N / (m + 2)));
int second_hidden_neurons = int(m * std::sqrt(N / (m + 2)));
bool useTLFN = false;//4层神经网络很难训练,所以我们采用3层神经网络方案
if (!useTLFN) {//3层网络
layers.create(1, 3, CV_32SC1);//1*3的mat二维阵,用于存放网络每层的神经元个数
layers.at<int>(0) = input_number;
layers.at<int>(1) = hidden_number;
layers.at<int>(2) = output_number;
}
else { // 创建两个隐藏层,即4层神经网络
fprintf(stdout, ">> Use two-layers neural networks,\n");
fprintf(stdout, ">> First_hidden_neurons: %d \n", first_hidden_neurons);
fprintf(stdout, ">> Second_hidden_neurons: %d \n", second_hidden_neurons);
layers.create(1, 4, CV_32SC1);
layers.at<int>(0) = input_number;
layers.at<int>(1) = first_hidden_neurons;
layers.at<int>(2) = second_hidden_neurons;
layers.at<int>(3) = output_number;
}
ann_->setLayerSizes(layers);//设置每层的神经元个数
ann_->setActivationFunction(cv::ml::ANN_MLP::SIGMOID_SYM, 1, 1);//激活函数使用逻辑S型函数
ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP);//训练方法选择后向传播
ann_->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 30000, 0.0001));//终止条件使用迭代次数,超过30000次停止训练,
ann_->setBackpropWeightScale(0.1);//权重梯度项的强度, 默认值为0.1。
ann_->setBackpropMomentumScale(0.1);//动量的强度, 从0(禁用功能)到1和更高。默认值为0.1。这个参数和梯度下降优化器有关,先去了解梯度优化器的内容。
auto files = Utils::getFiles(chars_folder_);//返回 vector<std::string>
if (files.size() == 0) {
fprintf(stdout, "No file found in the train folder!\n");
fprintf(stdout, "You should create a folder named \"tmp\" in EasyPR main folder.\n");
fprintf(stdout, "Copy train data folder(like \"ann\") under \"tmp\". \n");
return;
}
//返回cv::ml::TrainData::create(samples_, cv::ml::SampleTypes::ROW_SAMPLE,train_classes);
auto traindata = sdata(350);// 参见ann_train.cpp,生成训练数据的结构cv::ml::TrainData
std::cout << "Training ANN model, please wait..." << std::endl;
long start = utils::getTimestamp();
ann_->train(traindata);
long end = utils::getTimestamp();
ann_->save(ann_xml_);
test();
std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl;
std::cout << "Training done. Time elapse: " << (end - start) / (1000 * 60) << "minute" << std::endl;
}
3、
//#define CV_VERSION_THREE_ZERO
#define CV_VERSION_THREE_TWO //这是我自己更改的,不改它,无法编译运行该项目,详细细节百度一下就知道,不赘述
namespace easypr {
enum Color { BLUE, YELLOW, WHITE, UNKNOWN };//初始化车牌颜色,蓝,黄,白,未知
//定位方式:sobel边缘定位,颜色定位,MSER最稳定极值区域定位,其他方法
enum LocateType { SOBEL, COLOR, CMSER, OTHER };
enum CharSearchDirection { LEFT, RIGHT };//字符搜索:从左到右,从右到左
enum//应该指的是你的车牌是在哪种环境下获取的
{
PR_MODE_UNCONSTRAINED, //无约束环境
PR_MODE_CAMERPOHNE, //相机,手机(这里作者写错了吧?PHONE phone)
PR_MODE_PARKING, //停车场
PR_MODE_HIGHWAY //公路
};
enum
{
PR_DETECT_SOBEL = 0x01, /**Sobel detect type, using twice Sobel */
PR_DETECT_COLOR = 0x02, /**颜色检测车牌 */
PR_DETECT_CMSER = 0x04, /**用 mser 检测车牌 */
};
//训练保存的模型路径
static const char* kDefaultSvmPath = "model/svm_hist.xml";
static const char* kLBPSvmPath = "model/svm_lbp.xml";
static const char* kHistSvmPath = "model/svm_hist.xml";
static const char* kDefaultAnnPath = "model/ann.xml";
static const char* kChineseAnnPath = "model/ann_chinese.xml";
static const char* kGrayAnnPath = "model/annCh.xml";
//This is important to for key transform to chinese,根据车牌省份字符得到省份名称??
static const char* kChineseMappingPath = "model/province_mapping";
typedef enum {
kForward = 1, // correspond to "has plate"
kInverse = 0 // correspond to "no plate"
} SvmLabel;
static const int kPlateResizeWidth = 136;
static const int kPlateResizeHeight = 36;
static const int kShowWindowWidth = 1000;
static const int kShowWindowHeight = 800;
static const float kSvmPercentage = 0.7f;
static const int kCharacterInput = 120;//判别模型的输入个数(即单个样本由多少个特征值来描述)
static const int kChineseInput = 440;//一个中文字符由440个特征表示
static const int kAnnInput = kCharacterInput;
static const int kCharacterSize = 10;
static const int kChineseSize = 20;
static const int kPredictSize = kCharacterSize;
static const int kNeurons = 40;//神经元个数
static const char *kChars[] = {//所有的字符数组
"0", "1", "2","3", "4", "5","6", "7", "8", "9", /* 10 个数字字符 */
"A", "B", "C","D", "E", "F", "G", "H", "J", "K", "L", "M", "N","P", "Q", "R", "S", "T", "U", "V", "W", "X","Y", "Z",
/* 24 个英文字符 ,没有"I","O"*/
"zh_cuan" , "zh_e" , "zh_gan" , "zh_gan1" , "zh_gui" , "zh_gui1" , "zh_hei" , "zh_hu" , "zh_ji" , "zh_jin" , "zh_jing" , "zh_jl" , "zh_liao" , "zh_lu" , "zh_meng" ,"zh_min" , "zh_ning" , "zh_qing" , "zh_qiong", "zh_shan" , "zh_su" ,"zh_sx" , "zh_wan" , "zh_xiang", "zh_xin" , "zh_yu" , "zh_yu1" , "zh_yue" , "zh_yun" , "zh_zang" , "zh_zhe"};
/* 31 个省份缩写中文字符,猜想:得到这个字符之后,然后使用province_mapping得到最终的省份,譬如我得到“zh_wan” ,通过province_mapping得到“安徽省”*/
static const int kCharactersNumber = 34; //非中文字符数
static const int kChineseNumber = 31; //中文字符数
static const int kCharsTotalNumber = 65; //中英文所有字符数
static bool kDebug = false;
static const int kGrayCharWidth = 20;
static const int kGrayCharHeight = 32;
static const int kCharLBPGridX = 4;
static const int kCharLBPGridY = 4;
static const int kCharLBPPatterns = 16;
static const int kCharHiddenNeurans = 64;
static const int kCharsCountInOnePlate = 7;
static const int kSymbolsCountInChinesePlate = 6;
static const float kPlateMaxSymbolCount = 7.5f;
static const int kSymbolIndex = 2;
// Disable the copy and assignment operator for this class.
#define DISABLE_ASSIGN_AND_COPY(className) \
private:\
className& operator=(const className&); \
className(const className&)
// Display the image.
#define SET_DEBUG(param) \
kDebug = param
// Display the image.
#define SHOW_IMAGE(imgName, debug) \
if (debug) { \
namedWindow("imgName", WINDOW_AUTOSIZE); \
moveWindow("imgName", 500, 500); \
imshow("imgName", imgName); \
waitKey(0); \
destroyWindow("imgName"); \
}
// 加载模型。 兼容opencv3.0,3.1和3.2
#ifdef CV_VERSION_THREE_TWO
#define LOAD_SVM_MODEL(model, path) \
model = ml::SVM::load(path);
#define LOAD_ANN_MODEL(model, path) \
model = ml::ANN_MLP::load(path);
#else
#define LOAD_SVM_MODEL(model, path) \
model = ml::SVM::load<ml::SVM>(path);
#define LOAD_ANN_MODEL(model, path) \
model = ml::ANN_MLP::load<ml::ANN_MLP>(path);
#endif
}
4、
std::pair<std::string, std::string> AnnTrain::identify(cv::Mat input) {
cv::Mat feature = charFeatures2(input, kPredictSize);//计算字符图像的特征
float maxVal = -2;
int result = 0;
//std::cout << feature << std::endl;
cv::Mat output(1, kCharsTotalNumber, CV_32FC1);//存放神经网络输出
ann_->predict(feature, output);//predict函数输出kCharsTotalNumber个 得分数,代表65种字符,预测的得分,得分最高的视为最终的预测结果。
//std::cout << output << std::endl;
for (int j = 0; j < kCharsTotalNumber; j++) {//取出得分最大的预测
float val = output.at<float>(j);
//std::cout << "j:" << j << "val:" << val << std::endl;
if (val > maxVal) {
maxVal = val;
result = j;
}
}
auto index = result;//分数最大值对应的索引,用其去kchar表中,查找对应的字符
if (index < kCharactersNumber) {
return std::make_pair(kChars[index], kChars[index]);
}
else {
const char* key = kChars[index];
std::string s = key;
std::string province = kv_->get(s);
return std::make_pair(s, province);
}
}