AlignedReID_01 desde cero

preparación de datos

Prefacio

A partir de este blog, es necesario registrar el código para la reproducción del algoritmo AlignedReID, así como algunos problemas encontrados durante el proceso de reproducción. Como base para la posterior implementación del modelo, es muy necesario leer y comprender a otras personas. código. Y, en el proceso de implementación del código, haremos todo lo posible para hacer la ingeniería del código.

Así que a continuación, ¡comencemos!

1. Conozca el conjunto de datos Market1501

Si tiene una cierta comprensión del campo del re-reconocimiento de peatones, todos deberían conocer este conjunto de datos, y la implementación del código utilizará este conjunto de datos para entrenamiento y pruebas.

1.1 Descarga del conjunto de datos

Dirección de descarga del conjunto de datos: Market1501
(También puede ir al sitio web mencionado en el blog anterior para descargar, pero el enlace del sitio web parece no poder abrirse ~)

1.2 Introducción al directorio

Una vez descargado el conjunto de datos, se puede cambiar el nombre de la carpeta después de la descompresión.Para facilitar la adición de la ruta en el código más adelante, la estructura del directorio es como se muestra en la figura a continuación (guardé el conjunto de datos en el archivo de datos recién creado carpeta):
Inserte la descripción de la imagen aquí
1) "Bounding_box_test": utilizado para el conjunto de prueba de 750 personas, que contiene 19,732 imágenes, con el prefijo 0000, lo que significa que DPM detectó la imagen incorrecta durante el proceso de extracción de estas 750 personas (puede ser la misma persona que el consulta), -1 significa detectado
Imágenes de otras personas (no entre las 750 personas) 2) "bounding_box_train": 751 personas utilizadas en el conjunto de entrenamiento, que contiene 12,936 imágenes
3) "gt_bbox": cuadro delimitador etiquetado a mano, utilizado para determinar el límite de la detección de DPM ¿Es el cuadro un buen cuadro
4) "gt_query": formato matlab, utilizado para determinar qué imágenes de una consulta son buenas coincidencias (imágenes de la misma persona con diferentes cámaras) y malas coincidencias (imágenes de la misma persona con la misma cámara) (O imágenes de diferentes personas)
5) "consulta": seleccione aleatoriamente una imagen de cada cámara como una consulta para 750 personas, por lo que hay un máximo de 6 consultas para una persona y un total de 3368 imágenes

Nota: Las carpetas 1, 2 y 5 se utilizan principalmente ahora, la 3 y 4 están abandonadas.

1.3 Reglas de nomenclatura

Tome 0001_c1s1_000151_01.jpg como ejemplo
1) 0001 representa el número de etiqueta de cada persona, de 0001 a 1501;
2) c1 representa la primera cámara (cámara1), hay 6 cámaras en total;
3) s1 representa el primer segmento de video ( secuencia1), cada cámara tiene varios segmentos de video;
4) 000151 representa la imagen 000151 de c1s1, la velocidad de fotogramas de video es de
25 fps; 5) 01 representa el primer fotograma de detección en este fotograma de c1s1_001051, debido al detector DPM, para cada peatón en un marco puede enmarcar varios bboxes. 00 significa caja de etiquetado manual

2. Carga del conjunto de datos

Luego, primero cree un nuevo directorio de archivos de proyecto, aquí estoy usando el entorno pycharm + python3. El archivo del directorio del proyecto es el siguiente:
Inserte la descripción de la imagen aquí
Cree un nuevo archivo data_manager.py en la carpeta data_process para cargar el conjunto de datos:

#-*-coding:utf-8-*-
# 此文件用于加载数据集Market1501
"""
主要步骤:
1.拼接文件夹路径
2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
3.统计行人、图片总数

"""
from __future__ import print_function, absolute_import
import os.path as osp
import glob
import re

from IPython import embed


