Kerasのmodel.summary()をloggingする

Pythonのloggingを使ってきた時、kerasのmode.summary()をログに保存したくなった。

しかし、そのままloggerの引数に与えてもNoneが帰ってきてログに記録されなかった。

import logging

logger = logging.getLogger(__name__)

logger.info(model.summary())

> 2022-03-31 05:19:17,741 - None

model.summary()の引数を確認してみる。

tf.keras.Model  |  TensorFlow v2.16.1
A model grouping layers into an object with training/inference features.
summary(
    line_length=None,
    positions=None,
    print_fn=None,
    expand_nested=False,
    show_trainable=False
)

print_fnという引数がある。

print_fn

Print function to use. Defaults to print. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary.

ソースコードを見てみる。

  def summary(self,
              line_length=None,
              positions=None,
              print_fn=None,
              expand_nested=False,
              show_trainable=False):
    """Prints a string summary of the network.
    Args:
        line_length: Total length of printed lines
            (e.g. set this to adapt the display to different
            terminal window sizes).
        positions: Relative or absolute positions of log elements
            in each line. If not provided,
            defaults to `[.33, .55, .67, 1.]`.
        print_fn: Print function to use. Defaults to `print`.
            It will be called on each line of the summary.
            You can set it to a custom function
            in order to capture the string summary.
        expand_nested: Whether to expand the nested models.
            If not provided, defaults to `False`.
        show_trainable: Whether to show if a layer is trainable.
            If not provided, defaults to `False`.
    Raises:
        ValueError: if `summary()` is called before the model is built.
    """

print_fnは関数を受け取るようなので、引数としてロガーを設定すれば良さそう。

model.summary(print_fn=lambda x: logger.info(x))


2022-03-31 06:49:57,082 - Model: "model"
2022-03-31 06:49:57,082 - __________________________________________________________________________________________________
2022-03-31 06:49:57,082 -  Layer (type)                   Output Shape         Param #     Connected to
2022-03-31 06:49:57,082 - ==================================================================================================
2022-03-31 06:49:57,082 -  input_ids (InputLayer)         [(None, 136)]        0           []
2022-03-31 06:49:57,082 -
2022-03-31 06:49:57,082 -  attention_mask (InputLayer)    [(None, 136)]        0           []
2022-03-31 06:49:57,082 -
2022-03-31 06:49:57,087 -  tf_deberta_v2_model (TFDeberta  TFBaseModelOutput(l  70682112   ['input_ids[0][0]',
2022-03-31 06:49:57,087 -  V2Model)                       ast_hidden_state=(N               'attention_mask[0][0]']
2022-03-31 06:49:57,087 -                                 one, 136, 384),
2022-03-31 06:49:57,088 -                                  hidden_states=None
2022-03-31 06:49:57,088 -                                 , attentions=None)
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 -  global_average_pooling1d (Glob  (None, 384)         0           ['tf_deberta_v2_model[0][0]']
2022-03-31 06:49:57,088 -  alAveragePooling1D)
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 -  output (Dense)                 (None, 1)            385         ['global_average_pooling1d[0][0]'
2022-03-31 06:49:57,088 -                                                                  ]
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 - ==================================================================================================
2022-03-31 06:49:57,093 - Total params: 70,682,497
2022-03-31 06:49:57,093 - Trainable params: 70,682,497
2022-03-31 06:49:57,093 - Non-trainable params: 0
2022-03-31 06:49:57,093 - __________________________________________________________________________________________________

model.summary()をlogging出来るようになった。

参考

Attention Required! | Cloudflare

コメント

タイトルとURLをコピーしました