opencv-SVM

void MainWindow::on_SVM_Train_pushButton_clicked()
{
    const char *match_pic[match_all] = {
        "rightPic/0.bmp", "rightPic/1.bmp", "rightPic/2.bmp", "rightPic/3.bmp", "rightPic/4.bmp",
        "rightPic/5.bmp", "errorPic/6.bmp", "errorPic/7.bmp", "errorPic/8.bmp", "errorPic/9.bmp","errorPic/10.bmp","errorPic/11.bmp"
    };
    int i, j, k;

    float labels[match_all] = { 1, 1, 1, 1, 1, 1, -1, -1, -1, -1,-1,-1 };//标记样本类别
    Mat labelsMat(match_all, 1, CV_32FC1, labels);//必须为float

    //读取样本数据
    vector<vector<float>>  trainingData(match_all);
    for (i = 0; i<match_detect; i++){
        cv::Mat mat1 = cv::imread(match_pic[i], 0);
        uchar* ptr = mat1.ptr(0);
        int length = mat1.rows * mat1.cols;

        for (j = 0; j<length; j++){
            trainingData[i].push_back((float)ptr[j])  ;
        }
    }
    //准备训练数据
    //#define pic_size 68040 即图片像素总数
    Mat trainingDataMat(match_all, pic_size, CV_32FC1);
    for (i = 0; i<match_detect; i++)
    {
        for (j = 0; j<pic_size; j++)
        {
            trainingDataMat.at<float>(i, j) = trainingData[i][j];
        }
    }
    // Set up SVM's parameters
    //有很多训练参数的设置方式
    CvSVMParams params;
    params.svm_type = CvSVM::C_SVC;
    params.C = 0.1;
    params.kernel_type = CvSVM::LINEAR;
    params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);

    // Train the SVM
    CvSVM SVM;
    SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);
    SVM.save("1.xml");
}
void  MainWindow::on_SVM_detect_pushButton_clicked()
{
    CvSVM SVM;
    int i, j, k;
    string svm_test_img;
    SVM.load("1.xml");
    svm_test_img = ui->svm_test_file_path->text().toStdString();
    cv::Mat mat2 = cv::imread(svm_test_img, 0);
    uchar* ptr2 = mat2.ptr(0);
    vector<float>  testData;
    //float testData[pic_size];
    for (k = 0; k<pic_size; k++){
        testData.push_back((float)ptr2[k]);;
    }
    //cv::Mat mat3(1, pic_size, CV_32FC1, testData);
    Mat mat3(1, pic_size, CV_32FC1);//测试的数据
    for (i = 0; i<pic_size; i++)
    {

        mat3.at<float>(i) = testData[i];
    }
    float response = SVM.predict(mat3);//进行分类预测
    if (1 == response)
    {
        QString inf = QStringLiteral("合格");
        QMessageBox::information(this, QStringLiteral("测试结果"), inf);
    }
    else
    {
        QString inf = QStringLiteral("不合格 ");
        QMessageBox::information(this, QStringLiteral("测试结果"), inf);
    }

    printf("response:%f\n", response);
    ShowImage(mat2, ui->label_5);
}

猜你喜欢

转载自blog.csdn.net/nathan1025/article/details/54889841