决策树转规则


有些决策、分类的规则,手写比较麻烦,但用机器学习模型,比如LR搞的话又比较难运营和理解。这时,通过少node的决策树模型来做,并将其生成规则,是一个折衷的解决方案。


import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import _tree

trainx = []
trainy = []
with open('vm06.xy') as fd:
    fdl = fd.readline()
    while len(fdl) > 0:
        v = fdl.split(' ')
        trainx.append(np.asarray([float(v[2]), float(v[3]), float(v[4]), float(v[5].strip())])) #v[2]~v[5]是特征
        trainy.append(float(v[1]) > 60) #v[1]的值用于分类,大于60为True,小于等于60为False
        fdl = fd.readline()

regressor = DecisionTreeRegressor(max_leaf_nodes=8)
regressor.fit(np.asarray(trainx), np.asarray(trainy))

res = regressor.predict(trainx[39:51])
print (res, trainy[39:51])

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print ("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print ("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print ("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print ("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)

tree_to_code (regressor, ["length", "width", "height", "fps"])

输出


def tree(length, width, height, fps):
  if length <= 205.5:
    if length <= 91.5:
      return [[ 0.00090733]]
    else:  # if length > 91.5
      if width <= 1703.0:
        return [[ 0.02891943]]
      else:  # if width > 1703.0
        return [[ 0.81340058]]
  else:  # if length > 205.5
    if width <= 859.0:
      if length <= 795.0:
        return [[ 0.05918367]]
      else:  # if length > 795.0
        return [[ 0.75434531]]
    else:  # if width > 859.0
      if height <= 702.0:
        if length <= 596.5:
          return [[ 0.12064343]]
        else:  # if length > 596.5
          return [[ 0.93028025]]
      else:  # if height > 702.0
        return [[ 0.892728]]


猜你喜欢

转载自blog.csdn.net/mao_feng/article/details/73920261