私はTensorflowを使用しますが、私はよ一般的に深い学習の枠組みを越え異なりますユーザーのためにドキュメントを書きます。
ローカルファイルシステム(TB +)に収まらないデータセットを扱うIリモートデータストアからのサンプルデータと書き込みサンプルローカルTensorflow標準のtfrecords
フォーマット。
私はので、いくつかの値をサンプリングしています訓練の最初のエポックの間、エポックローカルデータのは非常に小さいですが、私はそれに訓練します。上エポック2 Iは、データファイルは、次のエポックのためのローカルデータファイルの拡張セットに私のサンプリングサブプロセス(今より)と電車で製造されていたものを再検討します。プロセスに、各エポックを繰り返します。このように、私は、サンプルのローカルキャッシュを構築し、私は、ローカルストレージを埋めるよう古いサンプルを立ち退かせることができます。キャッシュモデルは(訓練の後半に向かって)ほとんどの分散を必要とする時間程度で成長する地元のサンプル。
それはPythonのGILは、データ転送速度をサポートすることはできませんので、私はPythonのトレーニングループ処理中のデータをデシリアライズないことが重要だTensorflowのPython /ではこのように、とGPUのパフォーマンス(300〜600メガバイト/秒、データは生の非圧縮科学的です) PythonのGILは、高速トレーニングループにサービスを提供できない場合に苦しんでいます。
サンプルを書き込みtfrecords
、サブプロセスからファイル(Pythonのマルチプロセッシング)tensorflowのネイティブことができますTFRecordsDataset
のPythonのデシリアライズ外を行うには、したがって、我々は、Python GILの問題を回避し、私は高いIOデータレートでGPUを飽和することができます。
私はPytorchでこの問題に対処する方法を知っていただきたいと思います。私が使用し、TensorflowとPyTorchの両方のユーザーに特定の推奨事項を提供したいされていたサンプリング戦略について書いているが、私は十分な詳細と書き込みに十分な生態系を前処理PyTorchを知りません。
サイドノート:これらのデータ転送速度をサポートする唯一の純粋なPythonのベースのソリューションは、System VとのPython 3.8で来るかもしれないが、メモリを共有し、マルチプロセッシングが、私は(すぐになり、それに対するサポートは非常に十分ではないことを、まだとして試していません)。彼らは訓練のループ処理では、逆シリアル化が必要なため、高いIOレートでのデシリアライズ時にGILをロックしているため、既存のマルチプロセッシングソリューションでは十分ではありません。
実際には、簡単に使用することにより、サブプロセス内のデータをデシリアライズすることができますtorch.utils.data.DataLoader
。設定することにより、num_workers
1への引数またはより大きな値を、あなたは自分のPythonの通訳とGILSでサブプロセスを起動することができます。
loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
for batch_idx, data in enumerate(loader):
# loader in the main process does not claim GIL at this point
AがDataloader
必要ですtorch.utils.data.Dataset
からデータを取得します。あなたのケースでは、適切なサブクラスを実装するために些細な仕事ではないかもしれません。場合は、あなたが再作成する必要があるDataset
すべてのエポックのインスタンスを、あなたはこのような何かを行うことができます。
for epcoh in range(epochs):
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
for batch_idx, data in enumerate(loader):
# Do training
またはより良いです
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
for epcoh in range(epochs):
last_batch_idx = (len(dset)-1) // loader.batch_size
for batch_idx, data in enumerate(loader):
# Prepare next loader in advance to avoid blocking
if batch_idx == last_batch_idx:
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
# Do training
注意点として、すなわち、ほとんどの場合、GILの影響を受けていることのCPUバウンドの操作ではなく、I / Oバウンドの操作がありますのでご注意くださいthreading
任意の純粋なI / O重い操作のために行うとあなたも必要性がないだろうsubprocess
。詳細については、これを参照してください質問や、このWikipediaの記事。