main function
clc;
clear all;
%西瓜数据集
data = ["青绿","蜷缩","浊响","清晰","凹陷","硬滑","是";
"乌黑","蜷缩","沉闷","清晰","凹陷","硬滑","是";
"乌黑","蜷缩","浊响","清晰","凹陷","硬滑","是";
"青绿","蜷缩","沉闷","清晰","凹陷","硬滑","是";
"浅白","蜷缩","浊响","清晰","凹陷","硬滑","是";
"青绿","稍蜷","浊响","清晰","稍凹","软粘","是";
"乌黑","稍蜷","浊响","稍糊","稍凹","软粘","是";
"乌黑","稍蜷","浊响","清晰","稍凹","硬滑","是";
"乌黑","稍蜷","沉闷","稍糊","稍凹","硬滑","否";
"青绿","硬挺","清脆","清晰","平坦","软粘","否";
"浅白","硬挺","清脆","模糊","平坦","硬滑","否";
"浅白","蜷缩","浊响","模糊","平坦","软粘","否";
"青绿","稍蜷","浊响","稍糊","凹陷","硬滑","否";
"浅白","稍蜷","沉闷","稍糊","凹陷","硬滑","否";
"乌黑","稍蜷","浊响","清晰","稍凹","软粘","否";
"浅白","蜷缩","浊响","模糊","平坦","硬滑","否";
"青绿","蜷缩","沉闷","稍糊","稍凹","硬滑","否"];
%属性集合
label = ["色泽","根蒂","敲声","纹理","脐部","触感","好瓜"];
% 参数预定义
datasetRate = 1;
dataSize = size(data);
% 数据预处理
% index = randperm(dataSize(1,1),round(datasetRate*(dataSize(1,1)-1)));
index =[1:17];
trainSet = data(index,:);
testSet = data;
testSet(index,:) = [];
% 所有标签
deepth = ones(1,dataSize(1,2)-1);
% 生成树
rootNode = makeTree(label,trainSet,deepth,'null');
% 画出决策树
drawTree(rootNode);
Decision tree drawing sub-function
% 画出决策树
function [] = drawTree(node)
% 遍历树
nodeVec = [];
nodeSpec = [];
edgeSpec = [];
[nodeVec,nodeSpec,edgeSpec,total] = travesing(node,0,0,nodeVec,nodeSpec,edgeSpec);
treeplot(nodeVec);
[x,y] = treelayout(nodeVec);
[m,n] = size(nodeVec);
x = x';
y = y';
text(x(:,1),y(:,1),nodeSpec,'VerticalAlignment','bottom','HorizontalAlignment','right');
x_branch = [];
y_branch = [];
for i = 2:n
x_branch = [x_branch; (x(i,1)+x(nodeVec(i),1))/2];
y_branch = [y_branch; (y(i,1)+y(nodeVec(i),1))/2];
end
text(x_branch(:,1),y_branch(:,1),edgeSpec(1,2:n),'VerticalAlignment','bottom','HorizontalAlignment','right');
end
% 遍历树
function [nodeVec,nodeSpec,edgeSpec,current_count] = travesing(node,current_count,last_node,nodeVec,nodeSpec,edgeSpec)
nodeVec = [nodeVec last_node];
if node.value == 'null'
nodeSpec = [nodeSpec node.label];
else
if node.value == '是'
nodeSpec = [nodeSpec '好瓜'];
else
nodeSpec = [nodeSpec '坏瓜'];
end
end
edgeSpec = [edgeSpec node.branch];
current_count = current_count + 1;
current_node = current_count;
if node.value ~= 'null'
return;
end
for next_ndoe = node.children
[nodeVec,nodeSpec,edgeSpec,current_count] = travesing(next_ndoe,current_count,current_node,nodeVec,nodeSpec,edgeSpec);
end
end
Generate decision tree subfunction
% 生成决策树
function node = makeTree(features,examples,deepth,branch)
% feature:样本分类依据的所有标签
% examples:样本
% deepth:树的深度,每被分类一次与分类标签对应的值置零
% value:分类结果,若为null则表示该节点是分支节点
% label:节点划分标签
% branch:分支值
% children:子节点
node = struct('value','null','label',[],'branch',branch,'children',[]);
[m,n] = size(examples); %m=17;n=7
sample = examples(1,n); %获取第一行瓜的标签
check_res = true;
for i = 1:m
if sample ~= examples(i,n) %判断第一行瓜的属性是否跟其他瓜的一样;不一样给false
check_res = false;%即使相等check_res仍未归1
end
end
% 检测样本是否全部为同意分类结果,若相同则运行下面判断,只要出现一个0则不运行判断
% 若样本中全为同一分类结果 则作为叶节点
if check_res
node.value = examples(1,n);
return;
end
% 计算熵不纯度
impurity = calculateImpurity(examples);
% 选择合适的标签
bestLabel = getBestlabel(impurity,deepth,examples);
deepth(bestLabel) = 0;
node.label = features(bestLabel);
% 分类
grouping_res = strings;
count = 1;
for i = 1:m
pos = grouping_res == examples(i,bestLabel);
if sum(pos)
% 分类样本 计算同一标签类别的样本数量
else
% 将标签的类别添加到统计结果
grouping_res(count) = examples(i,bestLabel);
count = count + 1;
end
end
for k = grouping_res
sub_sample = examples(examples(:,bestLabel)==k,:);
node.children = [node.children makeTree(features,sub_sample,deepth,k)];
end
end
decision subfunction
% 决策过程 获取信息增量最大的分类标准
function label = getBestlabel(impurity_,features_,samples_)
% impurity_:划分前的熵不纯度
% features_:当前可供分类的标签 是01矩阵
% samples_:当前需要分类的样本
[m,n]=size(samples_);
delta_impurity = zeros(1,n-1);
% 遍历每个特征 每个特征把m个样本分为t组 每组m_t个样本 计算每个特征的不纯度减少量delta_impurity(i)
% 输入样本为m行n列矩阵 特征总数量为n-1
for i = 1:n-1
% 存放分类结果
count = 1;
grouping_res = strings;
sample_nums = [];
grouped_impurity = [];% 分类结果按分组计算熵不纯度
grouped_P = [];
% 如果features_(i)为1 说明该分支上该标签还未用于分类
if features_(i) == 1
% 分组
for j = 1:m
pos = grouping_res == samples_(j,i);
if sum(pos)
% 分类样本 计算同一标签类别的样本数量
sample_nums(pos) = sample_nums(pos) + 1;
else
% 将标签的类别添加到统计结果
sample_nums = [sample_nums 1];
grouping_res(count) = samples_(j,i);
count = count + 1;
end
end
% 计算该分类结果的不纯度减少量
% 按分组计算熵不纯度
D1=[];D2=[];D3=[];Dv=[];
for k=grouping_res
sub_sample = samples_(samples_(:,i)==k,:);
Dv=[Dv,size(sub_sample,1)];
grouped_impurity = [grouped_impurity calculateImpurity(sub_sample)];
end
Gain(i)=impurity_- sum(Dv/size(samples_,1).*grouped_impurity);
IV_a = -sum(Dv/size(samples_,1))*log2( Gain(i)/size(samples_,1));
Gain_ratio(i) = Gain(i)/IV_a %增益率
% for k = grouping_res
% sub_sample = samples_(samples_(:,i)==k,:);
% grouped_impurity = [grouped_impurity calculateImpurity(sub_sample)];
% grouped_P = [grouped_P sum(sub_sample(:,n)=='是')/sum(samples_(:,i)==k)];
% end
% delta_impurity(i) = impurity_ - sum(grouped_P.*grouped_impurity);
% end
% end
% 返回的label是索引数组
% temp = delta_impurity==max(delta_impurity);
temp= Gain==max(Gain);
% 如果存在多个结果一样的标签 则使用第一个
label = find(temp,1);
end
end
end
Information entropy subfunction
% 计算熵不纯度
function res = calculateImpurity(examples_)
P1 = 0;
P2 = 0;
[m_,n_] = size(examples_);
P1 = sum(examples_(:,n_) == '是');
P2 = sum(examples_(:,n_) == '否');
P1 = P1 / m_;
P2 = P2 / m_;
if P1 == 1 || P1 == 0
res = 0;
else
res = -(P1*log2(P1)+P2*log2(P2));
end
end