Práctica de código ReID basada en el aprendizaje de representación (1)

data_manager.py: procesa automáticamente el conjunto de datos y devuelve algunos atributos comunes del conjunto de datos

from __future__ import print_function, absolute_import

import os
import os.path as osp
import re

import numpy as np
import glob
# from .utils import mkdir_if_missing, write_json, read_json
from IPython import embed


# embed()


class Market1501(object):
    """
       Market1501
       Reference:
       Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
       URL: http://www.liangzheng.org/Project/project_reid.html

       Dataset statistics:
       # identities: 1501 (+1 for background)
       # images: 12936 (train) + 3368 (query) + 15913 (gallery)
       """
    dataset_dir = 'market1501'

    def __init__(self, root='data', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self._check_before_run()

        '''
        想要拿出的东西:文件路径,标注信息(id,camid),图片数量
        '''
        train, num_train_pids,num_train_imgs = self._process_dir(self.train_dir,relabel=True)
        query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
        gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs

        print("=> Market1501 loaded")
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # ids | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
        print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
        print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
        print("  ------------------------------")
        print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
        print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids
    def _check_before_run(self):
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}'is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    # 处理文件夹方法
    def _process_dir(self, dir_path, relabel=False):
        # 把所有图像的路径弄进来
        # 把dir_path文件夹下面所有以jpg结尾的文件都拿出来
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')
        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue
            pid_container.add(pid)
        # 将pid转换为标签
        pid2label = {pid: lable for lable, pid in enumerate(pid_container)}
        # relabel
        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1: continue
            assert 0 <= pid <= 1501
            assert 1 <= camid <= 6

            camid -= 1  # index starts from 0
            if relabel:
                pid = pid2label[pid]
            dataset.append((img_path,pid,camid))
        num_pids = len(pid_container)
        num_imgs = len(img_paths)
        return dataset, num_pids, num_imgs

if __name__ == '__main__':
    data = Market1501(root='../../data')

Publicado 134 artículos originales · elogiado 38 · 90,000 visitas +

Supongo que te gusta

Origin blog.csdn.net/rytyy/article/details/105590743
Recomendado
Clasificación