《机器学习实战》——第9章 树回归

第8章介绍的线性回归包含了一些强大的方法,但这些方法创建的模型需要拟合所有的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法就显得太难了,也略显笨拙。而且, 实际生活中很多问题都是非线性的,不可能使用全局线性模型来拟合任何数据。
一种可行的方法是将数据集切分成很多分易建模的数据,然后利用第8章的线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树结构和回归法相当有用。
CART(分类回归树)的树构建算法既可以用于分类还可以用于回归,因此非常值得学习。然后利用Python来构建并显示CART树,代码会保持足够的灵活性以便能用于多个问题当中。接着利用CART算法构建回归树并介绍其中的树剪枝技术(该技术的主要目的是防止树的过拟合)。之后引入了一个更高级的模型树算法。与回归树的做法(在每个叶节点上使用各自的均值做预测)不同,该算法需要在每个叶节点上都构建出一个线性模型。在这些树的构建算法中有一些需要调整的参数,所以还会介绍如何使用Python中Tkinter模块建立图形交互界面。最后,在该界面的辅助下分析参数对回归效果的影响。

9.1 复杂数据的局部性建模

树回归
优点:可以对复杂和非线性的数据建模。
缺点:结果不易理解。
使用数据类型:数值型和标称型数据。

CART是十分著名且广泛记载的树构建算法,它使用二元切分来处理连续型变量。对CART稍作修改就可以处理回归问题。曾经我们使用香农熵来度量集合的无组织程度。如果选用其他方法来代替香农熵,就可以使用树构建算法来完成回归。
下面将实现CART算法和回归树。回归树与分类树的思路类似,但叶节点的数据类型不是离散型,而是连续型。

9.2 连续和离散型特征的树的构建

在树的构建过程中,需要解决对哦中类型数据的存储问题。我们将使用一部字典来存储树的数据结构,该字典将包含以下4个元素:

  • 待切分的特征。
  • 待切分的特征值。
  • 右子树。当不再需要切分的时候,也可以是单个值。
  • 左子树。与右子树类似。

这与第3章的树结构有一点不同。第3章用一部字典来存储每个切分,但该字典可以包含两个或两个以上的值。而CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值。字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。当然,读者可以用面向对象的编程模式来建立这个数据结构。例如,可以用下面的Python代码来建立树节点:

class treeNode():
def __int__(self,feat,val,right,left):
featureToSplitOn = feat
valueOfSplit = val
rightBranch = right
leftBranch = left

当使用C++这样不太灵活的编程语言时,你可能要用面向对象编程模式来实现树结构。Python具有足够的灵活性,可以直接使用字典来存储树结构而无须另外自定义一个类,从而有效地减少代码量。Python不是一种强类型编程语言,因此接下来会看到,树的每个分枝还可以再包含其他树、数值型数据甚至是向量。
本章将构建两种树:第一种是9.4节的回归树,其每个叶节点包含单个值;第二种是9.5节的模型树,其每个叶节点包含一个线性方程。创建这两种树时,我们将尽量使得代码之间可以重用。下面先给出两种树构建算法中的一些共用代码。
函数 createTree() 的伪代码大致如下:

新建文件regTrees.py并添加如下代码:

from numpy import *

def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree  

loadDataSet()函数与其他章节中同名函数功能类似。在前面的章节中,目标变量会单独存放其自己的列表中,但这里的数据会存放在一起。该函数读取一个以tab键位分隔符的文件,然后将每行的内容保存成一组浮点数。
函数 binSplitDataSet() 有3个参数:数据集合、待切分的特征和该特征的某个值。在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。
最后一个函数是树构建函数 creatTree(),它有4个参数:数据集合其他3个可选参数。这些可选参数决定了树的类型:leafType给出建立叶节点的函数;errType代表误差计算函数;而ops是一个包含树构建所需其他参数的元组。
函数 creatTree() 是一个递归函数。该函数首先尝试将数据集分成两个部分,切分由函数chooseBestSplit() 完成。如果满足停止条件,chooseBestSplit() 将返回None和某类模型的值。如果构建的是回归树,该模型是一个常数。如果是模型树,其模型是一个线性方程。后面会看到停止条件的作用方式。如果不满足停止条件,chooseBestSplit() 将创建一个新的Python字典并将数据集分成两份,在这两份数据集上将分别继续递归调用 createTree() 函数。
chooseBestSplit() 现在暂时无法实现,所以还无法看到 createTree() 的实际效果。但可以先测试其他两个函数的效果。

import regTrees
from numpy import *
# 创建4阶对角矩阵
testMat = mat(eye(4))
mat0,mat1 = regTrees.binSplitDataSet(testMat, 1, 0.5)
print(mat0)
print(mat1)

