pytorch加载自己的数据集,通过读取train.txt、test.txt文件(附数据集txt生成完整代码,注释详细)

pytorch读取指定train.txt、test.txt文件加载自己的数据集

txt生成脚本完整代码如下:

读取自己的数据集,打乱并划分,生成train.txt、test.txt (每一行为图片的绝对路径+标签,完整代码,注释详细)
https://blog.csdn.net/weixin_44414948/article/details/110205546

train.txt、test.txt示例如下图所示:

先占坑,回去补图

pytorch数据集加载完整代码:

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

root = r"C:\Users\hq\Desktop\HoldingObject"

# 自定义图片图片读取方式,可以自行增加resize、数据增强等操作
def MyLoader(path):
    return Image.open(path).convert('RGB')
    
class MyDataset (Dataset):
    # 构造函数设置默认参数
    def __init__(self, txt, transform=None, target_transform=None, loader=MyLoader):
        with open(txt, 'r') as fh:
            imgs = []
            for line in fh:
                line = line.strip('\n')  # 移除字符串首尾的换行符
                line = line.rstrip()  # 删除末尾空
                words = line.split( )  # 以空格为分隔符 将字符串分成
                imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        #调用定义的loader方法
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


train_data = MyDataset(txt=root + '\\' + 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root + '\\' + 'test.txt', transform=transforms.ToTensor())

#train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

print('加载成功!')

注:支持pytorch的transforms操作

猜你喜欢

转载自blog.csdn.net/weixin_44414948/article/details/110206088