大量のTFRecordsを使い、tf.dataでモデルの訓練をしていた際、メモリの使用量が増え続け、Out of Memoryになってしまっていた。
仮想メモリ(スワップ)を3TB程度用意してみたが、スワップも食い続けてしまったため、ハードの限界というよりもソフトでメモリーリークが起きてるのではないかと予想した。
公式サイトのtf.dataのパフォーマンス改善を試みるも、特にメモリ使用量は改善しなかった。
調べてみると、大量のデータでトレーニングする時にメモリ使用量が増大し続けるケースを発見。
自分の状況によく似ている。
prefetchをautoにするとメモリを食い続けてしまう可能性があるらしい。
prefetchはtf.data.AUTOTUNEになっていることが多く、自動でいい感じにしてくれる機能なので引数を詳しく掘り下げたことがなかったので調べてみる。
prefetch(
buffer_size, name=None
)
Creates a Dataset that prefetches elements from this dataset.Most dataset input pipelines should end with a call to prefetch. This allows later elements to be prepared while the current element is being processed. This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.
GPUの学習中にCPUで事前にデータを用意してくれる機能と思われる。
メモリリークが起きるのは、prefetchによりデータをメモリに事前読み込みし続けてしまうからではと予想。
明示的にprefetchの量を制限する。
dataset = dataset.prefetch(2)
事前に読み込むバッチ数を2つにしてみた。
これでメモリリークは起きなくなり、スワップが必要になることもなくなった。
ただし、パイプラインによっては学習速度が遅くなる可能性があるので、速度とメモリのバランスを見ながらprefetchのパラメーターを決める必要があるように感じた。
コメント