9.3 将CART算法用于回归

9.3.1 构建树

构建回归树需要补充代码使得函数 createTree()  得以运转。给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外,该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。因此,函数chooseBestSplit() 只需完成两件事:用最佳方式切分数据集和生成相应的叶节点。
函数中可以看出,除了数据集以外,函数 chooseBestSplit() 还有leafType、errType和ops这三个参数。其中leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,而ops是一个用户定义的参数构成的元组,用以完成数的构建。
chooseBestSplit() 函数的目标是找到数据集切分的最佳位置。它遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。该函数的伪代码大致如下:

在regTrees.py文件中加入下列代码:

def regLeaf(dataSet):
    #对最后一列所有元素求均值
    return mean(dataSet[:,-1])

#求总方差
def regErr(dataSet):
    #目标变量的平方误差 * 样本个数(行数)的得到总方差
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):# tolist将数组和矩阵转化为列表
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

第一个函数 regLeaf() ,它负责生成叶节点。当 chooseBestSplit() 函数确定不再对数据进行切分时,将调用该 regLeaf() 函数来得到叶节点的模型。在回归树中,该模型其实就是目标变量的均值。
第二个函数是误差估计函数 regErr()。该函数在给定数据上计算目标变量的平方误差。当然也可以先计算出均值,然后计算每个差值再平方。但这里直接调用均方差函数 var() 更加方便。因为这里需要返回的是总方差,所以要用均方差乘以数据集中样本的个数。
第三个函数是 chooseBestSplit() ,它是回归树构建的核心函数。该函数的目的是找到数据的最佳二元切分方式。如果找不到一个“好”的二元切分,该函数返回None并同时调用 createTree() 方法来产生叶节点,叶节点的值也将返回None。接下来将会看到,在函数 chooseBestSplit() 中有三种情况不会切分,而是直接创建叶节点。如果找到了一个“好”的切分方式,则返回特征编号和切分特征值。
函数 chooseBestSplit() 一开始为ops设定了tolS和tolN这两个值。它们是用户指定的参数,用于控制函数的停止时机。其中变量tolS是容许的误差下降值,tolN是切分的最少样本数。接下来通过对当前所有目标变量建立一个集合,函数 chooseBestSplit() 会统计不同剩余特征值的数目。如果该数目为1,那么久不需要再切分而直接返回。然后函数计算了当前数据集的大小和误差。该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差。
这样,用于找到最佳切分的几个变量就被建立和初始化了。下面就将在所有可能的特征及其可能取值上遍历,找到最佳的切分方式。最佳切分也就是使得切分后能达到最低误差的切分。如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点。另外还需要检查两个切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。最后,如果这些提前终止条件都不满足,那么就返回切分特征和特征值。

9.3.2 运行代码

import regTrees
from numpy import *
myDat = regTrees.loadDataSet('ex00.txt.')
myMat = mat(myDat)
print(regTrees.createTree(myMat))

import regTrees
from numpy import *
myDat1 = regTrees.loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
print(regTrees.createTree(myMat1))

9.4 树剪枝

一棵树如果节点过多,表明该模型可能对数据进行了“过拟合”。之前的章节中使用了测试集上某种交叉验证技术来发现过拟合,决策树亦是如此。
通过降低决策树的复杂程度来避免过拟合的过程称为剪枝。其实本章前面已经进行过剪枝处理。在函数 chooseBestSplit() 中的提前终止条件,实际上是在进行一种所谓的预剪枝操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝。

9.4.1 预剪枝

上节两个简单实验的结果差强人意,但背后存在一些问题。树构建算法其实对输入的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。运行下面的代码:

import regTrees
from numpy import *
myDat = regTrees.loadDataSet('ex00.txt')
myMat = mat(myDat)
myDat1 = regTrees.loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
print(regTrees.createTree(myMat,ops = (0,1)))

与上节只包含两个节点的树相比,这里构建的树过于臃肿,它甚至为数据集中每个样本都分配了一个叶节点。
接下来使用新数据构建一颗新的树:

import regTrees
from numpy import *
myDat = regTrees.loadDataSet('ex00.txt')
myMat = mat(myDat)
myDat1 = regTrees.loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
myDat2 = regTrees.loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
print(regTrees.createTree(myMat2))

叶节点数量的差异,原因在于停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平方值,或许也能得到仅有两个叶节点组成的树:

import regTrees
from numpy import *
myDat = regTrees.loadDataSet('ex00.txt')
myMat = mat(myDat)
myDat1 = regTrees.loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
myDat2 = regTrees.loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
print(regTrees.createTree(myMat2,ops=(1000,4)))

