赞
踩
reduce_function
或 apply_node
函数作用于节点本身,即传入的参数是节点信息和节点的邮箱信息。聚合函数和更新函数。
绿框为节点属性,蓝框为边的属性,节点的状态更新需要聚合各入边传递来的消息。
消息函数:接受一个参数 edges
,这是一个 dgl.EdgeBatch
的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges
有三个成员属性:src
、dst
和data
,分别用于访问源节点、目标节点和边的特征。
mailbox
,暂存消息函数发送过来的数据;node+node->mailbox或node+edge->mailbox。dgl.function.u_add_v('hu','hv','he')
聚合函数:接受一个参数 nodes
,这是一个 NodeBatch
的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes
的成员属性 mailbox
可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum
、max
、min
、mean
等。聚合函数一般有2个参数,它们的类型都是字符串:
mailbox
中的字段名;dgl.function.sum('m', 'h')
等价于如下所示的对接收到消息求和的用户定义函数:import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
nodes
。此函数对 聚合函数
的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。即将mailbox
中的数据与节点的数据合并。(1)命名空间dfl.function
中实现了常用的(内置的)消息函数和聚合函数(能够自动处理维度广播),当然也可以自定义函数。
(2)(自定义)内置消息函数:可以是一元函数(dgl支持copy
函数),也支持二元函数(dgl支持add、sub、mul、div、dot函数):
u
表示源节点,v
表示目标节点,e
表示边。hu
特征和目标节点的hv
特征求和,然后将结果保存在边的he
特征上:dgl.function.u_add_v('hu', 'hv', 'he')
;def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
(3)在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges()
单独调用逐边计算。
apply_edges()
的参数是一个消息函数。import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
(4)消息传递的高级APIupdate_all()
:
update_all()
的参数:一个消息、聚合、更新函数(可选,也可以在外面操作,dgl不推荐在update_all
中指定更新函数)。
update_all
执行完后直接对节点特征进行操作;update_all
中指定更新函数,如函数:final f t i = 2 ∗ ∑ j ∈ N ( i ) ( f t j ∗ a i j ) \text { final } f t_{i}=2 * \sum_{j \in \mathcal{N}(i)}\left(f t_{j} * a_{i j}\right) final fti=2∗j∈N(i)∑(ftj∗aij)
def updata_all_example(graph):
# 在graph.ndata['ft']中存储结果
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# 在update_all外调用更新函数
final_ft = graph.ndata['ft'] * 2
return final_ft
ex:graph.update_all(fn.u_mul_e('ft', 'a', 'm')
将源节点特征tf
和边特征a
相乘生成消息m
,fn.sum('m', 'ft')
再对所有消息求和来更新节点特征ft
,再乘2后得到最终结果final_ft
。调用后,中间消息m
会被清除。
关于dgl内置函数是如何优化消息传递的内存消耗和计算速度的, 详见文字描述: DGL官方文档 ; 总结来说主要是合并内核, 并行逐边运算, 减少点边拷贝等; 如update_all()
函数就是一个效率很高的接口; 如果确实需要使用apply_edges()
函数在边上保存消息, 则内存占用会非常大;
(1)一个通过对节点特征降维来减少消息维度的示例:
import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim * 2)))
def concat_message_function(edges):
return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear
也可以将先行操作分成两部分, 即分别对源节点特征和目标节点特征进行线性变换后再相加, 即 W l × u + W r × v W_{l} \times u+W_{r} \times v Wl×u+Wr×v,其中 W = ( W l ∥ W r ) W = \left(W_{l} \| W_{r}\right) W=(Wl∥Wr),这样可能会更加优化。代码实例:
import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
这两种方法数学上等价, 但后一种方法更加高效, 因为无需再边上保存feat_src
和feat_dst
, 空间占用小, 另外加法可以直接用内置函数u_add_v
进行优化, 内置函数的效率一般比自定义函数要高。
如果用户只想更新图中部分节点,先将想处理的节点编号创建一个子图,然后对其调用update_all()
(这也是小批量处理中的常见用法)。
nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)
常见的GNN建模做法:在消息聚合前使用边的权重,如GAT和一些GCN的变种。dgl的处理:
ex:假定下面的权重eweight
是一个形状为(E, *)的张量,E是边的数量。权重存为边的特征,即eweight
被用作边的权重(通常是一个标量)。
import dgl.function as fn
# 假定eweight是一个形状为(E, *)的张量,E是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
本质上异构图的消息传递与同构图并没有太大区别,异构图上的消息传递可以分为两个部分:
对每个关系计算和聚合消息。
对每个结点聚合来自不同关系的消息。
DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)
:
etype_dict
: dict
类型, 键为一种关系, 值为这种关系对应的update_all()
的参数;cross_reducer
: str
类型, 表示跨类型整合函数, 来指定整合不同关系聚合结果的方式, 可以是sum, min, max, mean, stack中之一;在DGL中,对异构图进行消息传递的接口是 multi_update_all()
。 multi_update_all()
接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all()
的参数。 multi_update_all()
还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 sum
、 min
、 max
、 mean
和 stack
中的一个。
import dgl.function as fn
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# 把它存在图中用来做消息传递
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# 指定每个关系的消息传递函数:(message_func, reduce_func).
# 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
DGL遵循Gilmer等人提出的消息8传递框架,很多GNN模型能符合如下框架:
m
u
→
v
(
l
)
=
M
(
l
)
(
h
v
(
l
−
1
)
,
h
u
(
l
−
1
)
,
e
u
→
v
(
l
−
1
)
)
m
v
(
l
)
=
∑
u
∈
N
(
v
)
m
u
→
v
(
l
)
h
v
(
l
)
=
U
(
l
)
(
h
v
(
l
−
1
)
,
m
v
(
l
)
)
m(l)u→v=M(l)(h(l−1)v,h(l−1)u,e(l−1)u→v)m(l)v=∑u∈N(v)m(l)u→vh(l)v=U(l)(h(l−1)v,m(l)v)
其中:
如GraphSAGE可表示为:
h
N
(
v
)
k
←
Average
{
h
u
k
−
1
,
∀
u
∈
N
(
v
)
}
h
v
k
←
ReLU
(
W
k
⋅
CONCAT
(
h
v
k
−
1
,
h
N
(
v
)
k
)
)
hkN(v)← Average {hk−1u,∀u∈N(v)}hkv←ReLU(Wk⋅CONCAT(hk−1v,hkN(v)))
我们可以看到消息传递是定向(有方向)的:从一个节点u发送到另一个节点v的消息不一定与从节点v发送到相反方向的节点u的消息相同。
DGL提供了GraphSAGE的实现dgl.nn.SAGEConv
。
import dgl.function as fn class SAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(SAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. """ with g.local_scope(): g.ndata['h'] = h # update_all is a message passing API. g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) h_N = g.ndata['h_N'] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total)
上述代码中的核心部分是g.update_all
函数,该函数收集并平均相邻特征。这里有三个概念:
fn.copy_u('h','m')
,它将名为“h”的节点特征复制为发送给邻居的消息fn.mean('m', 'h_N')
,该函数对所有接收到的消息中名为’m’的信息进行平均,并将结果保存为新的节点特征’h_N’update_all
让DGL触发所有节点和边的消息函数和聚合函数然后我们可以堆叠自己的GraphSAGE卷积层以构成多层GraphSAGE网络:
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = GraphSAGE(in_feats, h_feats)
self.conv2 = GraphSAGE(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
train_mask
掩码为true的地方,labels
的值被保留,反之false的地方其labels
被丢弃。所以在计算交叉熵损失时,只会计算存在训练集中的点的损失,即loss = F.cross_entropy(logits[train_mask], labels[train_mask])
。
import dgl.data dataset = dgl.data.CoraGraphDataset() g = dataset[0] def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) all_logits = [] best_val_acc = 0 best_test_acc = 0 features = g.ndata['feat'] labels = g.ndata['label'] train_mask = g.ndata['train_mask'] val_mask = g.ndata['val_mask'] test_mask = g.ndata['test_mask'] for e in range(200): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that we should only compute the losses of the nodes in the training set, # i.e. with train_mask 1. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format( e, loss, val_acc, best_val_acc, test_acc, best_test_acc)) model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes) train(g, model)
Using backend: pytorch NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. In epoch 0, loss: 1.951, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103) In epoch 5, loss: 1.900, val acc: 0.290 (best 0.292), test acc: 0.278 (best 0.277) In epoch 10, loss: 1.790, val acc: 0.462 (best 0.462), test acc: 0.435 (best 0.435) In epoch 15, loss: 1.614, val acc: 0.502 (best 0.502), test acc: 0.489 (best 0.489) In epoch 20, loss: 1.372, val acc: 0.548 (best 0.548), test acc: 0.529 (best 0.529) In epoch 25, loss: 1.087, val acc: 0.592 (best 0.592), test acc: 0.591 (best 0.591) In epoch 30, loss: 0.798, val acc: 0.650 (best 0.650), test acc: 0.639 (best 0.639) In epoch 35, loss: 0.547, val acc: 0.690 (best 0.690), test acc: 0.682 (best 0.682) In epoch 40, loss: 0.358, val acc: 0.710 (best 0.710), test acc: 0.721 (best 0.721) In epoch 45, loss: 0.230, val acc: 0.736 (best 0.736), test acc: 0.734 (best 0.734) In epoch 50, loss: 0.149, val acc: 0.738 (best 0.738), test acc: 0.743 (best 0.744) In epoch 55, loss: 0.099, val acc: 0.740 (best 0.740), test acc: 0.744 (best 0.743) In epoch 60, loss: 0.068, val acc: 0.742 (best 0.742), test acc: 0.743 (best 0.745) In epoch 65, loss: 0.048, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.745) In epoch 70, loss: 0.036, val acc: 0.736 (best 0.742), test acc: 0.753 (best 0.745) In epoch 75, loss: 0.028, val acc: 0.734 (best 0.742), test acc: 0.755 (best 0.745) In epoch 80, loss: 0.023, val acc: 0.738 (best 0.742), test acc: 0.757 (best 0.745) In epoch 85, loss: 0.019, val acc: 0.738 (best 0.742), test acc: 0.758 (best 0.745) In epoch 90, loss: 0.017, val acc: 0.742 (best 0.742), test acc: 0.756 (best 0.745) In epoch 95, loss: 0.015, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 100, loss: 0.013, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 105, loss: 0.012, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 110, loss: 0.011, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745) In epoch 115, loss: 0.010, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745) In epoch 120, loss: 0.009, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745) In epoch 125, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745) In epoch 130, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745) In epoch 135, loss: 0.007, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745) In epoch 140, loss: 0.007, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751) In epoch 145, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751) In epoch 150, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.749 (best 0.751) In epoch 155, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.750 (best 0.751) In epoch 160, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.750 (best 0.751) In epoch 165, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 170, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 175, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.752 (best 0.751) In epoch 180, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 185, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 190, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751) In epoch 195, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)
这里需要改的两个位置,g.update_all
的2个参数,以及在Model
中要设置边权重的传参。其他就没啥变化了。即使用带权平均聚合邻居表示,edata
成员可以保存边权重(特征),这些特征也可以参与消息传递。
# data可以包含边特征信息,同时传递 class WeightedSAGEConv(nn.Module): """ in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(WeightedSAGEConv, self).__init__() # 将input和邻近节点特征映射到outpu线性子模块 self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h, w): """ g : Graph The input graph. h : Tensor The input node feature. w : Tensor The edge weight. """ with g.local_scope(): g.ndata['h'] = h # 加入边的权重,进行消息传递和更新 g.edata['w'] = w g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N')) h_N = g.ndata['h_N'] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) # 因为这个数据集中的图没有边的权值, # 所以我们在模型的 forward 函数中手动将所有边的权值赋给1。 class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv2 = WeightedSAGEConv(h_feats, num_classes) def forward(self, g, in_feat): # 3个参数,(g, h, w)即图,点特征,边权重 h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = F.relu(h) # 设置所有边的权重为1 h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) return h
dgl.nn
模块;dgl.nn.functional
内置方法,适合一些简单操作,如为每个节点计算softmax;update_all
,内置的消息函数和聚合函数;message
)函数和聚合(reduce
)函数。DGL允许用户自定义消息函数和聚合函数以获得最大的表达能力。以下是一个用户定义的消息函数,它等价于fn.u_mul_e('h', 'w', 'm')
。
def u_mul_e_udf(edges):
return {"m": edges.src["h"] * edges.data["w"]}
参数edges共有三个成员:src,data和dst,分别代表所有边的源节点特征,边特征和目标节点特征。
也可以编写自己的聚合函数。例如,下面的函数相当于内置的fn.sum(‘m’, ‘h’)函数,它对传入的消息求和:
def sum_udf(nodes):
return {"h": nodes.mailbox["m"].sum(dim=1)}
# dim=1,按行求和
总之,DGL将按节点的度数对节点进行分组,对于每个组DGL将传入的消息沿着第2维度(按行)进行堆叠,然后沿第2个维度执行缩减(reduce)以聚合消息。
(1)NYU、AWS联合推出:全新图神经网络框架DGL正式发布
(2)https://www.dgl.ai/
(3)Write your own GNN module
(4)图神经网络框架DGL中的 消息函数、聚合函数及更新函数 的理解与说明
(5)dgl.DGLGraph.update_all官方文档
(6)对DGL中train_mask的理解
(7)How Does DGL Represent A Graph
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。