Matplotlib绘制决策树代码:
# coding=utf-8
import matplotlib.pyplot as plt
'''
遇到不懂的问题?Python学习交流群:821460695满足你的需求,资料都已经上传群文件,可以自行下载!
'''
decisionNode = dict(boxstyle='sawtooth', fc='10')
leafNode = dict(boxstyle='round4',fc='0.8')
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
xytext=centerPt,textcoords='axes fraction',\
va='center', ha='center',bbox=nodeType,arrowprops\
=arrow_args)
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict:
if(type(secondDict[key]).__name__ == 'dict'):
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict:
if(type(secondDict[key]).__name__ == 'dict'):
thisDepth = 1+getTreeDepth((secondDict[key]))
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
#预先设置树的信息
listOfTree = [{'no surfacing':{0:'no',1:{'flipper':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flipper':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}},
{'Comment score greater than 8.0':{0:{'Comment score greater than 9.5':{0:'Yes',1:{'More than 45,000 people commented': {
0: 'Yes',1: 'No'}}}},1:'No'}}]
return listOfTree[i]
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks = [], yticks=[])
createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0), '')
plt.title('Douban reading Decision Tree\n')
plt.show()
def plotMidText(cntrPt, parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,\
plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict:
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),\
cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
if __name__ == '__main__':
myTree = retrieveTree(2)
createPlot(myTree)