ctm代码笔记《Finding Task-Relevant Features for Few-Shot Learning by Category Traversal》

源码地址:https://github.com/Clarifai/few-shot-ctm

无法直接运行代码

看代码有人也有这个问题Why i can’t run the code with default setting,所以要稍微修改下代码。

  1. 修改tools/general_utils.pyline276
yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
  1. main代码line 28中使用了配置文件configs/demo/mini/20way_1shot.yaml,但是会有异常抛出。所以我们将配置文件写到core/config.pyConfig类属性中。然后将line 28
    改为:
opts = Config(None)

训练流程

main.py line 129开始对模型进行训练

学习策略

  1. 优化器定义在main.py line49
    根据不同的参数,分别使用了adamsgdrmsprop 三种优化器
  2. 学习率调整定义在main.py line59
    根据不同的参数,分别使用了MultiStepLRExponentialLR 2种更新学习率的策略
  • MultiStepLR每到一个milestones区间,学习率× gamma
  • ExponentialLR 指数衰减学习率,新学习率 = 旧*(gama ^ epoch)

计算模型loss

根据论文和解读:https://blog.csdn.net/qq_36104364/article/details/106363521,我们可以看论文是如何实现
Y = M ( r ( S ) ⊙ p , r ( Q ) ⊙ p ) , Y = { y i j } Y=\mathcal{M}(\boldsymbol{r}(\mathcal{S}) \odot p, \boldsymbol{r}(\mathcal{Q}) \odot p), \quad Y=\left\{y_{i j}\right\} Y=M(r(S)p,r(Q)p),Y={ yij}
其中 S S S是支持集的特征向量, r ( S ) \boldsymbol{r}(\mathcal{S}) r(S)是对支持集 S S S的特征向量进行 r e s h a p e r reshaper reshaper Q Q Q是查询集的特征向量, r ( Q ) \boldsymbol{r}(\mathcal{Q}) r(Q)是对查询集 Q Q Q的特征向量进行 r e s h a p e r reshaper reshaper M \mathcal{M} M是距离函数
在这里插入图片描述
代码在core/model.py line415,根据注释函数主要包括三部分,以 20 w a y − 1 s h o t − 8 w a y 20way - 1shot - 8way 20way1shot8way为例:

  1. 特征提取,使用repnet进行特征提取:core/model.py line424
# support_sz (25), c (64), d (19), d (19)
support_xf_ori = self.repnet(support_x.view(batch_sz*support_sz, -1, _d, _d))  # torch.Size([1, 20, 3, 84, 84]) -> torch.Size([20, 3, 84, 84]) -> torch.Size([20, 64, 19, 19])
# query_sz (75), c (64), d (19), d (19)
query_xf_ori = self.repnet(query_x.view(batch_sz*query_sz, -1, _d, _d))# torch.Size([1, 160, 3, 84, 84]) -> torch.Size([160, 3, 84, 84]) -> torch.Size([160, 64, 19, 19])
  1. Concentrator:core/model.py line434
mp = self.main_component(support_xf_reshape)         #    ([20, 40, 19, 19])      # 5(n_way), 64, 3, 3
  1. projection:core/model.py line442
_input_P = mp.view(1, -1, mp.size(2), mp.size(3))   # ([1, 800, 19, 19])
P = self.projection(_input_P)   # 1, 64, 3, 3
P = F.softmax(P, dim=1)
  1. reshaper:core/model.py line457对support和query的特征进行reshaper
v = self.reshaper(support_xf_ori)
query = self.reshaper(query_xf_ori)     # 75, 64, 3, 3
  1. 相乘得到结果:core/model.py line458对support和query的特征进行reshaper
query = torch.matmul(query, P)
  1. 计算分数:core/model.py line476
  2. 输出:core/model.py line479

其他

  1. 展示模型的训练loss:main.py line161
  2. 在验证集上验证模型,并且保存最好精度的模型:main.py line169

おすすめ

転載: blog.csdn.net/qq_37252519/article/details/120946015