【Svm机器学习篇】Opencv3.4.1与C++实现对多分类问题的训练与预测

#include<opencv2\opencv.hpp>
#include<iostream>    
#include<string>
#include<vector>
#include<fstream>
#include<opencv2/ml/ml.hpp>

using namespace std;
using namespace cv;
using namespace ml;

//训练SVM

void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
{
    model->setType(SVM::C_SVC);     //SVM类型
    model->setKernel(SVM::LINEAR);  //核函数,这里使用线性核

    Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);

    cout << "SVM: start train ..." << endl;
    model->trainAuto(tData);
    cout << "SVM: train success ..." << endl;
}


void svm_pridect(Ptr<SVM> &model, Mat test)
{
    Mat result;
    float rst = model->predict(test, result);
    for (auto i = 0; i < result.rows; i++) {
        cout << result.at<float>(i, 0);
    }
}


int main() {
    vector<Mat>image;
    //string s;
    //string str = "C:\\Users\\wangz\\Documents\\Visual Studio 2015\\Projects\\Project15\\Project15\\pic\\txt.txt";//个人路径
    string ss;
    ifstream fin("C:\\Users\\wangz\\Documents\\Visual Studio 2015\\Projects\\Project15\\Project15\\pic\\txt.txt");
    //fin.open(str);

    //读取data
    while (getline(fin, ss))

    {
        ss = "C:\\Users\\wangz\\Documents\\Visual Studio 2015\\Projects\\Project15\\Project15\\pic\\" + ss;
        Mat m1 = imread(ss, 0);
        resize(m1, m1, Size(64, 64));
        image.push_back(m1); //连续放入Mat容器中
    }
    Mat imagedata(200,64*64,CV_8UC1);
    for (vector<Mat>::iterator it = image.begin(); it != image.end(); it++) {
        imagedata.push_back((*it).reshape(0, 1));
    }
    imagedata.convertTo(imagedata, CV_32F);

    //读取标签
    Mat lable(200, 1, CV_8UC1);
    string sss;
    while (getline(fin, sss))

扫描二维码关注公众号,回复: 2864157 查看本文章

    {
        sss = "C:\\Users\\wangz\\Documents\\Visual Studio 2015\\Projects\\Project15\\Project15\\pic\\" + sss;
        

        //image.push_back(m1); //连续放入Mat容器中
    }
    Mat label(200, 64 * 64, CV_8UC1);

    for (vector<Mat>::iterator it = image.begin(); it != image.end(); it++) {
        imagedata.push_back((*it).reshape(0, 1));
    }
    imagedata.convertTo(imagedata, CV_32F);

    Ptr<SVM> model = SVM::create();
    Mat trainData, trainLabels;
    get_data(train_path, trainData, trainLabels);
    svm_train(model, trainData, trainLabels);

    Mat testData;
    transform(test_set, testData);
    svm_pridect(model, testData);

    //imshow("aa", image[0]);//显示第一张图片
    waitKey(0);
    getchar();
    return 0;

}c

猜你喜欢

转载自blog.csdn.net/qq_35054151/article/details/81772593