【学习系列7】Pytorch中的数据加载

目录

1. 模型中使用数据加载器的目的

2. 数据集类

3. 迭代数据集


1. 模型中使用数据加载器的目的

在前面的线性回归横型中,我们使用的数据很少,所以直接把全部数据放到锁型中去使用。

但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在横型中进行向前的计算和反向传播,经带我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。
所以,接下来我们来学习pytorch中的数据加载的方法

2. 数据集类

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data.Dataset的源码如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

可知:我们需要在白定义的数据集类中继承Dataset类,同时还需要实现两个方法

1.__len__方法,能够实现通过全局的len()方法获取其中的元系个数

2.__getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取具中的第
i条

例子:

from torch.utils.data import Dataset

data_path = './data/SMSSpamCollection.txt'


# 完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path, encoding='utf8').readlines()

    def __getitem__(self, index):
        return self.lines[index].strip()

    def __len__(self):
        return len(self.lines)


if __name__ == '__main__':
    my_dataset = MyDataset()
    for i in range(10):
        print(f'{i} {my_dataset[i]}')
0 ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
1 ham	Ok lar... Joking wif u oni...
2 spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
3 ham	U dun say so early hor... U c already then say...
4 ham	Nah I don't think he goes to usf, he lives around here though
5 spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
6 ham	Even my brother is not like to speak with me. They treat me like aids patent.
7 ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
8 spam	WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
9 spam	Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030

3. 迭代数据集

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:
。批处理数据 (Batching the data)
。打乱数据 (Shuffling the data)
。使用多线程 multiprocessing 并行加载数据

在pytorch中torchutilsdata.DataLoader提供了上述的所用方法
DataLoader的使用方法示例:

例子:

from torch.utils.data import Dataset, DataLoader

data_path = './data/SMSSpamCollection.txt'


# 完成数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.lines = open(data_path, encoding='utf8').readlines()

    def __getitem__(self, index):
        label = self.lines[index][:4].strip()
        context = self.lines[index][4:].strip()
        return label, context

    def __len__(self):
        return len(self.lines)


if __name__ == '__main__':
    my_dataset = MyDataset()
    data_loader = DataLoader(dataset=my_dataset, batch_size=10, shuffle=True, num_workers=2)
    for index, (label, context) in enumerate(data_loader):
        print(index, label, context)
        print('*' * 50)
0 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam', 'ham') ('"Pete can you please ring meive hardly gotany credit"', 'This weekend is fine (an excuse not to do too much decorating)', 'Many more happy returns of the day. I wish you happy birthday.', 'Eh u send wrongly lar...', 'Wat r u doing now?', 'Take something for pain. If it moves however to any side in the next 6hrs see a doctor.', 'If we win its really no 1 side for long time.', 'Hmmm.but you should give it on one day..', 'Message Important information for O2 user. Today is your lucky day! 2 find out why log onto http://www.urawinner.com there is a fantastic surprise awaiting you', 'Hey sathya till now we dint meet not even a single time then how can i saw the situation sathya.')
**************************************************
1 ('ham', 'ham', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham') ('Take some small dose tablet for fever', 'No screaming means shouting..', 'Reply to win £100 weekly! Where will the 2006 FIFA World Cup be held? Send STOP to 87239 to end service', 'No calls..messages..missed calls', 'Bugis oso near wat...', 'Unni thank you dear for the recharge..Rakhesh', "A bloo bloo bloo I'll miss the first bowl", "Cool, I'll text you in a few", 'Hey...Great deal...Farm tour 9am to 5pm $95/pax, $50 deposit by 16 May', 'Hi kindly give us back our documents which we submitted for loan from STAPATI')
**************************************************
2 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham') ('Can you plz tell me the ans. BSLVYL sent via fullonsms.com', 'i can call in  <#>  min if thats ok', 'Gudnite....tc...practice going on', "Aight, lemme know what's up", 'I dunno they close oredi not... Ü v ma fan...', 'Camera quite good, 10.1mega pixels, 3optical and 5digital dooms. Have a lovely holiday, be safe and i hope you hav a good journey! Happy new year to you both! See you in a couple of weeks!', 'We can go 4 e normal pilates after our intro...', "Horrible u eat macs eat until u forgot abt me already rite... U take so long 2 reply. I thk it's more toot than b4 so b prepared. Now wat shall i eat?", 'Jane babes not goin 2 wrk, feel ill after lst nite. Foned in already cover 4 me chuck.:-)', 'Well boy am I glad G wasted all night at applebees for nothing')
**************************************************
3 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam', 'ham', 'spam') ("Alright, I'll head out in a few minutes, text me where to meet you", 'Yes. Please leave at  <#> . So that at  <#>  we can leave', 'Purity of friendship between two is not about smiling after reading the forwarded message..Its about smiling just by seeing the name. Gud evng', 'What Today-sunday..sunday is holiday..so no work..', 'Hi its in durban are you still on this number', 'I will send them to your email. Do you mind  <#>  times per night?', 'Good Morning plz call me sir', 'YOU 07801543489 are guaranteed the latests Nokia Phone, a 40GB iPod MP3 player or a £500 prize! Txt word:COLLECT to No:83355! TC-LLC NY-USA 150p/Mt msgrcvd18+', 'I though we shd go out n have some fun so bar in town or something – sound ok?', 'FREE GAME. Get Rayman Golf 4 FREE from the O2 Games Arcade. 1st get UR games settings. Reply POST, then save & activ8. Press 0 key for Arcade. Termsapply')

其中参数含义:
1.dataset: 提前定义的dataset的实例

2.batch_size:传入数据的batch的大小,常用128,256等

3.shuffle: bool类型,表示是否在每次获取数据的时候提前打乱数据

4.num_workers :加载数据的线程数

注意:
1.1en(dataset)=数居集的样本数
2.len(dataloader) = math,ceil(样本数/batch_size) 即向上取整

手写数字集代码:

import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor

data_path = './data'

transform_fn = Compose([
    ToTensor(),
])

dataset = torchvision.datasets.MNIST(root=data_path, train=True, download=False, transform=transform_fn)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

for index, (img, target) in enumerate(dataloader):
    print(index, img, target)
    exit()

猜你喜欢

转载自blog.csdn.net/WakingStone/article/details/129651555