TFRecordsでmodel.fitする時にメモリリークが起きてメモリ使用量が増え続けてしまう

大量のTFRecordsを使い、tf.dataでモデルの訓練をしていた際、メモリの使用量が増え続け、Out of Memoryになってしまっていた。

仮想メモリ(スワップ)を3TB程度用意してみたが、スワップも食い続けてしまったため、ハードの限界というよりもソフトでメモリーリークが起きてるのではないかと予想した。

公式サイトのtf.dataのパフォーマンス改善を試みるも、特にメモリ使用量は改善しなかった。

tf.data API によるパフォーマンスの改善  |  TensorFlow Core
プロファイラを使用した TensorFlow のパフォーマンス最適化  |  TensorFlow Core

調べてみると、大量のデータでトレーニングする時にメモリ使用量が増大し続けるケースを発見。

自分の状況によく似ている。

Memory leak when using TFRecords · Issue #909 · ROCm/tensorflow-upstream
System information Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, ...
Getting memory error when training a larger dataset on the GPU
I’m having the same problem. I have a large (5.7GB) NumPy array called train_x. When I try to create a tf.data.Dataset t...

prefetchをautoにするとメモリを食い続けてしまう可能性があるらしい。

prefetchはtf.data.AUTOTUNEになっていることが多く、自動でいい感じにしてくれる機能なので引数を詳しく掘り下げたことがなかったので調べてみる。

tf.data.Dataset  |  TensorFlow v2.16.1
Represents a potentially large set of elements.

prefetch(
buffer_size, name=None
)

Creates a Dataset that prefetches elements from this dataset.

Most dataset input pipelines should end with a call to prefetch. This allows later elements to be prepared while the current element is being processed. This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.

GPUの学習中にCPUで事前にデータを用意してくれる機能と思われる。

メモリリークが起きるのは、prefetchによりデータをメモリに事前読み込みし続けてしまうからではと予想。

明示的にprefetchの量を制限する。

dataset = dataset.prefetch(2)

事前に読み込むバッチ数を2つにしてみた。

これでメモリリークは起きなくなり、スワップが必要になることもなくなった。

ただし、パイプラインによっては学習速度が遅くなる可能性があるので、速度とメモリのバランスを見ながらprefetchのパラメーターを決める必要があるように感じた。

コメント

タイトルとURLをコピーしました