统计学习方法(二)——感知机

系统 1653 0

/*先把标题给写了、这样就能经常提醒自己*/

1. 感知机模型

我们先来定义一下什么是感知机。所谓感知机,就是二类分类的线性分类模型,其输入为样本的特征向量,输出为样本的类别,取+1和-1二值,即通过某样本的特征,就可以准确判断该样本属于哪一类。顾名思义,感知机能够解决的问题首先要求特征空间是线性可分的,再者是二类分类,即将样本分为{+1, -1}两类。从比较学术的层面来说,由输入空间到输出空间的函数:

                                                                                                         (1)

称为感知机,w和b为感知机参数,w为权值(weight),b为偏置(bias)。sign为符号函数:

                                                                                                         (2)

感知机模型的假设空间是定义在特征空间中的所有线性分类模型,即函数集合{f|f(x) = w·x + b}。在感知机的定义中,线性方程w·x + b = 0对应于问题空间中的一个超平面S,位于这个超平面两侧的样本分别被归为两类,例如下图,红色作为一类,蓝色作为另一类,它们的特征很简单,就是它们的坐标

统计学习方法(二)——感知机_第1张图片

图1

作为监督学习的一种方法,感知机学习由训练集求得感知机模型,即求得模型参数w,b,这里x和y分别是特征向量和类别(也称为目标)。基于此,感知机模型可以对新的输入样本进行分类。

前面半抄书半自说自话把感知机的定义以及是用来干嘛的简单记录了一下,作为早期的机器学习方法(1957年由Frank Rosenblatt提出),它是最简单的前馈神经网络,对之后的机器学习方法如神经网络起一个基础的作用,下一节将详细介绍感知机学习策略。

2. 感知机学习策略

上节说到,感知机是一个简单的二类分类的线性分类模型,要求我们的样本是线性可分的,什么样的样本是线性可分的呢?举例来说,在二维平面中,可以用一条直线将+1类和-1类完美分开,那么这个样本空间就是线性可分的。如图1就是线性可分的,图2中的样本就是线性不可分的,感知机就不能处理这种情况。因此,在本章中的所有问题都基于一个前提,就是问题空间线性可分。

统计学习方法(二)——感知机_第2张图片

图2

为方便说明问题,我们假设数据集中所有的的实例i,有;对的实例有。

这里先给出输入空间中任一点到超平面S的距离:

                                                                                                               (3)

这里||w||是w的范数。

对于误分类的数据,根据我们之前的假设,有

                                                                                                          (4)

因此误分类点到超平面S的距离可以写作:

                                                                                                            (5)

假设超平面S的误分类点集合为M,那么所有误分类点到超平面S的总距离为

                                                                                                    (6)

这里的||w||值是固定的,不必考虑,这样就得到了感知机学习的损失函数。根据我们的定义,这个损失函数自然是越小越好,因为这样就代表着误分类点越少、误分类点距离超平面S的距离越近,即我们的分类面越正确。显然,这个损失函数是非负的,若所有的样本都分类正确,那么我们的损失函数值为0。一个特定的样本集T的损失函数:在误分类时是参数w,b的线性函数。也就是说,为求得正确的参数w,b,我们的目标函数为

                                                                                          (7)

而它是连续可导的,这就使得我们可以比较容易的求得其最小值。

感知机学习的策略是在假设空间中选取使我们的损失函数(7)最小的模型参数w,b,即感知机模型。

根据感知机定义以及我们的假设,得到了感知机的模型,即目标函数(7),将其最小化的本质就是使得分类面S尽可能的正确,下一节介绍将其最小化的方法——随机梯度下降。

3. 感知机学习算法

根据感知机学习的策略,我们已经将寻找超平面S的问题转化为求解式(7)的最优化问题,最优化的方法是随机梯度下降法,书中介绍了两种形式:原始形式和对偶形式,并证明了在训练集线性可分时算法的收敛性。

3.1 原始形式

所谓原始形式,就是我们用梯度下降的方法,对参数w和b进行不断的迭代更新。具体来说,就是先任意选取一个超平面,对应的参数分别为和,当然现在是可以任意赋值的,比如说选取为全为0的向量,的值为0。然后用梯度下降不断地极小化损失函数(7)。由于随机梯度下降(stochastic     gradient descent)的效率要高于批量梯度下降(batch gradient descent)(详情可参考Andrew Ng教授的 讲义 ,在Part 1的LMS algorithm部分),所以这里采用随机梯度下降的方法,每次随机选取一个误分类点对w和b进行更新。

