torch.utils.data のデータセットについて、DataLoader は辞書を返します。

テスト データセットは辞書を返すことができますか?
DataLoader がbatch_szie>1 を設定すると、返されるデータはどのようになりますか?
次に答えてください。

まずテストコードを見てください。

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class TestDataset(Dataset):
    def __init__(self):
        # 随意定义一个数据集类,并且继承Dataset

        # 定义一个list模拟数据集 
        self.lines = [1, 2, 3, 4, 5, 6, 7, 8, 9]
        # 获取数据长度,保证类别可以是一个迭代类型
        self.length = len(self.lines)

    def __len__(self):
        # 返回数据集的长度,可以通过len方法获取数据集的长度
        return self.length

    def __getitem__(self, index):
        # 可以根据index获取数据集中的一个元素
        index = index % self.length  # 保证index安全,如果超过数据集长度,不会导致代码崩溃
        line = self.lines[index]  # 获取数据集的元素

        # 返回一个字典(字典的内容是瞎写的没有任何意义,只不过是为了测试是否可以返回一个字典)
        return {
    
    "image_name": str(index + 1),
                "image": line,
                "shape": np.array([line ** 3 + 5, - line ** 2 + line, 3])}


class Model(nn.Module):
    def __init__(self):
        """
        随意定义一个模型,可以获取DataLoader的内容,如果直接打印DataLoader返回的元素值,是一个地址
        eg:
            dataset = TestDataset()
            dataloader = DataLoader(dataset, batch_size=4, )
            for data_dict in dataloader:
                print(data_dict)
        output:
            <torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
            <torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
            <torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
        所以需要一个模型来对数据进行解码,这个步骤是编写训练模型的一个标准流程
        """
        super(Model, self).__init__()
        pass

    def forward(self, data_dict):
        print(data_dict)


def dataset_collate(batch):
	# dataloader加载时先是一个batch的数据保存到一个列表中,之后拼接到一起处理成一个整体的批次,但是由于不同图片可能存在不同多个目标边界框,所以需要自己编写拼接规则。
	pass

if __name__ == '__main__':
    model = Model()
    dataset = TestDataset()
    dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn(可选参数))
    for data_dict in dataloader:
        model(data_dict)

出力結果:

{
    
    'image_name': ['1', '2', '3', '4'], 'image': tensor([1, 2, 3, 4]), 'shape': tensor([[  6,   0,   3],
        [ 13,  -2,   3],
        [ 32,  -6,   3],
        [ 69, -12,   3]], dtype=torch.int32)}
{
    
    'image_name': ['5', '6', '7', '8'], 'image': tensor([5, 6, 7, 8]), 'shape': tensor([[130, -20,   3],
        [221, -30,   3],
        [348, -42,   3],
        [517, -56,   3]], dtype=torch.int32)}
{
    
    'image_name': ['9'], 'image': tensor([9]), 'shape': tensor([[734, -72,   3]], dtype=torch.int32)}

出力結果から、戻り値が辞書の場合、DataLoader はキーに従ってバッチのデータをリストに格納し、すべてのデータ型が Tensor に変換されることがわかります。

おすすめ

転載: blog.csdn.net/weixin_50727642/article/details/128248900