闲来无事最近复习了一下ID3决策树算法,并凭着理解用pandas实现了一遍。对pandas更熟悉的朋友可供参考(链接如下)。相比本篇博文,更简明清晰,更适合复习用。

https://github.com/DianeSoHungry/ShallowMachineLearningCodeItOut/blob/master/ID3.ipynb

 

 

 

现在要介绍的是ID3决策树算法,只适用于标称型数据,不适用于数值型数据。

 

决策树学习算法最大的优点是,他可以自学习,在学习过程中,不需要使用者了解过多的背景知识、领域知识,只需要对训练实例进行较好的标注就可以自学习了。

 

建立决策树的关键在于当前状态下选择哪一个属性作为分类依据,根据不同的目标函数,有三种主要的算法

ID3(Iterative Dichotomiser)

C4.5

CART(Classification And Regression Tree)

 

下面是一个小型的数据集,5条记录,2个特征(属性),有标签。

《机器学习实战》笔记——决策树(ID3)

根据这个数据集,我们可以建立如下决策树(用matplotlib的注释功能画的)。

观察决策树,决策节点为特征,其分支为决策节点的各个不同取值,叶节点为预测值。

《机器学习实战》笔记——决策树(ID3)

建树结束也就是建立好了一个决策树分类器,有了分类器,就可以根据这个分类器对其他的鱼进行预测了。预测准确性今天暂且不讨论。

那么如何建立这样的决策树呢?

第一步:建立决策树。

1.1 利用信息增益寻找当前最佳分类特征

想象现在你是一个判断结点,你从头顶的分支上获得了一个数据集,表中包含标签和若干属性。你现在要根据某个属性来对你接收到的数据集进行分组。到底用哪个属性来作为划分依据呢?

《机器学习实战》笔记——决策树(ID3)

我们用信息增益来选择某个节点上用哪个特征来进行分类。

 

什么是信息?

如果待分类的事物可能划分在多个分类中,则每个分类xi的信息定义为:

《机器学习实战》笔记——决策树(ID3)

(这里log前面应该有个负号。)

 

什么是香农熵?

香农熵是所有类别所有可能类别信息的期望值,即:

《机器学习实战》笔记——决策树(ID3)

 

什么是信息增益?

信息增益=原香农熵-新香农熵

 

注意:新香农熵为按照某特征划分之后,每个分支数据集的香农熵之和。

  

可以这样想:香农熵相当于数据类别(标签)的混乱程度,信息增益可以衡量划分数据集前后数据(标签)向有序性发展的程度。因此,回到怎样利用信息增益寻找当前最佳分类特征的话题,假如你是一个判断节点,你拿来一个数据集,数据集里面有若干个特征,你需要从中选取一个特征,使得信息增益最大(注意:将数据集中在该特征上取值相同的记录划分到同一个分支,得到若干个分支数据集,每个分支数据集都有自己的香农熵,各个分支数据集的香农熵的期望才是新香农熵)。要找到这个特征只需要将数据集中的每个特征遍历一次,求信息增益,取获得最大信息增益的那个特征。

代码如下(其中,calcShannonEnt(dataSet)函数用来计算数据集dataSet的香农熵,splitDataSet(dataSet, axis, value)函数将数据集dataSet的第axis列中特征值为value的记录挑出来,组成分支数据集返回给函数。这两个函数后面会给出函数定义。):

 1 # 3-3 选择最好的'数据集划分方式'(特征)
 2 # 一个一个地试每个特征,如果某个按照某个特征分类得到的信息增益(原香农熵-新香农熵)最大,
 3 # 则选这个特征作为最佳数据集划分方式
 4 def chooseBestFeatureToSplit(dataSet):
 5     numFeatures = len(dataSet[0]) - 1
 6     baseEntropy = calcShannonEnt(dataSet)
 7     bestInfoGain = 0.0
 8     bestFeature = -1
 9     for i in range(numFeatures):
10         featList = [example[i] for example in dataSet]
11         uniqueVals = set(featList)
12         newEntropy = 0.0
13         for value in uniqueVals:
14             subDataSet = splitDataSet(dataSet, i, value)
15             prob = len(subDataSet) / float(len(dataSet))
16             newEntropy += prob * calcShannonEnt(subDataSet)
17         infoGain = baseEntropy - newEntropy
18         if (infoGain > bestInfoGain):
19             bestInfoGain = infoGain
20             bestFeature = i
21     return bestFeature

 