设误分类点集合M是固定的,为求式(7)的最小值,我们需要知道往哪个方向下降速率最快,这是可由对损失函数L(w, b)求梯度得到,L(w, b)的梯度为

接下来随机选取一个误分类点对w,b进行更新

                                                                                                                (8)

                                                                                                                      (9)

其中为步长,也称为学习速率(learning rate),一般在0到1之间取值,步长越大,我们梯度下降的速度越快,也就能更快接近极小点。如果步长过大,就有直接跨过极小点导致函数发散的问题;如果步长过小,可能会耗费比较长的时间才能达到极小点。通过这样的迭代,我们的损失函数就不断减小,直到为0。综上所述,得到如下算法:

算法1 (感知机学习算法的原始形式)

输入:训练数据集,其中,,i = 1,2,…,N;学习率

输出:w,b;感知机模型

(1)选取初始值,

(2)在训练集中选取数据

(3)如果(从公式(3)变换而来)

(4)转至(2),直至训练集中没有误分类点

这种学习算法直观上有如下解释:当一个样本被误分类时,就调整w和b的值,使超平面S向误分类点的一侧移动,以减少该误分类点到超平面的距离,直至超平面越过该点使之被正确分类。

书上还给出了一个例题,这是我推崇这本书的原因之一,凡是只讲理论不给例子的行为都是耍流氓!

例1  如图3所示的训练数据集,其正实例点是,,负实例点是,试用感知机学习算法的原始形式求感知机模型,即求出w和b。这里,

统计学习方法(二)——感知机_第3张图片

图3

这里我们取初值,取。具体问题解释不写了,求解的方法就是 算法1 。下面给出这道题的Java代码(终于有一段是自己纯原创的了)。

 
 
          
            package
          
          
             org.juefan.perceptron;

          
          
            import
          
          
             java.util.ArrayList;

          
          
            import
          
          
             org.juefan.basic.FileIO;

          
          
            public
          
          
            class
          
          
             PrimevalPerceptron {
    
    
          
          
            public
          
          
            static
          
           ArrayList<Integer> w  = 
          
            new
          
           ArrayList<>
          
            ();
    
          
          
            public
          
          
            static
          
          
            int
          
          
             b ;
    
    
          
          
            /*
          
          
            初始化参数
          
          
            */
          
          
            public
          
          
             PrimevalPerceptron(){
        w.add(
          
          5
          
            );
        w.add(
          
          -2
          
            );
        b 
          
          = 3
          
            ;
    }
    
    
          
          
            /**
          
          
            
     * 判断是否分类正确
     * 
          
          
            @param
          
          
             data 待判断数据
     * 
          
          
            @return
          
          
             返回判断正确与否
     
          
          
            */
          
          
            public
          
          
            static
          
          
            boolean
          
          
             getValue(Data data){
        
          
          
            int
          
           state = 0
          
            ;
        
          
          
            for
          
          (
          
            int
          
           i = 0; i < data.x.size(); i++
          
            ){
            state 
          
          += w.get(i) *
          
             data.x.get(i);
        }
        state 
          
          +=
          
             b;
        
          
          
            return
          
           state * data.y > 0? 
          
            true
          
          : 
          
            false
          
          
            ;    
    }
    
    
          
          
            //
          
          
            此算法基于数据是线性可分的,如果线性不可分,则会进入死循环
          
          
            public
          
          
            static
          
          
            boolean
          
           isStop(ArrayList<Data>
          
             datas){
        
          
          
            boolean
          
           isStop = 
          
            true
          
          
            ;
        
          
          
            for
          
          
            (Data data: datas){
            isStop 
          
          = isStop &&
          
             getValue(data);
        }
        
          
          
            return
          
          
             isStop;
    }
    
    
          
          
            public
          
          
            static
          
          
            void
          
          
             main(String[] args) {
        PrimevalPerceptron model 
          
          = 
          
            new
          
          
             PrimevalPerceptron();
        ArrayList
          
          <Data> datas = 
          
            new
          
           ArrayList<>
          
            ();
        FileIO fileIO 
          
          = 
          
            new
          
          
             FileIO();
        fileIO.setFileName(
          
          ".//file//perceptron.txt"
          
            );
        fileIO.FileRead();
        
          
          
            for
          
          
            (String data: fileIO.fileList){
            datas.add(
          
          
            new
          
          
             Data(data));
        }
    
        
          
          
            /**
          
          
            
         * 如果全部数据都分类正确则结束迭代
         
          
          
            */
          
          
            while
          
          (!
          
            isStop(datas)){
            
          
          
            for
          
          (
          
            int
          
           i = 0; i < datas.size(); i++
          
            ){
                
          
          
            if
          
          (!getValue(datas.get(i))){  
          
            //
          
          
            这里面可以理解为是一个简单的梯度下降法
          
          
            for
          
          (
          
            int
          
           j = 0; j < datas.get(i).x.size(); j++
          
            )
                    w.set(j, w.get(j) 
          
          + datas.get(i).y *
          
             datas.get(i).x.get(j));
                    b 
          
          +=
          
             datas.get(i).y;
                    System.out.println(w 
          
          + "\t" +
          
             b);
                }
            }
        }    
        System.out.println(w 
          
          + "\t" + b);        
          
            //
          
          
            输出最终的结果
          
          
                }
}
          
        

