pytorch白话入门笔记1.8-批数据训练

目录

 

1.批数据训练

(1)代码

(2)运行结果


1.批数据训练

(1)代码

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible

BATCH_SIZE = 8      # 批训练的数据个数

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)

# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)

# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # 要不要打乱数据 (打乱比较好:True;不打乱:False)
    # num_workers=2,              # 多线程来读数据 windows删除此行
)

for epoch in range(5):   # 训练所有数据 5 次
    for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
        # 训练...

        # 打数据
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batch_x.numpy(), '| batch y: ', batch_y.numpy())

(2)运行结果

Epoch:  0 | Step:  0 | batch x:  [ 5.  7. 10.  3.  4.  2.  1.  8.] | batch y:  [ 6.  4.  1.  8.  7.  9. 10.  3.]
Epoch:  0 | Step:  1 | batch x:  [9. 6.] | batch y:  [2. 5.]
Epoch:  1 | Step:  0 | batch x:  [ 4.  6.  7. 10.  8.  5.  3.  2.] | batch y:  [7. 5. 4. 1. 3. 6. 8. 9.]
Epoch:  1 | Step:  1 | batch x:  [1. 9.] | batch y:  [10.  2.]
Epoch:  2 | Step:  0 | batch x:  [ 4.  2.  5.  6. 10.  3.  9.  1.] | batch y:  [ 7.  9.  6.  5.  1.  8.  2. 10.]
Epoch:  2 | Step:  1 | batch x:  [8. 7.] | batch y:  [3. 4.]
Epoch:  3 | Step:  0 | batch x:  [ 4. 10.  9.  8.  7.  6.  1.  2.] | batch y:  [ 7.  1.  2.  3.  4.  5. 10.  9.]
Epoch:  3 | Step:  1 | batch x:  [5. 3.] | batch y:  [6. 8.]
Epoch:  4 | Step:  0 | batch x:  [9. 8. 4. 6. 5. 3. 7. 2.] | batch y:  [2. 3. 7. 5. 6. 8. 4. 9.]
Epoch:  4 | Step:  1 | batch x:  [10.  1.] | batch y:  [ 1. 10.]

Process finished with exit code 0
原创文章 23 获赞 1 访问量 723

猜你喜欢

转载自blog.csdn.net/BSZJYAJ/article/details/105208100