DGL 的GATConv报错:Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead

使用DGL的GATConv层,居然意外的出现如下错误:

dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead.

注意到799是节点数,而2390刚好是799的3倍,这个3恰好又是num_heads的数值。因此
GATConv的返回值的shape为: ( N , H , M ) (N,H,M) (N,H,M) ,其中 N N N 是节点个数, H H H 是特征长度,而 M M M是头的数目。
当不做任何处理,DGL会默认对返回的矩阵做reshape,reshape的目标是(-1,H) 于是矩阵的行数就变成了 N × M N\times M N×M 了,此时就不对了。

解决方法:对GATConv的返回值执行一次flatten:

def forward(g):
···
        for layer in self.layers:
            pkt_length_matrix = layer(g,pkt_length_matrix.to(th.device(self.device)))
            arv_time_matrix = layer(g,arv_time_matrix.to(th.device(self.device)))
            if self.layer_type =='GAT':
                pkt_length_matrix = th.flatten(pkt_length_matrix,1)
                arv_time_matrix= th.flatten(arv_time_matrix,1)
···

同时,下一层GATConv的in_feat设置为上一层的out_feat × \times × num_heads。
这个就可以了。

猜你喜欢

转载自blog.csdn.net/jmh1996/article/details/107075205