然而,通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。这正是机器学习所关注的内容,计算机应该可以给出总体的概貌。

9.4.2 后剪枝

使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。
函数 prune() 的伪代码如下:

添加下列代码到regTr.py中:

def isTree(obj):
    return (type(obj).__name__ == 'dict')

def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree)  # if we have no test data collapse the tree
    if (isTree(tree['right']) or isTree(tree['left'])):  # if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    # if they are now both leafs, see if we can merge them
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

程序中包含三个函数:isTree()、geyMean()和 prune()。其中i sTree() 用于测试输入变量是否是一棵树,返回布尔类型的结果。即用于判断当前处理的节点是否是叶节点。
函数 getMean() 是一个递归函数,它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值),在 prune() 函数中调用该函数时应明确这一点。
prune() 是主函数,它有两个参数:待剪枝的树与剪枝所需的测试数据testData。prune() 函数首先需要确认测试集是否为空。一旦非空,则反复递归调用函数 prune() 对测试数据进行切分。因为树是由其他数据集(训练集)生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。一旦出现这种情况应当怎么办?数据发生过拟合应该进行剪枝吗?或者模型正确不需要任何剪枝?这里假设发生了过拟合,从而对树进行剪枝。
接下来要检查某个分支到底是子树还是节点。如果是子树,就调用函数 prune() 来对该子树进行剪枝。在对左右两个分支完成剪枝之后,还需要检查它们是否仍然还是子树。如果两个分支已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较。如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回。

import regTrees
from numpy import *
myTree = regTrees.createTree(myMat2,ops=(0, 1))
myDatTest = regTrees.loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
print(regTrees.prune(myTree,myMat2Test))

可以看到,大量的节点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。
下节将重用部分已有的树构建代码来创建一种新的树。该树仍采用二元切分,但叶节点不再是简单的数值,取而代之的是一些线性模型。

9.5 模型树

用树来对数据建模,除了把叶节点简单地设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的分段线性是指模型由多个线性分段组成。下图给出一个例子,使用两条直线拟合显然比使用一组常数来建模更好。因为数据集里的一部分数据以某个线性模型建模,而另一部分数据则以另一个线性模型建模,因此我们说采用了所谓的分段线性模型。


下一个问题就是,为了找到最佳切分,应该怎样计算误差呢?前面用于回归树的误差计算方法这里不能再用。稍加变化,对于给定的数据集,应该先用线性的模型来对它进行拟合,然后计算真实的目标值与预模型预测值间的差值。最后将这些差值的平方求和就得到了所需的误差。 

def linearSolve(dataSet):   #helper function used in two places
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

def modelLeaf(dataSet):#create linear model and return coeficients
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

函数 linearSolve() 会被其他两个函数调用。其主要功能是将数据集格式化成目标变量Y和自变量X。X和Y用于执行简单的线性回归。另外在这个函数中也应当注意,如果矩阵的逆不存在也会造成程序异常。
第二个函数 modelLeaf() 与函数 regLeaf() 类似,当数据不再需要切分的时候它负责生成叶节点的模型。该函数在数据集上调用 linearSolve() 并返回回归系数ws。
最后一个函数是 modelErr() ,可以在给定的数据集上计算误差。它与函数 regErr() 类似,会被chooseBestSplit() 调用来找到最佳的切分。该函数在数据集上调用 linearSolve() ,之后返回yHat和Y之间的平方误差。

import regTrees
from numpy import *
myMat2 = mat(regTrees.loadDataSet('exp2.txt'))
print(regTrees.createTree(myMat2,regTrees.modelLeaf,regTrees.modelErr,(1,10)))

可以看到,代码以0.285477为界创建了两个模型,而图像的数据实际在0.3处分段。createTree() 生成的两个线性模型分别是 y=3.468+1.1852和 y=0.0016985+11.96477x,与用于生成该数据的真实模型非常接近。该数据实际是由模型 y=3.5+1.0x 和 y=0+12x再加上高斯噪声生成的。下图可以看到前图的数据以及生成的线性模型。

模型树、回归树以及第8章里的其他模型,哪一种模型更好呢?一个比较客观的方法是计算相关系数,也称为R²值。该相关系数可以通过调用NumPy库中的命令 corrcoef(yHat,y, rowvar=0) 来求解,其中yHat是预测值,y是目标变量的实际值。
前一章使用了标准的线性回归法,本章则使用了树回归法,下面将通过实例对二者进行比较,最后用函数 corrcoef() 来分析哪个模型是最优的。

9.6 示例:树回归与标准回归的比较

