使用DGL的GATConv层,居然意外的出现如下错误:
dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead.
注意到799是节点数,而2390刚好是799的3倍,这个3恰好又是num_heads的数值。因此
GATConv的返回值的shape为: ( N , H , M ) (N,H,M) (N,H,M) ,其中 N N N 是节点个数, H H H 是特征长度,而 M M M是头的数目。
当不做任何处理,DGL会默认对返回的矩阵做reshape,reshape的目标是(-1,H) 于是矩阵的行数就变成了 N × M N\times M N×M 了,此时就不对了。
解决方法:对GATConv的返回值执行一次flatten:
def forward(g):
···
for layer in self.layers:
pkt_length_matrix = layer(g,pkt_length_matrix.to(th.device(self.device)))
arv_time_matrix = layer(g,arv_time_matrix.to(th.device(self.device)))
if self.layer_type =='GAT':
pkt_length_matrix = th.flatten(pkt_length_matrix,1)
arv_time_matrix= th.flatten(arv_time_matrix,1)
···
同时,下一层GATConv的in_feat设置为上一层的out_feat × \times × num_heads。
这个就可以了。