ビデオベースの歩行者の再識別-02
1はじめに
このセクションでは、主にデータのインポートについて説明します。モデルのトレーニングにはデータサポートが必要なため、データを前処理して入力する必要があります。
データ量が比較的少ない場合は手動入力の形式を使用できますが、データ量が多い場合はこの方法は非効率的です。
シャッフルを使用したり、ミニバッチやその他の操作に分割したりする必要がある場合は、PyTorchのAPIを使用してこれらの操作をすばやく完了することができます(データローダー)。
DataLoaderは、データをパッケージ化するためにトーチによって提供されるツールです。独自の(numpy配列またはその他の)データフォームをTensorにロードしてから、このラッパーに配置する必要があります。
Datasetはパッケージ化クラスであり、データをDatasetクラスにパッケージ化するために使用され、DataLoaderに渡されます。次に、DataLoaderクラスを使用してデータをより迅速に操作します。
前のセクションでは、火星データセットのデータセットへのカプセル化を実装しました。次に、データセットメソッドを書き直して、データを必要な方法でデータローダーに渡す必要があります。
2データセットを書き換えます
この部分では、歩行者の再識別とVIdeoベースのReIDに違いがあります。
2.1パッケージのインポート
from __future__ import print_function, absolute_import
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import random
# import data_manager
# import torchvision.transforms as T
# from torch.utils.data import DataLoader
# from torch.autograd import Variable
2.2写真の読み方
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
2.3データセットの書き換え
Datasetクラスを統合した後、init、len、およびgetitemメソッドを書き直す必要があります。
- initは主にいくつかの必要なパラメータを取得することです
- lenメソッドは、データセットのサイズを提供します。
- getitemメソッド、このメソッドは0からlen(self)までのインデックスをサポートします
# 这个方法可以用于常见的基于视频重识别的数据集
class VideoDataset(Dataset):
"""Video Person ReID Dataset.
Note batch data has shape (batch, seq_len, channel, height, width).
"""
# 枚举读取方法
sample_methods = ['evenly', 'random', 'all']
# 重写init 在创建类对象时调用
def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
# dataset为上一节mars对象
self.dataset = dataset
# seq——len 默认为15 项目中一般为4
self.seq_len = seq_len
# 采样方式
self.sample = sample
# 数据增强方式
self.transform = transform
# 返回dataset的大小
def __len__(self):
return len(self.dataset)
# 从 0 到 len(self)的索引
def __getitem__(self, index):
#print(index, len(self.dataset))
img_paths, pid, camid = self.dataset[index]
num = len(img_paths)
# 训练集 输入
if self.sample == 'random':
"""
Randomly sample seq_len consecutive frames from num frames,
if num is smaller than seq_len, then replicate items.
This sampling strategy is used in training phase.
"""
# 从n帧里挑出连续的seq帧作为样本
frame_indices = list(range(num))
rand_end = max(0, len(frame_indices) - self.seq_len - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.seq_len, len(frame_indices))
indices = frame_indices[begin_index:end_index]
# 如果indices帧数不足seq,使用indices补全
for index in indices:
if len(indices) >= self.seq_len:
break
indices.append(index)
indices=np.array(indices)
# 这里准备数组 就是要把img拼接在一起
imgs = []
for index in indices:
index=int(index)
img_path = img_paths[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
# imgs = [s,c,h,w]
imgs = torch.cat(imgs, dim=0)
#imgs=imgs.permute(1,0,2,3)
return imgs, pid, camid
# 测试集输入
elif self.sample == 'dense':
"""
Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
This sampling strategy is used in test phase.
"""
cur_index=0
frame_indices = list(range(num))
indices_list=[]
# 训练和测试的不同就在于测试需要分析每一张图片
while num-cur_index > self.seq_len:
# 每次向list中添加seq长度的list
indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
cur_index+=self.seq_len
last_seq=frame_indices[cur_index:]
# 最后不足4个 补全
for index in last_seq:
if len(last_seq) >= self.seq_len:
break
last_seq.append(index)
# imdices——list = [(0,4),(4,8),...]
indices_list.append(last_seq)
imgs_list=[]
for indices in indices_list:
imgs = []
for index in indices:
index=int(index)
img_path = img_paths[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
# imgs =[s,c,h,w]
imgs = torch.cat(imgs, dim=0)
#imgs=imgs.permute(1,0,2,3)
# imgs_list = [1,s,c,h,w]
imgs_list.append(imgs)
imgs_array = torch.stack(imgs_list)
return imgs_array, pid, camid
else:
raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
3結果テストの解釈
3.1トレーニングセットのインポート
# test
if __name__ == "__main__":
dataset =data_manager.init_dataset(name="mars")
transform_train = T.Compose([
T.Resize((224, 112)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])
trainloader = DataLoader(
VideoDataset(dataset.train, seq_len=4, sample='random', transform=transform_train),
batch_size=32, shuffle=True, num_workers=1,
pin_memory=False, drop_last=False,
)
# queryloader = VideoDataset(dataset.query,sample='dense',seq_len=16,transform=transform_test)
for batch_idx, (imgs, pids, camids) in enumerate(trainloader):
imgs = Variable(imgs, volatile=True)
print(imgs.size())
# b=1, n=number of clips, s=seq
b, s, c, h, w = imgs.size()
print(b,s,c,h,w)
- データセットの長さは8298で、トレーニングセット内のトラックレットの数に対応します。
- imgsのサイズはimgsです。Size([32、4、3、224、112])は[b、s、c、h、w]に対応します
3.2クエリデータセットのインポート
# test
if __name__ == "__main__":
dataset =data_manager.init_dataset(name="mars")
transform_test = T.Compose([
T.Resize((224, 112)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])
trainloader = DataLoader(
VideoDataset(dataset.query, seq_len=4, sample='dense', transform=transform_test),
batch_size=1, shuffle=False, num_workers=4,
pin_memory=False, drop_last=False,
)
# queryloader = VideoDataset(dataset.query,sample='dense',seq_len=16,transform=transform_test)
for batch_idx, (imgs, pids, camids) in enumerate(trainloader):
imgs = Variable(imgs, volatile=True)
print(imgs.size())
# b=1, n=number of clips, s=seq
b,n, s, c, h, w = imgs.size()
print(b,s,c,h,w)
- 1980年はクエリトラックレットです
- imgs_arrays.size()= [1,10,4,3,224,112]#[b、n、s、c、h、w]
- 最初のトラックには39枚の画像が含まれ、4枚ごとにシーケンスがあるため、n = 10
3.3ギャラリーデータセットのインポート
# test
if __name__ == "__main__":
dataset =data_manager.init_dataset(name="mars")
transform_test = T.Compose([
T.Resize((224, 112)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])
trainloader = DataLoader(
VideoDataset(dataset.gallery, seq_len=4, sample='dense', transform=transform_test),
batch_size=1, shuffle=False, num_workers=1,
pin_memory=False, drop_last=False,
)
# queryloader = VideoDataset(dataset.query,sample='dense',seq_len=16,transform=transform_test)
for batch_idx, (imgs, pids, camids) in enumerate(trainloader):
imgs = Variable(imgs, volatile=True)
print(imgs.size())
# b=1, n=number of clips, s=seq
b,n, s, c, h, w = imgs.size()
print(b,n,s,c,h,w)
- ギャラリーデータセットには9330のトラックレットがあります
- imgs.size()= [1、3、4、3、224、112] ## [b、n、s、c、h、w]
- クエリデータセットと基本的に一致