《Few-Shot Learning with Graph Neural Networks》代码理解

参考了解读: Few-shot learning with Graph Neural Networks - Yuan Z的文章 - 知乎

代码使用 o m n i g l o t omniglot omniglot数据集,以 5 w a y − 1 s h o t 5way-1shot 5way1shot为例,一个 e p i s o d e episode episode只有一张 q u e r y query query,一个 b a t c h batch batch中有 300 300 300 e p i s o d e episode episode


main.py中的第144行开始训练的迭代

数据

首先加载数据

# main.py line149
data = train_loader.get_task_batch(batch_size=args.batch_size, n_way=args.train_N_way,
                                   unlabeled_extra=args.unlabeled_extra, num_shots=args.train_N_shots,
                                   cuda=args.cuda, variable=True)
[batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi, hidden_labels] = data

在一般的情况下(不是半监督学习),我们只需要用到

  • batch_x查询的图片(300,1,28,28)
  • label_x查询图片的 o n e − h o t one-hot onehot标签(300,5)
  • batches_xi支持集[(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28)]
  • labels_yi支持集 o n e − h o t one-hot onehot标签[(300,5),(300,5),(300,5),(300,5),(300,15)]

调用train_batch进行训练

#main.py line157
loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
                                    data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])

训练

model中有三个模型

  • enc_nn对图片进行特征提取
  • metric_nn接下来的重点
  • softmax_module做softmax的

先用enc_nn模型提取图片的特征,[-1]是因为enc_nn模型有多个输出,我们只需要最后一个 64 64 64维的输出

# main.py line94
z = enc_nn(batch_x)[-1] # 查询的图片的特征
zi_s = [enc_nn(batch_xi)[-1] for batch_xi in batches_xi] # 支持集的图片的特征

然后调用模型,输出预测结果

# main.py line97
out_metric, out_logits = metric_nn(inputs=[z, zi_s, labels_yi, oracles_yi, hidden_labels])

然后对out_logits进行softmax

# main.py 98
 logsoft_prob = softmax_module.forward(out_logits)

然后和标签计算损失

# main.py 101
label_x_numpy = label_x.cpu().data.numpy()
formatted_label_x = np.argmax(label_x_numpy, axis=1)  # one-hot向量转换为标量
formatted_label_x = Variable(torch.LongTensor(formatted_label_x))
if args.cuda:
    formatted_label_x = formatted_label_x.cuda()
loss = F.nll_loss(logsoft_prob, formatted_label_x) #计算损失,第一个参数是模型的结果,第二个是正确答案
loss.backward()

return loss

这就是大致流程,然后详细说下metric_nn

metric_nn

在这里插入图片描述
打断点进去,发现调用了self.gnn_iclr_forward(z, zi_s, labels_yi)
继续运行进去
在这里插入图片描述
再回顾一下

  • z查询的特征(300,64)
  • zi_s支持集图片的特征[(300,64),(300,64),(300,64),(300,64),(300,64)]
  • labels_yi支持集的标签[(300,5),(300,5),(300,5),(300,5),(300,5)]

在这里插入图片描述
然后就调用了gnn.obj,也就是main.py中的GNN_nl_omniglot
在这里插入图片描述
可以看到载for循环里调用了WcomputeGconv
W_init形状为(300,6,6,1),其中(6,6)是单位矩阵

Wcompute

Wcompute用于计算节点之间的相似度
在这里插入图片描述

Gconv

先调用gmul产生新特征,然后使用全连接网络进行降维和soft

gmul

用于产生新特征
在这里插入图片描述

流程图

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_37252519/article/details/119653516