GPUでfitしてCPUでpredictすることでTensorFlowのOOMを回避する

TensorFlowで大量の画像や巨大な配列をpredictするとメモリーリークが生じることが多い。

model.fitでは問題ないのに、model.predictになるとOOMになってしまう報告が散見する。

Attention Required! | Cloudflare
OOM when calling model.predict()
Hi everyone, I’m having an issue with model.predict() causing OOM errors. Strangely, this doesn’t happen while training,...
out of memory when using model.predict() · Issue #5337 · keras-team/keras
hi~ I am now using keras to build my network. the training is normal. but when I use the model.predict() to predict the ...

根本的な解決ではないが、predictを明示的にCPUで行うようにしたところOOMは起きなくなった。

TensorFlowでは、tf.deviceを使うことで処理を行うデバイスを指定することができる。

tf.device  |  TensorFlow v2.16.1
Specifies the device for ops created/executed in this context.

そのため、model.fitはGPUで、model.predictはCPUで行うことが可能となる。

以下のように処理を行うことで、OOMを回避することができた。

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ...
    model.fit(...)

with tf.device("/cpu:0"):
    pred = model.predict(...)

コメント

タイトルとURLをコピーしました