ICLR 2019。
ps:我觉得论文看method看不大懂,不如直接去看代码最清楚。
1.一种无监督的训练方式,核心:最大化互信息。(全图的信息与正样本局部信息最大化,全图的信息与负样本局部信息最小化。)
2.大致流程:通过最大化互信息训练图嵌入结果(无监督),训练线性分类器(有监督)完成图分类任务。
3.DGI定义:
class DGI(nn.Module):
# ft_size, hid_units, nonlinearity
def __init__(self, n_in, n_h, activation):
super(DGI, self).__init__()
self.gcn = GCN(n_in, n_h, activation)
self.read = AvgReadout() #读出函数,其实这里就是所有节点表示的均值
self.sigm = nn.Sigmoid()
self.disc = Discriminator(n_h) #判别器,定义为一个双线性函数bilinear
def forward(self, seq1, seq2, adj, sparse, msk, samp_bias1, samp_bias2): # msk: None, samp_bias1: None, samp_bias2: None,
h_1 = self.gcn(seq1, adj, sparse)
c = self.read(h_1, msk)
c = self.sigm(c) # c表示全图信息
h_2 = self.gcn(seq2, adj, sparse)
ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) #计算c-h_1,c_h_2的双线性判别器的结果
return ret
# Detach the return variables
def embed(self, seq, adj, sparse, msk):
h_1 = self.gcn(seq, adj, sparse)
c = self.read(h_1, msk)
return h_1.detach(), c.detach() #将tensor从计算图中分离出来,不参与反向传播
4.Discriminator定义
class Discriminator(nn.Module):
def __init__(self, n_h):
super(Discriminator, self).__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1) # 双线性层 x_1 W x_2 + b, 输出 batch * 1 维度,相当于输出表示两个输入之间的关系?
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): #c应该与h_pl接近,与h_mi远离?
c_x = torch.unsqueeze(c, 1)
c_x = c_x.expand_as(h_pl)
sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 2)
sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 2)
if s_bias1 is not None:
sc_1 += s_bias1
if s_bias2 is not None:
sc_2 += s_bias2
logits = torch.cat((sc_1, sc_2), 1)
return logits
4.训练:
正样例:features
负样例:打乱了顺序的features
model = DGI(ft_size, hid_units, nonlinearity) #模型的创建
b_xent = nn.BCEWithLogitsLoss()
lbl_1 = torch.ones(batch_size, nb_nodes) #正样本标签
lbl_2 = torch.zeros(batch_size, nb_nodes) #负样本标签
lbl = torch.cat((lbl_1, lbl_2), 1) # shape: torch.Size([1, 5416])
logits = model(features, shuf_fts, sp_adj if sparse else adj, sparse, None, None, None)
# shuf_fts代表打乱了顺序的features
loss = b_xent(logits, lbl)