Few Shot Classification tips - loading of data sets

Overview

Few-shot classification is a subfield of machine learning and artificial intelligence that solves the problem of learning to classify new samples when training data is very limited. In traditional supervised learning, the model needs to be trained on a data set containing a large number of labeled samples, with rich labeled samples for each category. However, in practical applications, obtaining such large amounts of labeled data may be difficult or expensive.
At present, there are very few introductory few shots on the Internet. Bloggers were not very clear about episodes before. After reading some information and codes, they gradually understood how small samples are trained. In this regard, the blogger first summarized the data set loading part, hoping to provide some inspiration to readers.

step

1. Modify the file structure

- data_name
--- images
----- folder_name1
------- img1.png
------- img2.png
----- folder_name2
--- meta
----- classes.txt  
----- fsl_train.txt
----- fsl_test.txt
----- fsl_train_class.txt
----- fsl_test_class.txt

Among them, folder_name1 and folder_name2 are the names of the folders, usually category names, and some may also be subscript numbers (numbers from 1 to 100)

2. Find the label file classes.txt of the image

classes.txt contains all the categories of images. If you don’t need to build one yourself, the content of the label file is roughly as follows:

class_name1
class_name2
class_name3

3. Use code to generate files

The generated files include: fsl_train.txt, fsl_test.txt, fsl_train_class.txt, fsl_test_class.txt files.
The code currently supports the following situations:

  • (1) folder_name is the category name
  • (2) folder_name is the subscript corresponding to the category name, starting from 1
  • (3) The picture names under the folder_name folder are all numbers and have no other symbols.
  • (4) The picture names under the folder_name folder have any symbols.

The approximate content of the file is: fsl_train.txt:
Insert image description here
fsl_train_class.txt
Insert image description here

The code is:

def make_file(img_root_path, names, path, is_num):
    """
    :param img_root_path: 图像文件夹
    :param names: 对应的图像文件名称
    :param path: 要保存的路径
    :param is_num: 图像文件名称是否是数字
    """
    with open(path,"w") as f:
        for name in names:
            img_dir = os.path.join(img_root_path,str(name))
            img_names = os.listdir(img_dir)
            if is_num:
                sort_img_names = sorted(img_names,key=lambda s: int(s.split('.')[0]))
            else:
                sort_img_names = sorted(img_names)
            for img_name in sort_img_names:
                img_path = os.path.join(img_dir,img_name).replace(img_root_path + "/","")
                f.write(f"{
      
      img_path}\n")
            
def generate_split_dataset(data_root, train_num, is_imgs_id, is_img_name_num):
    """
    :param data_root: 数据集目录
    :param train_num: 用于训练的类别数目
    :param is_imgs_id: 图像文件夹名称是否是下标
    :param is_img_name_num: 图像名字是否是数字 
    :return: None
    """
    class_path = os.path.join(data_root,"meta", "classes.txt")
    class_list = list_from_file(class_path)
    if is_imgs_id:
    	# 下标从1开始,可以根据自己的需要修改
        id2class = {
    
    i + 1 : _class for i, _class in enumerate(class_list)}
    else:
        id2class = {
    
    i: _class for i, _class in enumerate(class_list)}
    # class2id = {_class : i + 1 for i, _class in enumerate(class_list)}
    # 选择train_num个类作为训练集的,其他作为测试的
    train_class_ids = random.sample(range(1, len(class_list) + 1),train_num)
    test_class_ids = []
    for id in range(1, len(class_list) + 1):
        if id not in train_class_ids:
            test_class_ids.append(id)
    # 获得images文件夹的名称
    if is_imgs_id:
        train_class_name = train_class_ids
        test_class_name = test_class_ids
    else:
        train_class_name = [id2class[id] for id in train_class_ids]
        test_class_name = [id2class[id] for id in test_class_ids]
    # 顺序排序
    train_class_name = sorted(train_class_name)
    test_class_name = sorted(test_class_name)
    train_class_save_path = os.path.join(data_root, "meta", "fsl_train_class.txt")
    test_class_save_path = os.path.join(data_root, "meta" , "fsl_test_class.txt")
    with open(train_class_save_path, "w") as f:
        for cls_name in train_class_name:
            f.write(f"{
      
      str(cls_name)}\n")

    with open(test_class_save_path, "w") as f:
        for cls_name in test_class_name:
            f.write(f"{
      
      str(cls_name)}\n")

    # 将这些数据保存在fsl_train.txt中,格式为:class_name/img_name
    img_root_path = os.path.join(data_root,"images")
    train_imgs_name_path = os.path.join(data_root, "meta", "fsl_train.txt")
    test_imgs_name_path = os.path.join(data_root, "meta", "fsl_test.txt")
    make_file(img_root_path, train_class_name, train_imgs_name_path,is_img_name_num)
    make_file(img_root_path, test_class_name,test_imgs_name_path, is_img_name_num)

