PyTorch学习(8)—批训练

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/github_39611196/article/details/82464445

本篇博客主要介绍PyTorch中的批训练。Torch中提供了一种整理数据结构的工具DataLoader。

示例代码:

import torch
import torch.utils.data as Data

BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)  # x (torch tensor)
y = torch.linspace(10, 1, 10)  # y (torch tensor)

torch_dataset = Data.TensorDataset(x, y)

# Loader让数据变成多个小批次
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,  # 在训练时需不需要打乱数据后再进行抽样
    # num_workers=2,  # 提取的时候使用两个线程来进行提取
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        # training
        print('Epoch: ', epoch, '| step: ', step, '| bacth_x: ',
              batch_x.numpy(), '| batch_y: ', batch_y.numpy())

运行结果:

Epoch:  0 | step:  0 | bacth_x:  [4. 8. 1. 5. 3.] | batch_y:  [ 7.  3. 10.  6.  8.]
Epoch:  0 | step:  1 | bacth_x:  [ 9.  7. 10.  2.  6.] | batch_y:  [2. 4. 1. 9. 5.]
Epoch:  1 | step:  0 | bacth_x:  [ 7.  5.  4.  9. 10.] | batch_y:  [4. 6. 7. 2. 1.]
Epoch:  1 | step:  1 | bacth_x:  [6. 2. 8. 3. 1.] | batch_y:  [ 5.  9.  3.  8. 10.]
Epoch:  2 | step:  0 | bacth_x:  [10.  6.  1.  9.  3.] | batch_y:  [ 1.  5. 10.  2.  8.]
Epoch:  2 | step:  1 | bacth_x:  [2. 8. 5. 4. 7.] | batch_y:  [9. 3. 6. 7. 4.]
 

猜你喜欢

转载自blog.csdn.net/github_39611196/article/details/82464445
今日推荐