SG Former实战:训练自定义分类数据集

代码地址:GitHub - OliverRensu/SG-Former

论文地址:https://arxiv.org/pdf/2308.12216.pdf

源码中数据集为Imagenet-1K,数据量比较大,完整的数据集容量超过100G,光是下载这个数据集就得花不少时间,为了玩转SG Former可以换成自定义的较小的数据集看看效果

首先按照大多分类数据集的制作方式,文件夹名作为类别名,文件夹的结构和imagenet-1k的的一致

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

自定义数据集结构:

│datasets/
├──train/
│  ├── sunflowers
│  │   ├── 1.JPEG
│  │   ├── 2.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── sunflowers
│  │   ├── 1.JPEG
│  │   ├── 2.JPEG
│  │   ├── ......
│  ├── ......

直接将源码中labeled_memcached_dataset.py的McDataset()基础上进行修改

import os
import random
from glob import glob

from PIL import Image
from torch.utils.data import Dataset


def load_img(filepath):
    img = Image.open(filepath).convert('RGB')
    return img


class McDataset(Dataset):
    def __init__(self, data_root, file_list, phase='train', transform=None):
        self.transform = transform
        self.root = os.path.join(data_root, phase)
        class_name = os.listdir(self.root)
        self.labels = {}

        for i in range(len(class_name)):
            self.labels[class_name[i]] = i
        externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
        imgfiles = list()
        for clas_name in class_name:
            for extern in externs:
                imgfiles.extend(glob(self.root + "\\" + clas_name + "\\*." + extern))

        self.A_paths = []
        self.A_labels = []

        for path in imgfiles:
            label = self.labels[path.replace("\\", '/').split('/')[-2]]
            self.A_paths.append(path)
            self.A_labels.append(label)

        self.num = len(self.A_paths)
        self.A_size = len(self.A_paths)

    def __len__(self):
        return self.num

    def __getitem__(self, index):
        try:
            return self.load_img(index)
        except:
            return self.__getitem__(random.randint(0, self.__len__() - 1))

    def load_img(self, index):
        A_path = self.A_paths[index % self.A_size]
        A = load_img(A_path)
        if self.transform is not None:
            A = self.transform(A)
        A_label = self.A_labels[index % self.A_size]
        return A, A_label

同时在main.py中--data设置数据集的路径./datasets

猜你喜欢

转载自blog.csdn.net/athrunsunny/article/details/133632982
sg