(李航统计学习方法)感知机Python实现

系统 1530 0

机器学习的三要素: 模型,策略,算法
模型:感知机是二分类 线性分类模型 ,属于 判别模型
策略:基于误分类点到超平面的总距离。
学习算法:略
感知机存在的问题:

  1. 存在多解,解依赖于初始超平面的选择以及迭代过程中误分类点的选择。
  2. 训练集线性不可分,算法无法收敛,解决方法:pocket算法或者使用核函数。
  3. 无法解决异或问题

Python代码实现:

            
              import numpy as np
def train(X_train,Y_train):
    print(np.shape(X_train))
    m,n=np.shape(X_train)
    w=np.zeros((n,1))
    b=0
    while True:
        count=m
        for i in range(m):
            result=Y_train[i]*(np.dot(X_train[i],w)+b)
            if result<=0:
                count-=1
                for j in range(n):
                    w[j]=w[j]+X_train[i][j]*Y_train[i]
                b=b+Y_train[i]
                print("w:",w)
                print("b:",b)
                break
        if count==m:
            break
    return  w,b
def predict(w,b,X_test):
    y_=np.dot(X_test,w)+b
    return np.where(y_>1,1,-1)
def main():
    X_train=np.array(([3,3],[4,3],[1,1]))
    Y_train=np.array(([1,1,-1]))
    w,b=train(X_train,Y_train)
    X_test=np.array(([2,3],[-15,6],[1,4]))
    print(predict(w,b,X_test))
if __name__=='__main__':
    main()

            
          

更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论