Source of the question: Statistical Learning Methods (Second Edition Li Hang) Chapter 8 Section 1 AdaBoost Example Implementation P158
Question: Given the training data set as shown in the figure. Assume that a weak classifier is generated by x<v or x>v, and its threshold v enables the classifier to have the lowest classification error rate on the training data. Learn a strong classifier using the AdaBoost algorithm.
Data information:
Code: Ideas are for reference only
#include <iostream>
#include <string>
#include <vector>
#include <cmath>
using namespace std;
// 定义弱分类器结构
typedef struct G* G_cls;
struct G{
G(double k) {
this->k = k;
this->is_Reverse = false; // 默认为 num > k : 1 : -1
}
int cal(int n) {
if(is_Reverse)
return n > k ? -1 : 1;
else
return n > k ? 1 : -1;
}
double k;
bool is_Reverse;
};
class Adaboost{
public:
Adaboost(vector<int>& data, vector<int>& label, int epoch, double acc) {
this->data = data;
this->label = label;
this->N = data.size();
this->weight = vector<float> (N, 1.0/N);
this->M = epoch;
this->acc_set = acc;
}
double cal_signx() {
// 当前的累加模型在 数据集上的 分类正确点个数
int cnt = 0;
for(int i=0; i<N; i++) {
double output = 0;
for(int j=0; j<models.size(); j++) {
output += (alpha[j] * models[j]->cal(data[i]));
}
if((output * label[i]) > 0) // 分类正确
cnt++;
}
double ret = double(cnt) / double(N);
return ret;
}
double cal_E(G_cls classifier) {
// 计算 当前分类器 在当前数据分布上的 误差率
double cnt = 0;
for(int i=0; i<N; i++) {
if( (classifier->cal(data[i]) * label[i] ) != 1 )
cnt += weight[i];
}
return cnt;
}
double cal_alpha(double e) {
// 计算 当前弱分类器(m) 的权重系数
return 0.5 * log((1-e)/e);
}
void updata_w(double alp, G_cls model) {
// 更新 训练样本的权值分布
double Z_m = 0.0;
for(int i=0; i<N; i++) {
// 计算分母 用于归一化
Z_m += weight[i] * exp(-alp * model->cal(data[i]) * label[i]);
}
for(int i=0; i<N; i++) {
weight[i] = weight[i] * exp(-alp * model->cal(data[i]) * label[i]) / Z_m;
}
}
pair<G_cls, double> cal_K() {
double k = (double(data[0]) + double(data[1])) / 2;
double ret_K = -1;
double best_E = 1.0;
G_cls bst_classifier = NULL;
for(int j=0; j<N-1; j++) {
// 遍历K,选则最优模型
G_cls classifier = new G(k); // 默认正常形式 rev = false
double e = cal_E(classifier);
if(e > 0.5) {
e = 1 - e;
if(e < best_E) {
classifier->is_Reverse = true;
bst_classifier = classifier;
best_E = e;
ret_K = k;
}
}
else {
if(e < best_E) {
bst_classifier = classifier;
best_E = e;
ret_K = k;
}
}
k += 1.0;
}
return make_pair(bst_classifier, best_E);
}
void refresh_mode(G_cls new_model, double alp) {
models.push_back(new_model);
alpha.push_back(alp);
}
void run() {
for(int i=0; i<M; i++) {
// 根据上一轮更新好的权值分布 暴力计算当前第i个 [弱分类器k值 及 分类误差率]
pair<G_cls, double> ret = cal_K();
// 计算当前分类器权值
double alp = cal_alpha(ret.second);
// 更新 数据分布权重
updata_w(alp, ret.first);
// 更新 加法模型
refresh_mode(ret.first, alp);
// 判断加法模型精度
double acc = cal_signx();
if(acc > acc_set) {
cout << "第" << i+1 << "轮的总分类误差率已满足要求: 当前[" << acc << "]" << " 预设[" << acc_set << "]" << endl;
break;
}
}
// 打印当前模型 及 最后一个 弱分类器 对应的权重分布参数
cout << "mode is: f(x)=";
for(int i=0; i<models.size(); i++) {
if(i==models.size()-1)
cout << alpha[i] << "*G" << i+1 << "(x)" << endl;
else
cout << alpha[i] << "*G" << i+1 << "(x) + ";
}
for(int i=0; i<models.size(); i++) {
string tmp = models[i]->is_Reverse? "发生" : "未发生" ;
cout << "弱分类器G" << i+1 << "(x)" << "的k值为" << models[i]->k << ",该分类器符号" << tmp << "转化" << endl;
}
cout << "最后一个弱分类器对应的权重分布参数: ";
for(int i=0; i<N; i++) {
cout << weight[i] << ' ';
}
}
int N; // 数据的数量
int M; // 最大弱分类器的个数(最大循环轮数)
double acc_set; // 设定精度
vector<int> data;
vector<int> label;
vector<float> weight; // 当前弱分类器 样本的权值分布
vector<float> alpha; // 所有弱分类器的信息
vector<G_cls> models;
};
int main()
{
vector<int> data = {
0,1,2,3,4,5,6,7,8,9};
vector<int> label = {
1,1,1,-1,-1,-1,1,1,1,-1};
Adaboost base(data, label, 10, 0.99);
base.run();
return 0;
}
// 输出:
// 第3轮的总分类误差率已满足要求: 当前[1] 预设[0.99]
// mode is: f(x)=0.423649*G1(x) + 0.649642*G2(x) + 0.752039*G3(x)
// 弱分类器G1(x)的k值为2.5,该分类器符号发生转化
// 弱分类器G2(x)的k值为8.5,该分类器符号发生转化
// 弱分类器G3(x)的k值为5.5,该分类器符号未发生转化
// 最后一个弱分类器对应的权重分布参数: 0.125 0.125 0.125 0.101852 0.101852 0.101852 0.0648148 0.0648148 0.0648148 0.125