基于Python3.0的决策树手写算法实现和对西瓜书第四章决策树习题4.3

系统 1402 0

手写代码实现基于信息熵划分的决策树算法

文章目录

  • 手写代码实现基于信息熵划分的决策树算法
  • 1. 简介
  • 2. 算法实现思路
  • 3.代码如下
  • 参考

1. 简介

阅读本文需要以下背景知识:
- 掌握周志华《西瓜书》第四章决策树原理
- Python3.0基础语法及数据类型及操作

不了解决策树请点击下面链接西瓜书第四章决策树学习笔记

本文是基于信息熵准则进行划分选择的决策树算法的手写实现,不使用现有的机器学习包。算法流程见《西瓜书》第四章第一节。数据集使用西瓜数据集3.0(数据集在代码中不需要另外下载),实现语言为Python3.0。代码注解详细,适合新手,欢迎转载

2. 算法实现思路

算法流程是现成的,关键是如何把数据集嵌入到算法中并实现递归,我的思路如下:

对决策树不同功能进行划分,每个功能封装成函数,不同功能的函数有
- def createDataSet() #对数据集进行加工,返回数据集dataSet和特征集labels
- def get_Value(dataSet, labels) #以字典labelsCounts返回数据集dataSet中所有的特征,和对应特征的所有取值
- def calcShannonEnt(dataSet) #计算dataSet的信息熵。返回信息熵数值
- def chooseBestFeatureToSplit(dataSet) #计算出信息增益,选择信息增益最大的特征作为最优划分属性。返回最优属性在特征集labels中的索引
- def splitDataSet(dataSet, bestFeat, value) #由给定的父数据集dataSet,最优特征 bestFeat,和最优特征的取值value(由labelsCounts获得)划分出数据子集,返回数据子集
- def majorityCnt(classList) #输入数据集dataSet的类别标签列classList得到在数据集dataSet中类别最多的样本的类别名(字符串)
- def createTree(dataSet, labels, labelscounts) #这是一个递归函数,输入数据集dataSet,特征集labels和所有特征取值字典labelscounts得到一个具有一层分支的树,要是这层分支中每个子集subdataSet都是叶节点,创建字典,以被划分的最优属性的取值value为键,对应这个取值的叶节点类型为值( 叶节点判定标准:集合中样本都相同标签也相同标为叶节点,叶类型为集合中样本标签;集合中样本都相同但是标签不同标为叶节点,叶类型为集合中众数样本类别;集合为空集标为叶结点,叶类别为其父节点众数样本类别 )。若这层分支中不全为叶节点,还有内部节点。则对于叶节点,创建字典,以被划分的最优属性的取值为键,对应这个取值的叶节点类型为值。对于内部节点,把这个子集subdataSet作为新的父集,以新父集的划分最优属性键,值是一个字典,并调用函数def createTree(subdataSet, sublabels, labelscounts)完成递归。返回一个以字典形式存储的决策树
- treePlotter.createPlot(desicionTree) #调用库函数将决策树绘出,treePlotter包是自定义包,代码及使用方法见此treePlotter

