GPUでfitしてCPUでpredictすることでTensorFlowのOOMを回避する

TensorFlowで大量の画像や巨大な配列をpredictするとメモリーリークが生じることが多い。

model.fitでは問題ないのに、model.predictになるとOOMになってしまう報告が散見する。

model.predict leads to oom issues, but model.fit does not:
my training model is a modified version of U-Net with the goal being to add a filter to a 512x1808 image. the issue I'm ...
301 Moved Permanently
out of memory when using model.predict() · Issue #5337 · keras-team/keras
hi~ I am now using keras to build my network. the training is normal. but when I use the model.predict() to predict the ...

根本的な解決ではないが、predictを明示的にCPUで行うようにしたところOOMは起きなくなった。

TensorFlowでは、tf.deviceを使うことで処理を行うデバイスを指定することができる。

tf.device  |  TensorFlow v2.16.1
Specifies the device for ops created/executed in this context.

そのため、model.fitはGPUで、model.predictはCPUで行うことが可能となる。

以下のように処理を行うことで、OOMを回避することができた。

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ...
    model.fit(...)

with tf.device("/cpu:0"):
    pred = model.predict(...)

コメント

タイトルとURLをコピーしました