pytorch:DataLoader

torch中的DataLoader主要是用来将给定数据集中的样本打包成一个一个batch的,那么它具体是怎么工作的呢?对于给定的数据集又有什么要求呢?

from torch.utils.data import DataLoader

class show_how_dataloader_work():
    def __init__(self,x):
        self.x = x
    
    def __len__(self): #必须要有!
        return len(self.x)
    
    def __getitem__(self,index): #必须要有!
        print('index是{},即dataloader取出了第{}个元素'.format(index,index+1))
        return self.x[index]

a = show_how_dataloader_work(['wyb','xz','zql','wx','hjy'])
a_batch = DataLoader(a,batch_size=2,shuffle=True)
#a就是相当于给定的数据集。
list(a_batch)
index是3,即dataloader取出了第4个元素
index是1,即dataloader取出了第2个元素
index是0,即dataloader取出了第1个元素
index是4,即dataloader取出了第5个元素
index是2,即dataloader取出了第3个元素
[['wx', 'xz'], ['wyb', 'hjy'], ['zql']]
  • 具体来说,dataloader是如何工作的呢?
  • 首先得到length,length是__len(self)__的返回值
  • 如果shuffle=True,那么就随机从range(length)中取出batch_size个数(一共去length次,不重复),依次作为index,放到__getitem__(self,index)中,然后得到batch_size个__getitem__(self,index)的返回值,并打包在一起
  • 如果shuffle=False,那么就依次从range(length)中取出batch_size个数,作为index,放到__getitem__(self,index)中,然后得到batch_size个__getitem__(self,index)的返回值,并打包在一起

    因此对于给定数据集的最基本要求就是要有__len__(self)函数和__getitem__(self,index)函数



带到上述例子中具体分析:
我们的self.x是['wyb','xz','zql','wx','hjy'],__len__(self)的返回值是len(self.x),依次length=5
shuffle=True,batch_size是2
第一次,取两个,index分别是3,1.那么调用__getitem__(self,index)函数两次,得到返回值self.x[3]即‘wx’和self.x[1]即‘xz’,并将它们打包在一起
第二次,取两个,index分别是0,4.那么调用__getitem__(self,index)函数两次,得到返回值self.x[0]即‘wyb’和self.x[4]即‘hjy’,并将它们打包在一起
最后一次,只剩一个了,index是2,那么调用__getitem__(self,index)函数一次,得到返回值self.x[2]即‘zql’
至此,length个返回值已经全部取完


再来举一个例子:
class diy():
    
    def __len__(self):#必须要有!
        print('len函数被调用了')
        return 4
    
    def __getitem__(self,index): #必须要有!
        print('getitem函数被调用一次,这次的index是{}'.format(index))
        return '并肩于雪山之巅!!'*(index+1)

b = diy()

b_batch = DataLoader(b,batch_size=2,shuffle=False)

for batch_idx,output_sentence in enumerate(b_batch):
    print('这是第{}个batch'.format(batch_idx+1))
    for i in range(len(output_sentence)):
        print(output_sentence[i])
len函数被调用了
getitem函数被调用一次,这次的index是0
getitem函数被调用一次,这次的index是1
这是第1个batch
并肩于雪山之巅!!
并肩于雪山之巅!!并肩于雪山之巅!!
getitem函数被调用一次,这次的index是2
getitem函数被调用一次,这次的index是3
这是第2个batch
并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!
并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!
c_batch = DataLoader(b,batch_size=3,shuffle=True)
len函数被调用了
len函数被调用了
list(c_batch)
len函数被调用了
len函数被调用了
getitem函数被调用一次,这次的index是0
getitem函数被调用一次,这次的index是1
getitem函数被调用一次,这次的index是3
getitem函数被调用一次,这次的index是2
[['并肩于雪山之巅!!', '并肩于雪山之巅!!并肩于雪山之巅!!', '并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!'],
 ['并肩于雪山之巅!!并肩于雪山之巅!!并肩于雪山之巅!!']]

在这里插入图片描述

发布了43 篇原创文章 · 获赞 1 · 访问量 737

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104994985