torch.utils.data.DataLoader(dataset, batch_size, shuffle, num_workers, collate_fn)
参数说明
dataset
传入的数据集batch_size
每个batch有多少个样本shuffle
是否打乱数据num_workers
有几个进程来处理data loadingcollate_fn
将一个list的sample组成一个mini-batch的函数
具体使用
- Dataloader的处理逻辑是先通过Dataset类里面的
__getitem__
函数获取单个的数据,然后组合成batch,再使用collate_fn
所指定的函数对这个batch做一些操作,比如padding之类的。 - 在NLP中的使用主要是要重构两个两个东西,一个是新建Dataset类构建dataset,必须继承
torch.utils.data.Dataset
类,内部要实现两个函数一个是__len__
用来获取整个数据集的大小,一个是__getitem__
用来从数据集中得到一个数据片段item。
如下代码新建MyDataset
类来构建输入DataLoader第一个参数dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, centers, contexts, negatives):
assert len(centers) == len(contexts) == len(negatives)
self.centers = centers
self.contexts = contexts
self.negatives = negatives
def __getitem__(self, index):
return (self.centers[index], self.contexts[index], self.negatives[index])
def __len__(self):
return len(self.centers)
# args:
# centers, contexts, negatives 都是list, 且contexts中元素也是list, 并且长度不一定都相同
在torch中有一个专门构建dataset的函数TensorDataset
, 使用如下
dataset = torch.utils.data.TensorDataset(x, y)
# 输入的x, y 是tensor类型
- 可以通过自定义
collate_fn=myfunction
来设计数据收集的方式,也就是通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数, nlp任务中,经常在collate_fn指定的函数里面做padding
如下定义一个函数进行padding
def batchify(data):
max_len = max(len(c)+len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0]*(max_len-cur_len)]
masks += [[1]*cur_len + [0]*(max_len-cur_len)]
labels += [[1]*len(context) + [0]*(max_len-len(context))]
return torch.tensor(centers).view(-1,1), torch.tensor(contexts_negatives), torch.tensor(masks), torch.tensor(labels)
总结
总的使用流程就是: 重建Dataset类(或直接用torch.utils.data.TensorDataset()
)来构建输入参数dataset, 设定其他参数, 自定义一个函数处理数据(传给collate_fn参数)如padding等
如下完整的传参过程
dataset = MyDataset(all_centers, all_contexts, all_negatives)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=num_workers)
完整程序见githup: 完整程序 word2vec.py