本文主要是介绍GraphSAGE-Inductive Representation Learning on Large Graphs,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
简介
GraphSAGE-原文在摘要中这样介绍:we learn a function that generates embeddings by sampling and aggregating features from a node’s local neighborhood.我们学习一个函数,这个函数可以从一个节点的邻居节点中进行采样和聚合特征来生成embedding。
如何理解呢?简单来说,就是在当前节点中的邻居节点中,随机抽取N个节点(有放回),通过将这N个节点的特征进行聚合操作,来生成当前节点的表征。
节点采样
原理
我们以论文中的图进行展示:
我们假设取深度k=2的距离作为最终的长度:
k=1,选择3个节点:以1为中心,选取2、3、5三个节点;
k=2,选择2个节点:选取{2:8和8,3:9和11,5:14和15};
虽然当k=2时,节点2的邻居节点只有1个,因为采用的有放回的随机抽样,因此,用来表示2的节点仍然是两个8。
论文中采样以及聚合的伪代码如下:
其中, X v X_{v} Xv是特征矩阵, v v v是节点, B \Beta B是需要生成向量的节点,K是深度, σ \sigma σ是非线性激活函数, A G G R E G A T E k AGGREGATE_{k} AGGREGATEk是聚合函数, N k N_{k} Nk是采样函数, W k W^{k} Wk是一个随机权重矩阵。
代码中1-7行是对k层的节点进行采样抽取,而第1层需要依赖第2层的节点表征,第2层需要依赖第2层所采样的节点,因此计算顺序应该和聚合步骤相反。
第11行是对每一层的节点进行聚合计算;
第12行是对聚合后的特征进行连接,并做一次激活;
第13行是对表征的好特征进行归一化。
优点
论文中随机采样的优点在哪里呢?
1、减少了训练量,我们可以人为控制深度以及节点个数;
假如我们有一个节点有100个邻居,那么我们可以从中选取40个节点作为数据,这样,无疑减少了计算量;
2、生成节点向量更加灵活;
不需要进行全图计算,只需要将该节点与之的部分相关节点的特征进行计算,就可以得到节点向量,这也就是GraphSAGE为什么是归纳式,而且不是直推式。
聚合算法
论文中给出了4中聚合算法,我们做以下解释:
Mean aggregator
平均聚合,此种操作是对所选择的节点特征求均值,公式如下:
h v k ← M e a n ( h u k − 1 , ∀ u ∈ N v ) h_{v}^{k}\larr{Mean({h_u^{k-1},\forall_{u}\isin{N_{v}}})} hvk←Mean(huk−1,∀u∈Nv)
GCN aggregator
gcn聚合,和平均聚合类似,就是将特征向量输入一个一层的网络,通过激活函数后使用,公式如下:
h v k ← σ ( W ⋅ M E A N ( h v k − 1 ∪ h u k − 1 , ∀ u ∈ N v ) ) h_{v}^{k}\larr\sigma(W·MEAN({h_{v}^{k-1}}\cup{h_{u}^{k-1},\forall_{u}\isin{N_{v}}})) hvk←σ(W⋅MEAN(hvk−1∪huk−1,∀u∈Nv))
LSTM aggregator
LSTM聚合,作者考虑到lstm有较好的抽取特征能力,因此采用lstm做了实验,但是,lstm具有序列性,而节点之间是无序列的,因为做了随机排列
Pooling aggregator
池化聚合,作者认为pooling操作可以有效的捕获邻域特征的不同方面,更有利于表达节点,而作者对max pooling和mean pooling进行比较,并没有发现哪个更具优势,因此采用max pooling,公式如下:
A G G R E G A T E k p o o l = m a x ( σ ( W p o o l h u i k + b ) , ∀ u i ∈ N ( v ) ) AGGREGATE_{k}^{pool}=max({\sigma(W_{pool}h_{u_{i}}^{k}+b),\forall_{u_{i}}\isin{N_{(v)}}}) AGGREGATEkpool=max(σ(Wpoolhuik+b),∀ui∈N(v))
比较
作者分别在有监督和无监督的情况下,对这四种聚合操作进行了比较,整体来看,效果最好的应该是Pooling聚合,但是,在不同数据集上,有监督和无监督的聚合操作效果还是有略微差别的,大家可以在训练自己数据集的时候,多尝试一下不同的聚合方式。
代码
论文作者采用了TensorFlow框架做的,并且实现了6种聚合方式,分别是平均聚合、GCN聚合、最大池化聚合、平均池化聚合、2层最大池化聚合、lstm聚合,详细的代码大家可以在TensorFlow版看到,另外还有Torch版,不过torch版只实现了mean聚合和gcn聚合两种方式。
结语
以上,就是小编自己对GraphSAGE的理解,如果大家有问题或者需要补充的,请留言或者加QQ:1143948594
附:Inductive Representation Learning on Large Graphs
这篇关于GraphSAGE-Inductive Representation Learning on Large Graphs的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!