3.代码如下

            
              
                #基于ID3算法的信息增益来实现的决策树
              
              
                #调用库
              
              
                from
              
               math 
              
                import
              
               log

              
                import
              
               operator

              
                import
              
               treePlotter                                 
              
                #自定义包,包和源程序应在同一文件夹,包代码见链接
              
              
                ''' 西瓜数据集3.0, dataset=[ # 1 ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460, '好瓜'], # 2 ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.774, 0.376, '好瓜'], # 3 ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.634, 0.264, '好瓜'], # 4 ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.608, 0.318, '好瓜'], # 5 ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.556, 0.215, '好瓜'], # 6 ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.403, 0.237, '好瓜'], # 7 ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 0.481, 0.149, '好瓜'], # 8 ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 0.437, 0.211, '好瓜'], # ---------------------------------------------------- # 9 ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 0.666, 0.091, '坏瓜'], # 10 ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 0.243, 0.267, '坏瓜'], # 11 ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, '坏瓜'], # 12 ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 0.343, 0.099, '坏瓜'], # 13 ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 0.639, 0.161, '坏瓜'], # 14 ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 0.657, 0.198, '坏瓜'], # 15 ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.360, 0.370, '坏瓜'], # 16 ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 0.593, 0.042, '坏瓜'], # 17 ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 0.719, 0.103, '坏瓜'] ] '''
              
              
                #导入数据,数据集有八个特征 '色泽', '根蒂', '敲声', '纹理','脐部','触感','密度','含糖率' ,
              
              
                #其中密度和含糖率是连续值,为了简略程序,我们忽略他们。为接下来要计算它们的信息增益率,来选择节点的构成方式做准备。
              
              
                def
              
              
                createDataSet
              
              
                (
              
              
                )
              
              
                :
              
              
                """ 对数据集进行一定处理,以方便显示,不出现乱码 色泽Color-> 0: 浅白 | 1: 青绿 | 2: 乌黑 根蒂Root-> 0: 硬挺 | 1: 稍蜷 | 2: 蜷缩 敲声Knock-> 0: 清脆 | 1: 浊响 | 2:沉闷 纹理Texture-> 0: 清晰 | 1: 稍糊 | 2:模糊 脐部Umbilical-> 0: 平坦 | 1: 稍凹 | 2: 凹陷 触感Touch-> 0: 硬滑 | 1: 软粘 标签lab->'GoodMalen'| 'BadMalen' """
              
              
    dataSet 
              
                =
              
              
                [
              
              
                [
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                2
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'GoodMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                'BadMalen'
              
              
                ]
              
              
                ]
              
              
    labels 
              
                =
              
              
                [
              
              
                'Color'
              
              
                ,
              
              
                'Root'
              
              
                ,
              
              
                'Knock'
              
              
                ,
              
              
                'Texture'
              
              
                ,
              
              
                'Umbilical'
              
              
                ,
              
              
                'Touch'
              
              
                ]
              
              
                return
              
               dataSet
              
                ,
              
               labels


              
                #获得每个特征的所有出现的取值
              
              
                def
              
              
                get_Values
              
              
                (
              
              dataSet
              
                ,
              
               labels
              
                )
              
              
                :
              
              
                ''' 输入:一个数据集 输出:数据集中每个特征的所有取值,字典形式;键是特征名,值是对应特征的所有取值 描述:获得特征的取值,为分支划分做准备 '''
              
              
    labelsCounts 
              
                =
              
              
                {
              
              
                }
              
              
                #初始化字典
              
              
                for
              
               label 
              
                in
              
               labels
              
                :
              
              
                #遍历特征集
              
              
        index 
              
                =
              
               labels
              
                .
              
              index
              
                (
              
              label
              
                )
              
              
                #获得特征名称在特征集中的索引
              
              
        featValues 
              
                =
              
              
                [
              
              example
              
                [
              
              index
              
                ]
              
              
                for
              
               example 
              
                in
              
               dataSet
              
                ]
              
              
                #取出一个特征的所有取值
              
              
        uniqueVals 
              
                =
              
              
                set
              
              
                (
              
              featValues
              
                )
              
              
                #利用集合性质数据去重
              
              
        labelsCounts
              
                [
              
              label
              
                ]
              
              
                =
              
               uniqueVals                    
              
                #将去重后的数据放入字典中,键名为特征名字
              
              
                return
              
               labelsCounts


              
                #计算数据集信息熵
              
              
                def
              
              
                calcShannonEnt
              
              
                (
              
              dataSet
              
                )
              
              
                :
              
              
                """ 输入:数据集 输出:数据集的信息熵 描述:计算给定数据集的信息熵;熵越大,数据集的混乱程度越大 """
              
              
    numEntries 
              
                =
              
              
                len
              
              
                (
              
              dataSet
              
                )
              
              
                #样本数
              
              
    labelCounts 
              
                =
              
              
                {
              
              
                }
              
              
                #创建一个数据字典:key是最后一列的数值(即标签,也就是目标分类的类别),value是属于该类别的样本个数,这个字典用来计数各个类别的样本的个数
              
              
                for
              
               featVec 
              
                in
              
               dataSet
              
                :
              
              
                #遍历数据集,每次取一行就是一个样本
              
              
        currentLabel 
              
                =
              
               featVec
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                #取出每行最后一列的元素(也就是样本标签)给currentLabel
              
              
                if
              
               currentLabel 
              
                not
              
              
                in
              
               labelCounts
              
                .
              
              keys
              
                (
              
              
                )
              
              
                :
              
              
                #判断:标签在不在字典labelCounts中?
              
              
            labelCounts
              
                [
              
              currentLabel
              
                ]
              
              
                =
              
              
                0
              
              
                #不在字典中则给字典创建新键值对,key是标签,value设为0
              
              
        labelCounts
              
                [
              
              currentLabel
              
                ]
              
              
                +=
              
              
                1
              
              
                #计数每一类样本的数量, {'GoodMalen': 8, 'BadMalen': 9}
              
              
                # print(labelCounts)
              
              
    shannonEnt 
              
                =
              
              
                0.0
              
              
                # 初始化信息熵
              
              
                for
              
               key 
              
                in
              
               labelCounts
              
                :
              
              
                #遍历数据字典的键
              
              
        prob 
              
                =
              
              
                float
              
              
                (
              
              labelCounts
              
                [
              
              key
              
                ]
              
              
                )
              
              
                /
              
              numEntries 
              
                #计算数据集D中K类样本所占比例Pk
              
              
        shannonEnt 
              
                -=
              
               prob 
              
                *
              
               log
              
                (
              
              prob
              
                ,
              
              
                2
              
              
                )
              
              
                #计算信息熵log2
              
              
                return
              
               shannonEnt 


              
                #计算样本集中类别数最多的类别
              
              
                def
              
              
                calmaxCnt
              
              
                (
              
              dataSet
              
                )
              
              
                :
              
              
                ''' 输入:数据集 输出:在输入数据集中类别数最多的类别名称 描述:对划分出的数据集为空的子数据集不能划分,标记为叶节点,将其类别设定为其父节点所含样本中类 别数最多的类别名称 '''
              
              
    classCount 
              
                =
              
              
                {
              
              
                }
              
              
                #创建字典
              
              
                for
              
               featVec 
              
                in
              
               dataSet
              
                :
              
              
                #对数据集中每一行遍历
              
              
                if
              
               featVec
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                not
              
              
                in
              
               classCount
              
                .
              
              keys
              
                (
              
              
                )
              
              
                :
              
              
                #键已存在字典中+1,不存在字典中创建后初始为0后+1
              
              
            classCount
              
                [
              
              featVec
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                ]
              
              
                =
              
              
                0
              
              
        classCount
              
                [
              
              featVec
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                ]
              
              
                +=
              
              
                1
              
              
    items 
              
                =
              
              
                list
              
              
                (
              
              classCount
              
                .
              
              items
              
                (
              
              
                )
              
              
                )
              
              
                #字典转为列表
              
              
    items
              
                .
              
              sort
              
                (
              
              key
              
                =
              
              
                lambda
              
               x
              
                :
              
              x
              
                [
              
              
                1
              
              
                ]
              
              
                ,
              
               reverse
              
                =
              
              
                True
              
              
                )
              
              
                #列表以值来排序(从大到小)
              
              
                return
              
               items
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                #输出类别数最多的类别名称
              
              
                #对数据集进行叶节点标记的准则
              
              
                def
              
              
                majorityCnt
              
              
                (
              
              classList
              
                )
              
              
                :
              
              
                """ #返回该数据集中类别数最多的类名 #该函数使用分类名称的列表(某个数据集或者其子集的),然后创建键值为classList中唯一值的 #数据字典。字典对象的存储了classList中每个类标签出现的频率。最后利用operator操作键值排序字典, #并返回出现次数最多的分类名称 输入:分类类别列表 输出:子节点的分类 描述:数据集已经处理了所有属性,但是类标签依然不是唯一的, 则采用多数判决的方法决定该子节点的分类 """
              
              
    classCount 
              
                =
              
              
                {
              
              
                }
              
              
                #创建字典
              
              
                for
              
               vote 
              
                in
              
               classList
              
                :
              
              
                #对类名列表遍历
              
              
                if
              
               vote 
              
                not
              
              
                in
              
               classCount
              
                .
              
              keys
              
                (
              
              
                )
              
              
                :
              
              
                #键已存在字典中+1,不存在字典中创建后初始为0后+1
              
              
            classCount
              
                [
              
              vote
              
                ]
              
              
                =
              
              
                0
              
              
        classCount
              
                [
              
              vote
              
                ]
              
              
                +=
              
              
                1
              
              
                # print(classCount)
              
              
    sortedClassCount 
              
                =
              
              
                sorted
              
              
                (
              
              classCount
              
                .
              
              iteritems
              
                (
              
              
                )
              
              
                ,
              
               key
              
                =
              
              operator
              
                .
              
              itemgetter
              
                (
              
              
                1
              
              
                )
              
              
                ,
              
              
                reversed
              
              
                =
              
              
                True
              
              
                )
              
              
                #将字典转换成列表并按照值([i][1])进行从大到小排序
              
              
                return
              
               sortedClassCount
              
                [
              
              
                0
              
              
                ]
              
              
                [
              
              
                0
              
              
                ]
              
              
                #选出最优划分特征
              
              
                def
              
              
                chooseBestFeatureToSplit
              
              
                (
              
              dataSet
              
                )
              
              
                :
              
              
                """ 选取当前数据集下,用于划分数据集的最优特征 输入:数据集dataSet 输出:最好的划分维度 描述:选择最好的数据集划分维度,返回的是该特征在该数据集中的索引 """
              
              
    numFeatures 
              
                =
              
              
                len
              
              
                (
              
              dataSet
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                -
              
              
                1
              
              
                #特征feature个数,数据集列数减一,减去的那个一是类别标签
              
              
    baseEntropy 
              
                =
              
               calcShannonEnt
              
                (
              
              dataSet
              
                )
              
              
                #计算父样本集的信息熵
              
              
    bestInfoGain 
              
                =
              
              
                0.0
              
              
                #初始化信息增益为0.0
              
              
    bestFeature 
              
                =
              
              
                -
              
              
                1
              
              
                #初始化最佳特征索引维度
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              numFeatures
              
                )
              
              
                :
              
              
                #遍历每个特征
              
              
        featList 
              
                =
              
              
                [
              
              example
              
                [
              
              i
              
                ]
              
              
                for
              
               example 
              
                in
              
               dataSet
              
                ]
              
              
                ##获取数据集中当前特征下的所有值组成list
              
              
        uniqueVals 
              
                =
              
              
                set
              
              
                (
              
              featList
              
                )
              
              
                #集合数据去重,获得当前特征的所有取值 
              
              
        newEntropy 
              
                =
              
              
                0.0
              
              
                # splitInfo = 0.0 #初始化固有值,用于C4.5决策树实现
              
              
                for
              
               value 
              
                in
              
               uniqueVals
              
                :
              
              
                #遍历该特征每一种取值结果
              
              
            subDataSet 
              
                =
              
               splitDataSet
              
                (
              
              dataSet
              
                ,
              
               i
              
                ,
              
               value
              
                )
              
              
                #获得该种特征该种结果的子样本集(去除了这种特征后的)
              
              
            prob 
              
                =
              
              
                len
              
              
                (
              
              subDataSet
              
                )
              
              
                /
              
              
                float
              
              
                (
              
              
                len
              
              
                (
              
              dataSet
              
                )
              
              
                )
              
              
                #计算|Dv|/|D|,计算子样本集样本数所占父样本数权重
              
              
            newEntropy 
              
                +=
              
               prob 
              
                *
              
               calcShannonEnt
              
                (
              
              subDataSet
              
                )
              
              
                #计算各个子样本集的权重*子样本集信息熵并加和
              
              
                # splitInfo += -prob * log(prob, 2) #计算该特征固有值,用于C4.5决策树实现 
              
              
        infoGain 
              
                =
              
               baseEntropy 
              
                -
              
               newEntropy                
              
                #这个feature的infoGain
              
              
                # if (splitInfo == 0): # fix the overflow bug #用于C4.5决策树实现 
              
              
                # continue #用于C4.5决策树实现
              
              
                # infoGainRatio = infoGain / splitInfo #这个feature的infoGainRatio#用于C4.5决策树实现 
              
              
                if
              
              
                (
              
              infoGain 
              
                >
              
               bestInfoGain
              
                )
              
              
                :
              
              
                #选择最大的信息增益gain对应的特征,并获得其索引,若用于C4.5决策树实现需要更改一部分变量名称
              
              
            bestInfoGain 
              
                =
              
               infoGain
            bestFeature 
              
                =
              
               i                                
              
                #选择最大的gain对应的特征,并把其索引赋值给bestFeature
              
              
                return
              
               bestFeature


              
                #划分数据集,为下一层计算准备
              
              
                def
              
              
                splitDataSet
              
              
                (
              
              dataSet
              
                ,
              
                bestFeat
              
                ,
              
               value
              
                )
              
              
                :
              
              
                """ #axis是dataSet数据集下要进行特征划分的列号例如outlook是0列,value是该列下某个特征值,0列中的sunny 输入:数据集,选择维度,选择值 输出:划分数据集 描述:按照给定特征划分数据集;想要将某个数据集以某特征完全划分成几个子数据集需要遍历该特征的不同取值并重复调用这个函数 新数据集由样本中某特征axis取指定值value的样本组成,且去除了该特征axis的列以避免之后的对该特征重复划分 """
              
              
    retDataSet 
              
                =
              
              
                [
              
              
                ]
              
              
                #初始化一个列表作为子集
              
              
                for
              
               featVec 
              
                in
              
               dataSet
              
                :
              
              
                #对数据集中每一行遍历
              
              
                if
              
               featVec
              
                [
              
               bestFeat
              
                ]
              
              
                ==
              
               value
              
                :
              
              
                #当某样本在被选择的特征列axis上取值=value(我们所指定的特征值)时
              
              
            reduceFeatVec 
              
                =
              
               featVec
              
                [
              
              
                :
              
               bestFeat
              
                ]
              
              
                #复制出选中特征列前面的列
              
              
            reduceFeatVec
              
                .
              
              extend
              
                (
              
              featVec
              
                [
              
               bestFeat
              
                +
              
              
                1
              
              
                :
              
              
                ]
              
              
                )
              
              
                #由上面的列拼接选中特征列后面的列
              
              
                #上两行代码作用是除去原样本集的第axis列
              
              
            retDataSet
              
                .
              
              append
              
                (
              
              reduceFeatVec
              
                )
              
              
                #把除去第axis列的样本放到新数据集中
              
              
                return
              
               retDataSet



              
                #多重字典构建树 
              
              
                def
              
              
                createTree
              
              
                (
              
              dataSet
              
                ,
              
               labels
              
                ,
              
               labelscounts
              
                )
              
              
                :
              
              
                """ 输入:数据集,特征标签 输出:决策树,每个数据集中优势类别的名称 描述:递归构建决策树 """
              
              
    classList 
              
                =
              
              
                [
              
              example
              
                [
              
              
                -
              
              
                1
              
              
                ]
              
              
                for
              
               example 
              
                in
              
               dataSet
              
                ]
              
              
                #返回当前数据集下标签列所有值
              
              
                if
              
               classList
              
                .
              
              count
              
                (
              
              classList
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ==
              
              
                len
              
              
                (
              
              classList
              
                )
              
              
                :
              
              
                #classList所有元素都相等,即类别完全相同,停止划分,设置为叶节点,以该集合中的类别名作为叶节点标签
              
              
                return
              
               classList
              
                [
              
              
                0
              
              
                ]
              
              
                #返回该类标签值
              
              
                if
              
              
                len
              
              
                (
              
              dataSet
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ==
              
              
                1
              
              
                :
              
              
                #因为每次划分都除去了被划分特征值对应的列,那么随着划分的进行,列越来越短,直到只剩下标
              
              
                #签列,该标签列中对应的样本都是特征值完全相同的,此时按照叶节点命名规则,取该标签列中类
              
              
                #别数最多的类别作为叶节点的划分 
              
              
                return
              
               majorityCnt
              
                (
              
              classList
              
                )
              
              
                #遍历完所有特征后返回出现次数最多的类别标签值
              
              
    bestFeat 
              
                =
              
               chooseBestFeatureToSplit
              
                (
              
              dataSet
              
                )
              
              
                #获得下次划分时候的最佳特征的索引 
              
              
                #选择最大的gain对应的feature
              
              
    bestFeatLabel 
              
                =
              
               labels
              
                [
              
              bestFeat
              
                ]
              
              
                #由索引取得最优特征名称
              
              
                # 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要。 
              
              
    myTree 
              
                =
              
              
                {
              
              bestFeatLabel
              
                :
              
              
                {
              
              
                }
              
              
                }
              
              
                #当前数据集选取最好的特征存储在bestFeat中
              
              
                del
              
              
                (
              
              labels
              
                [
              
              bestFeat
              
                ]
              
              
                )
              
              
                #在labels中删除已经被选择的特征 
              
              
    uniqueVals 
              
                =
              
               labelscounts
              
                [
              
              bestFeatLabel
              
                ]
              
              
                #获得最佳特征对应的所有特征值取值
              
              
                for
              
               value 
              
                in
              
               uniqueVals
              
                :
              
              
                #对所有特征取值遍历
              
              
        subLabels 
              
                =
              
               labels
              
                [
              
              
                :
              
              
                ]
              
              
                #获得子集的特征集
              
              
        subdataSet 
              
                =
              
               splitDataSet
              
                (
              
              dataSet
              
                ,
              
               bestFeat
              
                ,
              
               value
              
                )
              
              
                #划分出数据子集
              
              
                if
              
              
                len
              
              
                (
              
              subdataSet
              
                )
              
              
                ==
              
              
                0
              
              
                :
              
              
                #若划分出的数据子集为空集
              
              
            myTree
              
                [
              
              bestFeatLabel
              
                ]
              
              
                [
              
              value
              
                ]
              
              
                =
              
               calmaxCnt
              
                (
              
              dataSet
              
                )
              
              
                #数据子集设置为叶节点,用数据子集的父集中众数样本类别作为叶节点标签
              
              
                else
              
              
                :
              
              
            myTree
              
                [
              
              bestFeatLabel
              
                ]
              
              
                [
              
              value
              
                ]
              
              
                =
              
               createTree
              
                (
              
              subdataSet
              
                ,
              
               subLabels
              
                ,
              
               labelscounts
              
                )
              
              
                #以最优特征划分数据集为多个数据子集,并提供子集特征集,放入createTree()函数中开始递归
              
              
                return
              
               myTree                                            
              
                #返回字典形式树结构信息
              
              
                #可视化决策树的结果
              
              
dataSet
              
                ,
              
               labels 
              
                =
              
               createDataSet
              
                (
              
              
                )
              
              
                #生成数据集D和特征集A
              
              
                #print(len(dataSet[0]))#7
              
              
labelscounts 
              
                =
              
               get_Values
              
                (
              
              dataSet
              
                ,
              
               labels
              
                )
              
              
                #获得每种特征对应的所有特征值取值
              
              
                #print(labelscounts)#{'Color': {0, 1, 2}, 'Root': {0, 1, 2}, 'Knock': {0, 1, 2}, 'Texture': {0, 1, 2}, 'Umbilical': {0, 1, 2}, 'Touch': {0, 1}}
              
              
labels_tmp 
              
                =
              
               labels
              
                [
              
              
                :
              
              
                ]
              
              
                #复制特征集
              
              
desicionTree 
              
                =
              
               createTree
              
                (
              
              dataSet
              
                ,
              
               labels_tmp
              
                ,
              
               labelscounts
              
                )
              
              
                #创建决策树
              
              
                print
              
              
                (
              
              desicionTree
              
                )
              
              
                #{'Texture': {0: {'Root': {0: 'BadMalen', 1: {'Color': {0: 'GoodMalen', 1: 'GoodMalen', 2: {'Touch': {0: 'GoodMalen', 1: 'BadMalen'}}}}, 2: 'GoodMalen'}}, 1: {'Touch': {0: 'BadMalen', 1: 'GoodMalen'}}, 2: 'BadMalen'}}
              
              
                #决策树是一层层嵌套的字典,键是节点名(内部节点)或者特征值(子树的划分),值是一个字典(子树)或者类别名(叶节点)
              
              
treePlotter
              
                .
              
              createPlot
              
                (
              
              desicionTree
              
                )
              
              
                #使用treePlotter绘制决策树,
              
              
                #对新数据进行分类
              
              
                def
              
              
                classify
              
              
                (
              
              inputTree
              
                ,
              
               featLabels
              
                ,
              
               testVec
              
                )
              
              
                :
              
              
                """ 输入:决策树,分类标签,测试数据 输出:测试数据的决策结果 描述:跑决策树去预测测试数据的标签,返回一个预测值 """
              
              
                # print(testVec)
              
              
    classLabel
              
                =
              
              
                [
              
              
                ]
              
              
                #初始化测试数据标签
              
              
    firstStr 
              
                =
              
              
                list
              
              
                (
              
              inputTree
              
                .
              
              keys
              
                (
              
              
                )
              
              
                )
              
              
                [
              
              
                0
              
              
                ]
              
              
                #取出输入树中第一层字典的键名(某个特征)列表。树字典中第一层只有一个键值对,是父节点名字(键)及其对应子分支(值:字典形式)
              
              
    secondDict 
              
                =
              
               inputTree
              
                [
              
              firstStr
              
                ]
              
              
                #取出输入树字典中父节点键对应的值:除去了输入树第一层的树字典:二层树字典{0: {'B': {0: 'BadMalen', 1: {'A': {1: 'GoodMalen', 2: {'F': {0: 'GoodMalen', 1: 'BadMalen'}}}}, 2: 'GoodMalen'}}, 1: {'F': {0: 'BadMalen', 1: 'GoodMalen'}}, 2: 'BadMalen'}
              
              
    featIndex 
              
                =
              
               featLabels
              
                .
              
              index
              
                (
              
              firstStr
              
                )
              
              
                #获得输入树中第一层字典的键名(父节点名称:某个特征)对应特征名在特征集中的索引
              
              
                for
              
               key 
              
                in
              
               secondDict
              
                .
              
              keys
              
                (
              
              
                )
              
              
                :
              
              
                #对第二层树的键进行遍历,keys_value{'0','1','2'},第二层树的键的取值keys_value是对应父节点名字的特征值取值
              
              
                if
              
               testVec
              
                [
              
              featIndex
              
                ]
              
              
                ==
              
               key
              
                :
              
              
                # test数据的父节点上特征的取了哪个特征值({'0','1','2'}),就走哪个子分支
              
              
                if
              
              
                type
              
              
                (
              
              secondDict
              
                [
              
              key
              
                ]
              
              
                )
              
              
                .
              
              __name__ 
              
                ==
              
              
                'dict'
              
              
                :
              
              
                # 如果子分支的键值对中的值secondDict[key]仍然是字典,则进行递归
              
              
                classLabel 
              
                =
              
               classify
              
                (
              
              secondDict
              
                [
              
              key
              
                ]
              
              
                ,
              
               featLabels
              
                ,
              
               testVec
              
                )
              
              
                #递归函数的输入是(子分支的键值对中的值secondDict[key](字典,作为输入树),特征集,测试数据)
              
              
                else
              
              
                :
              
              
                # 如果子分支的键值对中的值secondDict[key]已经只是分类标签了,则返回这个类别标签
              
              
                # print(testVec)
              
              
                classLabel 
              
                =
              
               secondDict
              
                [
              
              key
              
                ]
              
              
                return
              
               classLabel                                          
              
                #返回测试数据的分类标签
              
              
                # Create Test Set生成测试集
              
              
                def
              
              
                createTestSet
              
              
                (
              
              
                )
              
              
                :
              
              
                """ 色泽Color-> 0: 浅白 | 1: 青绿 | 2: 乌黑 根蒂Root-> 0: 硬挺 | 1: 稍蜷 | 2: 蜷缩 敲声Knock-> 0: 清脆 | 1: 浊响 | 2:沉闷 纹理Texture-> 0: 清晰 | 1: 稍糊 | 2:模糊 脐部Umbilical-> 0: 平坦 | 1: 稍凹 | 2: 凹陷 触感Touch-> 0: 硬滑 | 1: 软粘 标签lab->'GoodMalen'| 'BadMalen' """
              
              
    testSet 
              
                =
              
              
                [
              
              
                [
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ]
              
              
                ,
              
              
                [
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                1
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ]
              
              
                ]
              
              
                return
              
               testSet
inputTree 
              
                =
              
               desicionTree                                                
              
                #导入已经建立的决策树
              
              
featLabels 
              
                =
              
              
                [
              
              
                'Color'
              
              
                ,
              
              
                'Root'
              
              
                ,
              
              
                'Knock'
              
              
                ,
              
              
                'Texture'
              
              
                ,
              
              
                'Umbilical'
              
              
                ,
              
              
                'Touch'
              
              
                ]
              
              
                #定义特征集
              
              
testVec 
              
                =
              
              
                [
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
              
                1
              
              
                ,
              
              
                0
              
              
                ]
              
              
                #一个测试数据
              
              
classify
              
                (
              
              inputTree
              
                ,
              
               featLabels
              
                ,
              
               testVec
              
                )
              
              
                #对测试数据分类
              
              
                #print(classify(inputTree, featLabels, testVec))
              
              
                #对多条新数据进行分类
              
              
                def
              
              
                classifyAll
              
              
                (
              
              inputTree
              
                ,
              
               featLabels
              
                ,
              
               testDataSet
              
                )
              
              
                :
              
              
                """ 输入:决策树,分类标签,测试数据集 输出:决策结果 描述:跑决策树 """
              
              
    classLabelAll 
              
                =
              
              
                [
              
              
                ]
              
              
                #初始化标签集
              
              
                for
              
               testVec 
              
                in
              
               testDataSet
              
                :
              
              
                #对测试数据集中的数据逐行遍历,对测试数据集中的数据逐个测试
              
              
                # print(testVec) 
              
              
        classLabelAll
              
                .
              
              append
              
                (
              
              classify
              
                (
              
              inputTree
              
                ,
              
               featLabels
              
                ,
              
               testVec
              
                )
              
              
                )
              
              
                #将测试结果添加到标签集中
              
              
                return
              
               classLabelAll                      
              
                #返回测试集的标签集
              
              
testSet 
              
                =
              
               createTestSet
              
                (
              
              
                )
              
              
                #获得测试集
              
              
                print
              
              
                (
              
              
                'classifyResult:\n'
              
              
                ,
              
               classifyAll
              
                (
              
              desicionTree
              
                ,
              
               labels
              
                ,
              
               testSet
              
                )
              
              
                )
              
              
                #打印分类结果
              
            
          

参考

周志华. (2016). 机器学习. 清华大学出版社, 北京
决策树的python实现
决策树算法及python实现
treePlotter模块


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论