最后解得(这里应该是写错了,最终结果b=-3)。不过,如果选取的初值不同,或者选取的误分类点不同,我们得到的超平面S也不尽相同,毕竟感知机模型的解是一组符合条件的超平面的集合,而不是某一个最优超平面。

3.2 算法的收敛性

一节纯数学的东西,用了两整页证明了Novikoff定理,看到这里才知道智商真的是个硬伤,反复看了两遍,又把证明的过程自己推导了一遍,算是看懂了,如果凭空证明的话,自己的功力还差得远。

Novikoff于1962年证明了感知机算法的收敛性,作为一个懒人,由于这一节涉及大量公式,即使有latex插件也是个麻烦的工作,具体什么情况我就不谈了,同时,哥伦比亚大学有这样的一篇叫《 Convergence Proof for the Perceptron Algorithm 》的笔记,讲解了这个定理的证明过程,也给了我一个偷懒的理由 微笑

3.3 感知机学习算法的对偶形式

书上说对偶形式的基本想法是,将w和b表示为实例和的线性组合形式,通过求解其系数而求得w和b。这个想法及下面的算法描述很容易看懂,只不过为什么要这么做?将w和b用x和y来表示有什么好处呢?看起来也不怎么直观,计算量上似乎并没有减少。如果说支持向量机的求最优过程使用对偶形式是为了方便引入核函数,那这里的对偶形式是用来做什么的呢?暂且认为是对后面支持向量机的一个铺垫吧,或者是求解这种优化问题的一个普遍解法。

继续正题,为了方便推导,可将初始值和都设为0,据上文,我们对误分类点通过

来更新w,b,假设我们通过误分类点更新参数的次数为次,那么w,b关于的增量为和,为方便,可将用来表示,很容易可以得到

                                                                                                          (10)

                                                                                                                (11)

这里i = 1,2,…,N。当时,表示第i个样本由于被误分类而进行更新的次数。某样本更新次数越多,表示它距离超平面S越近,也就越难正确分类。换句话说,这样的样本对学习结果影响最大。

算法2 (感知机学习算法的对偶形式)

输入:训练数据集,其中,,i = 1,2,…,N;学习率

输出:,b;感知机模型

其中

(1),

(2)在训练集中选取样本

(3)如果

(4)转至(2)直到没有误分类样本出现

由于训练实例仅以内积的形式出现,为方便,可预先将训练集中实例间的内积计算出来并以矩阵形式存储(就是那个的部分),这就是所谓的Gram矩阵(线性代数学的不好的飘过To T):

又到例题时间!再说一遍,这本书最大的好处就是有实例,凡是只讲理论不给例子的行为都是耍流氓!

例2  同 例1 ,只不过是用对偶形式来求解。

