pytorch学习之七 批训练

批训练是什么东西呢?在之前的迭代训练代码中。

for t in range(100):
    out = net(x)
    loss = loss_func(out,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

一次迭代,需要到用到训练样本的所有数据。那么当训练集非常大,或者说样本无法同时取出来的时候,就比较难以训练,这时候要用上批处理的方法。

何为批处理呢?就是每次迭代只使用训练集的一部分作为一个代表,来训练整个网络。这样可以加速网络的训练,同时,精度又不会有太大的下降。

pytorch提供了一些方法来进行批训练,主要下面两个

  • Data.TensorDataset
    将训练样本的x,y封装起来的一个数据集类型

  • Data.DataLoader
    这个是将数据集切分的一个工具,一般都是随机切分。

上代码

import torch#导入模块
import torch.utils.data as Data

#每一批的数据量
BATCH_SIZE=5#每一批的数据量

x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)


#转换成torch能识别的Dataset
torch_dataset=Data.TensorDataset(x,y)  #将数据放入torch_dataset


#torch.utils.data.DataLoader这个接口定义在dataloader.py脚本中,只要是用PyTorch来训练
#模型都会用到该接口,该接口主要用来将自动以的数据读取接口的输出或者PyTorch已有的数据读取接口按照batch size
#封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较
#

loader=Data.DataLoader(
        dataset=torch_dataset,      #将数据放入loader
        batch_size=BATCH_SIZE,      #批的尺寸,五个为一个批次
        shuffle=True,               #是否打断数据  
        num_workers=0              #多线程读取数据,如果为0就是主线程来读取数据
        )

#for epoch in range(3):  #训练所有的整套数据3次
#    for step,(batch_x,batch_y) in enumerate(loader):   #
for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):  
        print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
            ,'|batch y:',batch_y.numpy())

loader=Data.DataLoader(
dataset=torch_dataset, #将数据放入loader
batch_size=BATCH_SIZE, #批的尺寸,五个为一个批次
shuffle=True, #是否打断数据
num_workers=0 #多线程读取数据,如果为0就是主线程来读取数据
)

dataset就是刚刚建立好的数据集,batch_size是每一批的大小。shuffle,每个批是否是随机从dataset里面取数据。num_works,是否是多个线程来读取数据,我的机器上这个参数为非零值就会执行失败,不知道为什么

for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):  
        print('Epoch:',epoch,'|Step:',step,'|batch x:',batch_x.numpy()
            ,'|batch y:',batch_y.numpy())

for step,(batch_x,batch_y) in enumerate(loader):
这句话是生成了一个关于loader的迭代器,然后遍历loader
step代表索引,batch_x,batch_y代表每次随机切分的训练集,大小为5
打印效果如下

Epoch: 0 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 0 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 1 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 1 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]
Epoch: 2 |Step: 0 |batch x: [1. 2. 3. 4. 5.] |batch y: [10.  9.  8.  7.  6.]
Epoch: 2 |Step: 1 |batch x: [ 6.  7.  8.  9. 10.] |batch y: [5. 4. 3. 2. 1.]

猜你喜欢

转载自blog.csdn.net/ronaldo_hu/article/details/91958278
今日推荐