opencv系列之机器学习(ml)

(版本为3.4.0)

opencv官方api文档:https://docs.opencv.org/

ml模块的svm操作:

python版本

1、生成训练数据

训练文件分别以类标签为文件名,里面存放对应的类文件

def generate_data(self,file_dir):
    train_data= []
    train_labels = []
    if os.path.exists(file_dir):
        file_list = os.listdir(file_dir)
        for fl in file_list:
            class_dir = os.path.join(file_dir,fl)
            if os.path.isdir(class_dir):
                filenames = os.listdir(class_dir)
                for f in filenames: 
                img_name = os.path.join(class_dir,f)
                img = cv2.imread(img_name)
                img = cv2.resize(img,self.resize,interpolation=cv2.INTER_CUBIC)
                new_img = img.reshape((1,self.resize[0]*self.resize[1]*3))
                train_data.append(new_img[0])
                train_labels .append(int(fl))
    return (train_data,train_labels )
    

2、训练

def svmtrain(train_data,train_labels):
        # 创建分类器  
        svm = cv2.ml.SVM_create()  
        svm.setType(cv2.ml.SVM_C_SVC)  # SVM类型  
        svm.setKernel(cv2.ml.SVM_LINEAR) # 使用线性核  
        svm.setC(1.0)  
        train = np.array(train_data,np.float32)
        train_labels = np.array(train_labels,np.int32)
        train_labels = train_labels.reshape((train_labels.size,1))
        # 训练  
        ret = svm.train(train, cv2.ml.ROW_SAMPLE, train_labels) 
        svm.save("svm_data.dat")

3、测试

 def svmtest(model_path,test_file,resize):
        svm = cv2.ml.SVM_load(model_path)
        test_data = []
        img = cv2.imread(test_file)
        img = cv2.resize(img,resize,interpolation=cv2.INTER_CUBIC)
        new_img = img.reshape((1,resize[0]*resize[1]*3))
        test_data .append(new_img[0])
        (ret, res) = svm.predict(test_data )  
        for i,r in enumerate(res):
            text = ""
            text = text + str(int(r[0]))
            label = self.img_labels[i]
            cv2.putText(img,'result:'+text,(0,20),cv2.FONT_HERSHEY_COMPLEX,1,(0,0,255),1)
            cv2.putText(img,'label:'+str(label),(0,50),cv2.FONT_HERSHEY_COMPLEX,1,(0,0,255),1)
            ff = os.path.basename(test_file)
            cv2.imwrite("out"+ff,img)

c++版本

扫描二维码关注公众号,回复: 3565907 查看本文章

1、生成训练数据

void getFiles(string path, vector<string>& files, vector<int> &trainingLabels, int &label, vector<Mat>& trainingImages, Size dsize)
{
    DIR *p_dir;
    // path = path.append("/");
    const char* str = path.c_str();
    
    p_dir = opendir(str);
    if( p_dir == NULL)
    {   
        cout<< "can't open :" << path << endl;
    }
    struct dirent *p_dirent;
    while ( p_dirent = readdir(p_dir))
    {   
        string tmpFileName = p_dirent->d_name;
        if( tmpFileName == "." || tmpFileName == "..")
        {   
            continue;
        }
        else
        {
            cout<<"===========================02"<<endl;
            
            string filepath = path + + "/" + tmpFileName;
            cout<<"filename:"<<filepath<<endl;
            char const* filename = filepath.data();
            struct stat s_buf;
            /*获取文件信息,把信息放到s_buf中*/
            stat(filename, &s_buf);
            if(S_ISDIR(s_buf.st_mode))
	        {
                cout<<"===========================03"<<endl;
                label = atoi(tmpFileName.c_str());
                
                getFiles(filepath,files, trainingLabels, label, trainingImages, dsize);

            }/*若输入的文件路径是普通文件,则打印并退出程序*/
	        else if(S_ISREG(s_buf.st_mode))
            {
                cout<<"===========================04"<<endl;
                files.push_back(filepath);
                cout<<"===========================05"<<endl;
                Mat SrcImage=imread(filepath);
                cout<<"===========================06"<<endl;
                // 缩小图像
                Mat DstImage;
                resize(SrcImage, DstImage, dsize, 0, 0, INTER_LINEAR);
                DstImage= DstImage.reshape(1, 1);
                cout<<"===========================07"<<endl;
                trainingImages.push_back(DstImage);
                cout<<"===========================08"<<endl;
                trainingLabels.push_back(label);
                cout<<"===========================09"<<endl;
            }        
        }
    }
    closedir(p_dir);

}

