Kerasのmodel.fit()で出力されるlossやmetricをloggingする

kerasでmodel.fit()すると、ターミナルに出力されるエポック毎のlossやmetricをPythonでloggingしたい。

Epoch 00001: LearningRateScheduler setting learning rate to 0.0001.
Epoch 1/100
7/7 [==============================] - 1s 28ms/step - loss: 9.5452 - accuracy: 0.1074 - val_loss: 8.9494 - val_accuracy: 0.1171 - lr: 1.0000e-04

Epoch 00002: LearningRateScheduler setting learning rate to 0.0005999500000000001.
Epoch 2/100
7/7 [==============================] - 0s 5ms/step - loss: 8.0551 - accuracy: 0.1624 - val_loss: 6.5261 - val_accuracy: 0.2293 - lr: 5.9995e-04

Epoch 00003: LearningRateScheduler setting learning rate to 0.0010999000000000002.
Epoch 3/100
7/7 [==============================] - 0s 5ms/step - loss: 5.9735 - accuracy: 0.2918 - val_loss: 4.3467 - val_accuracy: 0.3659 - lr: 0.0011

tf.keras.callbacks.CSVLoggerというCallbackがあるが、csvとして別に保存するのではなく、logファイル一つで確認したい。

tf.keras.callbacks.CSVLogger  |  TensorFlow v2.16.1
Callback that streams epoch results to a CSV file.

model.fit()の引数を見てみる。

tf.keras.Model  |  TensorFlow v2.16.1
A model grouping layers into an object with training/inference features.
fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose='auto',
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False
)

model.summary()と違って、出力先をコントロール出来そうなパラメータは無さそう。

そのため、カスタムコールバックを書いてloggingする必要がある。

Writing your own callbacks  |  TensorFlow Core

今回はlossやmetricを取得したいので、エポックの終わりだけ処理すればいい。

on_epoch_endをオーバーライドする。

エポックレベルのメソッド(トレーニングのみ)

on_epoch_end(self, epoch, logs=None)
トレーニング中、エポックの最後に呼び出されます。

パラメータはlogsに格納されている。

logs ディクショナリを使用する


logs ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。

loggerはグローバルで宣言しておき、カスタムコールバックを書いていく。

logsは辞書なので、key, valueをループで回して全ての値を取得する。

一番長い名前の文字数をmax_lenとし、loggerに記録する時に右寄せして見た目を良くする。

import logging

logger = logging.getLogger(__name__)


class TrainingLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super(TrainingLogger, self).__init__()

    def on_epoch_end(self, epoch, logs: dict = None):
        logger.info(f'Epoch {(epoch + 1):05d}')

        max_len = max([len(key) for key in logs.keys()])
        for key, value in logs.items():
            message = f'{key:>{max_len + 1}}: {value:.5f}'
            logger.info(message)

model.fit()してみる。

Epoch 00001: LearningRateScheduler setting learning rate to 0.0001.
Epoch 1/100

2022-03-31 17:21:19,128 - Epoch 00001
2022-03-31 17:21:19,128 -          loss: 11.54617
2022-03-31 17:21:19,129 -      accuracy: 0.11233
2022-03-31 17:21:19,129 -      val_loss: 11.33803
2022-03-31 17:21:19,130 -  val_accuracy: 0.06829
2022-03-31 17:21:19,131 -            lr: 0.00010

7/7 [==============================] - 1s 47ms/step - loss: 11.5462 - accuracy: 0.1123 - val_loss: 11.3380 - val_accuracy: 0.0683 - lr: 1.0000e-04


Epoch 00002: LearningRateScheduler setting learning rate to 0.0005999500000000001.
Epoch 2/100

2022-03-31 17:21:19,164 - Epoch 00002
2022-03-31 17:21:19,164 -          loss: 10.15541
2022-03-31 17:21:19,165 -      accuracy: 0.13065
2022-03-31 17:21:19,166 -      val_loss: 10.06604
2022-03-31 17:21:19,166 -  val_accuracy: 0.13171
2022-03-31 17:21:19,167 -            lr: 0.00060

無事にターミナルに出力された。

ログファイルを見てみる。

2022-03-31 17:21:19,128 - Epoch 00001
2022-03-31 17:21:19,128 -          loss: 11.54617
2022-03-31 17:21:19,129 -      accuracy: 0.11233
2022-03-31 17:21:19,129 -      val_loss: 11.33803
2022-03-31 17:21:19,130 -  val_accuracy: 0.06829
2022-03-31 17:21:19,131 -            lr: 0.00010
2022-03-31 17:21:19,164 - Epoch 00002
2022-03-31 17:21:19,164 -          loss: 10.15541
2022-03-31 17:21:19,165 -      accuracy: 0.13065
2022-03-31 17:21:19,166 -      val_loss: 10.06604
2022-03-31 17:21:19,166 -  val_accuracy: 0.13171
2022-03-31 17:21:19,167 -            lr: 0.00060

プログレスバーは記録されなかったが、カスタムコールバックに書いたことは保存されていた。

これで、model.fit()のlossやmetricをloggingすることが出来た。

コメント

タイトルとURLをコピーしました