xgboost 实现多分类问题demo以及原理

本文先把xgboost支持的多分类问题的demo写起来,打印出生成的树结构,然后理解xgboost实现多分类问题的原理。这个顺序比较好理解一些。

xgboost 多分类问题 demo

这个demo从xgboost的源代码中就可以看到。在这个位置:/demo/multiclass_classification/train.py。train.py文件里的数据(dermatology.data)可以在  https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data  这个网址下载。下载下来的文件的后缀是.data,改成  .csv  或者  .txt  就可以直接用了。我把数据改成了  'data.txt'  。

现在来看看train.py 里的代码吧~

我把代码直接写在下边:这份数据的标签有6类,下边的代码我设置迭代了2轮。

import numpy as np
import xgboost as xgb

# label need to be 0 to num_class -1
data = np.loadtxt('data.txt', delimiter='\t',
        converters={33: lambda x:int(x == '?'), 34: lambda x:int(x) - 1})
sz = data.shape

train = data[:int(sz[0] * 0.7), :]
test = data[int(sz[0] * 0.7):, :]

train_X = train[:, :33]
train_Y = train[:, 34]

test_X = test[:, :33]
test_Y = test[:, 34]

xg_train = xgb.DMatrix(train_X, label=train_Y)
xg_test = xgb.DMatrix(test_X, label=test_Y)
# setup parameters for xgboost
param = {}
# use softmax multi-class classification
param['objective'] = 'multi:softmax'
# scale weight of positive examples
param['eta'] = 0.1
param['max_depth'] = 6
param['silent'] = 1
param['nthread'] = 4
param['num_class'] = 6

watchlist = [(xg_train, 'train'), (xg_test, 'test')]
num_round = 2   # 轮数设置成2轮
bst = xgb.train(param, xg_train, num_round, watchlist)
# get prediction
pred = bst.predict(xg_test)
error_rate = np.sum(pred != test_Y) / test_Y.shape[0]
print('Test error using softmax = {}'.format(error_rate))

xgboost 多分类问题实现原理

训练完之后,关键的一步是把训练好的树打印出来便于查看,下面的代码可以把树结构存为文本形式。我觉得文本形式比图形式好看很多。

bst.dump_model('multiclass_model')

然后我们打开这个文件。其中的每一个booster代表一棵树,这个模型一共有12棵树,booster从0到11。

booster[0]:
0:[f19<0.5] yes=1,no=2,missing=1
	1:[f21<0.5] yes=3,no=4,missing=3
		3:leaf=-0.0587906
		4:leaf=0.0906977
	2:[f6<0.5] yes=5,no=6,missing=5
		5:leaf=0.285523
		6:leaf=0.0906977
booster[1]:
0:[f27<1.5] yes=1,no=2,missing=1
	1:[f12<0.5] yes=3,no=4,missing=3
		3:[f31<0.5] yes=7,no=8,missing=7
			7:leaf=-1.67638e-09
			8:leaf=-0.056044
		4:[f4<0.5] yes=9,no=10,missing=9
			9:leaf=0.132558
			10:leaf=-0.0315789
	2:[f4<0.5] yes=5,no=6,missing=5
		5:[f11<0.5] yes=11,no=12,missing=11
			11:[f10<0.5] yes=15,no=16,missing=15
				15:leaf=0.264427
				16:leaf=0.0631579
			12:leaf=-0.0428571
		6:[f15<1.5] yes=13,no=14,missing=13
			13:leaf=-0.00566038
			14:leaf=-0.0539326
booster[2]:
0:[f32<1.5] yes=1,no=2,missing=1
	1:leaf=-0.0589339
	2:[f9<0.5] yes=3,no=4,missing=3
		3:leaf=0.280919
		4:leaf=0.0631579