2、svm训练

int main()
{
    //获取训练数据
    Mat classes;
    // Mat trainingData;
    vector<Mat> trainingImages;
    vector<int> trainingLabels;
    int label = 0;
    string path = "";
    string model_path = "";
    vector<string> files;
    Size dsize = Size(20, 20);

    cout<<"===========================01"<<endl;
    getFiles(path, files, trainingLabels, label, trainingImages, dsize);
    // getFiles(string path, vector<string>& files, vector<int> &trainingLabels, const int &label, Mat& trainingImages)

    // get_1(trainingImages, trainingLabels);
    // get_0(trainingImages, trainingLabels);
    Mat trainingData(trainingImages.size(), trainingImages[0].cols, CV_32FC1);
	for (int i = 0; i < trainingImages.size(); i++)
	{
		Mat temp(trainingImages[i]);
		temp.copyTo(trainingData.row(i));
	}
	trainingData.convertTo(trainingData, CV_32FC1);
	Mat(trainingLabels).copyTo(classes);
	// classes.convertTo(classes, CV_32SC1);

    //配置SVM训练器参数
    Ptr<SVM> model = SVM::create();//以下是设置SVM训练模型的配置
	model->setType(SVM::C_SVC);
	model->setKernel(SVM::LINEAR);
	model->setGamma(1);
	model->setC(1);
	model->setCoef0(0);
	model->setNu(0);
	model->setP(0);
	model->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 20000, 0.0001));
 
 
	Ptr<TrainData> tdata = TrainData::create(trainingData, ROW_SAMPLE, classes);
	//model->train(trainingData, ROW_SAMPLE, classes);
	model->train(tdata);
	model->save(model_path);//保存

    // svm.save(model_path);
    cout<<"训练好了!!!"<<endl;
    // getchar();
    return 0;
}

3、svm测试

/**
    Linux下扫描文件夹, 获得文件夹下的文件名
*/
int scanFiles(vector<string> &fileList, string inputDirectory)
{
    inputDirectory = inputDirectory.append("/");

    DIR *p_dir;
    const char* str = inputDirectory.c_str();

    p_dir = opendir(str);  
    if( p_dir == NULL)
    {
        cout<< "can't open :" << inputDirectory << endl;
    }

    struct dirent *p_dirent;
    while ( p_dirent = readdir(p_dir))
    {
        string tmpFileName = p_dirent->d_name;
        if( tmpFileName == "." || tmpFileName == "..")
        {
            continue;
        }
        else
        {
            fileList.push_back(tmpFileName);
        }
    }
    closedir(p_dir);
    return fileList.size();
}



int main(int argc, char** argv) {
    
    
    //读取文件夹下所有文件
    string file_path;
    string out_path;
    string svm_model_path;
    
    vector<string> files;
    int size = scanFiles(files, file_path);
    //加载svm模型
    Ptr<SVM> model = SVM::load(svm_model_path);
    for (int i = 0;i < size;i++)  
    {  
        string filename = file_path + '/' + files[i].c_str();
        string out_filename = out_path + '/' + files[i].c_str();
        // cout<<"filename:"<<filename<<endl;
        Mat img = imread(filename);
        Mat DestImage;
        DestImage= img.reshape(1, 1);
        DstImage.convertTo(DstImage, CV_32FC1);
        float response = model->predict(DstImage);
        
    }

   return 0;
}

猜你喜欢

转载自blog.csdn.net/zhanghenan123/article/details/82631594