使用svm训练mist数据集

1.SVM对象的创建和训练

1.1 创建svm

Ptr<ml::SVM> svm = ml::SVM::create();

1.2 svm参数设置

//设置SVM参数
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::RBF);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

或者是

cv::SVM::Params params;
params.svmType = cv::SVM::C_SVC;
params.kernelType = cv::SVM::RBF;
params.termCrit = cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6);
params.C = 1.0;
params.gamma = 0.1;

在这里插入图片描述

2. 使用mist数据集进行分类

使用mist数据集进行分类

#include <opencv2/opencv.hpp>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <Winsock2.h>
//在对话框左侧选择“配置属性->链接器->输入”,在右侧的“附加依赖项”中添加ws2_32.lib库文件

using namespace cv;
using namespace std;

//定义存储训练图像和标签的向量
vector<Mat> train_images;
vector<int> train_labels;

//定义函数来读取MNIST数据集
//是将一个无符号长整形数从网络字节顺序转换为主机字节顺序
//ntohl()返回一个以主机字节顺序表达的数。
void read_MNIST(string filename, vector<Mat>& vec_images, vector<int>& vec_labels)
{
    
    
    
    ifstream file(filename, ios::binary);
    if (file.is_open())
    {
    
    
        cout << "begin to read MNIST" << endl;
        int magic_number = 0;
        int number_of_images = 0;
        int rows = 0;
        int cols = 0;
        file.read((char*)&magic_number, sizeof(magic_number));
        magic_number = ntohl(magic_number);
        file.read((char*)&number_of_images, sizeof(number_of_images));
        number_of_images = ntohl(number_of_images);
        file.read((char*)&rows, sizeof(rows));
        rows = ntohl(rows);
        file.read((char*)&cols, sizeof(cols));
        cols = ntohl(cols);
        for (int i = 0; i < number_of_images; ++i)
        {
    
    
            Mat img = Mat::zeros(rows, cols, CV_8UC1);
            for (int r = 0; r < rows; ++r)
            {
    
    
                for (int c = 0; c < cols; ++c)
                {
    
    
                    unsigned char temp = 0;
                    file.read((char*)&temp, sizeof(temp));
                    img.at<uchar>(r, c) = (int)temp;
                }
            }
            int label = 0;
            file.read((char*)&label, sizeof(label));
            label = ntohl(label);
            vec_images.push_back(img);
            vec_labels.push_back(label);
        }
        cout << "read MNIST finish" << endl;
    }
}

int main()
{
    
    
    //读取训练数据
    
    string train_images_path = "E:/det/mnist/train-images.idx3-ubyte";
    string train_labels_path = "E:/det/mnist/train-labels.idx1-ubyte";
    read_MNIST(train_images_path, train_images, train_labels);
    
    //设置SVM参数
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(ml::SVM::C_SVC);
    svm->setKernel(ml::SVM::RBF);
    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

    //将图像转换为特征向量
    Mat trainData;
    cout << "images 2 vector" << endl;
    for (int i = 0; i < train_images.size(); ++i)
    {
    
    
        Mat img;
        train_images[i].convertTo(img, CV_32FC1);
        img = img.reshape(1, 1);
        trainData.push_back(img);
    }

    //训练SVM模型
    cout << "training svm"<< endl;
    Mat labelsMat(train_labels.size(), 1, CV_32SC1, train_labels.data());
    svm->train(trainData, ml::ROW_SAMPLE, labelsMat);
    cout << "training finished" << endl;


    //保存模型
    svm->save("svm_model.xml");

    return 0;
}



猜你喜欢

转载自blog.csdn.net/weixin_50862344/article/details/129980732