LaplacianShot官方代码笔记

训练模式

run.sh中复制命令行

python ./src/train_lshot.py -c $config --proto-rect $protrec --lmd $lmd --tune-lmd $tune  --data $datapath --lshot --log-file /LaplacianShot.log --evaluate

用值替换$变量,去掉最前的python ./src/train_lshot.py和最后的--evaluate

-c ../configs/mini/softmax/resnet18.config --proto-rect True --lmd 1.0 --tune-lmd True  --data ../data/images --lshot --log-file /LaplacianShot.log

然后复制到pycharm -> run -> Edit Configurations
在这里插入图片描述

项目目录

在这里插入图片描述

  • data下的image放了miniImageNet
    在这里插入图片描述

链接:https://pan.baidu.com/s/1rMy5aoDLS20_hYc41taePg
提取码:ldzv

  • tmp下面保存了作者的模型
    在这里插入图片描述

链接:https://pan.baidu.com/s/1DgIJnaoZxbaaLYzGU_GoIQ
提取码:ltyd

  • model下面的文件,放到results在这里插入图片描述

评估模式

run.sh中复制命令行

python ./src/train_lshot.py -c $config --proto-rect $protrec --lmd $lmd --tune-lmd $tune  --data $datapath --lshot --log-file /LaplacianShot.log --evaluate

用值替换$变量,去掉最前的python ./src/train_lshot.py

-c ../configs/mini/softmax/resnet18.config --proto-rect True --lmd 1.0 --tune-lmd True  --data ../data/images --lshot --log-file /LaplacianShot.log

然后复制到pycharm -> run -> Edit Configurations

main()主要逻辑

注意,代码用resnet18产生512维向量

  • 首先设置损失函数和优化器
# 损失函数
criterion = nn.CrossEntropyLoss().cuda()

# 优化器
optimizer = get_optimizer(model)
  • 从断点处加载模型
# 从断点处恢复
if os.path.isfile(args.save_path + '/checkpoint.pth.tar') and args.resume == '':
    args.resume = args.save_path + '/checkpoint.pth.tar'  # '../results/mini/softmax/resnet18/checkpoint.pth.tar'
# 从最后的断点处开始训练
if args.resume:  # path to latest checkpoint -> '../results/mini/softmax/resnet18/checkpoint.pth.tar'
    if os.path.isfile(args.resume):
        log.info("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']  # 90
        best_prec1 = checkpoint['best_prec1']  # 0.6250933455824852
        # scheduler.load_state_dict(checkpoint['scheduler'])
        model.load_state_dict(checkpoint['state_dict'])  # 模型参数
        optimizer.load_state_dict(checkpoint['optimizer'])  # 优化器参数
        log.info("=> loaded checkpoint '{}' (epoch {})"
                 .format(args.resume, checkpoint['epoch']))
    else:
        log.info('[Attention]: Do not find checkpoint {}'.format(args.resume))
  • 获得train_loadertest_loader
  • 设置回调函数
scheduler = get_scheduler(len(train_loader), optimizer)
  • 继续训练,epoch范围是range(args.start_epoch, args.epochs)
    1. train_loader训练模型,普通训练
    2. 回调函数调整学习率
    3. 用`val_loader验证模型,使用episode
      1. 如果精度最高,则保存model_bets.pth.tar
      2. 保存断点模型checkpoint.pth.tar
  • 最后评估模型
    • 微调 λ \lambda λ,用train_loader 1 s h o t 1shot 1shot 5 s h o t 5shot 5shot上找到使精度最高的 λ 1 \lambda_1 λ1 λ 5 \lambda_5 λ5
    • 使用最后一个断点的模型
      • out_mean, fc_out_mean, out_dict, fc_out_dict = extract_feature(train_loader, val_loader, model, 'last')
        • out_mean512维向量的均值
        • fc_out_mean64维向量的均值
        • out_dict标签对应列表、其中包含600个512维度向量
        • fc_out_dict标签对应列表、其中包含600个64维向量
      • 使用3种标准化得到的置信区间:均值和一半区间长度
        • accuracy_info_shot1 = meta_evaluate(out_dict, out_mean, 1)
        • accuracy_info_shot5 = meta_evaluate(out_dict, out_mean, 5)
    • 使用最好的模型
      • 同上
      • out_mean, fc_out_mean, out_dict, fc_out_dict = extract_feature(train_loader, val_loader, model, 'best')
      • accuracy_info_shot1 = meta_evaluate(out_dict, out_mean, 1)
      • accuracy_info_shot5 = meta_evaluate(out_dict, out_mean, 5)

中文注释代码

https://gitee.com/Lost_star/LaplacianShot-master

Guess you like

Origin blog.csdn.net/qq_37252519/article/details/120042139