同样过程不再分析,给出我的求解代码:

 

          
            package
          
          
             org.juefan.perceptron;

          
          
            import
          
          
             java.util.ArrayList;

          
          
            import
          
          
             org.juefan.basic.FileIO;

          
          
            public
          
          
            class
          
          
             GramPerceptrom {
    
    
          
          
            public
          
          
            static
          
           ArrayList<Integer> a  = 
          
            new
          
           ArrayList<>
          
            ();
    
          
          
            public
          
          
            static
          
          
            int
          
          
             b ;
    
    
          
          
            /*
          
          
            初始化参数
          
          
            */
          
          
            public
          
           GramPerceptrom(
          
            int
          
          
             num){
        
          
          
            for
          
          (
          
            int
          
           i = 0; i < num; i++
          
            )
            a.add(
          
          0
          
            );
        b 
          
          = 0
          
            ;
    }
    
          
          
            /**
          
          
            Gram矩阵
          
          
            */
          
          
            public
          
          
            static
          
           ArrayList<ArrayList<Integer>> gram = 
          
            new
          
           ArrayList<>
          
            ();
    
          
          
            public
          
          
            void
          
           setGram(ArrayList<Data>
          
             datas){
        
          
          
            for
          
          (
          
            int
          
           i = 0; i < datas.size(); i++
          
            ){
            ArrayList
          
          <Integer> rowGram = 
          
            new
          
           ArrayList<>
          
            ();
            
          
          
            for
          
          (
          
            int
          
           j = 0; j < datas.size(); j++
          
            ){
                rowGram.add(Data.getInner(datas.get(i), datas.get(j)));
            }
            gram.add(rowGram);
        }
    }
    
    
          
          
            /**
          
          
            是否正确分类
          
          
            */
          
          
            public
          
          
            static
          
          
            boolean
          
           isCorrect(
          
            int
          
           i, ArrayList<Data>
          
             datas){
        
          
          
            int
          
           value = 0
          
            ;
        
          
          
            for
          
          (
          
            int
          
           j = 0; j < datas.size(); j++
          
            )
            value 
          
          += a.get(j)*datas.get(j).y *
          
             gram.get(j).get(i);
        value 
          
          = datas.get(i).y * (value +
          
             b);
        
          
          
            return
          
           value > 0 ? 
          
            true
          
          : 
          
            false
          
          
            ;
    }
    
    
          
          
            //
          
          
            此算法基于数据是线性可分的,如果线性不可分,则会进入死循环
          
          
            public
          
          
            static
          
          
            boolean
          
           isStop(ArrayList<Data>
          
             datas){
        
          
          
            boolean
          
           isStop = 
          
            true
          
          
            ;
        
          
          
            for
          
          (
          
            int
          
           i = 0; i < datas.size(); i++
          
            ){
            isStop 
          
          = isStop &&
          
             isCorrect(i, datas);
        }
        
          
          
            return
          
          
             isStop;
    }
    
    
          
          
            public
          
          
            static
          
          
            void
          
          
             main(String[] args) {
        ArrayList
          
          <Data> datas = 
          
            new
          
           ArrayList<>
          
            ();
        FileIO fileIO 
          
          = 
          
            new
          
          
             FileIO();
        fileIO.setFileName(
          
          ".//file//perceptron.txt"
          
            );
        fileIO.FileRead();
        
          
          
            for
          
          
            (String data: fileIO.fileList){
            datas.add(
          
          
            new
          
          
             Data(data));
        }
        GramPerceptrom gram  
          
          = 
          
            new
          
          
             GramPerceptrom(datas.size());
        gram.setGram(datas);
        System.out.println(datas.size());
        
          
          
            while
          
          (!
          
            isStop(datas)){
            
          
          
            for
          
          (
          
            int
          
           i = 0; i < datas.size(); i++
          
            )
                
          
          
            if
          
          (!
          
            isCorrect(i, datas)){
                    a.set(i, a.get(i) 
          
          + 1
          
            );
                    b 
          
          +=
          
             datas.get(i).y;
                    System.out.println(a 
          
          + "\t" +
          
             b);
                }
        }
    }
}
          
        

 

4. 小结

终于写完一章的内容了,用了不少功夫,果然是说起来容易做起来难呢。不过通过这样记录的方式(虽然大部分是抄书),自己对相关算法的理论及过程就又学了一遍,觉得这个时间还是花的值得的。

本章介绍了统计学习中最简单的一种算法——感知机,对现在的机器学习理论来说,这个算法的确是太简单了,但这样简单的东西却是很多现在流行算法的基础,比如神经网络,比如支持向量机,Deep Learning还没了解过,不知道有多大联系。当然,实际应用价值不能说没有,可以用于一些简单的线性分类的情况,也可以由简单的二类分类扩展到多类分类(详见 此PPT ),可以用于自然语义处理等领域。

再将思路整理一下,尽量掌握感知机算法,再好好看看 维基百科链接 中的有关文献。

 

PS:本文的内容转载自 http://www.cnblogs.com/OldPanda/archive/2013/04/12/3017100.html

原博主是Python写的代码,我这边改成Java了

对代码有兴趣的可以上本人的GitHub查看: https://github.com/JueFan/StatisticsLearningMethod/

2014-06-29:俩种算法的代码都完成了,更新完毕

统计学习方法(二)——感知机


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论