Kerasのmodel.fitが返すhistoryをpandasで保存して図のplotまで

model.fit() や model.fit_generator() はコールバックのHistoryを返す。

これを保存しておくとエポック毎のAccuracyやLossの推移が見れて面白いので、生データの保存と可視化を行いたい。

# 生データの保存
history = model.fit(...)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('history.csv')

# 可視化
plt.figure()
hist_df[['accuracy', 'val_accuracy']].plot()
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.savefig(os.path.join(self.result_dir, 'acc.png'))
plt.close()

plt.figure()
hist_df[['loss', 'val_loss']].plot()
plt.ylabel('loss')
plt.xlabel('epoch')
plt.savefig(os.path.join(self.result_dir, 'loss.png'))
plt.close()

matplotlibで色々しなくても、pandasから直接plod出来るので楽ちん。

pandasでplotする時は二次元配列でラベルを指定することに注意。

参考

FAQ - Keras Documentation
pandasのplotメソッドでグラフを作成しデータを可視化 | note.nkmk.me
pandas.Series, pandas.DataFrameのメソッドとしてplot()がある。Pythonのグラフ描画ライブラリMatplotlibのラッパーで、簡単にグラフを作成できる。pandas.DataFrame.plot — pandas 0.22.0 documentation Visualizatio...
pandas.DataFrame.plot — pandas 1.0.4 documentation

コメント

タイトルとURLをコピーしました