版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/github_39611196/article/details/82464445
本篇博客主要介绍PyTorch中的批训练。Torch中提供了一种整理数据结构的工具DataLoader。
示例代码:
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10) # x (torch tensor)
y = torch.linspace(10, 1, 10) # y (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
# Loader让数据变成多个小批次
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True, # 在训练时需不需要打乱数据后再进行抽样
# num_workers=2, # 提取的时候使用两个线程来进行提取
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print('Epoch: ', epoch, '| step: ', step, '| bacth_x: ',
batch_x.numpy(), '| batch_y: ', batch_y.numpy())
运行结果:
Epoch: 0 | step: 0 | bacth_x: [4. 8. 1. 5. 3.] | batch_y: [ 7. 3. 10. 6. 8.]
Epoch: 0 | step: 1 | bacth_x: [ 9. 7. 10. 2. 6.] | batch_y: [2. 4. 1. 9. 5.]
Epoch: 1 | step: 0 | bacth_x: [ 7. 5. 4. 9. 10.] | batch_y: [4. 6. 7. 2. 1.]
Epoch: 1 | step: 1 | bacth_x: [6. 2. 8. 3. 1.] | batch_y: [ 5. 9. 3. 8. 10.]
Epoch: 2 | step: 0 | bacth_x: [10. 6. 1. 9. 3.] | batch_y: [ 1. 5. 10. 2. 8.]
Epoch: 2 | step: 1 | bacth_x: [2. 8. 5. 4. 7.] | batch_y: [9. 3. 6. 7. 4.]