KerasでカスタムAUC、カスタムAccuracyを実装する

10クラスの画像分類をするモデルを構築していた時、精度改善のために拡張して11クラス分類のモデルを構築した。

しかし、評価は元々の10クラスで行いたかったので、custom metricsのクラスを実装した。

例えばカスタムAUCだと以下の通り。

class auc(tf.keras.metrics.AUC):
    def __init__(self, **kwargs):
        super(auc, self).__init__(**kwargs)
        self.auc = tf.keras.metrics.AUC(multi_label=True)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.auc.update_state(y_true[:, :10], y_pred[:, :10])

    def result(self):
        return self.auc.result()

    def reset_states(self):
        self.auc.reset_states()

親クラスとして使いたいmetricsを継承し、initで変数として宣言する。

update_state、result、reset_statusをオーバーライドするだけでカスタムAUCの完成。

学習データでは11クラスになっているけど、update_stateにわたす際に(y_true[:, :10], y_pred[:, :10])としてオリジナルの10クラスで評価するようにした。

バッチx出力になるので、二次元配列として渡す必要がある。

間違えてy_true[:10]にしないよう注意する。

参考

tf.keras.Metric  |  TensorFlow v2.16.1
Encapsulates metric logic and state.
Attention Required! | Cloudflare

コメント

タイトルとURLをコピーしました