SG Former in practice: training a custom classification data set

Yogoji site:GitHub - OliverRensu/SG-Former

论文地址:https://arxiv.org/pdf/2308.12216.pdf

The data set in the source code is Imagenet-1K. The data volume is relatively large. The complete data set capacity exceeds 100G. It takes a lot of time just to download this data set. In order to play with SG Former, it can be replaced with customized smaller data. Collect and see the effect

First, according to the production method of most classification data sets, the folder name is used as the category name, and the folder structure is consistent with imagenet-1k.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Custom data set structure:

│datasets/
├──train/
│  ├── sunflowers
│  │   ├── 1.JPEG
│  │   ├── 2.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── sunflowers
│  │   ├── 1.JPEG
│  │   ├── 2.JPEG
│  │   ├── ......
│  ├── ......

Directly modify the McDataset() of labeled_memcached_dataset.py in the source code.

import os
import random
from glob import glob

from PIL import Image
from torch.utils.data import Dataset


def load_img(filepath):
    img = Image.open(filepath).convert('RGB')
    return img


class McDataset(Dataset):
    def __init__(self, data_root, file_list, phase='train', transform=None):
        self.transform = transform
        self.root = os.path.join(data_root, phase)
        class_name = os.listdir(self.root)
        self.labels = {}

        for i in range(len(class_name)):
            self.labels[class_name[i]] = i
        externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
        imgfiles = list()
        for clas_name in class_name:
            for extern in externs:
                imgfiles.extend(glob(self.root + "\\" + clas_name + "\\*." + extern))

        self.A_paths = []
        self.A_labels = []

        for path in imgfiles:
            label = self.labels[path.replace("\\", '/').split('/')[-2]]
            self.A_paths.append(path)
            self.A_labels.append(label)

        self.num = len(self.A_paths)
        self.A_size = len(self.A_paths)

    def __len__(self):
        return self.num

    def __getitem__(self, index):
        try:
            return self.load_img(index)
        except:
            return self.__getitem__(random.randint(0, self.__len__() - 1))

    def load_img(self, index):
        A_path = self.A_paths[index % self.A_size]
        A = load_img(A_path)
        if self.transform is not None:
            A = self.transform(A)
        A_label = self.A_labels[index % self.A_size]
        return A, A_label

At the same time, set the path of the dataset in --data in main.py./datasets

Guess you like

Origin blog.csdn.net/athrunsunny/article/details/133632982