calcShannonEnt(dataSet)函数代码:

 1 def calcShannonEnt(dataSet):
 2     numEntries = len(dataSet)    # 总记录数
 3     labelCounts = {}    # dataSet中所有出现过的标签值为键,相应标签值出现过的次数作为值
 4     for featVec in dataSet:
 5         currentLabel = featVec[-1]
 6         labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
 7     shannonEnt = 0.0
 8     for key in labelCounts:
 9         prob = -float(labelCounts[key])/numEntries
10         shannonEnt += prob * log(prob, 2)
11     return shannonEnt

 

splitDataSet(dataSet, axis, value)函数代码:

 1 # 3-2 按照给定特征划分数据集(在某个特征axis上,值等于value的所有记录
 2 # 组成新的数据集retDataSet,新数据集不需要axis这个特征,注意value是特征值,axis指的是特征(所在的列下标))
 3 def splitDataSet(dataSet, axis, value):
 4     retDataSet = []
 5     for featVec in dataSet:
 6         if featVec[axis] == value:
 7             reducedFeatVec = featVec[:axis]
 8             reducedFeatVec.extend(featVec[axis+1:])
 9             retDataSet.append(reducedFeatVec)
10     return retDataSet

 

1.2 建树

建树是一个递归的过程。

 

递归结束的标志(判断某节点是叶节点的标志):

情况1. 分到该节点的数据集中,所有记录的标签列取值都一样。

情况2. 分到该节点的数据集中,只剩下标签列。

 

a. 经判断,若是叶节点,则:

对应情况1,返回数据集中第一条记录的标签值(反正所有标签值都一样)。

对应情况2,返回数据集中所有标签值中,出现次数最多的那个标签值(代码中,定义一个函数majorityCnt(classList)来实现)

 

b. 经判断,若不是叶节点,则:

step1. 建立一个字典,字典的键为该数据集上选出的最佳特征(划分依据)。

step2. 将具有相同特征值的记录组成新的数据集(利用splitDataSet(dataSet, axis, value)函数实现,注意期间抛弃了当前用于划分数据的特征列),对新的数据集们进行递归建树。

 

建树代码:

 1 # 3-4 创建树的函数代码
 2 # 如果非叶子结点,则以当前数据集建树,并返回该树。该树的根节点是一个字典,键为划分当前数据集的最佳特征,值为按照键值划分后各个数据集构造的树
 3 # 叶子节点有两种:1.只剩没有特征时,叶子节点的返回值为所有记录中,出现次数最多的那个标签值 2.该叶子节点中,所有记录的标签相同。
 4 
 5 def createTree(dataSet, labels): #label向量的维度为特征数,不是记录数,是不同列下标对应的特征
 6     classList = [example[-1] for example in dataSet]
 7     if classList.count(classList[0]) == len(classList):
 8         return classList[0]
 9     if len(dataSet[0]) == 1:
10         return majorityCnt(classList)
11     bestFeat = chooseBestFeatureToSplit(dataSet)
12     bestFeatLabel = labels[bestFeat]
13     myTree = {bestFeatLabel: {}}
14     del(labels[bestFeat])
15     featValues = [example[bestFeat] for example in dataSet]
16     uniqueVals = set(featValues)
17     for value in uniqueVals:  #递归建子树,若值为字典,则非叶节点,若为字符串,则为叶节点
18         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels)
19     return myTree

 

用上面给出的数据来建立一颗决策树做示范:

《机器学习实战》笔记——决策树(ID3)

在同一个程序中输入如下代码并运行:

 1 def createDataSet():
 2     dataSet = [[1, 1, 'yes'],
 3                [1, 1, 'yes'],
 4                [1, 0, 'no'],
 5                [0, 1, 'no'],
 6                [0, 1, 'no']]
 7     labels = ['no surfacing', 'flippers']
 8     return dataSet, labels
 9 
10 myDat, labels = createDataSet()
11 myTree = createTree(myDat, labels)
12 print myTree

运行结果为:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

 若利用后面画决策树的代码可以画出这颗决策树:

《机器学习实战》笔记——决策树(ID3)

 

案例:

我们通过建立决策树来预测患者需要佩戴哪种隐形眼镜(soft(软材质)、hard(硬材质)、no lenses(不适合硬性眼睛)),数据集包含下面几个特征:age(年龄), prescript(近视还是远视), astigmatic(散光), tearRate(眼泪清除率)

建树的结果为:

{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}

 

画出来是这个样子:

《机器学习实战》笔记——决策树(ID3)

 

 

画决策树的代码(不讲)

涉及matplotlib.pyplot模块中的annotation的用法,点击链接进入官网学习这块内容的prerequisite。

 1 # _*_coding:utf-8_*_
 2 
 3 # 3-7 plotTree函数
 4 import matplotlib.pyplot as plt
 5 
 6 # 定义节点和箭头格式的常量
 7 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
 8 leafNode = dict(boxstyle="round4", fc="0.8")
 9 arrow_args = dict(arrowstyle="<-")
10 
11 
12 def plotMidTest(cntrPt, parentPt,txtString):
13     xMid = (parentPt[0] + cntrPt[0])/2.0
14     yMid = (parentPt[1] + cntrPt[1])/2.0
15     createPlot.ax1.text(xMid, yMid, txtString)
16 
17 # 绘制自身
18 # 若当前子节点不是叶子节点,递归
19 # 若当子节点为叶子节点,绘制该节点
20 def plotTree(myTree, parentPt, nodeTxt):
21     numLeafs = getNumLeafs(myTree)
22     # depth = getTreeDepth(myTree)
23     firstStr = myTree.keys()[0]
24     cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
25     plotMidTest(cntrPt, parentPt, nodeTxt)
26     plotNode(firstStr, cntrPt, parentPt, decisionNode)
27     secondDict = myTree[firstStr]
28     plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
29     for key in secondDict.keys():
30         if type(secondDict[key]).__name__=='dict':
31             plotTree(secondDict[key], cntrPt, str(key))
32         else:
33             plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
34             plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
35             plotMidTest((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
36     plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
37 
38 
39 # figure points
40 # 画结点的模板
41 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
42     createPlot.ax1.annotate(nodeTxt,  # 注释的文字,(一个字符串)
43                             xy=parentPt,  # 被注释的地方(一个坐标)
44                             xycoords='axes fraction',  # xy所用的坐标系
45                             xytext=centerPt,  # 插入文本的地方(一个坐标)
46                             textcoords='axes fraction', # xytext所用的坐标系
47                             va="center",
48                             ha="center",
49                             bbox=nodeType,  # 注释文字用的框的格式
50                             arrowprops=arrow_args)  # 箭头属性
51 
52 
53 def createPlot(inTree):
54     fig = plt.figure(1, facecolor='white')
55     fig.clf()
56     axprops = dict(xticks=[], yticks=[])
57     createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
58     plotTree.totalW = float(getNumLeafs(inTree))
59     plotTree.totalD = float(getTreeDepth(inTree))
60     plotTree.xoff = -0.5/plotTree.totalW
61     plotTree.yoff = 1.0
62 
63     plotTree(inTree, (0.5, 1.0),'') #树的引用作为父节点,但不画出来,所以用''
64     plt.show()
65 
66 def getNumLeafs(myTree):
67     numLeafs = 0
68     firstStr = myTree.keys()[0]
69     secondDict = myTree[firstStr]
70     for key in secondDict.keys():
71         if type(secondDict[key]).__name__ =='dict':
72             numLeafs += getNumLeafs(secondDict[key])
73         else:
74             numLeafs += 1
75     return numLeafs
76 
77 # 子树中树高最大的那一颗的高度+1作为当前数的高度
78 def getTreeDepth(myTree):
79     maxDepth = 0    #用来记录最高子树的高度+1
80     firstStr = myTree.keys()[0]
81     secondDict = myTree[firstStr]
82     for key in secondDict.keys():
83         if type(secondDict[key]).__name__ == 'dict':
84             thisDepth = 1 + getTreeDepth(secondDict[key])
85         else:
86             thisDepth = 1
87         if(thisDepth > maxDepth):
88             maxDepth = thisDepth
89     return maxDepth
90 
91 # 方便测试用的人造测试树
92 def retrieveTree(i):
93     listofTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
94                    {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
95                    ]
96     return listofTrees[i]