源码地址:https://github.com/Clarifai/few-shot-ctm
无法直接运行代码
看代码有人也有这个问题Why i can’t run the code with default setting,所以要稍微修改下代码。
- 修改
tools/general_utils.py
line276为
yaml_cfg = AttrDict(yaml.load(f, Loader=yaml.FullLoader))
main
代码line 28中使用了配置文件configs/demo/mini/20way_1shot.yaml
,但是会有异常抛出。所以我们将配置文件写到core/config.py
的Config类属性中
。然后将line 28
改为:
opts = Config(None)
训练流程
从main.py line 129开始对模型进行训练
学习策略
- 优化器定义在main.py line49
根据不同的参数,分别使用了adam
、sgd
、rmsprop
三种优化器 - 学习率调整定义在main.py line59
根据不同的参数,分别使用了MultiStepLR
、ExponentialLR
2种更新学习率的策略
MultiStepLR
每到一个milestones区间,学习率× gammaExponentialLR
指数衰减学习率,新学习率 = 旧*(gama ^ epoch)
计算模型loss
根据论文和解读:https://blog.csdn.net/qq_36104364/article/details/106363521,我们可以看论文是如何实现
Y = M ( r ( S ) ⊙ p , r ( Q ) ⊙ p ) , Y = { y i j } Y=\mathcal{M}(\boldsymbol{r}(\mathcal{S}) \odot p, \boldsymbol{r}(\mathcal{Q}) \odot p), \quad Y=\left\{y_{i j}\right\} Y=M(r(S)⊙p,r(Q)⊙p),Y={
yij}
其中 S S S是支持集的特征向量, r ( S ) \boldsymbol{r}(\mathcal{S}) r(S)是对支持集 S S S的特征向量进行 r e s h a p e r reshaper reshaper, Q Q Q是查询集的特征向量, r ( Q ) \boldsymbol{r}(\mathcal{Q}) r(Q)是对查询集 Q Q Q的特征向量进行 r e s h a p e r reshaper reshaper, M \mathcal{M} M是距离函数
代码在core/model.py line415,根据注释函数主要包括三部分,以 20 w a y − 1 s h o t − 8 w a y 20way - 1shot - 8way 20way−1shot−8way为例:
- 特征提取,使用
repnet
进行特征提取:core/model.py line424
# support_sz (25), c (64), d (19), d (19)
support_xf_ori = self.repnet(support_x.view(batch_sz*support_sz, -1, _d, _d)) # torch.Size([1, 20, 3, 84, 84]) -> torch.Size([20, 3, 84, 84]) -> torch.Size([20, 64, 19, 19])
# query_sz (75), c (64), d (19), d (19)
query_xf_ori = self.repnet(query_x.view(batch_sz*query_sz, -1, _d, _d))# torch.Size([1, 160, 3, 84, 84]) -> torch.Size([160, 3, 84, 84]) -> torch.Size([160, 64, 19, 19])
Concentrator
:core/model.py line434
mp = self.main_component(support_xf_reshape) # ([20, 40, 19, 19]) # 5(n_way), 64, 3, 3
projection
:core/model.py line442
_input_P = mp.view(1, -1, mp.size(2), mp.size(3)) # ([1, 800, 19, 19])
P = self.projection(_input_P) # 1, 64, 3, 3
P = F.softmax(P, dim=1)
reshaper
:core/model.py line457对support和query的特征进行reshaper
v = self.reshaper(support_xf_ori)
query = self.reshaper(query_xf_ori) # 75, 64, 3, 3
- 相乘得到结果:core/model.py line458对support和query的特征进行reshaper
query = torch.matmul(query, P)
其他
- 展示模型的训练loss:main.py line161
- 在验证集上验证模型,并且保存最好精度的模型:main.py line169