4. Build basedataset class

The basedataset class is a file used to load files containing category names. The code is:

import copy
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Mapping, Optional, Sequence, Union
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os.path as osp
from PIL import Image
import torch

from util import tools

from mmpretrain.evaluation import Accuracy

class BaseFewShotDataset(Dataset, metaclass=ABCMeta):
    def __init__(self,
                 pipeline,
                 data_prefix: str,
                 classes: Optional[Union[str, List[str]]] = None,
                 ann_file: Optional[str] = None) -> None:
        super().__init__()

        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.pipeline = pipeline
        self.CLASSES = self.get_classes(classes)
        self.data_infos = self.load_annotations()
        self.data_infos_class_dict = {
    
    i: [] for i in range(len(self.CLASSES))}
        for idx, data_info in enumerate(self.data_infos):
            self.data_infos_class_dict[data_info['gt_label'].item()].append(
                idx)

    def load_image_from_file(self,info_dict):
        img_prefix = info_dict['img_prefix']
        img_name = info_dict['img_info']['filename']
        img_file = osp.join(img_prefix,f"{
      
      img_name}")
        img_data = Image.open(img_file).convert('RGB')
        return img_data

    @abstractmethod
    def load_annotations(self):
        pass

    @property
    def class_to_idx(self) -> Mapping:
        return {
    
    _class: i for i, _class in enumerate(self.CLASSES)}

    def prepare_data(self, idx: int) -> Dict:
        results = copy.deepcopy(self.data_infos[idx])
        imgs_data = self.load_image_from_file(results)
        data = {
    
    
            "img" : self.pipeline(imgs_data),
            "gt_label" : torch.tensor(self.data_infos[idx]['gt_label'])
        }
        return data

    def sample_shots_by_class_id(self, class_id: int,
                                 num_shots: int) -> List[int]:
        all_shot_ids = self.data_infos_class_dict[class_id]
        return np.random.choice(
            all_shot_ids, num_shots, replace=False).tolist()

    def __len__(self) -> int:
        return len(self.data_infos)

    def __getitem__(self, idx: int) -> Dict:
        return self.prepare_data(idx)

    @classmethod
    def get_classes(cls,
                    classes: Union[Sequence[str],
                                   str] = None) -> Sequence[str]:
        if isinstance(classes, str):
            class_names = tools.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {
      
      type(classes)} of classes.')

        return class_names

5. Build a universal few-sample data set loading class UniversalFewShotDataset

The main function of this file is to take the data out of the label file and load the data.
code show as below:

