Pytorch自己定义Dataloader加载高光谱数据集

Pytorch自己定义Dataloader加载高光谱数据集


为了方便进行对比,这里简单说一下 图像分类中的数据集加载

ImageFolder是pytorch框架已经编好的加载数据集方法,可以直接拿来用。
但是如果我们的数据集为深度图、高光谱遥感图....这类非传统RGB图,我们就需要定义自己的加载数据集方法。
#准备好训练集
train_dateset = ImageFolder(image_path + '/train', transform=data_transform["train"])  

train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=0)  
)

1、 加载高光谱数据集

注意这里的dataset是模块,dataset.Dataset才是类,加载数据集需要继承Dataset这个类

首先的生成数据集只会调用 __init____len__方法。

一般就是__init__方法中传入图像和GT,以及预处理方法(将数据转换成tensor格式等等...)
import scipy.io as sio
import torch
from torch.utils.data import dataset
class my_dataset(dataset.Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        super(my_dataset, self).__init__()
        #预处理
        self.transform = transform
        self.target_transform = target_transform
        #加载Samson数据集的mat文件(字典文件)
        hsi_data = sio.loadmat(root)
        #原始HSI图像[156, 9025]
        training_data = hsi_data['Y']
        #丰度标签[3, 9025]
        labels = hsi_data['A']

        self.train_data = torch.reshape(torch.from_numpy(training_data), (156, 95, 95))
        self.labels = torch.reshape(torch.from_numpy(labels), (3, 95, 95))

    def __getitem__(self, index):
        img, target = self.train_data[index], self.labels[index]

        if self.transform is not None:
            img = torch.tensor(img)
        if self.target_transform is not None:
            target = torch.tensor(target)
        return img, target

    def __len__(self):
        return len(self.train_data)
        # return 1

但是__len__方法有一些讲究。比如我这里之间将原始mat文件里的图像和GT读出来,全部reshape成[channel, width, height]。这个时候返回的train_dataset 的数量取决于__len__方法。
这里以Samson数据集为例,如果想要将整个[156, 95, 95]的图片当成一个图片进行后续的训练,那么需要在__len__方法直接return1。
在这里插入图片描述
但是如果我们需要比如逐像素进行输入,就不能够直接return1。每个像素可以看作是一个1-D的向量,把每一个像素当成一个输入数据的话,是不是就是原始图像中[156, 9025]进行转置–>[9025, 156],逐个去取9025的每一个?

self.train_data = training_data.T  #[9025, 156]
self.labels = labels.T  #[9025, 3]

__len__return len(self.train_data)
在这里插入图片描述
可以看到train_dataset是一个9025个像素组成的数据集。
label是9025个(3,), img是9025个(156,),相当于取每一个像素作为一个训练数据。后面根据batch_size的大小划分这9025个数据就行。
在这里插入图片描述
在这里插入图片描述

// 接着就是my_dataset类实例化对象,train_dataset 就是我们自己高光谱图像的数据集了。
train_dataset = my_dataset(root="samson_dataset.mat")

2、 送进Dataloder加载

DataLoader按照 batch_size=20将train_dataset 划分。比如这里如果逐像素进行输入,一共9025个1-D向量数据,每一个batch训练20个数据,一轮epoch要训练完,需要迭代9025/20=452次。

source_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=20)

开始迭代训练

for i, (x, y) in enumerate(source_dataloader):
    output = net(x).to(device)
    n += 1
print(n)

这里按照batch_size开始迭代,就会访问上面创建数据集中的__getitem__方法了。index其实就是逐个索引访问每个像素,每batch_size=20个为一次。
所以上面的代码中的n=9025/20=4025.

def __getitem__(self, index):
    img, target = self.train_data[index], self.labels[index]
     if self.transform is not None:
         img = torch.tensor(img)
     if self.target_transform is not None:
         target = torch.tensor(target)
     return img, target
// 高光谱,我目前涉及的是无监督的解混方法,这里batch_size就直接设为1了。
将数据集送进Dataloder,加载我们的数据集。
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=1, shuffle=False
)

循环遍历train_loader,这里因为输入为1个156x95x95的Samson数据集,idx只有一个0的索引,返回图像和GT。
在这里插入图片描述
i遍历train_loader,是一个包含两个元素的元组,i[0]就是原始输入HSI了,第一个维度是batch,i[1]就是GT啦。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_50557558/article/details/130337086