データの準備
序文
このブログから、AlignedReIDアルゴリズムを再現するためのコードと、再現プロセス中に発生したいくつかの問題を記録する必要があります。その後のモデルの実装の基礎として、他の人のを読んで理解することが非常に必要です。コード。そして、コードを実装する過程で、コードエンジニアリングを行うために最善を尽くします。
それでは、次に始めましょう!
1.Market1501データセットを理解する
歩行者の再認識の分野をある程度理解している場合は、このデータセットをすべて知っている必要があります。コードの実装では、このデータセットをトレーニングとテストに使用します。
1.1データセットのダウンロード
データセットのダウンロードアドレス:Market1501
(前のブログに記載されているWebサイトにアクセスしてダウンロードすることもできますが、Webサイトのリンクを開くことができないことがよくあります〜)
1.2ディレクトリの紹介
データセットをダウンロードした後、解凍後にフォルダの名前を変更できます。後でコードにパスを追加しやすくするために、ディレクトリ構造は次の図のようになります(新しく作成したデータにデータセットを保存しました) folder):
1) "Bounding_box_test" -750人のテストセットに使用されます。19,732枚の画像が含まれ、接頭辞は0000です。これは、DPMがこれらの750人を抽出するプロセス中に間違った画像を検出したことを意味します(クエリ)、-1は検出された他の人の
写真(750人ではない)を意味します2)「bounding_box_train」—トレーニングセットで使用された751人、12,936枚の画像を含む
3)「gt_bbox」—決定に使用される手書きのバウンディングボックスDPM検出の境界ボックスは良いボックス
ですか4) "gt_query" -matlab形式。クエリのどの画像が良い一致(異なるカメラを使用した同じ人物の画像)と悪い一致(同じ画像)であるかを判断するために使用されます同じカメラを持っている人)(または異なる人の画像)
5)「クエリ」-750人のクエリとして各カメラからランダムに画像を選択するため、1人のクエリは最大6つ、合計3,368枚の画像になります。
注:フォルダー1、2、および5は現在主に使用されており、3および4は使用されていません。
1.3命名規則
例として0001_c1s1_000151_01.jpgを取り上げます。1
)0001は0001から1501までの各人のタグ番号を表します
。2)c1は最初のカメラ(camera1)を
表し、合計6台のカメラがあります。3)s1は最初のビデオセグメントを表します( sequece1)、各カメラにはいくつかのビデオセグメントがあります;
4)000151はc1s1の000151番目の画像を表します;ビデオフレームレートは25fpsです;
5)01は各歩行者のDPM検出器によるc1s1_001051のこのフレームの最初の検出フレームを表しますフレーム上で複数のbboxをフレーム化する場合があります。00は手動ラベリングボックスを意味します
2.データセットのロード
次に、最初に新しいプロジェクトファイルディレクトリを作成します。ここでは、pycharm + python3環境を使用しています。プロジェクトディレクトリファイルは次のとおり
です。data_processフォルダの下に新しいファイルdata_manager.pyを作成して、データセットをロードします。
#-*-coding:utf-8-*-
# 此文件用于加载数据集Market1501
"""
主要步骤:
1.拼接文件夹路径
2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
3.统计行人、图片总数
"""
from __future__ import print_function, absolute_import
import os.path as osp
import glob
import re
from IPython import embed
class Market1501(object):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
# 数据集Market1501目录
dataset_dir = "market"
# 通过创建类对象完成对数据集的加载,因此把读取操作都放入 init方法
# 默认传入参数root 为数据集所在根目录
# 默认传入参数min_seq_len 为最小序列长度 默认值为0
# **kwargs可能会有其他参数
def __init__(self, root='/home/dmb/Desktop/materials/data', min_seq_len=0,**kwargs):
# 1.加载几个文件夹目录 拼接路径
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
# 检查是否成功加载
self._check_before_run()
# 调用目录处理方法
# 2.获取图片路径信息、行人ID(pid)、摄像头ID(camid)
# train: ('/home/dmb/Desktop/materials/data/market/bounding_box_train/0796_c3s2_089653_01.jpg', 420, 2)
train,num_train_pids,num_train_imgs = self._process_dir(self.train_dir,relabel=True)
query,num_query_pids,num_query_imgs = self._process_dir(self.query_dir,relabel=False)
gallery,num_gallery_pids,num_gallery_imgs = self._process_dir(self.gallery_dir,relabel=False)
# 3.统计行人、图片总数
num_total_pids = num_train_pids + num_query_pids
num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
# embed()
# 打印信息
print("=> Market1501 loaded")
print("Dataset statistics:")
print(" ------------------------------")
print(" subset | # ids | # images")
print(" ------------------------------")
print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
print(" ------------------------------")
print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
print(" ------------------------------")
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids = num_train_pids
self.num_query_pids = num_query_pids
self.num_gallery_pids = num_gallery_pids
def _check_before_run(self):
# 定义检验加载是否成功方法
if not osp.exists(self.dataset_dir):
raise RuntimeError("{} is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("{} is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("{} is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("{} is not available".format(self.gallery_dir))
def _process_dir(self,dir_path,relabel=False):
# 此函数返回一个符合glob匹配的pathname的list,返回结果有可能是空
# 2.1匹配该路径下所有以.jpg结尾的文件,放入list
img_paths = glob.glob(osp.join(dir_path,'*.jpg'))
# 正则表达式设置匹配规则 只提取行人id以及摄像头id
pattern = re.compile(r'([-\d]+)_c(\d)')
# 2.2实现relabel
# 原因是由于训练集只有751个行人,但标注是到1501,直接使用1501会使模型产生750个无效神经元
# set集合存放的行人ID 后面会用的到
# 使用set集合可以去重
pid_container = set()
# 遍历list集合中的图片名
for img_path in img_paths:
# 只关心每张图片的pid,其他值设置为缺省值
# map() 会根据提供的函数对指定序列做映射。
# 第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表
pid,_ = map(int,pattern.search(img_path).groups())
# 跳过所有pid为-1的项
if pid == -1:continue
# 添加pid到列表
pid_container.add(pid)
pid2label = {
pid:label for label,pid in enumerate(pid_container)}
# embed()
dataset = []
for img_path in img_paths:
pid,camid = map(int,pattern.search(img_path).groups())
if pid == -1:continue
assert 0 <= pid <=1501
assert 1 <= camid <= 6
camid -= 1
# 这里有个判断 只有relabel = True 我才relabel
if relabel :pid = pid2label[pid]
dataset.append((img_path,pid,camid))
num_pids = len(pid_container)
num_imgs = len(dataset)
# 返回值为dataset,图片id数量,图片数量
return dataset,num_pids,num_imgs
"""Create dataset"""
__img_factory = {
'market1501': Market1501,
# 'cuhk03': CUHK03,
# 'dukemtmcreid': DukeMTMCreID,
# 'msmt17': MSMT17,
}
# __vid_factory = {
# 'mars': Mars,
# 'ilidsvid': iLIDSVID,
# 'prid': PRID,
# 'dukemtmcvidreid': DukeMTMCVidReID,
# }
def get_names():
return __img_factory.keys()
def init_img_dataset(name, **kwargs):
if name not in __img_factory.keys():
raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __img_factory.keys()))
return __img_factory[name](**kwargs)
# 验证
if __name__ == "__main__":
init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
結果:
3.data_loaderライブラリをリファクタリングします
data_loaderは、主にデータのスループットを担当するpytorchのより重要なライブラリです。必要な方法でデータをスループットにする必要があるため、ソースコードに基づいてある程度の変更を加える必要があります。
#-*-coding:utf-8-*-
from __future__ import print_function, absolute_import
from PIL import Image
import numpy as np
import os.path as osp
import torch
from torch.utils.data import Dataset
from IPython import embed
from AlignedReId.data_process import data_manager
# 设置图片读取方法
def read_image(image_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
# 标志位表示是否读取到图片
got_image = False
if not osp.exists(image_path):
raise IOError("{} is not exists".format(image_path))
# 没读到图片就一直读
while not got_image:
try:
# 把读到的图片转化为RGB格式
img = Image.open(image_path).convert('RGB')
got_image = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(image_path))
pass
return img
# 重写dataset类
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self,dataset,transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self,index):
# 读取dataset的一行信息
img_path,pid,camid =self.dataset[index]
# 使用read_image读取图片
img = read_image(img_path)
# 判断是否进行数据增广
if self.transform is not None:
img = self.transform(img)
return img, pid, camid
# 验证
# if __name__ == "__main__":
# dataset =data_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
# train_loader = ImageDataset(dataset.train)
# for batch_id,(img,pid,camid) in enumerate(train_loader):
# break
# print(batch_id,img,pid,camid)
戻り結果:
train_loaderはデータセットのインスタンスであり、イテレーターを介して取得でき、取得される値はbatch_id、picture、pedestrian id、cameraidであることがわかります。
(ただし、ここで注意すべきことの1つは、画像自体ではなく、画像をテンソルに変換することです。これについては後で説明します)
データサンプリング
utilsに新しいsample.pyファイルを作成し、トレーニングセットのサンプリングを担当します。エポックごとに、トレーニングセット内の各歩行者の4枚の写真、つまり751 * 4 = 3004枚の写真を収集します。コードは次のとおりです。
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
Args:
data_source (Dataset): dataset to sample from.
num_instances (int): number of instances per identity.
"""
def __init__(self, data_source, num_instances=4):
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
def __iter__(self):
indices = torch.randperm(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
return iter(ret)
def __len__(self):
return self.num_identities * self.num_instances
if __name__ == "__main__":
dataset =dataset_manager.init_img_dataset(root='/home/dmb/Desktop/materials/data',name="market1501")
train_loader = ImageDataset(dataset.train)
sample = RandomIdentitySampler(train_loader)
for ret in enumerate(sample):
print(ret)
retを印刷した結果は次のとおりです。
括弧内の最初の項目はn番目の画像で、2番目の項目は画像に対応するラベルです。
5.データの前処理
この場所にもデータ拡張に関する部分があるはずですが、pytorchには豊富なデータ拡張方法が用意されているので、ここでは自分で記述せず、提供されているデータ拡張方法をすぐに使用します。
後で独自のデータ拡張方法を設計する必要がある場合は、それを再度記録してください。