本节首先给出一些函数,它们可以在树构建好的情况下对给定的输入进行预测,之后利用这些函数来计算三种回归模型的测试误差。
在regTrees.py加入下列代码:

def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

对于输入的单个数据点或者行向量,函数 treeForeCast() 会返回一个浮点值。在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。调用函数 treeForeCast() 时需要指定树的类型,以便在叶节点上能够调用合适的模型。参数modelEval是对叶节点数据进行预测的函数的引用。函数 treeForeCast() 自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它就会在输入数据上调用 modelEval() 函数,而该函数的默认值是 regTreeEval( )。
要对回归树叶节点进行预测,就调用函数 regTreeEval() ;要对模型树节点进行预测时,就调用 modelTreeEval() 函数。它们会对输入数据进行格式化处理,在原数据矩阵上增加第0列,然后计算并返回预测值。为了与函数 modelTreeEval() 保持一致,尽管 regTreeEval() 只使用一个输入,但仍保留了两个输入参数。
最后一个函数是 createForCast() ,它会多次调用 treeForeCast() 函数。由于它能够以向量形式返回一组预测值,因此该函数在对整个测试集进行预测时非常有用。下面很快会看到这一点。
下图是一组数据,给出了骑自行车的速度和人的智商之间的关系。下面将基于该数据集建立多个模型并在另一个测试集上进行测试。对应的训练集数据保存在文件bikeSpeedVsIq_train.txt中,而测试集数据保存在文件bikeSpeedVsIq_test.txt中。

下面将为上图数据构建三个模型。

import regTrees
from numpy import *
#创建回归树
trainMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_test.txt'))
myTree = regTrees.createTree(trainMat,ops=(1,20))
yHat = regTrees.createForeCast(myTree,testMat[:, 0])
print(corrcoef(yHat,testMat[:,1],rowvar=0)[0,1])#相关系数
#创建模型树
myTree = regTrees.createTree(trainMat,regTrees.modelLeaf,regTrees.modelErr,(1,20))
yHat = regTrees.createForeCast(myTree,testMat[:,0],regTrees.modelTreeEval)
print(corrcoef(yHat,testMat[:,1],rowvar=0)[0,1])

 

R²值越接近1.0越好,所以从上面的结果可以看出,这里模型树的结果比回归树好。下面是标准的线性回归效果验证,使用已实现过的线性方程求解函数 linearSolve():

import regTrees
from numpy import *
trainMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_train.txt'))
ws,X,T = regTrees.linearSolve(trainMat)
print(ws)

为了得到测试集上所有的yHat预测值,在测试数据上循环执行: 

import regTrees
from numpy import *
trainMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_test.txt'))
myTree = regTrees.createTree(trainMat,ops=(1,20))
yHat = regTrees.createForeCast(myTree,testMat[:, 0])
ws,X,T = regTrees.linearSolve(trainMat)
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0]*ws[1,0]+ws[0,0]
print(corrcoef(yHat,testMat[:,1],rowvar=0)[0,1])

可以看到,该方法在R²值上的表现上不如上面两种树回归方法。所以,树回归方法在预测复杂数据时会比简单的线性模型更有效。下面将展示如何对回归模型进行定性的比较。

9.7 使用Python的Tkinter库创建GUI

9.7.1 用Tkinter创建GUI

下面的代码能够创建一个小窗口,并显示一行文字。

from tkinter import *
root = Tk()  # 建立窗口
myLabel = Label(root,text = "Hello World")
myLabel.grid()  # 窗口布局
root.mainloop()  # 进入等待与处理窗口事件,使该窗口在众多事件中可以相应鼠标点击、按键和重绘等动作

Tkinter的GUI由一些小部件(Widget)组成。所谓小部件,指的是文本框(Text Box)、按钮(Button)、标签(Label)和复选按钮(Check Button)等对象。在刚才的Hello World例子中,标签myLabel就是其中唯一的小部件。当调用myLabel的.grid()方法时,就等于把myLabel的位置告诉了布局管理器(Geometry Manager)。Tkinter中提供了几种不同的布局管理器,其中的.grid()方法会把小部件安排在一个二维的表格中。用户可以设定每个小部件所在的行列位置。这里没有做任何设定,myLabel会默认显示在0行0列。
下面将所需的小部件集成在一起构建树管理器。建立一个新的Python文件treeExplore.py,并加入下列代码:

from numpy import *
from tkinter import *
import regTrees

def reDraw(tolS,tolN):
    pass

def drawNewTree():
    pass

