随机梯度下降法(Stochastic gradient descent, SGD)+python实现! |
文章目录
- 一、设定样本
- 二、梯度下降法原理
- 三、BGD:批量梯度下降法
- 3.1、python代码实现:
- 3.2、损失函数值曲线
- 3.3、BGD总结
- 四、SGD:随机梯度下降法
- 4.1、python代码实现:
- 4.2、损失函数值曲线
- 4.3、SGD总结
- 五、MBGD:小批量梯度下降
- 5.1、python代码实现:
- 5.2、损失函数值曲线
- 5.3、MBGD总结
- 参考文章
一、设定样本
假设我们提供了这样的数据样本(样本值取自于 y = 3 x 1 + 4 x 2 y=3x_{1}+4x_{2} y = 3 x 1 + 4 x 2 ):其中: x 1 x_{1} x 1 和 x 2 x_{2} x 2 是样本值, y y y 是预测目标。
x 1 x_{1} x 1 | x 2 x_{2} x 2 | y y y |
---|---|---|
1 | 4 | 19 |
2 | 5 | 26 |
5 | 1 | 19 |
4 | 2 | 29 |
我们需要以一条直线来拟合上面的数据,待拟合的函数如下:
(1) h ( Θ ) = Θ 1 x 1 + Θ 2 x 2 h(\Theta)=\Theta_{1} x_{1}+\Theta_{2} x_{2}\tag{1}
h
(
Θ
)
=
Θ
1
x
1
+
Θ
2
x
2
(
1
)
我们的目的就是要求出
Θ 1 \Theta_{1}
Θ
1
和
Θ 2 \Theta_{2}
Θ
2
的值,让
h ( Θ ) h(\Theta)
h
(
Θ
)
尽量逼近目标值
y y
y
。这是一个线性回归问题,若对线性回归有所了解的话我们知道:利用
最小二乘法则和梯度下降法
可以求出两个参数,而深度学习也同样可以利用这两种方法求得所有的网络参数,因此,在这里用这个数学模型来解释BGD、SGD、MSGD这几个概念。
二、梯度下降法原理
我们首先确定损失函数如下(均方误差):
(2) J ( Θ ) = 1 2 m ∑ i = 1 m [ h Θ ( x i ) − y i ] 2 J(\Theta)=\frac{1}{2 m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right]^{2}\tag{2}
J
(
Θ
)
=
2
m
1
i
=
1
∑
m
[
h
Θ
(
x
i
)
−
y
i
]
2
(
2
)
其中:
J ( Θ ) J(\Theta)
J
(
Θ
)
是损失函数,
m m
m
代表每次取多少样本进行训练,如果采用
S G D SGD
S
G
D
进行训练,那每次随机取一个样本
m = 1 m=1
m
=
1
;如果是批处理,则
m m
m
等于每次抽取作为训练样本的数量。
Θ \Theta
Θ
是参数,对应式
( 1 ) (1)
(
1
)
的
Θ 1 \Theta_{1}
Θ
1
和
Θ 2 \Theta_{2}
Θ
2
。求出了
Θ 1 \Theta_{1}
Θ
1
和
Θ 2 \Theta_{2}
Θ
2
,
h ( Θ ) h(\Theta)
h
(
Θ
)
的表达式就出来了:
(3) h ( Θ ) = ∑ Θ j x j = Θ 1 x 1 + Θ 2 x 2 h(\Theta)=\sum \Theta_{j} x_{j}=\Theta_{1} x_{1}+\Theta_{2} x_{2}\tag{3}
h
(
Θ
)
=
∑
Θ
j
x
j
=
Θ
1
x
1
+
Θ
2
x
2
(
3
)
我们的目标是让损失函数
J ( Θ ) J(\Theta)
J
(
Θ
)
的值最小,根据梯度下降法,首先要用
J ( Θ ) J(\Theta)
J
(
Θ
)
对
Θ \Theta
Θ
求偏导:
(4) σ J ( Θ ) σ Θ j = 2 1 2 m ∑ i = 1 m [ h Θ ( x i ) − y i ] x j i = 1 m ∑ i = 1 m [ h Θ ( x i ) − y i ] x j i \frac{\sigma J(\Theta)}{\sigma \Theta_{j}}=2 \frac{1}{2 m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right] x_{j}^{i}=\frac{1}{m} \sum_{i=1}^{m}\left[h_{\Theta}\left(x^{i}\right)-y^{i}\right] x_{j}^{i}\tag{4}
σ
Θ
j
σ
J
(
Θ
)
=
2
2
m
1
i
=
1
∑
m
[
h
Θ
(
x
i
)
−
y
i
]
x
j
i
=
m
1
i
=
1
∑
m
[
h
Θ
(
x
i
)
−
y
i
]
x
j
i
(
4
)
由于是要最小化损失函数,所以参数
Θ \Theta
Θ
按其负梯度方向来更新:其中
α \alpha
α
是学习率;
(5) Θ j ′ = Θ j − α σ J ( Θ ) σ Θ j = Θ j − α 1 m ∑ m i = 1 ( y i − h Θ ( x i ) ) x j i \Theta_{j}^{\prime}=\Theta_{j}-\alpha\frac{\sigma J(\Theta)}{\sigma \Theta_{j}}=\Theta_{j}-\alpha \frac{1}{m} \sum_{m}^{i=1}\left(y^{i}-h_{\Theta}\left(x^{i}\right)\right) x_{j}^{i}\tag{5}
Θ
j
′
=
Θ
j
−
α
σ
Θ
j
σ
J
(
Θ
)
=
Θ
j
−
α
m
1
m
∑
i
=
1
(
y
i
−
h
Θ
(
x
i
)
)
x
j
i
(
5
)
三、BGD:批量梯度下降法
BGD(Batch gradient descent)批量梯度下降法:每次迭代使用所有的样本!
- BGD(批量梯度下降):更新每一参数都用所有样本更新,m=all,更新100次遍历所有数据100次
- 优点 :每次迭代都需要把所有样本都送入,这样的好处是每次迭代都顾及了全部的样本,能保证做的是全局最优化。
- 缺点 :由于这种方法是在一次更新中,就对整个数据集计算梯度,所以计算起来非常慢,遇到很大量的数据集也会非常棘手,而且不能投入新数据实时更新模型。
3.1、python代码实现:
import random
import matplotlib.pyplot as plt
#用y = Θ1*x1 + Θ2*x2来拟合下面的输入和输出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ参数初始化
loss = 10 #loss先定义一个数,为了进入循环迭代
lr = 0.01 #学习率(步长)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次数
error = 0 #损失值
iter_count = 0 #当前迭代次数
err1=[0,0,0,0] #求Θ1梯度的中间变量1
err2=[0,0,0,0] #求Θ2梯度的中间变量2
loss_curve = []
iter_curve= []
while loss > eps and iter_count < max_iters: #迭代条件
loss = 0
err1sum = 0
err2sum = 0
for i in range(4): #每次迭代所有的样本都进行训练
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
err1[i] = (pred_y - output_y[i]) * input_x[i][0]
err1sum += err1[i]
err2[i] = (pred_y - output_y[i]) * input_x[i][1]
err2sum += err2[i]
theta[0] = theta[0] - lr * err1sum / 4 # 对应公式(5)
theta[1] = theta[1] - lr * err2sum / 4 # 对应公式(5)
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
error = (1 / (2 * 4)) * (pred_y - output_y[i])**2 #损失值
loss = loss + error #总损失值
loss_curve.append(loss)
iter_curve.append(iter_count)
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
plt.plot(iter_curve, loss_curve, linewidth=3.0, label = ' loss value ')
plt.xlabel('iter_count')
plt.ylabel('loss value')
plt.legend(loc='upper right')
plt.show()
- 运行结果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/bgd.py
iter_count: 1 loss: 77.30604843750001
iter_count: 2 loss: 51.92155212969726
iter_count: 3 loss: 34.93305894023124
iter_count: 4 loss: 23.55788142744176
iter_count: 5 loss: 15.93612681814619
iter_count: 6 loss: 10.82463082981111
iter_count: 7 loss: 7.392409964576455
iter_count: 8 loss: 5.083959396333441
iter_count: 9 loss: 3.527892640305925
iter_count: 10 loss: 2.47588355411119
iter_count: 11 loss: 1.7618596986778925
iter_count: 12 loss: 1.2747299205213376
iter_count: 13 loss: 0.9401570394404473
iter_count: 14 loss: 0.7083755280604223
iter_count: 15 loss: 0.5460491800945114
iter_count: 16 loss: 0.4308288035166698
iter_count: 17 loss: 0.3477144254831318
iter_count: 18 loss: 0.28662352563217885
iter_count: 19 loss: 0.2407653218442937
iter_count: 20 loss: 0.20555379684871375
iter_count: 21 loss: 0.1778808149288243
iter_count: 22 loss: 0.1556299310282455
iter_count: 23 loss: 0.13735109414029084
iter_count: 24 loss: 0.12204291187159358
iter_count: 25 loss: 0.10900683012823309
iter_count: 26 loss: 0.09774940254077526
iter_count: 27 loss: 0.08791672434030241
iter_count: 28 loss: 0.07925038532524892
iter_count: 29 loss: 0.07155782538929152
iter_count: 30 loss: 0.06469233462112152
iter_count: 31 loss: 0.058539516396843426
iter_count: 32 loss: 0.053008085575962184
iter_count: 33 loss: 0.04802357824982824
iter_count: 34 loss: 0.04352402034121233
iter_count: 35 loss: 0.03945691714912546
iter_count: 36 loss: 0.03577713642499401
iter_count: 37 loss: 0.03244539834116959
iter_count: 38 loss: 0.029427179885821626
iter_count: 39 loss: 0.026691904238495462
iter_count: 40 loss: 0.024212327873428634
iter_count: 41 loss: 0.021964066404403914
iter_count: 42 loss: 0.019925219138275187
iter_count: 43 loss: 0.01807606502771293
iter_count: 44 loss: 0.01639881126834203
iter_count: 45 loss: 0.014877381549280188
iter_count: 46 loss: 0.013497234860471212
iter_count: 47 loss: 0.012245208401310822
iter_count: 48 loss: 0.011109379935017922
iter_count: 49 loss: 0.010078946167772749
iter_count: 50 loss: 0.009144114585438616
iter_count: 51 loss: 0.008296006777324977
iter_count: 52 loss: 0.007526571698840795
iter_count: 53 loss: 0.0068285076286013985
iter_count: 54 loss: 0.006195191797999068
iter_count: 55 loss: 0.005620616837548782
iter_count: 56 loss: 0.005099333311490803
iter_count: 57 loss: 0.004626397711632085
iter_count: 58 loss: 0.0041973253611029835
iter_count: 59 loss: 0.003808047743918568
iter_count: 60 loss: 0.003454873830664867
iter_count: 61 loss: 0.003134455016855624
iter_count: 62 loss: 0.0028437533303267305
iter_count: 63 loss: 0.002580012598749813
iter_count: 64 loss: 0.002340732298900542
iter_count: 65 loss: 0.0021236438364046354
iter_count: 66 loss: 0.0019266890288379284
iter_count: 67 loss: 0.0017480005866892083
iter_count: 68 loss: 0.001585884406131675
iter_count: 69 loss: 0.0014388035050594725
iter_count: 70 loss: 0.001305363449643687
iter_count: 71 loss: 0.0011842991329443218
iter_count: 72 loss: 0.0010744627800306976
iter_count: 73 loss: 0.0009748130657567269
iter_count: 74 loss: 0.0008844052419326704
iter_count: 75 loss: 0.0008023821802307486
iter_count: 76 loss: 0.0007279662458677689
iter_count: 77 loss: 0.000660451924992406
iter_count: 78 loss: 0.0005991991358649235
iter_count: 79 loss: 0.0005436271604005447
iter_count: 80 loss: 0.0004932091385371781
iter_count: 81 loss: 0.0004474670732236192
iter_count: 82 loss: 0.0004059672986700165
iter_count: 83 loss: 0.0003683163688928273
iter_count: 84 loss: 0.0003341573275743529
iter_count: 85 loss: 0.00030316632387097236
iter_count: 86 loss: 0.0002750495420852944
iter_count: 87 loss: 0.0002495404160923516
iter_count: 88 loss: 0.00022639710211123758
iter_count: 89 loss: 0.00020540018586152943
iter_count: 90 loss: 0.00018635060236680122
iter_count: 91 loss: 0.0001690677486837031
iter_count: 92 loss: 0.0001533877716635301
iter_count: 93 loss: 0.000139162014513147
iter_count: 94 loss: 0.00012625560742820846
iter_count: 95 loss: 0.00011454618893581613
iter_count: 96 loss: 0.00010392274582505133
iter_count: 97 loss: 9.428456066652548e-05
final theta: [3.0044552563214433, 3.9955447274498894]
final loss: 9.428456066652548e-05
final iter_count: 97
3.2、损失函数值曲线
- 可以发现下降趋势比较平稳!
3.3、BGD总结
这里我们只有4个样本,所以训练的时间不长。但是,如果面对数量巨大的样本量(如40万个),采取这种训练方式,所耗费的时间会非常长。Batch gradient descent 对于凸函数可以收敛到全局极小值,对于非凸函数可以收敛到局部极小值。
四、SGD:随机梯度下降法
SGD(Stochastic gradientdescent)随机梯度下降法:每次迭代使用一个样本!
- 针对BGD算法训练速度过慢的缺点,提出了SGD算法,普通的BGD算法是每次迭代把所有样本都过一遍,每训练一组样本就把梯度更新一次。而SGD算法是从样本中随机抽出一组,训练后按梯度更新一次,然后再抽取一组,再更新一次,在样本量及其大的情况下,可能不用训练完所有的样本就可以获得一个损失值在可接受范围之内的模型了。
- SGD(随机梯度下降):更新每一参数都随机选择一个样本更新, m = 1 m=1 m = 1 。
4.1、python代码实现:
- 代码如下:
import random
#用y = Θ1*x1 + Θ2*x2来拟合下面的输入和输出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ参数初始化
loss = 10 #loss先定义一个数,为了进入循环迭代
lr = 0.01 #学习率(步长)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次数
error = 0 #损失值
iter_count = 0 #当前迭代次数
while loss > eps and iter_count < max_iters: #迭代条件
loss = 0
# 0、1、2、3随意一个,包括3
i = random.randint(0, 3) #每次迭代在input_x中随机选取一组样本进行权重的更新
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
theta[0] = theta[0] - lr * (pred_y - output_y[i]) * input_x[i][0]
theta[1] = theta[1] - lr * (pred_y - output_y[i]) * input_x[i][1]
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
error = 0.5 * (pred_y - output_y[i])**2 #损失值
loss = loss + error #总损失值
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
- 运行结果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/sgd.py
iter_count: 1 loss: 260.8896
iter_count: 2 loss: 204.61415423999998
iter_count: 3 loss: 157.22159618457601
iter_count: 4 loss: 124.88134971162623
iter_count: 5 loss: 111.84245738309359
iter_count: 6 loss: 103.09215224055406
iter_count: 7 loss: 88.76429814860212
iter_count: 8 loss: 60.029919860818374
iter_count: 9 loss: 59.32261643585773
iter_count: 10 loss: 58.80735498266485
iter_count: 11 loss: 53.54691203311109
iter_count: 12 loss: 49.70774185517555
iter_count: 13 loss: 27.482152507395202
iter_count: 14 loss: 29.053473973655233
iter_count: 15 loss: 17.241137945844635
iter_count: 16 loss: 18.27205019986861
iter_count: 17 loss: 18.279734003553486
iter_count: 18 loss: 13.291935986982754
iter_count: 19 loss: 14.080979216135848
iter_count: 20 loss: 14.91676557815299
iter_count: 21 loss: 10.592300543806674
iter_count: 22 loss: 11.235546322899923
iter_count: 23 loss: 11.836768171833583
iter_count: 24 loss: 11.485334037115793
iter_count: 25 loss: 6.810177856305382
iter_count: 26 loss: 4.6959930466488675
iter_count: 27 loss: 3.9105771662118256
iter_count: 28 loss: 3.444454636311172
iter_count: 29 loss: 3.0123913259123145
iter_count: 30 loss: 2.5679800277626557
iter_count: 31 loss: 2.3099925227435065
iter_count: 32 loss: 1.9590397523779994
iter_count: 33 loss: 1.9666361109780515
iter_count: 34 loss: 1.957111930033281
iter_count: 35 loss: 2.051533423257222
iter_count: 36 loss: 1.5311103292311716
iter_count: 37 loss: 1.6083965712408976
iter_count: 38 loss: 1.195994912123778
iter_count: 39 loss: 1.2294713235041943
iter_count: 40 loss: 0.9457370295093082
iter_count: 41 loss: 0.7388360036353944
iter_count: 42 loss: 0.7067997228867577
iter_count: 43 loss: 0.7302176054470183
iter_count: 44 loss: 0.7733523408616976
iter_count: 45 loss: 0.8194031381390366
iter_count: 46 loss: 0.4936013810958017
iter_count: 47 loss: 0.37792875276479576
iter_count: 48 loss: 0.3917742838122076
iter_count: 49 loss: 0.27063075933550573
iter_count: 50 loss: 0.21948139835860636
iter_count: 51 loss: 0.200739559185535
iter_count: 52 loss: 0.16909904922322447
iter_count: 53 loss: 0.14881528969674812
iter_count: 54 loss: 0.139454946004233
iter_count: 55 loss: 0.13791962027556026
iter_count: 56 loss: 0.13940575127970778
iter_count: 57 loss: 0.11836942697279389
iter_count: 58 loss: 0.11781711553567996
iter_count: 59 loss: 0.1193375232933355
iter_count: 60 loss: 0.12138112869067261
iter_count: 61 loss: 0.12332118692472657
iter_count: 62 loss: 0.12467078283377861
iter_count: 63 loss: 0.07858332982329701
iter_count: 64 loss: 0.07002100640810523
iter_count: 65 loss: 0.06684354978608165
iter_count: 66 loss: 0.06681954120255698
iter_count: 67 loss: 0.05813640578322409
iter_count: 68 loss: 0.04822861760481913
iter_count: 69 loss: 0.04278262637099636
iter_count: 70 loss: 0.041198736725976785
iter_count: 71 loss: 0.04122491647644537
iter_count: 72 loss: 0.041360454874579164
iter_count: 73 loss: 0.028430758340382445
iter_count: 74 loss: 0.02673444676702969
iter_count: 75 loss: 0.024607291352450787
iter_count: 76 loss: 0.023584031391060585
iter_count: 77 loss: 0.018509550739981902
iter_count: 78 loss: 0.017020377213216822
iter_count: 79 loss: 0.01639800972323979
iter_count: 80 loss: 0.016217979857842655
iter_count: 81 loss: 0.01173575386765937
iter_count: 82 loss: 0.011095127889579639
iter_count: 83 loss: 0.011025658424848761
iter_count: 84 loss: 0.011162231122708137
iter_count: 85 loss: 0.011281435021340219
iter_count: 86 loss: 0.007484744689449223
iter_count: 87 loss: 0.007443897686619277
iter_count: 88 loss: 0.00753805887938067
iter_count: 89 loss: 0.005266781977291559
iter_count: 90 loss: 0.005050442824531053
iter_count: 91 loss: 0.005061511925216842
iter_count: 92 loss: 0.005082620784454073
iter_count: 93 loss: 0.004233203826419785
iter_count: 94 loss: 0.0033139392411621234
iter_count: 95 loss: 0.0032913233079854493
iter_count: 96 loss: 0.0028481973056087625
iter_count: 97 loss: 0.002706631981691858
iter_count: 98 loss: 0.002492232603268663
iter_count: 99 loss: 0.002221730694417584
iter_count: 100 loss: 0.0018638132079459395
iter_count: 101 loss: 0.0018362099891547448
iter_count: 102 loss: 0.0016114266187654506
iter_count: 103 loss: 0.001557626575305546
iter_count: 104 loss: 0.0014751564038771846
iter_count: 105 loss: 0.0015177552847051115
iter_count: 106 loss: 0.0015527624369436643
iter_count: 107 loss: 0.0010676134405512782
iter_count: 108 loss: 0.0008875592185810756
iter_count: 109 loss: 0.0008989428657822404
iter_count: 110 loss: 0.0009062642743213555
iter_count: 111 loss: 0.0006607681707515327
iter_count: 112 loss: 0.0005677737130591256
iter_count: 113 loss: 0.0005661179132734488
iter_count: 114 loss: 0.0004906413380605921
iter_count: 115 loss: 0.0004485281635358341
iter_count: 116 loss: 0.0004277189872613536
iter_count: 117 loss: 0.0004272993211319688
iter_count: 118 loss: 0.00042831950236645467
iter_count: 119 loss: 0.0004337135792476141
iter_count: 120 loss: 0.0004413233162548959
iter_count: 121 loss: 0.0003468706133226126
iter_count: 122 loss: 0.0003488193444937763
iter_count: 123 loss: 0.00028928141012031104
iter_count: 124 loss: 0.00028583399647533843
iter_count: 125 loss: 0.0002901585029896304
iter_count: 126 loss: 0.0002939322848970706
iter_count: 127 loss: 0.00023730769207386048
iter_count: 128 loss: 0.0001779396656655188
iter_count: 129 loss: 0.00015978049098605478
iter_count: 130 loss: 0.00013564931442016493
iter_count: 131 loss: 0.0001361042816188513
iter_count: 132 loss: 0.0001153612600831637
iter_count: 133 loss: 0.0001132019901158441
iter_count: 134 loss: 0.00010928002673695943
iter_count: 135 loss: 0.00011264583765529378
iter_count: 136 loss: 8.89461228685356e-05
final theta: [3.002053708602476, 3.997626634178193]
final loss: 8.89461228685356e-05
final iter_count: 136
4.2、损失函数值曲线
- 可以发现下降趋势优点波折,没有刚才那样平稳(这里效果不是特别明显)!
4.3、SGD总结
随机梯度下降是通过每个样本来迭代更新一次,如果样本量很大的情况,那么可能只用其中部分的样本,就已经将theta迭代到最优解了,对比上面的批量梯度下降,迭代一次需要用到十几万训练样本,一次迭代不可能最优,如果迭代10次的话就需要遍历训练样本10次。缺点是SGD的噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。所以虽然训练速度快,但是准确度下降,并不是全局最优。虽然包含一定的随机性,但是从期望上来看,它是等于正确的导数的。
缺点:
-
SGD 因为更新比较频繁,会造成 cost function 有严重的震荡。
-
BGD 可以收敛到局部极小值,当然 SGD 的震荡可能会跳到更好的局部极小值处。
-
当我们稍微减小 learning rate,SGD 和 BGD 的收敛性是一样的。
五、MBGD:小批量梯度下降
MBGD(Mini-batch gradient descent)小批量梯度下降:每次迭代使用b组样本!
- SGD相对来说要快很多,但是也有存在问题,由于单个样本的训练可能会带来很多噪声, 使得SGD并不是每次迭代都向着整体最优化方向,因此在刚开始训练时可能收敛得很快,但是训练一段时间后就会变得很慢。在此基础上又提出了 小批量梯度下降法 ,它是每次从样本中随机抽取一小批进行训练,而不是一组。
5.1、python代码实现:
import random
import matplotlib.pyplot as plt
#用y = Θ1*x1 + Θ2*x2来拟合下面的输入和输出
#input1 1 2 5 4
#input2 4 5 1 2
#output 19 26 19 20
input_x = [[1, 4], [2, 5], [5, 1], [4, 2]]
output_y = [19, 26, 19, 20]
theta = [1, 1] #θ参数初始化
loss = 10 #loss先定义一个数,为了进入循环迭代
lr = 0.01 #学习率(步长)
eps =0.0001 #精度要求
max_iters = 10000 #最大迭代次数
error = 0 #损失值
iter_count = 0 #当前迭代次数
loss_curve = []
iter_curve= []
while loss > eps and iter_count < max_iters: #迭代条件
loss = 0
# 这里每次批量选取的是2个样本进行更新,另一个点是随机点+1的相邻点
# 0、1、2、3随意一个,包括3
i = random.randint(0, 3) # 随机抽取一组样本
j = (i + 1) % 4 # 抽取另一组样本,j=i+1
pred_y0 = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
pred_y1 = theta[0] * input_x[j][0] + theta[1] * input_x[j][1] # 预测值
theta[0] = theta[0] - lr * (1 / 2) * ((pred_y0 - output_y[i]) * input_x[i][0] + (pred_y1 - output_y[j]) * input_x[j][0])# 对应5式
theta[1] = theta[1] - lr * (1 / 2) * ((pred_y0 - output_y[i]) * input_x[i][1] + (pred_y1 - output_y[j]) * input_x[j][1])
for i in range(4):
pred_y = theta[0] * input_x[i][0] + theta[1] * input_x[i][1] # 预测值
error = (1/(2*2)) * (pred_y - output_y[i])**2 #损失值
loss = loss + error #总损失值
loss_curve.append(loss)
iter_curve.append(iter_count)
iter_count += 1
print('iter_count:', iter_count, 'loss:', loss)
print('final theta:', theta)
print('final loss:', loss)
print('final iter_count:', iter_count)
plt.plot(iter_curve, loss_curve, linewidth=3.0, label = ' SGD loss value ')
plt.xlabel('iter_count')
plt.ylabel('loss value')
plt.legend(loc='upper right')
plt.show()
- 运行结果:
C:\Anaconda3\envs\tf2\python.exe E:/Codes/MyCodes/TF2/TF2_6/MBGD.py
iter_count: 1 loss: 145.51273750000001
iter_count: 2 loss: 93.61212678531253
iter_count: 3 loss: 63.84553483319602
iter_count: 4 loss: 40.78188950700523
iter_count: 5 loss: 26.635296256279755
iter_count: 6 loss: 19.14453037224417
iter_count: 7 loss: 11.863074835686199
iter_count: 8 loss: 7.351282634603321
iter_count: 9 loss: 5.285122186891444
iter_count: 10 loss: 3.5150322864765666
iter_count: 11 loss: 2.5222674128094504
iter_count: 12 loss: 1.726271885060386
iter_count: 13 loss: 1.0782688839189605
iter_count: 14 loss: 0.7734545491420018
iter_count: 15 loss: 0.4847190164962192
iter_count: 16 loss: 0.30268382725301546
iter_count: 17 loss: 0.21710714208249052
iter_count: 18 loss: 0.15636548583181112
iter_count: 19 loss: 0.0951813297472256
iter_count: 20 loss: 0.06014949607215332
iter_count: 21 loss: 0.04555522775364396
iter_count: 22 loss: 0.029448135848294105
iter_count: 23 loss: 0.019291337709467643
iter_count: 24 loss: 0.012850369417312057
iter_count: 25 loss: 0.00873525233733187
iter_count: 26 loss: 0.005002257350216796
iter_count: 27 loss: 0.004675521915666117
iter_count: 28 loss: 0.0034010908555124775
iter_count: 29 loss: 0.003314771376617083
iter_count: 30 loss: 0.002003304782844371
iter_count: 31 loss: 0.002035716129121544
iter_count: 32 loss: 0.001600873625308816
iter_count: 33 loss: 0.0016385823787498255
iter_count: 34 loss: 0.001668187043426684
iter_count: 35 loss: 0.0010776974297530592
iter_count: 36 loss: 0.0009406763734116385
iter_count: 37 loss: 0.0007948887717540301
iter_count: 38 loss: 0.0007110442660078077
iter_count: 39 loss: 0.0007262540103407296
iter_count: 40 loss: 0.0007492868334222326
iter_count: 41 loss: 0.0007731591748400462
iter_count: 42 loss: 0.0005294763719741898
iter_count: 43 loss: 0.0005440182877813707
iter_count: 44 loss: 0.00045517250370294903
iter_count: 45 loss: 0.00040394528620568096
iter_count: 46 loss: 0.0003627723578104964
iter_count: 47 loss: 0.00036950515296137464
iter_count: 48 loss: 0.0003127529335059308
iter_count: 49 loss: 0.0002801383258107215
iter_count: 50 loss: 0.00024161953198902414
iter_count: 51 loss: 0.0002452223039598993
iter_count: 52 loss: 0.0001869776452880176
iter_count: 53 loss: 0.00018574447211436057
iter_count: 54 loss: 0.00015989716877477456
iter_count: 55 loss: 0.00013838733780820178
iter_count: 56 loss: 0.00012025272881104103
iter_count: 57 loss: 0.00012094637362213334
iter_count: 58 loss: 0.00012421504197512097
iter_count: 59 loss: 9.101480560449152e-05
final theta: [3.0025974709772676, 3.996606737971915]
final loss: 9.101480560449152e-05
final iter_count: 59
5.2、损失函数值曲线
5.3、MBGD总结
MBGD(小批量梯度下降):更新每一参数都选 m m m 个样本平均梯度更新, 1 < m < a l l 1<m<all 1 < m < a l l ;超参数 m m m 设定值: m m m 一般取值在 50 ~ 256 50~256 5 0 ~ 2 5 6
- MBGD 每一次利用一小批样本,即 m 个样本进行计算, 这样它可以降低参数更新时的方差,收敛更稳定 ,另一方面可以充分地 利用深度学习库中高度优化的矩阵操作来进行更有效的梯度计算。
缺点:(两大缺点)
- 不过 Mini-batch gradient descent 不能保证很好的收敛性,learning rate 如果选择的太小,收敛速度会很慢,如果太大,loss function 就会在极小值处不停地震荡甚至偏离。(有一种措施是先设定大一点的学习率,当两次迭代之间的变化低于某个阈值后,就减小 learning rate,不过这个阈值的设定需要提前写好,这样的话就不能够适应数据集的特点。) 对于非凸函数, 还要避免陷于局部极小值处,或者鞍点处, 因为鞍点周围的error是一样的,所有维度的梯度都接近于0,SGD 很容易被困在这里。( 会在鞍点或者局部最小点震荡跳动,因为在此点处,如果是训练集全集带入即BGD,则优化会停止不动,如果是mini-batch或者SGD,每次找到的梯度都是不同的,就会发生震荡,来回跳动。 )
- SGD对所有参数更新时应用同样的 learning rate,如果我们的数据是稀疏的, 我们更希望对出现频率低的特征进行大一点的更新。LR会随着更新的次数逐渐变小。
鞍点就是:一个光滑函数的鞍点邻域的曲线,曲面,或超曲面,都位于这点的切线的不同边。例如这个二维图形,像个马鞍:在x-轴方向往上曲,在y-轴方向往下曲,鞍点就是(0,0)。
参考文章
参考了一下作者的文章,在这里表示感谢!
- https://blog.csdn.net/kwame211/article/details/80364079
- https://www.cnblogs.com/guoyaohua/p/8542554.html