send_recv




'''
Send messages through all edges >>> update all nodes.
DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')

message_func --message function on the edges
reduce_func--reduce function on the node
apply_node_func:apply function on the nodes
'''


'''
DGLGraph.send(edges='__ALL__', message_func='default')
edges:
int:one edge using edge id
pair of int :one edge using its endpoints
int iterable/tensor :multiple edges using edge id
pair of int iterable/pair of tensor :multiple edges using their endpoints

returns messages on the edges and can be later fetched in the destination node’s mailbox

'''


'''
DGLGraph.recv(v='__ALL__', reduce_func='default', apply_node_func='default', inplace=False)

'''

import warnings
warnings.filterwarnings("ignore")
import torch as th
import dgl
g=dgl.DGLGraph()

g.add_nodes(3)
g.ndata["x"]=th.tensor([[5.],[6.],[7.]])
g.add_edges([0,1],[1,2])
src=th.tensor([0])
dst=th.tensor([2])
g.add_edges(src,dst)
print("ndata",g.ndata["x"])


def send_source(edges):

print("src",edges.src["x"].shape,edges.src["x"]) #源节点特征 ([2, 1])
print("dst",edges.dst["x"].shape,edges.dst["x"]) #目标节点特征 ([2, 1])


return {"m":edges.src["x"]}

g.register_message_func(send_source)

'''
ndata tensor([[5.],
[6.],
[7.]])
src torch.Size([3, 1]) tensor([[5.],
[6.],
[5.]])
dst torch.Size([3, 1]) tensor([[6.],
[7.],
[7.]])

'''




def simple_reduce(nodes):
print("data_nodes",nodes.data["x"]) #节点特征
print("mailbox",nodes.mailbox["m"].shape,nodes.mailbox["m"]) #mailbox包含沿第二维堆叠的所有传入message特征 [2, 1, 1]
print("sum",nodes.mailbox["m"].sum(1))



return {"x":nodes.mailbox["m"].sum(1)} #按行求和

g.register_reduce_func(simple_reduce)



g.send(g.edges())
g.recv(g.nodes())
print("ndata",g.ndata["x"])


'''

data_nodes tensor([[6.]])
mailbox torch.Size([1, 1, 1]) tensor([[[5.]]])
sum tensor([[5.]])
data_nodes tensor([[7.]])
mailbox torch.Size([1, 2, 1]) tensor([[[6.],
[5.]]])
sum tensor([[11.]])
ndata tensor([[ 0.],
[ 5.],
[11.]])

'''

猜你喜欢

转载自www.cnblogs.com/hapyygril/p/11586319.html