用Python实现GBDT算法并处理Iris数据集

系统 2114 0

GBDT,梯度提升树属于一种有监督的集成学习方法,与之前学习的监督算法类似,同样可以用于分类问题的识别和预测问题的解决。该集成算法体现了三个方面的又是,分别是提升Boosting、梯度Gradient、决策树Decision Tree。“提升”是指将多个弱分类器通过线下组合实现强分类器的过程;“梯度”指的是在Boosting过程中求解损失函数时增加了灵活性和便捷性,“决策树”是指算法所使用的弱分类器为CART决策树,该决策树具有简单直观、通俗易懂的特性。

GBDT模型对数据类型不做任何限制,既可以是连续的数值型,也可以是离散的字符型(但在Python的落地过程中需要将字符型变量做数值化处理或者哑变量处理)。相对于SVM模型来说,较少参数的GBDT具有更高的准确率和更少的运算时间,GBDT模型在面对异常数据时具有更强的稳定性。由于上面种种优点,GBDT常常被广大用户所采用。

最具代表的提升树算法为AdaBoost算法,即:

 

其中F(x)是由M棵基础决策树构成的最终提升树,F_(m-1) (x)表示经过m-1轮迭代后的提升树,α_m为第m棵基础决策树所对应的权重,f_m (x)为第m棵基础决策树。
AdaBoost算法在解决分类问题时,它的核心就是不停地改变样本点的权重,并将每一轮的基础决策树通过权重的方式进行线性组合。该算法在迭代过程中需要进行如下四个步骤:

用Python实现GBDT算法并处理Iris数据集_第1张图片

在这里举一个简单的例子。如下表:

x

0

1

2

3

4

5

6

7

8

9

y

1

1

1

-1

-1

-1

1

1

1

-1

用Python实现GBDT算法并处理Iris数据集_第2张图片

故第一轮迭代结果为:

x

0

1

2

3

4

5

6

7

8

9

y实际

1

1

1

-1

-1

-1

1

1

1

-1

预测得分

0.424

0.424

0.424

-0.424

-0.424

-0.424

-0.424

-0.424

-0.424

-0.424

y预测

1

1

1

-1

-1

-1

-1

-1

-1

-1

显然6、7、8三个点的预测结果是错的,所以它们对应的权重也是最大的,在进入第二轮时,模型会更加关注这三个点。

AdaBoost算法具体在Python上的实现方式为导入sklearn中的子模块ensemble,从中调用AdaBoostClassifier类。

在这里我再次使用Iris数据集进行测试,原码及效果为如下:

            adaBoost =
            
               ensemble.AdaBoostClassifier()
adaBoost.fit(x_train,y_train)
predict 
            
            =
            
               adaBoost.predict(x_test)

            
            
              print
            
            (
            
              '
            
            
              Accuracy: 
            
            
              '
            
            
              ,metrics.accuracy_score(y_test,predict))

            
            
              print
            
            (
            
              '
            
            
              Report :\n
            
            
              '
            
            ,metrics.classification_report(y_test,predict))
            

Accuracy: 0.7
Report :
precision recall f1-score support

0.0 1.00 0.92 0.96 12
1.0 0.62 0.67 0.64 12
2.0 0.33 0.33 0.33 6

micro avg 0.70 0.70 0.70 30
macro avg 0.65 0.64 0.64 30
weighted avg 0.71 0.70 0.71 30


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论