统计学习方法(三)——K近邻法

系统 1681 0

 

/*先把标题给写了、这样就能经常提醒自己*/

1. k近邻算法

k临近算法的过程,即对一个新的样本,找到特征空间中与其最近的k个样本,这k个样本多数属于某个类,就把这个新的样本也归为这个类。

算法 

输入:训练数据集

其中为样本的特征向量,为实例的类别,i=1,2,…,N;样本特征向量x(新样本);

输出:样本x所属的类y。

(1)根据给定的距离度量,在训练集T中找出与x最相邻的k个点,涵盖这k个点的邻域记作;

(2)在中根据分类决策规则(如多数表决)决定x的类别y:

                                                                    (1)

式中I为指示函数,即当时I为1,否则为0。

由这个简单的算法过程可以看出来,距离的选择、以及k的选择都是很重要的,这恰好对应的三个要素中的两个,另一个为分类决策规则,一般来说是多数表决法。

 

2. k近邻模型

k近邻算法使用的模型实际上对应于特征空间的划分,模型由三个基本要素——距离度量、k值的选择和分类决策规则决定。

距离度量

      特征空间中俩个实例的距离是俩个实例点相似程度的反映,k近邻中一般使用欧氏距离,本文中主要只介绍这一种。

设特征空间 维实数向量空间 , , 距离定义为

统计学习方法(三)——K近邻法_第1张图片

 

当p=2时,称为欧氏距离(Euclidean distance).

==================在此吐槽一下,博客园的图片插入好折腾人啊,已经搞出肩周炎了,明天再继续码第二要素了 2014-6-30========================

举个粟子,已知 , ,则  的欧氏距离为 ,挺容易理解的吧!

K 值的选择

首先说明一下K值的选择对最终的结果有很大的影响!!!

      如果选择的k过小,则预测的结果对近邻的实例点非常敏感,如果近邻刚好是噪声,则预测就会出错,例如k=1,很难保证最近的一个点就是正确的预测,亦即容易发生过拟合!如果选择的k过大,则会忽略掉训练实例中的大量有用信息,例如k=N,那么无论输入实例是什么最终的结果都将是训练实例中最多的类。

      关于分类决策规则这里就不再赘述,正常情况下直接采用多数表决即可,如果觉得结果不满意的话,可以加入各个类的先验概率进去融合!

 

3. K近邻的实现

该小节书本中用到了KD树,通过构造平衡KD树来方便快速查找训练数据中离测试实例最近的点,不过构造这颗树本身是一个比较繁琐的过程(其实是本人代码能力实在太菜了,真的觉得把KD树写下来需要花太多时间了,而且KD树中每增加一个新数据又要进行节点插入操作,实在不方便,直接放弃),所以直接用最土豪的方法,时间复杂度差就差了,咱有的是CPU!!!

在这里直接套用书中例子,不过实现上就用其它算法了。稍等,我勒个去!书中的例子只是用于构造KD树的,李航兄你不厚道啊,说好的K近邻怎么变成这样了,不能直接引用书中例子了,自己再编一个得了。

例子: 训练数据集中,正样本点有 ,负样本点有 ,现要求判断实例 属于哪个类别,如下图所示:

       统计学习方法(三)——K近邻法_第2张图片

