OPENCV 随机森林使用

RandomTree.h

#ifndef RANDOM_TREE_H
#define RANDOM_TREE_H
#include <opencv2/core/core.hpp>
#include <opencv2/ml/ml.hpp>
#include <string>
#include <vector>
#include <iostream>

using namespace cv;
using namespace std;

class RandomTree
{
public:
    RandomTree(string name);
    void set_train_data_(vector<vector<float>> &train_data, vector<int> &label);
    void Train();
    void Save();
    float AccOnTrain();
    int Predict(vector<float> &f);
private:
    string name_;
    Ptr<ml::TrainData> train_data_;
    Ptr<ml::RTrees> forest_;
};

RandomTree::RandomTree(string name)
{
    name_ = name;
    forest_ = cv::ml::RTrees::create();
    forest_->setMaxDepth(10); //树的最大深度
    forest_->setPriors(cv::Mat());
    forest_->setRegressionAccuracy(0.01); //设置回归精度
    //终止标准
    forest_->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER + cv::TermCriteria::EPS, 100, 0.01));
    forest_->setMinSampleCount(10); //节点的最小样本数量
    forest_->setUseSurrogates(false);
    forest_->setMaxCategories(15);
    forest_->setCalculateVarImportance(true); //计算变量的重要性
    forest_->setActiveVarCount(4);            //树节点随机选择特征子集的大小
}

void RandomTree::set_train_data_(vector<vector<float>> &train_data, vector<int> &label)
{
    int feat_num = train_data[0].size();
    // cout << feat_num << endl;
    cv::Mat train_mat;
    for (int i = 0; i < train_data.size(); ++i)
    {
        cv::Mat ele(1, feat_num, CV_32F);
        for (int j = 0; j < feat_num; ++j)
        {
            ele.at<float>(j) = train_data[i][j];
        }
        train_mat.push_back(ele);
    }
    train_data_ = ml::TrainData::create(train_mat, ml::ROW_SAMPLE, Mat(label));
}

void RandomTree::Train()
{
    forest_->train(train_data_);
}

void RandomTree::Save()
{
    string path = "../model/" + name_ + ".xml";
    // cout << path << endl;
    forest_->save(path);
}

float RandomTree::AccOnTrain()
{
    cv::Mat train_sample = train_data_->getTrainSamples();
    cv::Mat train_label = train_data_->getTrainResponses();
    int size = train_sample.rows, cnt = 0;
    // cout << size << endl;
    for (int i = 0; i < size; ++i)
    {
        cv::Mat sample = train_sample.row(i);
        int r = forest_->predict(sample);
        // cout << train_label.at<int>(i) << endl;
        if (r == train_label.at<int>(i))
            cnt++;
    }
    return 1.0 * cnt / size;
}

int RandomTree::Predict(vector<float> &f)
{
    int feat_num = f.size();
    cv::Mat ele(1, feat_num, CV_32F);
    for (int j = 0; j < feat_num; ++j)
    {
        ele.at<float>(j) = f[j];
    }
    return forest_->predict(ele);
}
#endif

main.cpp
特征提取换成自己的,懒得改了

#include "feature.h"
#include "RandomTree.h"
#include "tool.h"
#include <iomanip>
#include <algorithm>

using namespace std;

int main(int argc, char **argv)
{
  string pos_path("data/train/pos");
  vector<vector<float>> train_data;
  vector<int> label;
  vector<string> file_names;
  findfileinfolder(pos_path.c_str(), "txt", file_names);
  for (int i = 0; i < file_names.size(); ++i)
  {
    vector<float> feat = GetFeature(pos_path + "/" + file_names[i]);
    if (feat.size() == 0)
    {
      cout << file_names[i] << " kong!\n";
      continue;
    }
    train_data.push_back(feat);
    label.push_back(1);
  }
  cout << "load pos " << file_names.size() << endl;
  string neg_path("data/train/neg");
  file_names.clear();
  findfileinfolder(neg_path.c_str(), "txt", file_names);
  for (int i = 0; i < file_names.size(); ++i)
  {
    vector<float> feat = GetFeature(neg_path + "/" + file_names[i]);
    if (feat.size() == 0)
    {
      cout << file_names[i] << " kong!\n";
      continue;
    }
    train_data.push_back(feat);
    label.push_back(-1);
  }
  cout << "load neg " << file_names.size() << endl;
  RandomTree rt("leg");
  rt.set_train_data_(train_data, label);
  rt.Train();
  cout << "AccOnTrain is " << rt.AccOnTrain() << endl;
  /////////////////////////////////////////////////////
  string test_path(argv[1]);
  file_names.clear();
  findfileinfolder(test_path.c_str(), "txt", file_names);
  sort(file_names.begin(), file_names.end());
  int flag, cnt = 0;
  sscanf(argv[2], "%d", &flag);
  for (int i = 0; i < file_names.size(); ++i)
  {
    vector<float> x, y;
    loadxy(test_path + "/" + file_names[i], x, y);
    vector<float> feat;
    GetLF(x, y, feat);
    if (feat.size() == 0)
    {
      cout << file_names[i] << " kong!\n";
      continue;
    }
    if (rt.Predict(feat) == flag)
    {
      cnt++;
    }
  }
  cout << "Predict " << flag << " : " << cnt << "/" << file_names.size() << endl;
  rt.Save();
  return 0;
}
发布了50 篇原创文章 · 获赞 31 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/random_repick/article/details/104063306
今日推荐