pytorch --- torch.utils.data.DataLoder()を使用

torch.utils.data.DataLoader(データセット、BATCH_SIZE、シャッフル、num_workers、collat​​e_fn)

パラメータ説明

  • dataset 着信データ・セット
  • batch_sizeどのように多くの各バッチのサンプル
  • shuffleかどうかは、データを混乱させる
  • num_workersデータのロードを処理するには、いくつかのプロセスがあります。
  • collate_fnミニバッチ機能を形成するためのサンプルのリスト

具体的な使用

  1. データローダ処理ロジックは、データセットによって最初のクラス内で__getitem__次に使用、ゲッター個人データ、およびその後のバッチに組み合わさcollate_fn例えば等パディングとして、または特定の操作を実行するためにバッチに割り当てられた機能を。
  2. 主に2つの2のものを再構築するために使用NLP、新しいデータセットクラスは、データセットを構築し、継承しなければならないtorch.utils.data.Dataset機能を実現するために2つの内部、クラスがされた__len__データセット全体のサイズを取得するために使用され、1がされて__getitem__から使用しますデータセグメントは、データ・セットの項目を取得します。

次のコードの作成MyDataset最初の入力パラメータのデータセットデータローダーを構築するために、クラスを

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return (self.centers[index], self.contexts[index], self.negatives[index])

    def __len__(self):
        return len(self.centers)
  # args:
  #		centers, contexts, negatives 都是list, 且contexts中元素也是list, 并且长度不一定都相同

データセットで特別に構築トーチ機能がありTensorDataset、以下を使用して、

dataset = torch.utils.data.TensorDataset(x, y)
# 输入的x, y 是tensor类型
  1. カスタマイズすることができcollate_fn=myfunction、データをBATCH_SIZE __getitem__ DataSetクラスは、関数collat​​e_fn、NLPタスクに指定されたパケットの形で送信される上記機能によってサンプリングされたデータの収集を設計するための方法を、それがしばしばcollat​​e_fnに指定されています内部パディングを行う関数

次のようにパディング機能が定義されています

def batchify(data):
    max_len = max(len(c)+len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0]*(max_len-cur_len)]
        masks += [[1]*cur_len + [0]*(max_len-cur_len)]
        labels += [[1]*len(context) + [0]*(max_len-len(context))]
    return torch.tensor(centers).view(-1,1), torch.tensor(contexts_negatives), torch.tensor(masks), torch.tensor(labels)

概要
総流量が使用されている:再構成データセットのクラス(または直接torch.utils.data.TensorDataset()、他のパラメータ設定のビルド入力パラメータデータセットに)、カスタムデータ等パディングとしての機能(パラメータ渡さcollat​​e_fn)を処理します

これは、質量参加の完全なプロセスを以下の

dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers)

完全なプログラムgithupを参照してください:完全なプログラムword2vec.py

公開された33元の記事 ウォンの賞賛1 ビュー2619

おすすめ

転載: blog.csdn.net/orangerfun/article/details/103955880