GPUでfitしてCPUでpredictすることでTensorFlowのOOMを回避する

TensorFlowで大量の画像や巨大な配列をpredictするとメモリーリークが生じることが多い。

model.fitでは問題ないのに、model.predictになるとOOMになってしまう報告が散見する。

model.predict leads to oom issues, but model.fit does not:
my training model is a modified version of U-Net with the goal being to add a filter to a 512x1808 image. the issue I'm having is that I keep getting out of mem...
OOM when calling model.predict()
Hi everyone, I’m having an issue with model.predict() causing OOM errors. Strangely, this doesn’t happen while training, only while predicting on the val datase...
out of memory when using model.predict() · Issue #5337 · keras-team/keras
hi~ I am now using keras to build my network. the training is normal. but when I use the model.predict() to predict the results, it happend to a error : out of ...

根本的な解決ではないが、predictを明示的にCPUで行うようにしたところOOMは起きなくなった。

TensorFlowでは、tf.deviceを使うことで処理を行うデバイスを指定することができる。

tf.device  |  TensorFlow v2.13.0
Specifies the device for ops created/executed in this context.

そのため、model.fitはGPUで、model.predictはCPUで行うことが可能となる。

以下のように処理を行うことで、OOMを回避することができた。

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ...
    model.fit(...)

with tf.device("/cpu:0"):
    pred = model.predict(...)

コメント

タイトルとURLをコピーしました