ID3的python实现

  1. import math
  2. import operator
  3. def calcShannonEnt(dataset):
  4. numEntries = len(dataset)
  5. labelCounts = {}
  6. for featVec in dataset:
  7. currentLabel = featVec[ -1]
  8. if currentLabel not in labelCounts.keys():
  9. labelCounts[currentLabel] = 0
  10. labelCounts[currentLabel] += 1
  11. shannonEnt = 0.0
  12. for key in labelCounts:
  13. prob = float(labelCounts[key]) / numEntries
  14. shannonEnt -= prob * math.log(prob, 2)
  15. return shannonEnt
  16. def CreateDataSet():
  17. '''dataset = [[1, 1, 'yes'],
  18. [1, 1, 'yes'],
  19. [1, 0, 'no'],
  20. [0, 1, 'no'],
  21. [0, 1, 'no']]
  22. labels = ['outlook', 'temperature','humidity','false']
  23. return dataset, labels'''
  24. lines_set = open( 'Dataset.txt').readlines()
  25. labelLine = lines_set[ 2];
  26. labels = labelLine.strip().split()
  27. lines_set = lines_set[ 4: 11]
  28. dataSet = [];
  29. for line in lines_set:
  30. data = line.split();
  31. dataSet.append(data);
  32. return dataSet, labels
  33. def splitDataSet(dataSet, axis, value):
  34. retDataSet = []
  35. for featVec in dataSet:
  36. if featVec[axis] == value:
  37. reducedFeatVec = featVec[:axis]
  38. reducedFeatVec.extend(featVec[axis + 1:])
  39. retDataSet.append(reducedFeatVec)
  40. return retDataSet
  41. def chooseBestFeatureToSplit(dataSet):
  42. numberFeatures = len(dataSet[ 0]) - 1
  43. baseEntropy = calcShannonEnt(dataSet)
  44. bestInfoGain = 0.0;
  45. bestFeature = -1;
  46. for i in range(numberFeatures):
  47. featList = [example[i] for example in dataSet]
  48. uniqueVals = set(featList)
  49. newEntropy = 0.0
  50. for value in uniqueVals:
  51. subDataSet = splitDataSet(dataSet, i, value)
  52. prob = len(subDataSet) / float(len(dataSet))
  53. newEntropy += prob * calcShannonEnt(subDataSet)
  54. infoGain = baseEntropy - newEntropy
  55. if (infoGain > bestInfoGain):
  56. bestInfoGain = infoGain
  57. bestFeature = i
  58. return bestFeature
  59. def majorityCnt(classList):
  60. classCount = {}
  61. for vote in classList:
  62. if vote not in classCount.keys():
  63. classCount[vote] = 0
  64. classCount[vote] = 1
  65. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter( 1), reverse= True)
  66. return sortedClassCount[ 0][ 0]
  67. def createTree(dataSet, labels):
  68. classList = [example[ -1] for example in dataSet]
  69. if classList.count(classList[ 0]) == len(classList):
  70. return classList[ 0]
  71. if len(dataSet[ 0]) == 1:
  72. return majorityCnt(classList)
  73. bestFeat = chooseBestFeatureToSplit(dataSet)
  74. bestFeatLabel = labels[bestFeat]
  75. myTree = {bestFeatLabel: {}}
  76. del (labels[bestFeat])
  77. featValues = [example[bestFeat] for example in dataSet]
  78. uniqueVals = set(featValues)
  79. for value in uniqueVals:
  80. subLabels = labels[:]
  81. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
  82. return myTree
  83. myDat, labels = CreateDataSet()
  84. myTree = createTree(myDat, labels)
  85. print myTree

运行结果如下:

{'outlook': {'overcast': 'Y', 'sunny': 'N', 'rain': {'windy': {'false': 'Y', 'true': 'N'}}}}

训练集和测试集

  1. 训练集:
  2. outlook temperature humidity windy
  3. ---------------------------------------------------------
  4. sunny hot high false N
  5. sunny hot high true N
  6. overcast hot high false Y
  7. rain mild high false Y
  8. rain cool normal false Y
  9. rain cool normal true N
  10. overcast cool normal true Y
  11. 测试集
  12. outlook temperature humidity windy
  13. ---------------------------------------------------------
  14. sunny mild high false
  15. sunny cool normal false
  16. rain mild normal false
  17. sunny mild normal true
  18. overcast mild high true
  19. overcast hot normal false
  20. rain mild high true

猜你喜欢

转载自blog.csdn.net/qq_34514046/article/details/80923957
ID3