MLiA笔记_树回归

#-*-coding:utf-8-*-

from numpy import *

# 9.1 CART算法的实现代码
# createTree()树构建函数(数据集,其他三个可选参数:建立叶结点的函数、误差计算函数、包含树构建所需其他参数的元组)
# 是一个递归函数
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float, curLine)
        dataMat.append(fltLine)
    return dataMat

# 参数:数据集合,待切分的特征,和该特征的某个值
# 在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回
def binSplitDataSet(dataSet,feature,value):
    mat0 = dataSet[nonzero(dataSet[:, feature]>value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:, feature]<=value)[0],:]
    #mat0 = dataSet[nonzero(dataSet[:,feature]>value)[0],:][0]
    #mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0],:][0]
    return mat0, mat1

# 9.2 回归树的切分函数
# regLeaf()函数,负责生成叶节点,
# 当chooseBestSplit()切分函数确定不再对数据进行切分时,调用该函数来得到叶节点的模型。
# 在回归树中,该模型其实就是目标变量的均值
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

# regErr()函数,在给定数据上计算目标变量的平方误差。
# 直接调用var()函数(当然也可以先计算出均值,然后计算每个差值在平方)
def regErr(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]


# 首先由chooseBestSplit()切分函数将数据分成两部分,chooseBestSplit()切分函数将返回none值和某类模型的值
# 如果找不到一个“好”的二元切分,该函数返回none的同时调用createTree()树构建函数来产生叶节点,叶节点的值也将返回none
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 一开始为ops设定了tolS、tolN两个值,用户指定参数,用于控制函数的停止时机
    tolS = ops[0] # 容许的误差下降值
    tolN = ops[1] # 切分的最少样本数
    # chooseBestSplit()切分函数会统计不同剩余特征值的数目,如果为1,那么就不需要再切分而是直返回
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    # 然后函数计算当前数据集的大小和误差
    m,n = shape(dataSet)
    # 该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差
    S = errType(dataSet)
    bestS = inf
    bestIndex = 0
    bestValue = 0
    # 在所有可能的特征及
    for featIndex in range(n-1):
        # 其可能取值上遍历
        # for splitVal in set(dataSet[:,featIndex]):
        for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
            mat0 , mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS<bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 接下来将会看到chooseBestSplit()切分函数中有三种情况不会切分,而是直接创造叶节点。
    # 如果切分数据集后效果提升不够大,那么就不应该进行切分操作而直接创建叶节点
    if (S - bestS)<tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
    # 另外还需检查两个切分后的子集大小,如果某个子集的大小小于用户定义参tolN,那么也不应切分
    if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
        return None, leafType(dataSet)
    # 如果找到一个“好”的切分方式,则返回特征编号和切分特征值
    return bestIndex, bestValue

# createTree()树构建函数(数据集,其他三个可选参数:建立叶结点的函数、误差计算函数、包含树构建所需其他参数的元组)
# 是一个递归函数
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    # 如果构建的是回归树,该模型是个常数;如果是模型树,该模型是个线性方程
    if feat == None:
        return val
    retTree = { }
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet,feat,val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType,errType,ops)
    return retTree


# 树剪枝
# 9.3 回归树剪枝函数
# isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果
def isTree(obj):
    return (type(obj).__name__ == 'dict')

# getMean()是一个递归函数,从上往下遍历树直到叶节点为止
# 如果找到两个叶节点则计算它们的平均值
# 该函数对树进行塌陷处理(即返回树平均值),在prune()函数中调用该函数时应明确这一点
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0


# 主函数是prune()函数,参数:待剪枝的树与剪枝所需的测试数据testData
def prune(tree, testData):
    # 首先需确认测试集是否为空
    if shape(testData)[0] == 0:
        return getMean(tree)
    # 非空则反复递归调用函数prune()对数据进行切分
    if (isTree(tree['right']) or isTree(tree['left'])) :
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +  sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else: return tree
    else: return tree


# 9.4 模型树的叶节点生成函数
# linearSolve()函数,会被其他两个函数调用
# 其主要功能是将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
    m,n = shape(dataSet)
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError("this matrix is singular, cannot do inverse,try increasing the second value of ops")
    ws = xTx.I * (X.T * Y)
    return ws, X, Y


# modelLesf()函数,当数据不在需要切分时它负责乘胜叶节点的模型
def modelLesf(dataSet):
    # 该函数在数据集上调用linearSolve()函数,并返回回归系数ws
    ws, X, Y = linearSolve(dataSet)
    return ws

# modelErr()函数,在给定的数据集上计算误差,会被chooseBestSplit()函数调用来找到最佳的切分
def modelErr(dataSet):
    # 该函数在数据集上调用linearSolve()函数,之后返回yhat和Y之间的平方误差
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y -yHat,2))


# 9.5 用树回归进行预测的代码
# 要对回归树叶节点进行预测,就调用regTreeEval()函数,对输入数据进行格式化处理
def regTreeEval(model, inDat):
    return float(model)

# 要对模型树叶节点进行预测,就调用modelTreeEval()函数,对输入数据进行格式化处理
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    # 在原数据矩阵上增加第0列,然后计算并返回预测值
    X = mat(ones(1,n+1))
    X[:,1:n+1] = inDat
    return float(X*model)

# 对于输入的单个数据点或者行向量,函数treeForeCast()会返回一个浮点值。
# 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
# 调用函数treeForeCast()时需要指定树的类型,以便在叶节点上能够调用合适的模型
def treeForeCast(tree, inData, modelEval = regTreeEval):
    if not isTree(tree):
        return modelEval(tree, inData)
    if inData(tree['spInd']) > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'],inData,modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData,modelEval)
        else:
            return modelEval(tree['right'], inData)

def createForeCast(tree, testData, modelEval = regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0 ] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

# 9.6 用于构建树管理器界面的Thinter小部件
from numpy import *
from Tkinter import *
import regTree

def reDraw(tolS,tolN):
    pass
def drawNewTree():
    pass

root = Tk()

Label(root, text = "Plot Place Holder").grid(row = 0, columnspan = 2)

Label(root ,text = "tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text = "tolS").grid(row =2,column = 0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command = drawNewTree).grid(row=1,column=2, rowspan = 3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="model tree", varible = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(regTree.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0],max(reDraw.rawDat[:,0],0.01)))
reDraw(1.0,10)
root.mainloop()

猜你喜欢

转载自blog.csdn.net/weixin_42836351/article/details/81393075