代码参考:https://github.com/msight-tech/research-ms-loss
论文参考:Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf
1.环境准备
ubuntu16.04
cuda10.1
torch==1.7.1
torchvision==0.8.2
numpy==1.19.5
yacs==0.1.8
Pillow==8.1.0
cd reseach-ms-loss
python setup.py develop build
2.准备预训练模型
3)将其放到
~/.cache/torch/checkpoints/目录下
4)修改ret_benchmark/config/model_path.py中的目录:
from yacs.config import CfgNode as CN
MODEL_PATH = {
'bninception': "~/.cache/torch/checkpoints/bn_inception-52deb4733.pth",
'resnet50': "~/.cache/torch/checkpoints/resnet50-19c8e357.pth",
}
MODEL_PATH = CN(MODEL_PATH)
3.数据下载与准备
1)下载:https://blog.csdn.net/zengyujianjianghu/article/details/98323836
2)准备:
①修改split_cub_for_ms_loss.py中的,CUB_ROOT路径。
②运行(生成train.txt与test.txt文件,之前我误以为是ms loss对数据划分有特殊要求):
python scripts/split_cub_for_ms_loss.py
3)修改configs/example.yaml中的目录:
DATA:
TRAIN_IMG_SOURCE: /root/datasets/CUB_200_2011/CUB_200_2011/train.txt
TEST_IMG_SOURCE: /root/datasets/CUB_200_2011/CUB_200_2011/test.txt
4.运行训练
bast scripts/run_cub.sh
显存大概使用6G。
如果出现了显存不足,最好是调节下,config/example.yaml中的test batch size,为128等更小的。
和readme中对比,结果应该是可以复现,比较一致的。