from datasets.base import BaseFewShotDataset
from typing_extensions import Literal
from typing import List, Optional, Sequence, Union
from util import tools
import os
import os.path as osp
import numpy as np
import torchvision.transforms as transforms
class UniversalFewShotDataset(BaseFewShotDataset):
    def __init__(self,
                 img_size,
                 subset: Literal['train', 'test', 'val'] = 'train',
                 *args,
                 **kwargs):
        if isinstance(subset, str):
            subset = [subset]
        for subset_ in subset:
            assert subset_ in ['train', 'test', 'val']
        self.subset = subset
        self.file_format = file_format
        # 归一化参数
        norm_params = {
    
    'mean': [0.485, 0.456, 0.406],
                       'std': [0.229, 0.224, 0.225]}
        # 对数据进行处理
        if subset[0] == 'train':
            pipeline = transforms.Compose([
                transforms.RandomResizedCrop(img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
        ])
        elif subset[0] == 'test':
            pipeline = transforms.Compose([
                transforms.Resize(size=int(img_size * 1.15)),
                transforms.CenterCrop(size=img_size),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
            ])
        super().__init__(pipeline=pipeline, *args, **kwargs)

    def get_classes(
            self,
            classes: Optional[Union[Sequence[str], str]] = None) -> Sequence[str]:
        class_names = tools.list_from_file(classes)
        return class_names
    
	# 加载标签文件
    def load_annotations(self) -> List:
        data_infos = []
        ann_file = self.ann_file
        with open(ann_file) as f:
            for i, line in enumerate(f):
                class_name, filename = line.strip().split('/')
                gt_label = self.class_to_idx[class_name]
                info = {
    
    
                    'img_prefix':
                    osp.join(self.data_prefix, 'images', class_name),
                    'img_info': {
    
    
                        'filename': filename
                    },
                    'gt_label':
                    np.array(gt_label, dtype=np.int64)
                }
                data_infos.append(info)
        return data_infos

6. Construct a data set loading class EpisodicDataset for meta-learning

code show as below:

import numpy as np
from torch import Tensor
from torch.utils.data import Dataset,DataLoader
from functools import partial
import os.path as osp
from typing import Mapping
from util import tools
import json
class EpisodicDataset:
    def __init__(self,
                 dataset: Dataset,
                 num_episodes: int,
                 num_ways: int,
                 num_shots: int,
                 num_queries: int,
                 episodes_seed: int):
        self.dataset = dataset
        self.num_ways = num_ways
        self.num_shots = num_shots
        self.num_queries = num_queries
        self.num_episodes = num_episodes
        self._len = len(self.dataset)
        self.CLASSES = dataset.CLASSES
        self.episodes_seed = episodes_seed
        self.episode_idxes, self.episode_class_ids = \
            self.generate_episodic_idxes()

    def generate_episodic_idxes(self):
        """Generate batch indices for each episodic."""
        episode_idxes, episode_class_ids = [], []
        class_ids = [i for i in range(len(self.CLASSES))]
        # 这一句可以不用
        with tools.local_numpy_seed(self.episodes_seed):
            for _ in range(self.num_episodes):
                np.random.shuffle(class_ids)
                # sample classes
                sampled_cls = class_ids[:self.num_ways]
                episode_class_ids.append(sampled_cls)
                episodic_support_idx = []
                episodic_query_idx = []
                # sample instances of each class
                for i in range(self.num_ways):
                    shots = self.dataset.sample_shots_by_class_id(
                        sampled_cls[i], self.num_shots + self.num_queries)
                    episodic_support_idx += shots[:self.num_shots]
                    episodic_query_idx += shots[self.num_shots:]
                episode_idxes.append({
    
    
                    'support': episodic_support_idx,
                    'query': episodic_query_idx
                })
        return episode_idxes, episode_class_ids

    def __getitem__(self, idx: int):
        support_data = [self.dataset[i] for i in self.episode_idxes[idx]['support']]
        query_data = [self.dataset[i] for i in self.episode_idxes[idx]['query']]
        return {
    
    
            'support_data':support_data,
            'query_data':query_data
        }

    def __len__(self):
        return self.num_episodes

    def evaluate(self, *args, **kwargs):
        return self.dataset.evaluate(*args, **kwargs)

    def get_episode_class_ids(self, idx: int):
        return self.episode_class_ids[idx]

7. Build your own configuration file, such as: json format

In addition to json, the configuration file can also be in other forms. Here we take the json format as an example:

{
    
    
    "train":{
    
    
        "num_episodes":2000,
        "num_ways":10,
        "num_shots":5,
        "num_queries":5,
        "episodes_seed":1001,
        "per_gpu_batch_size":1,
        "per_gpu_workers": 8,
        "epoches": 160,
        "dataset":{
    
    
            "name": "vireo_172",
            "img_size": 224,
            "data_prefix":"/home/gaoxingyu/dataset/vireo-172/",
            "classes":"/home/gaoxingyu/dataset/vireo-172/meta/fsl_train_class.txt",
            "ann": "/home/gaoxingyu/dataset/vireo-172/meta/fsl_train.txt"
        }
    }
}

8. Write the main program and test it

code show as below:

with open("config.json", 'r', encoding='utf-8') as f:
     f = f.read()
     configs = json.loads(f)
     logger.info(f"Experiment Setting:{
      
      configs}")
# 创建数据集
## train_dataset
train_food_dataset = UniversalFewShotDataset(data_prefix=configs['train']['dataset']['data_prefix'],
                         subset="train", classes=configs['train']['dataset']['classes'],
                         img_size=configs['train']['dataset']['img_size'],ann_file=configs['train']['dataset']['ann'])
train_dataset = EpisodicDataset(dataset=train_food_dataset,
                                num_episodes=configs['train']['num_episodes'],
                                num_ways=configs['train']['num_ways'],
                                num_shots=configs['train']['num_shots'],
                                num_queries=configs['train']['num_queries'],
                                episodes_seed=configs['train']['episodes_seed'])
## train dataloader
train_samper = torch.utils.data.distributed.DistributedSampler(train_dataset, rank = local_rank, shuffle=True)
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=configs['train']['per_gpu_batch_size'],
    sampler=train_samper,
    num_workers=configs['train']['per_gpu_workers'],
    collate_fn=partial(collate, samples_per_gpu=1),
    worker_init_fn=worker_init_fn,
    drop_last=True
)
for data in train_data_loader:
	print(data)

Guess you like

Origin blog.csdn.net/qq_41234663/article/details/132006130