85
京東網上商城
純幹貨 | 機器學習中梯度下降法的分類及對比分析(附源碼)



α
線性回歸
其中, 是參數,
是輸入特征。為了求解線性回歸模型,需要找到合適的參數使擬合函數能夠更好地適合模型,然後使用梯度下降最小化代價函數J(θ)。
代價函數:
下麵的偽代碼能夠解釋其詳細原理:
1. 初始化參數值
2. 迭代更新這些參數使目標函數J(θ)不斷變小。
使用整個數據集()去計算代價函數的梯度批量梯度下降法會很慢

3. 然後重複上麵每一步;
4. 這意味著需要較長的時間才能收斂;
批量梯度下降法不適合大數據集。下麵的Python代碼實現了批量梯度下降法:
1. import numpy as np
2. import random
3. def gradient_descent(alpha, x, y, ep=0.0001, max_iter=10000):
4. converged = False
5. iter = 0
6. m = x.shape[0] # number of samples
7.
8. # initial theta
9. t0 = np.random.random(x.shape[1])
10. t1 = np.random.random(x.shape[1])
11.
12. # total error, J(theta)
13. J = sum([(t0 + t1*x[i] - y[i])**2 for i in range(m)])
14.
15. # Iterate Loop
16. while not converged:
17. # for each training sample, compute the gradient (d/d_theta j(theta))
18. grad0 = 1.0/m * sum([(t0 + t1*x[i] - y[i]) for i in range(m)])
19. grad1 = 1.0/m * sum([(t0 + t1*x[i] - y[i])*x[i] for i in range(m)])
20. # update the theta_temp
21. temp0 = t0 - alpha * grad0
22. temp1 = t1 - alpha * grad1
23.
24. # update theta
25. t0 = temp0
26. t1 = temp1
27.
28. # mean squared error
29. e = sum( [ (t0 + t1*x[i] - y[i])**2 for i in range(m)] )
30.
31. if abs(J-e) <= ep:
32. print 'Converged, iterations: ', iter, '!!!'
33. converged = True
34.
35. J = e # update error
36. iter += 1 # update iter
37.
38. if iter == max_iter:
39. print 'Max interactions exceeded!'
40. converged = True
41.
42. return t0,t1
批量梯度下降法被證明是一個較慢的算法,所以,我們可以選擇隨機梯度下降法達到更快的計算。隨機梯度下降法的第一步是隨機化整個數據集。在每次迭代僅選擇一個訓練樣本去計算代價函數的梯度,然後更新參數。即使是大規模數據集,隨機梯度下降法也會很快收斂。隨機梯度下降法得到結果的準確性可能不會是最好的,但是計算結果的速度很快。在隨機化初始參數之後,使用如下方法計算代價函數的梯度:

如下為隨機梯度下降法的偽碼:
如下圖所示,隨機梯度下降法不像批量梯度下降法那樣收斂,而是遊走到接近全局最小值的區域終止。

小批量梯度下降法不是使用完整數據集,在每次迭代中僅使用m個訓練樣本去計算代價函數的梯度

這裏b表示一批訓練樣本的個數,m是訓練樣本的總數。
1. 實現該算法時,同時更新參數。
2. 學習速率α(也稱之為步長)。如果α過大,算法可能不會收斂;如果α比較小,就會很容易收斂。
3. 檢查梯度下降法的工作過程。畫出迭代次數與每次迭代後代價函數值的關係圖,這能夠幫助你了解梯度下降法是否取得了好的效果。每次迭代後J(θ)應該降低,多次迭代後應該趨於收斂。
4. 不同的學習速率在梯度下降法中的效果
本文詳細介紹了不同類型的梯度下降法。這些算法已經被廣泛應用於神經網絡。下麵的圖詳細展示了3種梯度下降法的比較。
以上為譯文
本文由北郵@愛可可-愛生活 老師推薦,阿裏雲雲棲社區組織翻譯。
文章原標題《3 Types of Gradient Descent Algorithms for Small & Large Data Sets》,由HackerEarth blog發布。
譯者:李烽 ;審校:
文章為簡譯,更為詳細的內容,請查看原文。中文譯製文檔下載見此。
最後更新:2017-04-09 18:03:54