torch.utils.data.DataLoader(データセット、BATCH_SIZE、シャッフル、num_workers、collate_fn)
パラメータ説明
dataset
着信データ・セットbatch_size
どのように多くの各バッチのサンプルshuffle
かどうかは、データを混乱させるnum_workers
データのロードを処理するには、いくつかのプロセスがあります。collate_fn
ミニバッチ機能を形成するためのサンプルのリスト
具体的な使用
- データローダ処理ロジックは、データセットによって最初のクラス内で
__getitem__
次に使用、ゲッター個人データ、およびその後のバッチに組み合わさcollate_fn
例えば等パディングとして、または特定の操作を実行するためにバッチに割り当てられた機能を。 - 主に2つの2のものを再構築するために使用NLP、新しいデータセットクラスは、データセットを構築し、継承しなければならない
torch.utils.data.Dataset
機能を実現するために2つの内部、クラスがされた__len__
データセット全体のサイズを取得するために使用され、1がされて__getitem__
から使用しますデータセグメントは、データ・セットの項目を取得します。
次のコードの作成MyDataset
最初の入力パラメータのデータセットデータローダーを構築するために、クラスを
class MyDataset(torch.utils.data.Dataset):
def __init__(self, centers, contexts, negatives):
assert len(centers) == len(contexts) == len(negatives)
self.centers = centers
self.contexts = contexts
self.negatives = negatives
def __getitem__(self, index):
return (self.centers[index], self.contexts[index], self.negatives[index])
def __len__(self):
return len(self.centers)
# args:
# centers, contexts, negatives 都是list, 且contexts中元素也是list, 并且长度不一定都相同
データセットで特別に構築トーチ機能がありTensorDataset
、以下を使用して、
dataset = torch.utils.data.TensorDataset(x, y)
# 输入的x, y 是tensor类型
- カスタマイズすることができ
collate_fn=myfunction
、データをBATCH_SIZE __getitem__ DataSetクラスは、関数collate_fn、NLPタスクに指定されたパケットの形で送信される上記機能によってサンプリングされたデータの収集を設計するための方法を、それがしばしばcollate_fnに指定されています内部パディングを行う関数
次のようにパディング機能が定義されています
def batchify(data):
max_len = max(len(c)+len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0]*(max_len-cur_len)]
masks += [[1]*cur_len + [0]*(max_len-cur_len)]
labels += [[1]*len(context) + [0]*(max_len-len(context))]
return torch.tensor(centers).view(-1,1), torch.tensor(contexts_negatives), torch.tensor(masks), torch.tensor(labels)
概要
総流量が使用されている:再構成データセットのクラス(または直接torch.utils.data.TensorDataset()
、他のパラメータ設定のビルド入力パラメータデータセットに)、カスタムデータ等パディングとしての機能(パラメータ渡さcollate_fn)を処理します
これは、質量参加の完全なプロセスを以下の
dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers)
完全なプログラムgithupを参照してください:完全なプログラムword2vec.py