pytorch之DataLoader

本文代码转载自“https://www.cnblogs.com/JeasonIsCoding/p/10168753.html” 

直接写代码

import torch
import torch.utils.data as Data

if __name__ == '__main__':
    BATCH_SIZE = 3
    # 建立两个向量x和y,一个作为输入的数据,一个作为正确的结果
    x = torch.linspace(1, 10, 10)  # x data(torch tensor)
    y = torch.linspace(10, 1, 10)  # y data(torch tensor)

    # 我们需要把x和y组成一个完整的数据集,并转化为pytorch能识别的数据集类型
    torch_dataset = Data.TensorDataset(x, y)

    # 把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器
    loader = Data.DataLoader(
        dataset=torch_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,        # 要不要打乱数据(打乱比较好)
        num_workers=2,       # 多线程来读数据
    )

    for epoch in range(5):
        i = 0
        for batch_x, batch_y in loader:
            i = i+1
            print('Epoch:{} | num:{} | batch_x:{} | batch_y:{}'.format(epoch, i, batch_x, batch_y))

需要注意的是,代码需放在if __name__ == '__main__':下,否则会报错,报错结果显示如下:

猜你喜欢

转载自blog.csdn.net/Acmer_future_victor/article/details/105312954