booster[3]:
0:[f4<0.5] yes=1,no=2,missing=1
	1:[f0<1.5] yes=3,no=4,missing=3
		3:[f3<0.5] yes=7,no=8,missing=7
			7:[f27<0.5] yes=13,no=14,missing=13
				13:leaf=-0.0375
				14:leaf=0.0631579
			8:leaf=-0.0515625
		4:leaf=-0.058371
	2:[f2<1.5] yes=5,no=6,missing=5
		5:[f32<0.5] yes=9,no=10,missing=9
			9:[f15<0.5] yes=15,no=16,missing=15
				15:leaf=-0.0348837
				16:leaf=0.230097
			10:leaf=-0.0428571
		6:[f3<0.5] yes=11,no=12,missing=11
			11:leaf=0.0622641
			12:[f16<1.5] yes=17,no=18,missing=17
				17:leaf=-1.67638e-09
				18:[f3<1.5] yes=19,no=20,missing=19
					19:leaf=-0.00566038
					20:leaf=-0.0554622
booster[4]:
0:[f14<0.5] yes=1,no=2,missing=1
	1:leaf=-0.0590296
	2:leaf=0.255665
booster[5]:
0:[f30<0.5] yes=1,no=2,missing=1
	1:leaf=-0.0591241
	2:leaf=0.213253
booster[6]:
0:[f19<0.5] yes=1,no=2,missing=1
	1:[f21<0.5] yes=3,no=4,missing=3
		3:leaf=-0.0580493
		4:leaf=0.0831786
	2:leaf=0.214441
booster[7]:
0:[f27<1.5] yes=1,no=2,missing=1
	1:[f12<0.5] yes=3,no=4,missing=3
		3:[f31<0.5] yes=7,no=8,missing=7
			7:leaf=0.000227226
			8:leaf=-0.0551713
		4:[f15<1.5] yes=9,no=10,missing=9
			9:leaf=-0.0314418
			10:leaf=0.121289
	2:[f4<0.5] yes=5,no=6,missing=5
		5:[f11<0.5] yes=11,no=12,missing=11
			11:[f10<0.5] yes=15,no=16,missing=15
				15:leaf=0.206326
				16:leaf=0.0587528
			12:leaf=-0.0420568
		6:[f15<1.5] yes=13,no=14,missing=13
			13:leaf=-0.00512865
			14:leaf=-0.0531389
booster[8]:
0:[f32<1.5] yes=1,no=2,missing=1
	1:leaf=-0.0581933
	2:[f11<0.5] yes=3,no=4,missing=3
		3:leaf=0.0549185
		4:leaf=0.218241
booster[9]:
0:[f4<0.5] yes=1,no=2,missing=1
	1:[f0<1.5] yes=3,no=4,missing=3
		3:[f3<0.5] yes=7,no=8,missing=7
			7:[f27<0.5] yes=13,no=14,missing=13
				13:leaf=-0.0367718
				14:leaf=0.0600201
			8:leaf=-0.0506891
		4:leaf=-0.0576147
	2:[f27<0.5] yes=5,no=6,missing=5
		5:[f3<0.5] yes=9,no=10,missing=9
			9:leaf=0.0238016
			10:leaf=-0.054874
		6:[f5<1] yes=11,no=12,missing=11
			11:leaf=0.200442
			12:leaf=-0.0508502
booster[10]:
0:[f14<0.5] yes=1,no=2,missing=1
	1:leaf=-0.058279
	2:leaf=0.201977
booster[11]:
0:[f30<0.5] yes=1,no=2,missing=1
	1:leaf=-0.0583675
	2:leaf=0.178016

不要忘记我们是6分类问题,训练轮数 num_round 设置成了2。在这12棵树中,第一轮有6棵树,对应  booster0-booster5,第二轮有6棵树,对应  booster6-booster11。第二轮的第 n 棵树 在 第一轮的第 n 棵树基础上再学习,然后两棵树的结果加在一起,再经softmax函数,就得到了预测为第 n 类的概率。那么到这里应该知道xgboost是怎么训练多分类问题的了吧,其实在每一轮,都要训练6棵树。有不明白softmax的可以参考 多分类问题的softmax函数以及损失函数推导

至于xgboost多分类问题的公式推导没有在这篇博客里写,打算专写一篇来讲公式推导~

到此就结束啦,有不对的地方欢迎各位大佬留言~

猜你喜欢

转载自blog.csdn.net/qq_32103261/article/details/116748189