深度學習小技巧(二):如何保存和恢複scikit-learn訓練的模型
更多深度文章,請關注雲計算頻道:https://yq.aliyun.com/cloud
深度學習小技巧(一):如何保存和恢複TensorFlow訓練的模型
在許多情況下,在使用scikit學習庫的同時,你需要將預測模型保存到文件中,然後在使用它們的時候還原它們,以便重複使用以前的工作。比如在新數據上測試模型,比較多個模型的優劣。這種保存過程也稱為對象序列化——表示具有字節流的對象,以便將其存儲在磁盤上,它可以通過網絡發送或保存到數據庫,而其恢複的過程被稱為反序列化。在本文中,我們將在Python和scikit學習中看到三種可能的方法,而且每種都有其優點和缺點。
1.保存和恢複模型的工具
我們第一個介紹的工具是Pickle,用於對象(de)序列化的標準Python工具。之後,我們會介紹Joblib庫,它提供了容易(de)序列化方法,其中包含了大數據數組的對象,最後我們會介紹一種手動方法來保存和恢複JSON對象(JavaScript Object Notation)。這些方法都不能代表最佳解決方案,但是可以根據項目的需要選擇合適的方案。
2.模型初始化
首先,我們要創建一個scikit學習模型。在我們的例子中,我們將使用Logistic回歸模型和Iris數據集。我們導入所需的庫,並且加載數據,並將其拆分為訓練集和測試集。
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load and split data
data = load_iris()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(data.data, data.target, test_size=0.3, random_state=4)
現在讓我們用一些非默認參數來創建模型,並用訓練數據來“喂養”它。我們假設你先前已經找到了模型的最優參數,即產生最高估計精度的參數。
# Create a model
model = LogisticRegression(C=0.1,
max_iter=20,
fit_intercept=True,
n_jobs=3,
solver='liblinear')
model.fit(Xtrain, Ytrain)
這是我們產生的模型:
LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True,
intercept_scaling=1, max_iter=20, multi_class='ovr', n_jobs=3,
penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
verbose=0, warm_start=False)
使用該fit
方法,模型已經學習了存儲在其中的係數model.coef_
。目標是將模型的參數和係數保存到文件中,因此你不需要再次對新數據重複模型訓練和參數優化的步驟。
3.Pickle模塊
在以下幾行代碼中,我們將上一步中創建的模型保存到文件中,然後作為一個新對象加載pickled_model
。然後使用加載的模型計算準確度分數,並對新的未見(測試)數據進行預測結果。
import pickle
#
# Create your model here (same as above)
#
# Save to file in the current working directory
pkl_filename = "pickle_model.pkl"
with open(pkl_filename, 'wb') as file:
pickle.dump(model, file)
# Load from file
with open(pkl_filename, 'rb') as file:
pickle_model = pickle.load(file)
# Calculate the accuracy score and predict target values
score = pickle_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)
運行此代碼應該會產生你的預測分數,並通過Pickle保存模型:
$ python save_model_pickle.py
Test score: 91.11 %
使用Pickle來保存和恢複學習模型的好處在於它很快,並且你可以用兩行代碼完成。如果你已經對訓練數據上的模型參數進行了優化,那麼這是非常有用的,因此你不需要重複此步驟。不管如何,它都不保存測試結果和任何數據。但仍然可以保存多個對象的元組或列表(並記住哪個對象在哪裏),如下所示:
tuple_objects = (model, Xtrain, Ytrain, score)
# Save tuple
pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb'))
# Restore tuple
pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))
3.Joblib模塊
Joblib庫它的目的是替代Pickle,用於包含大數據的對象。我們將重複與Pickle一樣的保存和恢複過程。
from sklearn.externals import joblib
# Save to file in the current working directory
joblib_file = "joblib_model.pkl"
joblib.dump(model, joblib_file)
# Load from file
joblib_model = joblib.load(joblib_file)
# Calculate the accuracy and predictions
score = joblib_model.score(Xtest, Ytest)
print("Test score: {0:.2f} %".format(100 * score))
Ypredict = pickle_model.predict(Xtest)
$ python save_model_joblib.py
Test score: 91.11 %
從示例中可以看出,與Pickle相比,Joblib庫提供了一個簡單的工作流程。雖然Pickle要求將文件對象作為參數傳遞,但是Joblib可與文件對象和字符串文件名一起使用。如果你的模型包含大量數據,則每個數組將存儲在單獨的文件中,但整體的保存和恢複過程將保持不變。Joblib還允許使用不同的壓縮方法,如“zlib”,“gzip”,“bz2”和不同的壓縮級別。
4.手動保存並還原到JSON
根據你的項目,很多時候你會發現Pickle和Joblib都不是合適的解決方案。其中一些原因將在兼容性問題部分中稍後討論。無論何時要想完全控製保存和恢複過程,最好的方法是手動構建自己的功能。
以下顯示了使用JSON手動保存和恢複對象的示例。這種方法允許我們選擇需要保存的數據,例如模型參數,係數,訓練數據以及我們需要的任何其他數據。
由於我們想將所有這些數據保存在一個對象中,所以一個可能的方法是創建一個繼承我們的示例中的模型類的新類LogisticRegression
。這個新類被MyLogReg
調用,然後分別實現save_json
和load_json
的
方法以保存和恢複JSON文件。
為簡單起見,我們將隻保存三個模型參數和訓練數據。我們可以用這種方法存儲一些額外的數據,例如訓練集上的交叉驗證分數,測試數據,測試數據的準確度等等。
import json
import numpy as np
class MyLogReg(LogisticRegression):
# Override the class constructor
def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
self.X_train = X_train
self.Y_train = Y_train
# A method for saving object data to JSON file
def save_json(self, filepath):
dict_ = {}
dict_['C'] = self.C
dict_['max_iter'] = self.max_iter
dict_['solver'] = self.solver
dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
# Creat json and save to file
json_txt = json.dumps(dict_, indent=4)
with open(filepath, 'w') as file:
file.write(json_txt)
# A method for loading data from JSON file
def load_json(self, filepath):
with open(filepath, 'r') as file:
dict_ = json.load(file)
self.C = dict_['C']
self.max_iter = dict_['max_iter']
self.solver = dict_['solver']
self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None
現在我們來試一試MyLogReg
。首先我們創建一個對象mylogreg
,將訓練數據傳遞給它,並將其保存到文件中。然後我們創建一個新對象json_mylogreg
,
並調用該load_json
方法從文件加載數據。
filepath = "mylogreg.json"
# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)
# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg
打印出新的對象,我們可以根據需要來查看我們的參數和訓練數據。
MyLogReg(C=1.0,
X_train=array([[ 4.3, 3. , 1.1, 0.1],
[ 5.7, 4.4, 1.5, 0.4],
...,
[ 7.2, 3. , 5.8, 1.6],
[ 7.7, 2.8, 6.7, 2. ]]),
Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
由於使用JSON的數據序列化實際上是將對象保存為字符串格式,而不是字節流,所以'mylogreg.json'文件可以使用文本編輯器打開和修改。雖然這種方法對開發人員來說很方便,但是由於入侵者可以查看和修改JSON文件的內容,因此安全性較低。此外,這種方法更適合於具有少量實例變量的對象,例如scikit-learn模型,因為任何添加新變量都需要在保存和恢複方法中進行更改。
5.兼容性問題
盡管到目前為止,每個工具的優點和缺點已被介紹,但Pickle和Joblib工具的最大缺點可能是其與不同型號的Python版本的兼容性。
5.1:Python版本的兼容性——兩種工具的文檔都指出,不建議(de)在不同的Python版本之間對對象進行序列化,盡管它可能在低級的版本更改中起作用。
5.2:模型兼容性——最常見的錯誤之一是使用Pickle或Joblib保存模型,然後在嚐試從文件還原之前更改模型。模型的內部結構需要在保存和重新加載之間保持不變。
Pickle和Joblib的最後一個問題與安全性有關。這兩種工具都可能包含惡意代碼,因此不建議從不受信任或未經身份驗證的源代碼。
6.結論
在這篇文章中,我們描述了三種保存和恢複scikit學習模型的工具。Pickle和Joblib庫可以快速方便地使用,但是在不同的Python版本和學習模型的變化中存在兼容性問題。另一方麵,手動方法更難實現,需要在模型結構發生任何變化中進行修改,但在另一方麵,它可以輕鬆地適應各種需求,並且沒有任何兼容性問題。
作者信息
作者:Mihajlo Pavloski,數據科學與機器學習的愛好者,博士生。
本文由阿裏雲雲棲社區組織翻譯。
文章原標題《TensorFlow : Save and Restore Models》
作者:Mihajlo Pavloski 譯者:虎說八道,審閱:
文章為簡譯,更為詳細的內容,請查看原文
最後更新:2017-10-30 16:34:08