deep graph infomax代码阅读总结

ICLR 2019。
ps:我觉得论文看method看不大懂,不如直接去看代码最清楚。

1.一种无监督的训练方式,核心:最大化互信息。(全图的信息与正样本局部信息最大化,全图的信息与负样本局部信息最小化。)
2.大致流程:通过最大化互信息训练图嵌入结果(无监督),训练线性分类器(有监督)完成图分类任务。
3.DGI定义:

class DGI(nn.Module):
    # ft_size, hid_units, nonlinearity
    def __init__(self, n_in, n_h, activation):
        super(DGI, self).__init__()
        self.gcn = GCN(n_in, n_h, activation)
        self.read = AvgReadout() #读出函数,其实这里就是所有节点表示的均值

        self.sigm = nn.Sigmoid()

        self.disc = Discriminator(n_h) #判别器,定义为一个双线性函数bilinear

    def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2): # msk: None, samp_bias1: None, samp_bias2: None,

        h_1 = self.gcn(seq1, adj, sparse)

        c = self.read(h_1, msk)
        c = self.sigm(c) # c表示全图信息

        h_2 = self.gcn(seq2, adj, sparse)

        ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) #计算c-h_1,c_h_2的双线性判别器的结果

        return ret

    # Detach the return variables
    def embed(self, seq, adj, sparse, msk):
        h_1 = self.gcn(seq, adj, sparse)
        c = self.read(h_1, msk)

        return h_1.detach(), c.detach() #将tensor从计算图中分离出来,不参与反向传播

4.Discriminator定义

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1) # 双线性层 x_1 W x_2 + b, 输出 batch * 1 维度,相当于输出表示两个输入之间的关系?

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): #c应该与h_pl接近,与h_mi远离?

        c_x = torch.unsqueeze(c, 1)
        c_x = c_x.expand_as(h_pl)

        sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)

        sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)

        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2

        logits = torch.cat((sc_1, sc_2), 1)

        return logits

4.训练:
正样例:features
负样例:打乱了顺序的features

model = DGI(ft_size, hid_units, nonlinearity) #模型的创建
b_xent = nn.BCEWithLogitsLoss()

lbl_1 = torch.ones(batch_size, nb_nodes) #正样本标签
    lbl_2 = torch.zeros(batch_size, nb_nodes) #负样本标签
    lbl = torch.cat((lbl_1, lbl_2), 1) # shape: torch.Size([1, 5416])

logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None) 
# shuf_fts代表打乱了顺序的features
loss = b_xent(logits, lbl)

猜你喜欢

转载自blog.csdn.net/ptxx_p/article/details/124027761