class Market1501(object):
    """
    Market1501

    Reference:
    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.

    URL: http://www.liangzheng.org/Project/project_reid.html

    Dataset statistics:
    # identities: 1501 (+1 for background)
    # images: 12936 (train) + 3368 (query) + 15913 (gallery)
    """
    # 数据集Market1501目录
    dataset_dir = "market"

    # 通过创建类对象完成对数据集的加载,因此把读取操作都放入 init方法
    # 默认传入参数root 为数据集所在根目录
    # 默认传入参数min_seq_len 为最小序列长度 默认值为0
    # **kwargs可能会有其他参数
    def __init__(self, root='/home/dmb/Desktop/materials/data', min_seq_len=0,**kwargs):
        # 1.加载几个文件夹目录 拼接路径
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        # 检查是否成功加载
        self._check_before_run()

        # 调用目录处理方法
        # 2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
        # train: ('/home/dmb/Desktop/materials/data/market/bounding_box_train/0796_c3s2_089653_01.jpg', 420, 2)
        train,num_train_pids,num_train_imgs = self._process_dir(self.train_dir,relabel=True)
        query,num_query_pids,num_query_imgs = self._process_dir(self.query_dir,relabel=False)
        gallery,num_gallery_pids,num_gallery_imgs = self._process_dir(self.gallery_dir,relabel=False)
        # 3.统计行人、图片总数
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
        # embed()
        # 打印信息
        print("=> Market1501 loaded")
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # ids | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
        print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
        print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
        print("  ------------------------------")
        print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
        print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

    def _check_before_run(self):
        # 定义检验加载是否成功方法
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("{} is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("{} is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("{} is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("{} is not available".format(self.gallery_dir))


    def _process_dir(self,dir_path,relabel=False):

        # 此函数返回一个符合glob匹配的pathname的list,返回结果有可能是空
        # 2.1匹配该路径下所有以.jpg结尾的文件,放入list
        img_paths = glob.glob(osp.join(dir_path,'*.jpg'))
        # 正则表达式设置匹配规则 只提取行人id以及摄像头id
        pattern = re.compile(r'([-\d]+)_c(\d)')

        # 2.2实现relabel
        # 原因是由于训练集只有751个行人,但标注是到1501,直接使用1501会使模型产生750个无效神经元
        # set集合存放的行人ID 后面会用的到
        # 使用set集合可以去重
        pid_container = set()
        # 遍历list集合中的图片名
        for img_path in img_paths:
            # 只关心每张图片的pid,其他值设置为缺省值
            # map() 会根据提供的函数对指定序列做映射。
            # 第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表
            pid,_ = map(int,pattern.search(img_path).groups())

            # 跳过所有pid为-1的项
            if pid == -1:continue
            # 添加pid到列表
            pid_container.add(pid)
        pid2label = {
    
    pid:label for label,pid in enumerate(pid_container)}
        # embed()

        dataset = []
        for img_path in img_paths:
            pid,camid = map(int,pattern.search(img_path).groups())
            if pid == -1:continue
            assert 0 <= pid <=1501
            assert 1 <= camid <= 6
            camid -= 1
            # 这里有个判断 只有relabel = True 我才relabel
            if relabel :pid = pid2label[pid]
            dataset.append((img_path,pid,camid))

        num_pids = len(pid_container)
        num_imgs = len(dataset)
        # 返回值为dataset,图片id数量,图片数量
        return dataset,num_pids,num_imgs


"""Create dataset"""

__img_factory = {
    
    
    'market1501': Market1501,
    # 'cuhk03': CUHK03,
    # 'dukemtmcreid': DukeMTMCreID,
    # 'msmt17': MSMT17,
}

# __vid_factory = {
    
    
#     'mars': Mars,
#     'ilidsvid': iLIDSVID,
#     'prid': PRID,
#     'dukemtmcvidreid': DukeMTMCVidReID,
# }

def get_names():
    return __img_factory.keys()

def init_img_dataset(name, **kwargs):
    if name not in __img_factory.keys():
        raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __img_factory.keys()))
    return __img_factory[name](**kwargs)

# 验证
if __name__ == "__main__":
    init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")

