AlignedReID_04 from scratch

Model evaluation index

Preface

When introducing pedestrian re-identification, I mentioned that the commonly used evaluation indicators are mAP, cmc, rerank, etc., so this blog mainly analyzes this part of the code.
For some commonly used evaluation indicators, you can read this blog:
Common Evaluation Indicators for Pedestrian Re-identification

mAP and CMC evaluation

The following code is in the eval_metrics.py file under the utils package:

from __future__ import print_function, absolute_import
import numpy as np
import copy
from collections import defaultdict
import sys

# distmat.shape [3368,15913]
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
    """Evaluation with market1501 metric
    Key: for each query identity, its gallery images from the same camera view are discarded.
    """
    # 获得query gallery图片(特征)的数目
    num_q, num_g = distmat.shape
    # 判断 如果gallery的数目小于rank 则吧gallery的数目给rank
    if num_g < max_rank:
        max_rank = num_g
        print("Note: number of gallery samples is quite small, got {}".format(num_g))
    # 将dismat中的元素从小到大排列,提取其对应的index(索引),然后输出到indexs
    indices = np.argsort(distmat, axis=1)
    # 进行匹配,如果g_pids[indices]等于q_pids[:, np.newaxis]身份ID,则被置1。
    # matches[3368,15913],排列之后的结果类似如下:
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

    # compute cmc curve for each query
    all_cmc = []
    all_AP = [] # 记录query每张图像的AP
    num_valid_q = 0. # number of valid query 记录有效的query数量
    # 对每一个query中的图片进行处理
    for q_idx in range(num_q):
        # get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # remove gallery samples that have the same pid and camid with query
        # 取出当前第q_idx张图片,在gallery中查询过后的排序结果
        # [3368,15913]-->[15913,]
        order = indices[q_idx]
        # 删除与查询具有相同pid和camid的gallery样本,也就是删除query和gallery中相同图片的结果
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)

        # compute cmc curve
        # 二进制向量,值为1的位置是正确的匹配
        orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
        if not np.any(orig_cmc):
            # 当查询标识未出现在图库中时,此条件为真
            # this condition is true when query identity does not appear in gallery
            continue
        # 计算一行中的累加值,如一个数组为[0,0,1,1,0,2,0]
        # 通过cumsum得到[0,0,1,2,2,4,4]
        cmc = orig_cmc.cumsum()
        # cmc > 1的位置,表示都预测正确了
        cmc[cmc > 1] = 1

        # 根据max_rank,添cmc到all_cmc之中
        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        # compute average precision 平均精度
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
        num_rel = orig_cmc.sum()
        tmp_cmc = orig_cmc.cumsum()
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)
    # # 所有查询身份没有出现在图库(gallery)中则报错
    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
    # 计算平均cmc精度
    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    mAP = np.mean(all_AP)

    return all_cmc, mAP

# 测试代码
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):

    return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)


re-rank evaluation

This is the re-rank code directly. I have time to go back and read the corresponding paper. I don’t understand the implementation process. This code is placed in the re-rank file in utils.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri, 25 May 2018 20:29:09

@author: luohao
"""

"""
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
Matlab version: https://github.com/zhunzhong07/person-re-ranking
"""

"""
API

probFea: all feature vectors of the query set (torch tensor)
probFea: all feature vectors of the gallery set (torch tensor)
k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
MemorySave: set to 'True' when using MemorySave mode
Minibatch: avaliable when 'MemorySave' is 'True'
"""

import numpy as np
import torch

def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat = None, only_local = False):
    # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
    # 原图像特征 3368? 对应query中图片数目
    query_num = probFea.size(0)
    # query+gally总共数目
    all_num = query_num + galFea.size(0)
    if only_local:
        original_dist = local_distmat
    else:
        # 拼接
        feat = torch.cat([probFea,galFea])
        # 计算距离
        print('using GPU to compute original distance')
        distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \
                      torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
        distmat.addmm_(1,-2,feat,feat.t())
        original_dist = distmat.numpy()
        # 删除变量feat,解除对数据的引用
        del feat
        if not local_distmat is None:
            original_dist = original_dist + local_distmat
    gallery_num = original_dist.shape[0]
    original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
    V = np.zeros_like(original_dist).astype(np.float16)
    initial_rank = np.argsort(original_dist).astype(np.int32)

    print('starting re_ranking')
    for i in range(all_num):
        # k-reciprocal neighbors  
        forward_k_neigh_index = initial_rank[i, :k1 + 1]
        backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
        fi = np.where(backward_k_neigh_index == i)[0]
        k_reciprocal_index = forward_k_neigh_index[fi]
        k_reciprocal_expansion_index = k_reciprocal_index
        for j in range(len(k_reciprocal_index)):
            candidate = k_reciprocal_index[j]
            candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
                                               :int(np.around(k1 / 2)) + 1]
            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
            if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
                    candidate_k_reciprocal_index):
                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)

        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
        weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
        V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
    original_dist = original_dist[:query_num, ]
    if k2 != 1:
        V_qe = np.zeros_like(V, dtype=np.float16)
        for i in range(all_num):
            V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
        V = V_qe
        del V_qe
    del initial_rank
    invIndex = []
    for i in range(gallery_num):
        invIndex.append(np.where(V[:, i] != 0)[0])

    jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)

    for i in range(query_num):
        temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
        indNonZero = np.where(V[i, :] != 0)[0]
        indImages = [invIndex[ind] for ind in indNonZero]
        for j in range(len(indNonZero)):
            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
                                                                               V[indImages[j], indNonZero[j]])
        jaccard_dist[i] = 1 - temp_min / (2 - temp_min)

    final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
    del original_dist
    del V
    del jaccard_dist
    final_dist = final_dist[:query_num, query_num:]
    return final_dist


Guess you like

Origin blog.csdn.net/qq_37747189/article/details/112662779