pytorch1.0批训练神经网络

pytorch1.0批训练神经网络

import torch
import torch.utils.data as Data
# Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoader, 能用它来包装自己的数据, 进行批训练.
torch.manual_seed(1)    # reproducible
# 批训练的数据个数
BATCH_SIZE = 5
BATCH_SIZE = 8

x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
# DataLoader 是 torch 用来包装开发者自己的数据的工具.
# 将自己的 (numpy array 或其他) 数据形式装换成 Tensor, 然后再放进这个包装器中.
# 使用 DataLoader 的好处就是他们帮你有效地迭代数据

# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)  # torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # random shuffle for training     # 随机打乱数据--打乱比较好
    num_workers=2,              # subprocesses for loading data   # 多线程来读数据
)


def show_batch():
    for epoch in range(3):   # train entire dataset 3 times   # 训练所有/整套数据 3 次
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step  # 每一步 loader 释放一小批数据用来学习
            # train your data...  # 假设这里就是训练的代码块...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())

if __name__ == '__main__':
    show_batch()
# BATCH_SIZE = 5
'''
Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4.] | batch y: [6. 4. 1. 8. 7.]
Epoch: 0 | Step: 1 | batch x: [2. 1. 8. 9. 6.] | batch y: [ 9. 10. 3. 2. 5.]
Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8.] | batch y: [7. 5. 4. 1. 3.]
Epoch: 1 | Step: 1 | batch x: [5. 3. 2. 1. 9.] | batch y: [ 6. 8. 9. 10. 2.]
Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10.] | batch y: [7. 9. 6. 5. 1.]
Epoch: 2 | Step: 1 | batch x: [3. 9. 1. 8. 7.] | batch y: [ 8. 2. 10. 3. 4.]
'''
# BATCH_SIZE = 8
'''
Epoch: 0 | Step: 0 | batch x: [ 5. 7. 10. 3. 4. 2. 1. 8.] | batch y: [ 6. 4. 1. 8. 7. 9. 10. 3.]
Epoch: 0 | Step: 1 | batch x: [9. 6.] | batch y: [2. 5.]
Epoch: 1 | Step: 0 | batch x: [ 4. 6. 7. 10. 8. 5. 3. 2.] | batch y: [7. 5. 4. 1. 3. 6. 8. 9.]
Epoch: 1 | Step: 1 | batch x: [1. 9.] | batch y: [10. 2.]
Epoch: 2 | Step: 0 | batch x: [ 4. 2. 5. 6. 10. 3. 9. 1.] | batch y: [ 7. 9. 6. 5. 1. 8. 2. 10.]
Epoch: 2 | Step: 1 | batch x: [8. 7.] | batch y: [3. 4.]
'''

猜你喜欢

转载自www.cnblogs.com/jeshy/p/11200000.html