参考了解读: Few-shot learning with Graph Neural Networks - Yuan Z的文章 - 知乎
代码使用 o m n i g l o t omniglot omniglot数据集,以 5 w a y − 1 s h o t 5way-1shot 5way−1shot为例,一个 e p i s o d e episode episode只有一张 q u e r y query query,一个 b a t c h batch batch中有 300 300 300个 e p i s o d e episode episode
在
main.py
中的第144行开始训练的迭代
数据
首先加载数据
# main.py line149
data = train_loader.get_task_batch(batch_size=args.batch_size, n_way=args.train_N_way,
unlabeled_extra=args.unlabeled_extra, num_shots=args.train_N_shots,
cuda=args.cuda, variable=True)
[batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi, hidden_labels] = data
在一般的情况下(不是半监督学习),我们只需要用到
batch_x
查询的图片(300,1,28,28)
label_x
查询图片的 o n e − h o t one-hot one−hot标签(300,5)
batches_xi
支持集[(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28)]
labels_yi
支持集 o n e − h o t one-hot one−hot标签[(300,5),(300,5),(300,5),(300,5),(300,15)]
调用train_batch
进行训练
#main.py line157
loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])
训练
model中有三个模型
enc_nn
对图片进行特征提取metric_nn
接下来的重点softmax_module
做softmax的
先用enc_nn
模型提取图片的特征,[-1]
是因为enc_nn
模型有多个输出,我们只需要最后一个 64 64 64维的输出
# main.py line94
z = enc_nn(batch_x)[-1] # 查询的图片的特征
zi_s = [enc_nn(batch_xi)[-1] for batch_xi in batches_xi] # 支持集的图片的特征
然后调用模型,输出预测结果
# main.py line97
out_metric, out_logits = metric_nn(inputs=[z, zi_s, labels_yi, oracles_yi, hidden_labels])
然后对out_logits
进行softmax
# main.py 98
logsoft_prob = softmax_module.forward(out_logits)
然后和标签计算损失
# main.py 101
label_x_numpy = label_x.cpu().data.numpy()
formatted_label_x = np.argmax(label_x_numpy, axis=1) # one-hot向量转换为标量
formatted_label_x = Variable(torch.LongTensor(formatted_label_x))
if args.cuda:
formatted_label_x = formatted_label_x.cuda()
loss = F.nll_loss(logsoft_prob, formatted_label_x) #计算损失,第一个参数是模型的结果,第二个是正确答案
loss.backward()
return loss
这就是大致流程,然后详细说下metric_nn
metric_nn
打断点进去,发现调用了self.gnn_iclr_forward(z, zi_s, labels_yi)
继续运行进去
再回顾一下
z
查询的特征(300,64)
zi_s
支持集图片的特征[(300,64),(300,64),(300,64),(300,64),(300,64)]
labels_yi
支持集的标签[(300,5),(300,5),(300,5),(300,5),(300,5)]
然后就调用了gnn.obj
,也就是main.py
中的GNN_nl_omniglot
可以看到载for
循环里调用了Wcompute
和Gconv
W_init
形状为(300,6,6,1)
,其中(6,6)是单位矩阵
Wcompute
Wcompute
用于计算节点之间的相似度
Gconv
先调用gmul
产生新特征,然后使用全连接网络进行降维和soft
gmul
用于产生新特征