How to use OpenCV's Random Forest (Python)

Preliminary Summary: "Random Forests C++ Implementation: Details, Use and Experiments"

OpenCV's ml module has a C++ implementation of random forest ( RTrees ), and also provides a Python interface. This article introduces how to use OpenCV's random forest for classification and regression, and explains the algorithm parameter setting and function usage with the help of source code and official documents .

0. Experimental environment

OS: Ubuntu 18.04
opencv: 4.1.2
python: 3.6.9

1. Python interface call

The following code is an example of classification using random forest

import cv2
import numpy as np
......
# 读取数据集,X:样本,Y:类别
# X, Y 为numpy类型
......
rf=cv2.ml.RTrees_create()
# 节点分裂时候选的特征数量
rf.setActiveVarCount(27)
# 如果达到节点的样本数小于5, 停止分裂
rf.setMinSampleCount(5)
# RTrees的最大深度的最大值为25,这里设置为40,而实际上用于训练的参数为25
rf.setMaxDepth(40)
# 设置终止条件
rf.setTermCriteria((1,200,0.0))
X=X.astype(np.float32)
# 对于分类器,确保类别为整型,否则算法会认为是回归问题
Y=Y.astype(np.int32)
traindata=cv2.ml.TrainData_create(X,cv2.ml.ROW_SAMPLE,Y) 
rf.train(traindata) 

2. Function and parameter description

OpenCV's random forest is implemented in cv::ml::RTrees . During the use process, I found a few things I didn't understand-how to set the Random Forests parameters (such as the number of trees), and whether the training algorithm to be run is classification or regression—— The interface settings are more abstract, not as clear as sklearn. It may be that OpenCV has to take into account the versatility of the ml module interface, so it is not so intuitive for RF parameter settings, and it is a bit unfriendly to novices.

2.1 About tree depth

It should be noted that in the opencv implementation, the maximum depth of the tree will not exceed 25. Even if the user sets more than 25, the actual runtime will be limited. However, for problems of average size, this value is sufficient.
insert image description here

2.2 Number of trees

An important parameter in random forest is to set the number of random trees in the forest. There is no way to set the number of trees in the RTree class. In fact, this value is in the termination condition. For details, see the following section [Termination Condition] Termination Condition.

rf.setTermCriteria((1,200,0.0)) The number of random trees is the second element of the termination condition triple

insert image description here

2.3 Classification or regression

The training of classification and regression in OpenCV is carried out through the function train , so how can the user distinguish whether the algorithm is running classification or regression? First, take the RF classification code in the first section as an example to see the source code calling process:

When constructing the opencv training data object (traindata) from the sample set (X, Y), call the method create of the class TrainDataImpl , and then call setData in the create method .

traindata=cv2.ml.TrainData_create(X,cv2.ml.ROW_SAMPLE,Y)

insert image description here

In the setData method, set varType.at(ninputvars) to VAR_CATEGORICAL through a series of checks

Note the condition "responses.type()<CV_32F" (red line) in the figure below. There is a line in our python script that is Y=Y.astype(np.int32), Y is the responses in the figure, and its type is a 32-bit integer that satisfies the condition of "<CV_32F".

insert image description here

Then at the end of the setData method, satisfy the conditions such as "varType.at(ninputvars) == VAR_CATEGORICAL", assign labels to classLabels , so classLabels is not empty, remember this.
insert image description here

Then start training, in the startTraining method of the DTreesImpl class, according to the expressiondata->getResponseType() == VAR_CATEGORICALThe result is assigned to _isClassifier .

rf.train(traindata)

insert image description here

The getResponseType() method returns VAR_CATEGORICAL or VAR_ORDERED according to whether classLabels is empty . As mentioned earlier, the end of the setData method has already assigned a value to classLabels , so it is not empty, and the getResponse function returns VAR_CATEGORICAL , which is to assign true to _isClassifier .
insert image description here

In summary, OpenCV's random forest determines whether it is actually running a classification algorithm or a regression algorithm based on the type (integer or floating point) of the training set output (Y in the sample code in this article, responses in the opencv source code), and Y is an integer That is classification, and Y is a floating-point type that is regression. Therefore, in the code of the classification algorithm in the first section, the output type is deliberately convertedY=Y.astype(np.int32)

2.4 Termination conditions

Code about the termination condition in the Python script:rf.setTermCriteria((1,200,0.0)), this triplet is (type, maxCount, epsilon ) . Explanation of termination conditions :

The first parameter type , you can set one of the three values ​​1(COUNT), 2(EPS), 3(COUNT+EPS), and the termination condition is determined by the type value to be maxCount or epsilon or both.
The second parameter maxCount , in the RF algorithm isnumber of random trees.
The second parameter , epsilon , represents precision.
For example, if type is set to 1, then the system termination condition takes the second parameter ( maxCount ), that is, the training stops after generating maxCount random trees.

/*
COUNT: the maximum number of iterations or elements to compute
MAX_ITER: ditto
EPS: the desired accuracy or change in parameters at which the iterative algorithm stops
*/
enum Type {
    
    
  COUNT = 1,
  MAX_ITER = COUNT,
  EPS = 2
}

Three parameters of the setTermCriteria method:
insert image description here

Attachment: complete sample code

import cv2
import pandas as pd
import numpy as np
pd_data=pd.read_csv('/to/your/dir/spambase-pandas.data',header=None,sep=' ');
npdata=pd_data.values
X=npdata[:,1:-1]
Y=npdata[:,0]
rf=cv2.ml.RTrees_create()
rf.setActiveVarCount(7)
rf.setMinSampleCount(5)
rf.setMaxDepth(40)
rf.setTermCriteria((1,200,0.0))
X=X.astype(np.float32)
Y=Y.astype(np.int32)
traindata=cv2.ml.TrainData_create(X,cv2.ml.ROW_SAMPLE,Y) 
rf.train(traindata) 

insert image description here
(The above spambase dataset can be downloaded here )

Guess you like

Origin blog.csdn.net/gxf1027/article/details/127196427