root=Tk()
Label(root,text = "plot place holder").grid(row=0,columnspan=3)
Label(root,text = "tolN").grid(row=1,column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1,column=1)
tolNentry.insert(0,'10')
Label(root,text="tolS").grid(row=2,column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2,column=1)
tolSentry.insert(0,'1.0')
Button(root,text="ReDraw",command=drawNewTree).grid(row=1,column=2,rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root,text="Model Tree",variable=chkBtnVar)
chkBtn.grid(row=3,column=0,columnspan=2)
reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0,0.01)
reDraw(1.0,10)
root.mainloop()

上述代码建立了一组Tkinter模块,并用网格布局管理器安排了它们的位置,这里还给出了两个绘制占位符(plot placeholder)函数,这两个函数的内容会在后面补充。代码格式与之前一致,首先创建一个Tk类型的根部件然后插入标签。我们可以使用.grid()方法设定行和列的位置。另外,也可以通过设定columnspan和rowspan的值来告诉布局管理器是否允许一个小部件跨行或跨列。除此之外还有其他设置项可供使用。
还有一些新的小部件暂时未使用到,这些小部件包括文本输入框(Entry)、复选按钮(Check-button)和按钮整数值(IntVar)等。其中Entry部件是一个允许单行文本输入的文本框。Checkbutton和IntVar的功能:为了读取CheckButton的状态需要创建一个变量,也就是IntVar。
最后初始化一些与 reDraw() 关联的全局变量,这些变量会在后面用到。退出可以通过右上角关闭整个窗口,或是通过下面的代码添加退出按钮:

Button(root,text='Quit',fg="black",command=root.quit).grid(low=1,column=2)

运行代码可看到如下图:

9.7.2 集成Matplotlib和Tkinter

Matplotlib的构建程序包含一个前端,也就是面向用户的一些代码,如plot()和scatter()方法等。事实上,它同时创建了一个后端,用于实现绘图和不同应用之间接口。通过改变后端可以将图像绘制在PNG、PDF、SVG等格式的文件上。下面将设置后端为TkAgg(Agg是一个C++的库,可以从图像创建光栅图)。TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。我们可以在Tk的GUI上放置一个画布,并用.grid()来调整布局。
先用画布来替换绘制占位符,删掉对应标签并添加以下代码:

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

现在将树创建函数与该画布链接起来。打开treeEx.py并添加下面的代码。注意之前实现过reDraw()和drawTree()的存根(stub),确保同一个函数不要重复出现。

def reDraw(tolS, tolN):
    reDraw.f.clf()  # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, \
                                     regTrees.modelErr, (tolS, tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat, \
                                       regTrees.modelTreeEval)
    else:
        myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = regTrees.createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:, 0], reDraw.rawDat[:, 1], s=5)  # use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)  # use plot for yHat
    reDraw.canvas.show()

def getInputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

def drawNewTree():
    tolN, tolS = getInputs()  # get values from Entry boxes
    reDraw(tolS, tolN)

上述程序中一开始导入Matplotlib文件并设定后端为TkAgg。接下来的两个import声明将TkAgg和Matplotlib图链接起来。
当有人点击ReDraw按钮时就会调用 drawNewTree()  函数。函数实现了两个功能:第一,调用getInputs() 方法得到输入框的值;第二,利用该值调用 reDraw() 方法生成一个漂亮的图。
函数 getInputs() 试图理解用户的输入并防止程序崩溃。其中tolS期望的输入是浮点数,而tolN期望的输入是整数。为了得到用户输入的文本,可以在Entry部件上调用.get()方法。虽然表单验证会在GUI编程时花费大量的时间,但这一点对于用户体验来说必不可少。另外,这里使用了try:和except:模式。如果Python可以把输入文本解析成整数就继续执行,如果不能识别则输出错误消息,同时清空输入框并恢复其默认值。对tolS而言也存在同样的处理过程,最后返回输入值。
函数 reDraw() 的主要目的是把树绘制出来。该函数假定输入是合法的,它首先要做的是清空之前的图像,使得前后两个图像不会重叠。清空时图像的各个子图也都会被清除,所以需要重新添加一个新图。接下来函数会检查复选框是否被选中。根据复选框是否被选中,确定基于tolS和tolN参数构建模型树还是回归树。当树构建完成之后就对测试集testDat进行预测,该测试集与训练集有相同的范围且点的分布均匀。最后,真实数据和预测值都被绘制出来。具体实现是,真实值采用scatter() 方法绘制,而预测值则采用 plot() 方法绘制,这是因为 scatter() 方法构建的是离散型散点图,而 plot() 方法则构建连续曲线。

为构建尽可能大的树,应当将tolN设为1,tolS设为0。

猜你喜欢

转载自blog.csdn.net/fjyalzl/article/details/126942356