pickle を使った、学習済みモデルの保存・読み出し方法

Posted on 2019/10/26 in 機械学習 , Updated on: 2019/10/26

はじめに

機械学習ライブラリで学習したモデルの保存をすることで、学習済みモデルの使い回しが可能となる。サンプル数の多いデータにおいて、KFold で各Fold ごとに学習したモデルを保存 -> 削除、テストデータ推論時に、あらためて保存したモデルを読み出すというような流れを実行することでメモリの節約ができる。今回は、Python 標準ライブラリであるに pickle を使用して、LightGBM での学習済みモデルの保存・読み出し方法紹介する。

pickle は、オブジェクトをシリアライズ化(バイト列などの表現に変換)、またはデシリアライズ(バイト列から元のオブジェクトに復元)を実行してくれる。

学習モデルを作成

今回は、例として scikit-learn から breast_cancer データを用いて分類問題モデルを作成する。

In [4]:
# load libraries
import pickle

import lightgbm as lgb
from sklearn import datasets

# breast_cancer データの読み出し
cancer = datasets.load_breast_cancer()

# train データを準備
X_train, y_train = cancer.data, cancer.target

データが準備できたので、LightGBM の分類器を作成し、train データで学習する。

In [5]:
clf = lgb.LGBMClassifier()
clf.fit(X_train, y_train)
Out[5]:
LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
        importance_type='split', learning_rate=0.1, max_depth=-1,
        min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,
        n_estimators=100, n_jobs=-1, num_leaves=31, objective=None,
        random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
        subsample=1.0, subsample_for_bin=200000, subsample_freq=0)

これで、clf には、パラメータ(今回は特に何も設定しいないのでデフォルトパラメータ)および、上記のデータで学習した学習済みの情報が含まれている。これを pickle モジュールを用いて、保存する。

モデルを保存する

学習済みモデル clfpickle.dump で保存する。下記例では、カレントディレクトリに trained_model.pkl ファイルが保存される。保存したファイルを削除する。

In [6]:
file = 'trained_model.pkl'
pickle.dump(clf, open(file, 'wb'))

# 学習済みモデルを削除
del clf

モデルを読み出す

同じく pickle.load で保存したモデル読み出す。clf として読み出したモデルには、先ほどの分類器モデルの情報が含まれていることがわかる。

In [7]:
clf = pickle.load(open('trained_model.pkl', 'rb'))
    
clf
Out[7]:
LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
        importance_type='split', learning_rate=0.1, max_depth=-1,
        min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,
        n_estimators=100, n_jobs=-1, num_leaves=31, objective=None,
        random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
        subsample=1.0, subsample_for_bin=200000, subsample_freq=0)

もちろん下記のように、この学習済みモデルを未知のデータへの予測に使用することができる。

In [ ]:
clf.predict(X_test)