对鸢尾花卉数据集训练决策树
python 实现,代码比较乱
import random import math import copy import sys class decisionTree (object): def __init__(self, file, label): self.file = file self.label = label self.dataset = self.initDataset() self.readfile() def initDataset(self): dataset={} for i in range(len(self.label)): dataset[self.label[i]] = [] return dataset def readfile(self): myfile = open(self.file, 'r') for line in myfile: line = line.strip() data = line.split(',') if data[-1] != '': self.dataset[data[-1]].append(data[:-1]) for i in range(len(self.label)): for j in range(len(self.dataset[self.label[i]])): for m in range(len(self.dataset[self.label[i]][j])): self.dataset[self.label[i]][j][m] = float(self.dataset[self.label[i]][j][m]) # 样本lable所占比例 为了取样时各类样本均衡 def k_cross(self,k): num = [] for i in range(len(self.label)): num.append(len(self.dataset[self.label[i]])) all = 0 for i in range(len(num)): all += num[i] for i in range(len(num)): num[i] = num[i] / k for i in range(len(num)): if num[i] - int(num[i]) > 0.5 : num[i] = int(num[i]) + 1 else : num[i] = int(num[i]) #进行k折样本取样 data = [] for i in range(k): data.append([]) all_list = [] for i in range(len(self.label)): all_list.append([]) for i in range(len(self.label)): start =0 for j in range(len(self.dataset[self.label[i]])): all_list[i].append(start) start += 1 for k in range(k): for i in range(len(self.label)): list = random.sample(all_list[i],num[i]) all_list[i] = set(all_list[i])^set(list) for j in range(len(list)): pop_data = self.dataset[self.label[i]][list[j]] pop_data.append(self.label[i]) data[k].append(pop_data) return data # 返回所有k 类的样本 [[[train],[test]],[[train],[test]],[[train],[test]]........ def k_data(self,k): dataset = self.k_cross(k) data = [] for i in range(len(dataset)): data.append([[],[]]) #返回k组 [[训练集] ,[测试集]] 对 for i in range(len(dataset)): data[i][1] = dataset[i] for j in range(len(dataset)): if j != i : for k in range(len(dataset[j])): data[i][0].append(dataset[j][k]) return data # 此处的dataset为应该是 自己给的训练集 应该为一个list [[],[],[],[],[],[],[],[],[],[],[],]这种格式 然后返回自己定义的树结构 def train_tree(self,dataset): #计算信息增益 def ent(dataset): num=[] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j] : num[j] += 1 all = 0 for i in range(len(num)): all += num[i] ent_data = 0 for i in range(len(num)): if num[i] != 0: ent_data -= num[i]/all * math.log2(num[i]/all) return ent_data def Gain(dataset,root,key): #key 为第几类特征 def Gain_sub(dataset,root,key,a): # a 为第几个化分点 ent_data = ent(dataset) sub_data_1 = [] sub_data_2 = [] for i in range(len(dataset)): if dataset[i][key] < root[key][a]: sub_data_1.append(dataset[i]) else : sub_data_2.append(dataset[i]) gain_data = 0 gain_data = ent_data - (len(sub_data_1)/len(dataset)) * ent(sub_data_1) - (len(sub_data_2)/len(dataset)) * ent(sub_data_2) return gain_data gain=[] for i in range(len(root[key])): gain.append(Gain_sub(dataset,root,key,i)) return max(gain),gain.index(max(gain)) def next_opr(dataset): # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益 # 对特征的属性值总结 feature = {} for i in range(len(dataset[0]) - 1): feature[i] = [] for i in range(len(dataset)): for j in range(len(dataset[i]) - 1): if dataset[i][j] not in feature[j]: feature[j].append(dataset[i][j]) for i in range(len(feature.keys())): feature[i] = sorted(feature[i]) # 划分连续值根节点 root = {} for i in range(len(feature.keys())): root[i] = [] for i in range(len(feature.keys())): for j in range(len(feature[i])): if j != len(feature[i]) - 1: root[i].append((feature[i][j] + feature[i][j + 1]) / 2) gain = {} for i in range(len(root.keys())): gain[i] = {} for i in range(len(gain.keys())): for j in range(len(root[i])): gain_data, k = Gain(dataset, root, i) gain[i][root[i][k]] = gain_data return gain #以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset def key_root(my_gain): #对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点 key_1 = list(my_gain.keys()) max = 0 key = 0 root = 0 for i in range(len(my_gain)): key_2 = list(my_gain[key_1[i]].keys()) for j in range(len(key_2)): if my_gain[key_1[i]][key_2[j]] > max: max = my_gain[key_1[i]][key_2[j]] key = key_1[i] root = key_2[j] return key,root tree = [] my_gain = next_opr(dataset) key,root = key_root(my_gain) tree.append([key,root]) tree.append([]) tree.append([]) #通过key root 划分剩余数据集 dataset def sub_root(dataset,key,root): sub_left = [] sub_right = [] for i in range(len(dataset)): if dataset[i][key] < root: sub_left.append(dataset[i]) else: sub_right.append(dataset[i]) return sub_left,sub_right sub_left, sub_right = sub_root(dataset,key,root) tree[1] = sub_left tree[2] = sub_right #检测左子树右子树中的样本是否为同一label def test(dataset): num=[] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] +=1 if max(num) == len(dataset): return self.label[num.index(max(num))] else: return dataset def next(tree): for i in range(len(tree)): if i != 0: tree[i] = test(tree[i]) for i in range(len(tree)): if i != 0: if tree[i] not in self.label: dataset = tree[i] tree[i] = [] gains = next_opr(dataset) key,root = key_root(gains) tree[i].append([key,root]) tree[i].append([]) tree[i].append([]) left,right = sub_root(dataset,key,root) tree[i][1] = test(left) tree[i][2] = test(right) next(tree[i]) next(tree) return tree #进行预剪枝的训练树 def pre_pruning(self,train_dataset,test_dataset): def next_opr(dataset): # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益 def Gain(dataset, root, key): # key 为第几类特征 # 计算信息增益 def ent(dataset): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 all = 0 for i in range(len(num)): all += num[i] ent_data = 0 for i in range(len(num)): if num[i] != 0: ent_data -= num[i] / all * math.log2(num[i] / all) return ent_data def Gain_sub(dataset, root, key, a): # a 为第几个化分点 ent_data = ent(dataset) sub_data_1 = [] sub_data_2 = [] for i in range(len(dataset)): if dataset[i][key] < root[key][a]: sub_data_1.append(dataset[i]) else: sub_data_2.append(dataset[i]) gain_data = 0 gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len( dataset)) * ent(sub_data_2) return gain_data gain = [] for i in range(len(root[key])): gain.append(Gain_sub(dataset, root, key, i)) return max(gain), gain.index(max(gain)) # 对特征的属性值总结 feature = {} for i in range(len(dataset[0]) - 1): feature[i] = [] for i in range(len(dataset)): for j in range(len(dataset[i]) - 1): if dataset[i][j] not in feature[j]: feature[j].append(dataset[i][j]) for i in range(len(feature.keys())): feature[i] = sorted(feature[i]) # 划分连续值根节点 root = {} for i in range(len(feature.keys())): root[i] = [] for i in range(len(feature.keys())): for j in range(len(feature[i])): if j != len(feature[i]) - 1: root[i].append((feature[i][j] + feature[i][j + 1]) / 2) gain = {} for i in range(len(root.keys())): gain[i] = {} for i in range(len(gain.keys())): for j in range(len(root[i])): gain_data, k = Gain(dataset, root, i) gain[i][root[i][k]] = gain_data return gain # 以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset def key_root(my_gain): # 对于存储有信息增益的结构进行解析 返回最大信息增益的key root 返回key 第几类特征 root 划分点 返回 key_1 = list(my_gain.keys()) max = 0 key = 0 root = 0 for i in range(len(my_gain)): key_2 = list(my_gain[key_1[i]].keys()) for j in range(len(key_2)): if my_gain[key_1[i]][key_2[j]] > max: max = my_gain[key_1[i]][key_2[j]] key = key_1[i] root = key_2[j] return key, root # 通过key root 划分剩余数据集 dataset def sub_root(dataset, key, root): sub_left = [] sub_right = [] for i in range(len(dataset)): if dataset[i][key] < root: sub_left.append(dataset[i]) else: sub_right.append(dataset[i]) return sub_left, sub_right def max_lable(dataset): # 返回label 样本数 最大的 label if type(dataset) == type([]): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 max = 0 for i in range(len(num)): if num[i] > max: max = num[i] return self.label[num.index(max)] elif type(dataset) == type('abc'): return dataset def test_tree(tree,dataset): #这棵树来验证 dataset的准确度 def test_label(train_tree,test_data): #用树来验证这个数据是否验证正确 label = None if test_data[train_tree[0][0]] < train_tree[0][1]: if train_tree[1] not in self.label: train_tree = train_tree[1] label = test_label(train_tree, test_data) else: label = train_tree[1] else: if train_tree[2] not in self.label: train_tree = train_tree[2] label = test_label(train_tree, test_data) else: label = train_tree[2] if label == test_data[-1]: return True else: return False all_num = len(dataset) right = 0 for i in range(len(dataset)): if test_label(tree,dataset[i]): right += 1 return right/all_num # 初始化树 也就是未进行第一个根节点划分前 tree = [[0,0],max_lable(train_dataset),max_lable(train_dataset)] tree_data = train_dataset #初始化进行根节点划分操作的树 next_tree = [] next_tree_data=[] my_gain = next_opr(train_dataset) key, root = key_root(my_gain) next_tree.append([key, root]) next_tree.append([]) next_tree.append([]) next_tree_data.append([key,root]) next_tree_data.append([]) next_tree_data.append([]) sub_left, sub_right = sub_root(train_dataset, key, root) next_tree[1] = max_lable(sub_left) next_tree[2] = max_lable(sub_right) next_tree_data[1] = sub_left next_tree_data[2] = sub_right # 检测左子树右子树中的样本是否为同一label def test(dataset): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 if max(num) == len(dataset): return self.label[num.index(max(num))] else: return dataset #传入两颗树 然后判断是否进行操作 def pruning(first,first_data,next,next_data,test_dataset): def next_tree(tree_data,bool): #看是不是符合剪枝操作 #这里应该传入的是 next_data if bool : for i in range(len(tree_data)): if i != 0: tree_data[i] = test(tree_data[i]) for i in range(len(tree_data)): if i != 0 and type(tree_data[i]) == type([]): next_data = copy.deepcopy(tree_data) # 这样两个数据就不会存在同一快内存地址 if tree_data[i] not in self.label: dataset = tree_data[i] tree_data[i] = [] gains = next_opr(dataset) key, root = key_root(gains) tree_data[i].append([key, root]) tree_data[i].append([]) tree_data[i].append([]) next_data[i] =[] next_data[i].append([key,root]) next_data[i].append([]) next_data[i].append([]) left, right = sub_root(dataset, key, root) next_data[i][1] = left next_data[i][2] = right tree_data[i][1] = max_lable(left) tree_data[i][2] = max_lable(right) break; else : for i in range(len(tree_data)): if i != 0: tree_data[i] = test(tree_data[i]) for i in range(len(tree_data)): if i != 0 and type(tree_data[i]) == type([]): next_data = copy.copy(tree_data) dataset = None if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]): str = tree_data[i][1] dataset = tree_data[i][2].append(str) elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]): str = tree_data[i][2] dataset = tree_data[i][1].append(str) else: dataset = tree_data[i][1] + tree_data[i][2] next_data[i] = max_lable(dataset) break; for i in range(len(tree_data)): if i != 0 and type(tree_data[i]) == type([]): dataset = None if type(tree_data[i][1]) == type('str') and type(tree_data[i][2]) == type([]): dataset = tree_data[i][2].append(tree_data[i][1]) elif type(tree_data[i][2]) == type('str') and type(tree_data[i][1]) == type([]): dataset = tree_data[i][1].append(tree_data[i][2]) else: dataset = tree_data[i][1] + tree_data[i][2] tree_data[i] = max_lable(dataset) return tree_data,next_data while(next != next_data): if test_tree(first, test_dataset) < test_tree(next, test_dataset): first = next first_data = copy.deepcopy(next_data) next, next_data = next_tree(next_data, True) else: next, next_data = next_tree(next_data, False) return next new_tree = pruning(tree,tree_data,next_tree,next_tree_data,test_dataset) return new_tree def post_pruning(self,train_dataset,test_dataset): # 计算信息增益 def ent(dataset): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 all = 0 for i in range(len(num)): all += num[i] ent_data = 0 for i in range(len(num)): if num[i] != 0: ent_data -= num[i] / all * math.log2(num[i] / all) return ent_data def Gain(dataset, root, key): # key 为第几类特征 def Gain_sub(dataset, root, key, a): # a 为第几个化分点 ent_data = ent(dataset) sub_data_1 = [] sub_data_2 = [] for i in range(len(dataset)): if dataset[i][key] < root[key][a]: sub_data_1.append(dataset[i]) else: sub_data_2.append(dataset[i]) gain_data = 0 gain_data = ent_data - (len(sub_data_1) / len(dataset)) * ent(sub_data_1) - (len(sub_data_2) / len( dataset)) * ent(sub_data_2) return gain_data gain = [] for i in range(len(root[key])): gain.append(Gain_sub(dataset, root, key, i)) return max(gain), gain.index(max(gain)) def next_opr(dataset): # 对于这个dataset中的每种特征进行信息增益计算 结构为双重字典 第一个key为第几类特征 第二个key为划分点 最后value为 信息增益 # 对特征的属性值总结 feature = {} for i in range(len(dataset[0]) - 1): feature[i] = [] for i in range(len(dataset)): for j in range(len(dataset[i]) - 1): if dataset[i][j] not in feature[j]: feature[j].append(dataset[i][j]) for i in range(len(feature.keys())): feature[i] = sorted(feature[i]) # 划分连续值根节点 root = {} for i in range(len(feature.keys())): root[i] = [] for i in range(len(feature.keys())): for j in range(len(feature[i])): if j != len(feature[i]) - 1: root[i].append((feature[i][j] + feature[i][j + 1]) / 2) gain = {} for i in range(len(root.keys())): gain[i] = {} for i in range(len(gain.keys())): for j in range(len(root[i])): gain_data, k = Gain(dataset, root, i) gain[i][root[i][k]] = gain_data return gain # 以list为最后树的储存方式 [root,[],[]] list[0]为root list[1] 为左子树 list[2]为右子树 若没有label 则子节点存储dataset def key_root(my_gain): # 对于存储有信息增益的结构进行解析 返回key 第几类特征 root 划分点 key_1 = list(my_gain.keys()) max = 0 key = 0 root = 0 for i in range(len(my_gain)): key_2 = list(my_gain[key_1[i]].keys()) for j in range(len(key_2)): if my_gain[key_1[i]][key_2[j]] > max: max = my_gain[key_1[i]][key_2[j]] key = key_1[i] root = key_2[j] return key, root tree = [] my_gain = next_opr(train_dataset) key, root = key_root(my_gain) tree.append([key, root]) tree.append([]) tree.append([]) # 通过key root 划分剩余数据集 dataset def sub_root(dataset, key, root): sub_left = [] sub_right = [] for i in range(len(dataset)): if dataset[i][key] < root: sub_left.append(dataset[i]) else: sub_right.append(dataset[i]) return sub_left, sub_right sub_left, sub_right = sub_root(train_dataset, key, root) tree[1] = sub_left tree[2] = sub_right # 检测左子树右子树中的样本是否为同一label def test(dataset): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 if max(num) == len(dataset): return self.label[num.index(max(num))] else: return dataset def next(tree): for i in range(len(tree)): if i != 0: tree[i] = test(tree[i]) for i in range(len(tree)): if i != 0: if tree[i] not in self.label: dataset = tree[i] tree[i] = [] gains = next_opr(dataset) key, root = key_root(gains) tree[i].append([key, root]) tree[i].append([]) tree[i].append([]) left, right = sub_root(dataset, key, root) tree[i][1] = test(left) tree[i][2] = test(right) next(tree[i]) next(tree) tree_data = copy.deepcopy(tree) def clear_lable_add_data(tree_data,train_dataset): def clear_lable(tree_data): for i in range(len(tree_data)): if i!= 0: if type(tree_data[i]) == type('str'): tree_data[i] = [] elif type(tree_data) == type([]): clear_lable(tree_data[i]) clear_lable(tree_data) def put_data(tree,data): if data[tree[0][0]] < tree[0][1] : if len(tree[1]) != 3 : tree[1].append(data) else: if len(tree[1][0]) == 2: put_data(tree[1],data) else: tree[1].append(data) elif data[tree[0][0]] > tree[0][1] : if len(tree[2]) != 3 : tree[2].append(data) else: if len(tree[2][0]) == 2: put_data(tree[2],data) else: tree[2].append(data) for i in range(len(train_dataset)): put_data(tree_data,train_dataset[i]) return tree_data tree_data = clear_lable_add_data(tree_data,train_dataset) def post_purning(tree,tree_data,test_dataset): def test_tree(tree, dataset): # 这棵树来验证 dataset的准确度 def test_label(train_tree, test_data): # 用树来验证这个数据是否验证正确 label = None if test_data[train_tree[0][0]] < train_tree[0][1]: if train_tree[1] not in self.label: train_tree = train_tree[1] label = test_label(train_tree, test_data) else: label = train_tree[1] return label else: if train_tree[2] not in self.label: train_tree = train_tree[2] label = test_label(train_tree, test_data) else: label = train_tree[2] return label return label all_num = len(dataset) right = 0 for i in range(len(dataset)): if test_label(tree, dataset[i]) == dataset[i][-1]: right += 1 return right / all_num def branchs(tree): def branch(tree): list = [] for i in range(len(tree)): if len(tree[i]) == 3 and len(tree[i][0]) == 2 : if len(tree[i][1]) == 3 or len(tree[i][2]) ==3: list.append(i) list.append(branch(tree[i])) elif len(tree[i][1][0]) != 2 and len(tree[i][2][0]) !=2: list.append(i) return list branch_dict = branch(tree) return branch_dict def process(br): all =[] def all_num(list): key = 0 for i in range(len(list)): if type(list[i]) == type(list): return False else: key+=1 if key == len(list): return True def empty_list(list): #判断是否有空的list all_key = 0 for i in range(len(list)): if list[i] == []: return True else: all_key +=1 if all_key == len(list): return False def br_tree(br, list): for i in range(len(br)): if type(br[i]) == type([]) and len(br[i]) != 1: br_tree(br[i], list) if empty_list(br[i]): list.insert(0, br[i][br[i].index([])-1]) br[i].pop(br[i].index([])) list.insert(0, br[i - 1]) else: list.insert(0,br[i-1]) break elif type(br[i]) == type([]) and len(br[i]) == 1: num = br[i].pop(0) list.append(num) break elif type(br[i]) == type([]) and len(br[i]) == 0: br.pop(br.index([])) br.pop(-1) elif all_num(br): num = br.pop(0) list.append(num) break elif type(br[0]) == type(1) and type(br[1]) == type([]) and br[1] == []: br.pop(-1) break while (br != []): list = [] br_tree(br, list) if len(list) == 0: list.append(br[0]) br.pop(-1) elif len(list) == 1: list.insert(0,br[0]) all.append(list) return all def max_lable(dataset): # 返回label 样本数 最大的 label if type(dataset) == type([]): num = [] for i in range(len(self.label)): num.append(0) for i in range(len(dataset)): for j in range(len(self.label)): if dataset[i][-1] == self.label[j]: num[j] += 1 max = 0 for i in range(len(num)): if num[i] > max: max = num[i] return self.label[num.index(max)] elif type(dataset) == type('abc'): return dataset def tree_lable(tree_data): for i in range(len(tree_data)): if i != 0 and len(tree_data[i]) ==3 and len(tree_data[i][0]) == 2: tree_data[i]=tree_lable(tree_data[i]) elif i != 0 and len(tree_data[i]) !=3 : tree_data[i] = max_lable(tree_data[i]) elif i !=0 and len(tree_data[i]) ==3 and len(tree_data[i][0]) != 2: tree_data[i] = max_lable(tree_data[i]) return tree_data br = branchs(tree) first = tree first_data = tree_data br = process(br) def tree_list(tree_data,list): #使用一个列表来对树进行剪枝操作 def the_Data(tree): the_da = [] for i in range(len(tree)): if len(tree[i]) == 3 and len(tree[i][0]) == 2: the_da.extend(the_Data(tree[i])) elif i != 0 and len(tree[i]) !=3 : the_da.extend(tree[i]) elif i != 0 and len(tree[i]) == 3 and type(tree[i][0][-1]) == type('str'): the_da.extend(tree[i]) return the_da def dir(tree_data,list): if len(list) == 0 : tree_data = the_Data(tree_data) return tree_data else: tree_data[list[0]] = dir(tree_data[list[0]],list[1:]) return tree_data dir(tree_data,list) return tree_data # 利用处理后的所有剪枝节点序列 来进行剪枝操作 for i in range(len(br)): second_data = copy.deepcopy(first_data) tree_list(second_data,br[i]) data = copy.deepcopy(second_data) second = tree_lable(data) firs_acc = test_tree(first,test_dataset) sec_acc = test_tree(second,test_dataset) # print(firs_acc) # print(sec_acc) if firs_acc <= sec_acc: first = second first_data = second_data return first tree = post_purning(tree,tree_data,test_dataset) return tree def label_sample(self,train_tree,test_data): def process(train_tree,test_data): label = None if test_data[train_tree[0][0]] < train_tree[0][1]: if train_tree[1] not in self.label: train_tree = train_tree[1] label = process(train_tree,test_data) else: label = train_tree[1] else: if train_tree[2] not in self.label: train_tree = train_tree[2] label =process(train_tree,test_data) else: label = train_tree[2] return label label = process(train_tree,test_data) # print(test_data[:-1],'\'s true label:',test_data[-1],' predict label is :',label) return label def k_accuracy(self,data): def accuracy(train,test): acc = 0 my_tree = self.train_tree(train) for i in range(len(test)): if test[i][-1] == self.label_sample(my_tree,test[i]): acc += 1 return acc/len(test) for i in range(len(data)): print('No-pruning Processing',i+1,'batch.....') print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....') print('Batch',i+1,'is finished.....') print('The purning is finished..') print('*****************************************') def pre_accuracy(self,data): def accuracy(train,test): acc = 0 my_tree = self.pre_pruning(train,test) for i in range(len(test)): if test[i][-1] == self.label_sample(my_tree,test[i]): acc += 1 return acc/len(test) for i in range(len(data)): print('Pre-pruning Processing',i+1,'batch.....') print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....') print('Batch',i+1,'is finished.....') print('The purning is finished..') print('*****************************************') def post_accuracy(self,data): def accuracy(train,test): acc = 0 my_tree = self.post_pruning(train,test) for i in range(len(test)): if test[i][-1] == self.label_sample(my_tree,test[i]): acc += 1 return acc/len(test) for i in range(len(data)): print('Post-pruning Processing',i+1,'batch.....') print('The accuracy is',accuracy(data[i][0],data[i][1]),'.....') print('Batch',i+1,'is finished.....') print('The purning is finished..') print('*****************************************') if __name__ == '__main__': label = ['Iris-setosa','Iris-versicolor','Iris-virginica'] tree = decisionTree('iris.data',label) k_data = tree.k_data(5) tree.k_accuracy(k_data) tree.pre_accuracy(k_data) tree.post_accuracy(k_data)以下是数据集
鸢尾花卉Iris数据集描述:
iris是鸢尾植物,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类。所以该数据集一共包含4个特征变量,1个类别变量。共有150个样本,鸢尾有三个亚属,分别是山鸢尾 (Iris-setosa),变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。
也就是说我们的数据集里每个样本含有四个属性,并且我们的任务是个三分类问题。三个类别分别为:Iris Setosa(山鸢尾),Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾)。
例如:
样本一:5.1, 3.5, 1.4, 0.2, Iris-setosa
其中“5.1,3.5,1.4,0.2”代表当前样本的四个属性的取值,“Iris-setosa”代表当前样本的类别。
5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setosa 4.6,3.1,1.5,0.2,Iris-setosa 5.0,3.6,1.4,0.2,Iris-setosa 5.4,3.9,1.7,0.4,Iris-setosa 4.6,3.4,1.4,0.3,Iris-setosa 5.0,3.4,1.5,0.2,Iris-setosa 4.4,2.9,1.4,0.2,Iris-setosa 4.9,3.1,1.5,0.1,Iris-setosa 5.4,3.7,1.5,0.2,Iris-setosa 4.8,3.4,1.6,0.2,Iris-setosa 4.8,3.0,1.4,0.1,Iris-setosa 4.3,3.0,1.1,0.1,Iris-setosa 5.8,4.0,1.2,0.2,Iris-setosa 5.7,4.4,1.5,0.4,Iris-setosa 5.4,3.9,1.3,0.4,Iris-setosa 5.1,3.5,1.4,0.3,Iris-setosa 5.7,3.8,1.7,0.3,Iris-setosa 5.1,3.8,1.5,0.3,Iris-setosa 5.4,3.4,1.7,0.2,Iris-setosa 5.1,3.7,1.5,0.4,Iris-setosa 4.6,3.6,1.0,0.2,Iris-setosa 5.1,3.3,1.7,0.5,Iris-setosa 4.8,3.4,1.9,0.2,Iris-setosa 5.0,3.0,1.6,0.2,Iris-setosa 5.0,3.4,1.6,0.4,Iris-setosa 5.2,3.5,1.5,0.2,Iris-setosa 5.2,3.4,1.4,0.2,Iris-setosa 4.7,3.2,1.6,0.2,Iris-setosa 4.8,3.1,1.6,0.2,Iris-setosa 5.4,3.4,1.5,0.4,Iris-setosa 5.2,4.1,1.5,0.1,Iris-setosa 5.5,4.2,1.4,0.2,Iris-setosa 4.9,3.1,1.5,0.1,Iris-setosa 5.0,3.2,1.2,0.2,Iris-setosa 5.5,3.5,1.3,0.2,Iris-setosa 4.9,3.1,1.5,0.1,Iris-setosa 4.4,3.0,1.3,0.2,Iris-setosa 5.1,3.4,1.5,0.2,Iris-setosa 5.0,3.5,1.3,0.3,Iris-setosa 4.5,2.3,1.3,0.3,Iris-setosa 4.4,3.2,1.3,0.2,Iris-setosa 5.0,3.5,1.6,0.6,Iris-setosa 5.1,3.8,1.9,0.4,Iris-setosa 4.8,3.0,1.4,0.3,Iris-setosa 5.1,3.8,1.6,0.2,Iris-setosa 4.6,3.2,1.4,0.2,Iris-setosa 5.3,3.7,1.5,0.2,Iris-setosa 5.0,3.3,1.4,0.2,Iris-setosa 7.0,3.2,4.7,1.4,Iris-versicolor 6.4,3.2,4.5,1.5,Iris-versicolor 6.9,3.1,4.9,1.5,Iris-versicolor 5.5,2.3,4.0,1.3,Iris-versicolor 6.5,2.8,4.6,1.5,Iris-versicolor 5.7,2.8,4.5,1.3,Iris-versicolor 6.3,3.3,4.7,1.6,Iris-versicolor 4.9,2.4,3.3,1.0,Iris-versicolor 6.6,2.9,4.6,1.3,Iris-versicolor 5.2,2.7,3.9,1.4,Iris-versicolor 5.0,2.0,3.5,1.0,Iris-versicolor 5.9,3.0,4.2,1.5,Iris-versicolor 6.0,2.2,4.0,1.0,Iris-versicolor 6.1,2.9,4.7,1.4,Iris-versicolor 5.6,2.9,3.6,1.3,Iris-versicolor 6.7,3.1,4.4,1.4,Iris-versicolor 5.6,3.0,4.5,1.5,Iris-versicolor 5.8,2.7,4.1,1.0,Iris-versicolor 6.2,2.2,4.5,1.5,Iris-versicolor 5.6,2.5,3.9,1.1,Iris-versicolor 5.9,3.2,4.8,1.8,Iris-versicolor 6.1,2.8,4.0,1.3,Iris-versicolor 6.3,2.5,4.9,1.5,Iris-versicolor 6.1,2.8,4.7,1.2,Iris-versicolor 6.4,2.9,4.3,1.3,Iris-versicolor 6.6,3.0,4.4,1.4,Iris-versicolor 6.8,2.8,4.8,1.4,Iris-versicolor 6.7,3.0,5.0,1.7,Iris-versicolor 6.0,2.9,4.5,1.5,Iris-versicolor 5.7,2.6,3.5,1.0,Iris-versicolor 5.5,2.4,3.8,1.1,Iris-versicolor 5.5,2.4,3.7,1.0,Iris-versicolor 5.8,2.7,3.9,1.2,Iris-versicolor 6.0,2.7,5.1,1.6,Iris-versicolor 5.4,3.0,4.5,1.5,Iris-versicolor 6.0,3.4,4.5,1.6,Iris-versicolor 6.7,3.1,4.7,1.5,Iris-versicolor 6.3,2.3,4.4,1.3,Iris-versicolor 5.6,3.0,4.1,1.3,Iris-versicolor 5.5,2.5,4.0,1.3,Iris-versicolor 5.5,2.6,4.4,1.2,Iris-versicolor 6.1,3.0,4.6,1.4,Iris-versicolor 5.8,2.6,4.0,1.2,Iris-versicolor 5.0,2.3,3.3,1.0,Iris-versicolor 5.6,2.7,4.2,1.3,Iris-versicolor 5.7,3.0,4.2,1.2,Iris-versicolor 5.7,2.9,4.2,1.3,Iris-versicolor 6.2,2.9,4.3,1.3,Iris-versicolor 5.1,2.5,3.0,1.1,Iris-versicolor 5.7,2.8,4.1,1.3,Iris-versicolor 6.3,3.3,6.0,2.5,Iris-virginica 5.8,2.7,5.1,1.9,Iris-virginica 7.1,3.0,5.9,2.1,Iris-virginica 6.3,2.9,5.6,1.8,Iris-virginica 6.5,3.0,5.8,2.2,Iris-virginica 7.6,3.0,6.6,2.1,Iris-virginica 4.9,2.5,4.5,1.7,Iris-virginica 7.3,2.9,6.3,1.8,Iris-virginica 6.7,2.5,5.8,1.8,Iris-virginica 7.2,3.6,6.1,2.5,Iris-virginica 6.5,3.2,5.1,2.0,Iris-virginica 6.4,2.7,5.3,1.9,Iris-virginica 6.8,3.0,5.5,2.1,Iris-virginica 5.7,2.5,5.0,2.0,Iris-virginica 5.8,2.8,5.1,2.4,Iris-virginica 6.4,3.2,5.3,2.3,Iris-virginica 6.5,3.0,5.5,1.8,Iris-virginica 7.7,3.8,6.7,2.2,Iris-virginica 7.7,2.6,6.9,2.3,Iris-virginica 6.0,2.2,5.0,1.5,Iris-virginica 6.9,3.2,5.7,2.3,Iris-virginica 5.6,2.8,4.9,2.0,Iris-virginica 7.7,2.8,6.7,2.0,Iris-virginica 6.3,2.7,4.9,1.8,Iris-virginica 6.7,3.3,5.7,2.1,Iris-virginica 7.2,3.2,6.0,1.8,Iris-virginica 6.2,2.8,4.8,1.8,Iris-virginica 6.1,3.0,4.9,1.8,Iris-virginica 6.4,2.8,5.6,2.1,Iris-virginica 7.2,3.0,5.8,1.6,Iris-virginica 7.4,2.8,6.1,1.9,Iris-virginica 7.9,3.8,6.4,2.0,Iris-virginica 6.4,2.8,5.6,2.2,Iris-virginica 6.3,2.8,5.1,1.5,Iris-virginica 6.1,2.6,5.6,1.4,Iris-virginica 7.7,3.0,6.1,2.3,Iris-virginica 6.3,3.4,5.6,2.4,Iris-virginica 6.4,3.1,5.5,1.8,Iris-virginica 6.0,3.0,4.8,1.8,Iris-virginica 6.9,3.1,5.4,2.1,Iris-virginica 6.7,3.1,5.6,2.4,Iris-virginica 6.9,3.1,5.1,2.3,Iris-virginica 5.8,2.7,5.1,1.9,Iris-virginica 6.8,3.2,5.9,2.3,Iris-virginica 6.7,3.3,5.7,2.5,Iris-virginica 6.7,3.0,5.2,2.3,Iris-virginica 6.3,2.5,5.0,1.9,Iris-virginica 6.5,3.0,5.2,2.0,Iris-virginica 6.2,3.4,5.4,2.3,Iris-virginica 5.9,3.0,5.1,1.8,Iris-virginica