简单决策树调用&可视化【Python】

系统 1344 0

决策树部分理论支撑

1* 通过选取一定的特征来降低数据的不确定性(熵)

2* 建议寻找多分类问题的最优特征的最优候选值。把多分类问题转换成多几层递归的二分类问题,防止数据对特征值的控制敏感。

3* 停止条件

  • 取得了最够好的分类结果
  • 递归到了预定的最深深度
  • 叶子节点的纯度
  • 分裂次数达到极限
  • 最大特征数
  • . . .

4* 相关公式

  • e n t r o p y ( D ) = − ∑ i = 1 n P i l o g 2 P i entropy(D) = -\sum_{i=1}^n P_ilog_2 P_i e n t r o p y ( D ) = i = 1 n P i l o g 2 P i
    e n t r o p y ( D , A ) = ∑ i = 1 k D A i D l o g 2 D A i entropy(D,A) = \sum_{i=1}^k \frac {D_{A_i}}{D} log_2D_{A_i} e n t r o p y ( D , A ) = i = 1 k D D A i l o g 2 D A i
    g a i n ( D , A ) = e n t r o p y ( D ) − e n t r o p y ( D , A ) gain(D,A) = entropy(D) - entropy(D,A) g a i n ( D , A ) = e n t r o p y ( D ) e n t r o p y ( D , A )
    原本的熵 减去 考虑某种特征条件A之后的熵,得到信息增益
    g a i n r a t e ( D , A ) = g a i n ( D , A ) / e n t r o p y ( D , A ) gain_rate(D,A) = gain(D,A)/entropy(D,A) g a i n r a t e ( D , A ) = g a i n ( D , A ) / e n t r o p y ( D , A )
    同理,根据同样的方法可以得到 信息增益率

简单决策树调用&可视化【Python】_第1张图片

            
              
                import
              
               pandas 
              
                as
              
               pd

              
                import
              
               numpy 
              
                as
              
               np
df 
              
                =
              
               pd
              
                .
              
              read_csv
              
                (
              
              
                'C:\\Users\\76485\\Desktop\\column.2C.csv'
              
              
                )
              
              
                '''
导入数据源&导入基本包;
交互命令窗口输入df.head()查看前五行数据
'''
              
              
X 
              
                =
              
               df
              
                .
              
              drop
              
                (
              
              
                'V7'
              
              
                ,
              
              axis 
              
                =
              
              
                1
              
              
                )
              
              
                #drop进行有选择的数据删除,删除头标签为'V7'的列数据
              
              
y 
              
                =
              
               df
              
                .
              
              V7

              
                '''
轴用来为超过一维的数组定义属性,二维数据:0轴沿着行的方向向下,1轴沿着列的水平方向延伸
'''
              
              
                '''
#将文字转换为数字的标准化程序

from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
for col in data.columns:
    data[col] = labelencoder.fit_transform(data[col])

'''
              
              
                #from sklearn.cross_validation import train_test_split#该包在新版本中,如下实现
              
              
                from
              
               sklearn
              
                .
              
              model_selection 
              
                import
              
               train_test_split
X_train
              
                ,
              
               X_test
              
                ,
              
               y_train
              
                ,
              
               y_test 
              
                =
              
               train_test_split
              
                (
              
              X
              
                ,
              
               y
              
                ,
              
               random_state
              
                =
              
              
                1
              
              
                )
              
              
                #将数据集拆分为训练集和测试集
              
              
                from
              
               sklearn 
              
                import
              
               tree
clf 
              
                =
              
               tree
              
                .
              
              DecisionTreeClassifier
              
                (
              
              max_depth 
              
                =
              
              
                4
              
              
                )
              
              
                #建树
              
              
clf 
              
                =
              
               clf
              
                .
              
              fit
              
                (
              
              X_train
              
                ,
              
               y_train
              
                )
              
              

test_rec 
              
                =
              
               X_test
              
                .
              
              iloc
              
                [
              
              
                1
              
              
                ,
              
              
                :
              
              
                ]
              
              
clf
              
                .
              
              predict
              
                (
              
              
                [
              
              test_rec
              
                ]
              
              
                )
              
              
                #测试集测试,交互窗口输入
              
              
                '''
Ans:
Out[16]: array(['NO'], dtype=object)
'''
              
              

y_test
              
                .
              
              iloc
              
                [
              
              
                1
              
              
                ]
              
              
                #调出真实结果,交互窗口输入
              
              
                '''
Ans:
Out[17]: 'NO' 
'''
              
              
                from
              
               sklearn
              
                .
              
              metrics 
              
                import
              
               accuracy_score
rate_ac 
              
                =
              
               accuracy_score
              
                (
              
              y_test
              
                ,
              
               clf
              
                .
              
              predict
              
                (
              
              X_test
              
                )
              
              
                )
              
              
                #测试模型准确率
              
              
                print
              
              
                (
              
              rate_ac
              
                )
              
              
                '''
0.8205128205128205 稳定在80%左右,建树层数对准确率影响较小
'''
              
              
                '''决策树可视化'''
              
              
                '''
with open("lc-is.dot", 'w') as f:
     f = tree.export_graphviz(clf,
                              out_file=f,
                              max_depth = 3,
                              impurity = True,
                              feature_names = list(X_train),
                              class_names = ['AB', 'NO'],
                              rounded = True,
                              filled= True )
'''
              
              
                from
              
               sklearn
              
                .
              
              tree 
              
                import
              
               DecisionTreeClassifier

              
                import
              
               pydotplus
              
                #若提示没有此包,需在cmd-Anaconda Prompt键入install pydotplus
              
              
                from
              
               IPython
              
                .
              
              display 
              
                import
              
               Image

              
                from
              
               IPython
              
                .
              
              display 
              
                import
              
               display

              
                from
              
               sklearn
              
                .
              
              tree 
              
                import
              
               export_graphviz
              
                #需手动下载并配置绝对路径
              
              
                import
              
               os
os
              
                .
              
              environ
              
                [
              
              
                "PATH"
              
              
                ]
              
              
                +=
              
               os
              
                .
              
              pathsep 
              
                +
              
              
                'C:/Program Files (x86)/Graphviz2.38/bin/'
              
              
                #沙雕pydotplus,配置环境变量(路径)
              
              

dot_tree 
              
                =
              
               tree
              
                .
              
              export_graphviz
              
                (
              
              clf
              
                ,
              
              out_file
              
                =
              
              
                None
              
              
                ,
              
              
                                feature_names
              
                =
              
              
                [
              
              
                'V1'
              
              
                ,
              
              
                'V2'
              
              
                ,
              
              
                'V3'
              
              
                ,
              
              
                'V4'
              
              
                ,
              
              
                'V5'
              
              
                ,
              
              
                'V6'
              
              
                ]
              
              
                ,
              
              
                                class_names
              
                =
              
              
                [
              
              
                'AB'
              
              
                ,
              
              
                'NO'
              
              
                ]
              
              
                ,
              
              
                                filled
              
                =
              
              
                True
              
              
                ,
              
               
                                rounded
              
                =
              
              
                True
              
              
                ,
              
              
                                special_characters
              
                =
              
              
                True
              
              
                )
              
              
graph 
              
                =
              
               pydotplus
              
                .
              
              graph_from_dot_data
              
                (
              
              dot_tree
              
                )
              
              
img 
              
                =
              
               Image
              
                (
              
              graph
              
                .
              
              create_png
              
                (
              
              
                )
              
              
                )
              
              
graph
              
                .
              
              write_png
              
                (
              
              
                "out.png"
              
              
                )
              
            
          

更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论