假设取K=3,则距离 最近的3个点为 ,按照多数表决规则可得出 应该属于正类。

  为了表示咱们不是拍脑袋给出的结果,下面给出具体的代码实现

      
        package
      
      
         org.juefan.knn;


      
      
        import
      
      
         java.util.ArrayList;

      
      
        import
      
      
         java.util.Collections;

      
      
        import
      
      
         java.util.Comparator;

      
      
        import
      
      
         java.util.HashMap;

      
      
        import
      
      
         java.util.Map;


      
      
        import
      
      
         org.juefan.basic.FileIO;

      
      
        import
      
      
         org.juefan.data.Data;


      
      
        public
      
      
        class
      
      
         SimpleKnn {
    
    
      
      
        public
      
      
        static
      
      
        final
      
      
        int
      
       K = 3
      
        ;        
    
      
      
        public
      
      
        static
      
      
        int
      
       P = 2;        
      
        //
      
      
        距离函数的选择,P=2即欧氏距离
      
      
        public
      
      
        class
      
      
         LabelDistance{
        
      
      
        public
      
      
        double
      
       distance = 0
      
        ;
        
      
      
        public
      
      
        int
      
      
         label;        
        
      
      
        public
      
       LabelDistance(
      
        double
      
       d, 
      
        int
      
      
         l){
            distance 
      
      =
      
         d;
            label 
      
      =
      
         l;
        }
    }
    
    
      
      
        public
      
       sort compare = 
      
        new
      
      
         sort();
    
      
      
        public
      
      
        class
      
       sort 
      
        implements
      
       Comparator<LabelDistance>
      
         {
        
      
      
        public
      
      
        int
      
      
         compare(LabelDistance arg0, LabelDistance arg1) {
            
      
      
        return
      
       arg0.distance < arg1.distance ? -1 : 1;        
      
        //
      
      
        JDK1.7的新特性,返回值必须是一对正负数
      
      
                }
    }
    
    
      
      
        /**
      
      
        
     * 俩个实例间的距离函数
     * 
      
      
        @param
      
      
         a
     * 
      
      
        @param
      
      
         b
     * 
      
      
        @return
      
      
         返回距离值,如果俩个实例的维度不一致则返回一个极大值
     
      
      
        */
      
      
        public
      
      
        double
      
      
         getLdistance(Data a, Data b){
        
      
      
        if
      
      (a.x.size() !=
      
         b.x.size())
            
      
      
        return
      
      
         Double.MAX_VALUE;
        
      
      
        double
      
       inner = 0
      
        ;
        
      
      
        for
      
      (
      
        int
      
       i = 0; i < P; i++
      
        ){
            inner 
      
      += Math.pow((a.x.get(i) -
      
         b.x.get(i)) , P);
        }
        
      
      
        return
      
       Math.pow(inner, (
      
        double
      
      )1/
      
        P);    
    }
    
    
      
      
        /**
      
      
        
     * 计算实例与训练集的距离并返回最终判断结果
     * 
      
      
        @param
      
      
         d 待判断实例
     * 
      
      
        @param
      
      
         tran 训练集
     * 
      
      
        @return
      
      
         实例的判断结果
     
      
      
        */
      
      
        public
      
      
        int
      
       getLabelvalue(Data d, ArrayList<Data>
      
         tran){
        ArrayList
      
      <LabelDistance> labelDistances= 
      
        new
      
       ArrayList<>
      
        ();
        Map
      
      <Integer, Integer> map = 
      
        new
      
       HashMap<>
      
        ();
        
      
      
        int
      
       label = 0
      
        ;
        
      
      
        int
      
       count = 0
      
        ;
        
      
      
        for
      
      
        (Data data: tran){
            labelDistances.add(
      
      
        new
      
      
         LabelDistance(getLdistance(d, data), data.y));
        }
        Collections.sort(labelDistances, compare);
        
      
      
        for
      
      (
      
        int
      
       i = 0; i < K & i < labelDistances.size(); i++
      
        ){
            //System.out.println(labelDistances.get(i).distance 
      
      + "\t" +
      
         labelDistances.get(i).label);
            
      
      
        int
      
       tmplabel =
      
         labelDistances.get(i).label;
            
      
      
        if
      
      
        (map.containsKey(tmplabel)){
                map.put(tmplabel, map.get(tmplabel) 
      
      + 1
      
        );
            }
      
      
        else
      
      
         {
                map.put(tmplabel, 
      
      1
      
        );
            }
        }
        
      
      
        for
      
      (
      
        int
      
      
         key: map.keySet()){
            
      
      
        if
      
      (map.get(key) >
      
         count){
                count 
      
      =
      
         map.get(key);
                label 
      
      =
      
         key;
            }
        }
        
      
      
        return
      
      
         label;    
    }
    
    
      
      
        public
      
      
        static
      
      
        void
      
      
         main(String[] args) {
        SimpleKnn knn 
      
      = 
      
        new
      
      
         SimpleKnn();
        ArrayList
      
      <Data> datas = 
      
        new
      
       ArrayList<>
      
        ();
        FileIO fileIO 
      
      = 
      
        new
      
      
         FileIO();
        fileIO.setFileName(
      
      ".//file//knn.txt"
      
        );
        fileIO.FileRead();
        
      
      
        for
      
      
        (String data: fileIO.fileList){
            datas.add(
      
      
        new
      
      
         Data(data));
        }
        Data data 
      
      = 
      
        new
      
      
         Data();
        data.x.add(
      
      2); data.x.add(1
      
        );
        System.out.println(knn.getLabelvalue(data, datas));
    }
}
      
    

 

 

对代码有兴趣的可以上本人的GitHub查看: https://github.com/JueFan/StatisticsLearningMethod/

统计学习方法(三)——K近邻法


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论