代码地址:GitHub - OliverRensu/SG-Former
源码中数据集为Imagenet-1K,数据量比较大,完整的数据集容量超过100G,光是下载这个数据集就得花不少时间,为了玩转SG Former可以换成自定义的较小的数据集看看效果
首先按照大多分类数据集的制作方式,文件夹名作为类别名,文件夹的结构和imagenet-1k的的一致
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
自定义数据集结构:
│datasets/
├──train/
│ ├── sunflowers
│ │ ├── 1.JPEG
│ │ ├── 2.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── sunflowers
│ │ ├── 1.JPEG
│ │ ├── 2.JPEG
│ │ ├── ......
│ ├── ......
直接将源码中labeled_memcached_dataset.py的McDataset()基础上进行修改
import os
import random
from glob import glob
from PIL import Image
from torch.utils.data import Dataset
def load_img(filepath):
img = Image.open(filepath).convert('RGB')
return img
class McDataset(Dataset):
def __init__(self, data_root, file_list, phase='train', transform=None):
self.transform = transform
self.root = os.path.join(data_root, phase)
class_name = os.listdir(self.root)
self.labels = {}
for i in range(len(class_name)):
self.labels[class_name[i]] = i
externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
imgfiles = list()
for clas_name in class_name:
for extern in externs:
imgfiles.extend(glob(self.root + "\\" + clas_name + "\\*." + extern))
self.A_paths = []
self.A_labels = []
for path in imgfiles:
label = self.labels[path.replace("\\", '/').split('/')[-2]]
self.A_paths.append(path)
self.A_labels.append(label)
self.num = len(self.A_paths)
self.A_size = len(self.A_paths)
def __len__(self):
return self.num
def __getitem__(self, index):
try:
return self.load_img(index)
except:
return self.__getitem__(random.randint(0, self.__len__() - 1))
def load_img(self, index):
A_path = self.A_paths[index % self.A_size]
A = load_img(A_path)
if self.transform is not None:
A = self.transform(A)
A_label = self.A_labels[index % self.A_size]
return A, A_label
同时在main.py中--data设置数据集的路径./datasets