10クラスの画像分類をするモデルを構築していた時、精度改善のために拡張して11クラス分類のモデルを構築した。
しかし、評価は元々の10クラスで行いたかったので、custom metricsのクラスを実装した。
例えばカスタムAUCだと以下の通り。
class auc(tf.keras.metrics.AUC):
def __init__(self, **kwargs):
super(auc, self).__init__(**kwargs)
self.auc = tf.keras.metrics.AUC(multi_label=True)
def update_state(self, y_true, y_pred, sample_weight=None):
self.auc.update_state(y_true[:, :10], y_pred[:, :10])
def result(self):
return self.auc.result()
def reset_states(self):
self.auc.reset_states()
親クラスとして使いたいmetricsを継承し、initで変数として宣言する。
update_state、result、reset_statusをオーバーライドするだけでカスタムAUCの完成。
学習データでは11クラスになっているけど、update_stateにわたす際に(y_true[:, :10], y_pred[:, :10])としてオリジナルの10クラスで評価するようにした。
バッチx出力になるので、二次元配列として渡す必要がある。
間違えてy_true[:10]にしないよう注意する。
参考
tf.keras.Metric | TensorFlow v2.16.1
Encapsulates metric logic and state.
Attention Required! | Cloudflare
コメント