Pythonのloggingを使ってきた時、kerasのmode.summary()をログに保存したくなった。
しかし、そのままloggerの引数に与えてもNoneが帰ってきてログに記録されなかった。
import logging
logger = logging.getLogger(__name__)
logger.info(model.summary())
> 2022-03-31 05:19:17,741 - None
model.summary()の引数を確認してみる。
tf.keras.Model | TensorFlow v2.16.1
A model grouping layers into an object with training/inference features.
summary(
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False
)
print_fnという引数がある。
print_fn
Print function to use. Defaults to
ソースコードを見てみる。
def summary(self,
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False):
"""Prints a string summary of the network.
Args:
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements
in each line. If not provided,
defaults to `[.33, .55, .67, 1.]`.
print_fn: Print function to use. Defaults to `print`.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
expand_nested: Whether to expand the nested models.
If not provided, defaults to `False`.
show_trainable: Whether to show if a layer is trainable.
If not provided, defaults to `False`.
Raises:
ValueError: if `summary()` is called before the model is built.
"""
print_fnは関数を受け取るようなので、引数としてロガーを設定すれば良さそう。
model.summary(print_fn=lambda x: logger.info(x))
2022-03-31 06:49:57,082 - Model: "model"
2022-03-31 06:49:57,082 - __________________________________________________________________________________________________
2022-03-31 06:49:57,082 - Layer (type) Output Shape Param # Connected to
2022-03-31 06:49:57,082 - ==================================================================================================
2022-03-31 06:49:57,082 - input_ids (InputLayer) [(None, 136)] 0 []
2022-03-31 06:49:57,082 -
2022-03-31 06:49:57,082 - attention_mask (InputLayer) [(None, 136)] 0 []
2022-03-31 06:49:57,082 -
2022-03-31 06:49:57,087 - tf_deberta_v2_model (TFDeberta TFBaseModelOutput(l 70682112 ['input_ids[0][0]',
2022-03-31 06:49:57,087 - V2Model) ast_hidden_state=(N 'attention_mask[0][0]']
2022-03-31 06:49:57,087 - one, 136, 384),
2022-03-31 06:49:57,088 - hidden_states=None
2022-03-31 06:49:57,088 - , attentions=None)
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 - global_average_pooling1d (Glob (None, 384) 0 ['tf_deberta_v2_model[0][0]']
2022-03-31 06:49:57,088 - alAveragePooling1D)
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 - output (Dense) (None, 1) 385 ['global_average_pooling1d[0][0]'
2022-03-31 06:49:57,088 - ]
2022-03-31 06:49:57,088 -
2022-03-31 06:49:57,088 - ==================================================================================================
2022-03-31 06:49:57,093 - Total params: 70,682,497
2022-03-31 06:49:57,093 - Trainable params: 70,682,497
2022-03-31 06:49:57,093 - Non-trainable params: 0
2022-03-31 06:49:57,093 - __________________________________________________________________________________________________
model.summary()をlogging出来るようになった。
参考
Attention Required! | Cloudflare
コメント