resultado:
Inserte la descripción de la imagen aquí

3. Refactorice la biblioteca data_loader

data_loader es una biblioteca más importante de pytorch, que es principalmente responsable del rendimiento de los datos. Necesitamos que los datos se realicen de la manera que necesitamos, por lo que debemos realizar una cierta cantidad de modificaciones basadas en el código fuente.

#-*-coding:utf-8-*-
from __future__ import print_function, absolute_import
from PIL import Image
import numpy as np
import os.path as osp

import torch
from torch.utils.data import Dataset
from IPython import embed
from AlignedReId.data_process import data_manager

# 设置图片读取方法
def read_image(image_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    # 标志位表示是否读取到图片
    got_image = False
    if not osp.exists(image_path):
        raise IOError("{} is not exists".format(image_path))
    # 没读到图片就一直读
    while not got_image:
        try:
            # 把读到的图片转化为RGB格式
            img = Image.open(image_path).convert('RGB')
            got_image = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(image_path))
            pass
        return img


# 重写dataset类
class ImageDataset(Dataset):
    """Image Person ReID Dataset"""
    def __init__(self,dataset,transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,index):
        # 读取dataset的一行信息
        img_path,pid,camid =self.dataset[index]
        # 使用read_image读取图片
        img = read_image(img_path)
        # 判断是否进行数据增广
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, camid



# 验证
# if __name__ == "__main__":
#     dataset =data_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
#     train_loader = ImageDataset(dataset.train)
#     for batch_id,(img,pid,camid) in enumerate(train_loader):
#         break
#     print(batch_id,img,pid,camid)


Devolver resultado:
Inserte la descripción de la imagen aquíse puede ver que train_loader es una instancia de conjunto de datos, que se puede obtener a través de un iterador, y los valores obtenidos son batch_id, picture, pedestrian id e camera id.
(Pero una cosa a tener en cuenta aquí, lo que queremos no es la imagen en sí, sino convertir la imagen en un tensor, que se mencionará más adelante)

Muestreo de datos

Cree un nuevo archivo sample.py en utils, responsable de muestrear el conjunto de entrenamiento. Para cada época, recopilaremos 4 imágenes de cada peatón en el conjunto de entrenamiento, es decir, 751 * 4 = 3004 imágenes. El código es el siguiente:

from __future__ import absolute_import
from collections import defaultdict
import numpy as np

import torch
from torch.utils.data.sampler import Sampler

class RandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.

    Args:
        data_source (Dataset): dataset to sample from.
        num_instances (int): number of instances per identity.
    """
    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _) in enumerate(data_source):
            self.index_dic[pid].append(index)
        self.pids = list(self.index_dic.keys())
        self.num_identities = len(self.pids)

    def __iter__(self):
        indices = torch.randperm(self.num_identities)
        ret = []
        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]
            replace = False if len(t) >= self.num_instances else True
            t = np.random.choice(t, size=self.num_instances, replace=replace)
            ret.extend(t)
        return iter(ret)

    def __len__(self):
        return self.num_identities * self.num_instances

if __name__ == "__main__":
    dataset =dataset_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
    train_loader = ImageDataset(dataset.train)
    sample = RandomIdentitySampler(train_loader)
    for ret in enumerate(sample):
        print(ret)

El resultado de imprimir ret es el siguiente:
Inserte la descripción de la imagen aquíel primer elemento entre paréntesis es la enésima imagen y el segundo elemento es la etiqueta correspondiente a la imagen.

5. Procesamiento previo de datos

También debería haber una parte sobre la mejora de datos en este lugar, pero pytorch proporciona una gran cantidad de métodos de mejora de datos, por lo que no lo escribiré yo mismo aquí y usaré los métodos de mejora de datos proporcionados directamente más adelante.
Si necesita diseñar su propio método de mejora de datos más adelante, ¡regístrelo nuevamente!

Supongo que te gusta

Origin blog.csdn.net/qq_37747189/article/details/111867247
Recomendado
Clasificación