opencv SVM 使用

SVM是一种分类器,下面通过手写0-9数字识别对其进行以下介绍。
1.首先准备训练使用的手写字体
这里写图片描述
这里写图片描述
如图所示,将手写字体分类放在不同的文件夹。
2.读取图片

//每种数字个数
const int count[10] = {5923,6742,5958,6131,5842,5421,5918,6265,5851,5949};
    string filename = "shouxieziti/";
    vector<Mat> imgin;
    vector<int> number;
    int sum = 0;
    for(int i = 0; i < 10; i++){
        string s;
        stringstream ss;
        ss<<i;
        ss>>s;
        for(int j = 1; j < count[i]+1; j++){
            string s1;
            stringstream ss1;
            ss1<<j;
            ss1>>s1;
            if(j<10){
                s1 = s+"_0000"+s1;
            }else if(j < 100){
                s1 = s+"_000"+s1;
            }else if(j < 1000){
                s1 = s+"_00"+s1;
            }else{
                s1 = s+"_0"+s1;
            }
            string in = filename + s + "/" + s1 +".jpg";
            Mat img = imread(in,IMREAD_GRAYSCALE);
//            imshow(in,img);
            imgin.push_back(img);
            number.push_back(i);
            cout<<in<<"  ok"<<" "<<img.channels()<<" "<<number[sum + j - 1]<<" "<<sum<<endl;

        }
        sum += count[i];
    }
    cout<<imgin.size()<<" "<<imgin[0].size()<<"have been read"<<endl;

图片信息的读取由自己的存储方式进行。
3.生成opencv中SVM需要的形式

    Mat imgtrain((int)imgin.size(), 28*28, CV_32FC1);
    Mat imglabel((int)imgin.size(), 1, CV_32SC1);
//    cout<<imgtrain.channels()<<" "<<imglabel.channels()<<endl;
    cout<<"creat train data..."<<endl;
    for(int i = 0; i < (int)imgin.size(); i++){
        Mat_<float>::iterator trainbegin = imgtrain.begin<float>() + 28*28*i;
        Mat_<int>::iterator labelbegin = imglabel.begin<int>();
        Mat_<uchar>::iterator inbegin = imgin[i].begin<uchar>();
        for(int j = 0; j < 28*28; j++){
            float data = (float)*(inbegin+j);
            *(trainbegin+j) = (data+0.0)/255.0;
//            if(data > 200){
//                cout<<*(trainbegin+j)<<" "<<*(labelbegin+j);
//            }
        }
        *(labelbegin+i) = number[i];
        cout<<*(labelbegin+i)<<" ";
    }

其中训练数据是CV_32FC1类型;label数据是CV_32SC1类型。
另外,需要将数据进行归一化,因为读取的是灰度图0-255范围之内,所以我们将每个数据除以255就可以得到0-1之间的数据。
4.利用SVM进行训练

    //设置SVM参数
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(ml::SVM::C_SVC);
    svm->setKernel(ml::SVM::RBF);
    svm->setGamma(0.01);
    svm->setC(10.0);
    svm->setTermCriteria(TermCriteria(CV_TERMCRIT_ITER, 1000,FLT_EPSILON));
    //进行训练
        cout<<"trainning..."<<endl;
    bool f = svm->train(imgtrain,ml::ROW_SAMPLE,imglabel);

//    Ptr<ml::TrainData> traindata = ml::TrainData::create(imgtrain,ml::ROW_SAMPLE,imglabel);
//    bool f = svm->trainAuto(traindata, 10);
//    cout<<f<<endl;
    //保存训练好的数据
    cout<<"saving..."<<endl;
    svm->save("train1.xml");
    cout<<"save done..."<<endl;

5.读取生成的train1.xml进行预测

Ptr<ml::SVM> svm = ml::StatModel::load<ml::SVM>("train1.xml");
cout<<"predicting..."<<endl;
    vector<float> result;
    int right = 0, wrong = 0;
    Mat_<int>::iterator labelbegin = imglabel.begin<int>();
    for(int i = 0; i < (int)imgtrain.rows; i++){
        Mat sample = imgtrain.row(i);
        result.push_back(svm->predict(sample));
        cout<<result[i]<<endl;
        if(abs(result[i] - *(labelbegin+i)) < 0.001){
            right++;
        }else{
            wrong++;
        }
    }
    cout<<"predict done... "<<right<<" right "<<wrong<<" wrong"<<endl;
    cout<<"right rate "<<(float)right/(float)(right+wrong)<<endl;
    cout<<"wrong rate "<<(float)wrong/(float)(right+wrong)<<endl;

6.通过训练60000个样本,能实现非常高的正确率。下图是识别了10000个测试数据的结果
这里写图片描述
7.补充
学习过程中主要参考了如下链接:
https://www.cnblogs.com/cheermyang/p/5624333.html
手写字体是由mnist手写字体图像数据库生成的,参考下列链接:
http://m.blog.csdn.net/fengbingchun/article/details/49611549

猜你喜欢

转载自blog.csdn.net/qq_34359028/article/details/78905417