【论文笔记】GraphSAGE:Inductive Representation Learning on Large Graphs(NIPS)

本文主要是介绍【论文笔记】GraphSAGE:Inductive Representation Learning on Large Graphs(NIPS),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

学习心得

  • GCN不能泛化到训练过程中没有出现的节点(即属于 t r a n s d u c t i v e transductive transductive 直推式学习,若加入新节点则需要重新训练模型),既然有新增的结点(一定会改变原有节点),那就没必要一定得到每个节点的固定表示。而GraphSAGE就是为了解决这种问题,利用Sample(采样)和Aggregate(聚合)两大核心步骤,通过利用学习到的聚合函数,得到一个新节点的表示。
  • 本文先介绍GraphSAGE向前传播过程(生成节点embedding),不同的聚合函数设定,然后介绍无监督学习和有监督学习的损失函数和参数学习——参数学习:通过前向传播得到节点 u 的embedding z u z_u zu,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数 W k \mathbf{W}^{k} Wk 和聚合函数内的参数。
  • GCN每次迭代AW 是会用到A整个图邻接矩阵;graphsage可以说对GCN做了进一步精简,每次迭代只抽样取直接相连的邻居;而且GraphSAGE可以通过mini-batch的形式训练,定义合适的领域范围,可以大大减小领接矩阵的维度。

在这里插入图片描述

图源自百度飞桨
注:因无法一次性全图送入计算资源,需要借鉴深度学习中的MiniBatch

PS:其中的SAGE即是Sample(采样)和Aggregate(聚合)两个单词里面抽取的英文字母组合(这两个步骤也是该 i n d u c t i v e inductive inductive 归纳式图表示学习模型的核心步骤)。

用v在k-1跳的表示和v的邻居在k跳的表示聚合得到v在k跳的表示,因此每一跳都包含了v的上一跳信息和这一跳的邻居信息。v在1跳的表示是input(v自己),而通过递归得到的最后一跳(K跳)的表示是output z(v)。z(v)融合了从1跳到K跳所有v的表示信息,因此可以更准确地描述v在整张图里的地位,从而更好地聚类v。
每次只计算一跳邻居特征,然后通过递归得获得k跳内的邻居信息,极大减少计算复杂度

文章目录

  • 学习心得
  • 零、Abstract
  • 一、Introduction
    • 论文的工作:
  • 二、Related work
    • 2.1 Factorization-based embedding approaches(节点embedding)
    • 2.2 Supervised learning over graphs
    • 2.3 Graph convolutional networks
  • 三、Proposed method:GraphSAGE
    • 3.1 Embedding generation algorithm
      • (1)Relation to the Weisfeiler-Lehman Isomorphism Test
      • (2)Neighborhood definition
    • 3.2 Learning the parameters of GraphSAGE
      • (1)基于图的无监督损失
      • (2)基于图的有监督损失
      • (3)参数学习
      • (4)新节点embedding的生成
    • 3.3 Aggregator Architectures
      • (1)均值聚合Mean aggregator
      • (2)LSTM聚合
      • (3)池化聚合Pooling aggregator
  • 四、Experiments
    • 4.1 Inductive learning on evolving graphs:Gitation and Reddit data
      • (一)实验背景
        • (1)实验目的
        • (2)数据集和任务
        • (3)baselines
        • (4)实验参数设置
      • (二)数据集介绍
        • (1)数据集Citation data
        • (2)数据集Reddit data
        • (3)Generalizing across graphs: Protein-protein interactions
      • (三)实验结果
    • 4.2 Runtime and parameter sensitivity
    • 4.3 Summary comparison between the different aggregator architectures
  • 五、Theoretical analysis
  • 六、Conclusion
  • Reference

论文题目:Inductive Representation Learning on Large Graphs
题目中文:在大规模图上的归纳表示学习
作者:Hamilton, William L. and Ying, Rex and Leskovec, Jure(NIPS 2017)
论文链接:https://arxiv.org/abs/1706.02216
一作代码:https://github.com/williamleif/GraphSAGE
官方链接:http://snap.stanford.edu/graphsage/

零、Abstract

从内容推荐到蛋白质功能识别,大型图中节点的低维嵌入在各种预测任务中被证明是非常有用的。然而,现有的大多数方法都要求在训练嵌入时,需要图中的所有节点都存在。以往的方法都是 t r a n s d u c t i v e transductive transductive 式的,而作者提出的 i n d u c t i v e inductive inductive(直推式) 式的 G r a p h S A G E GraphSAGE GraphSAGE 算法利用节点特征信息(如文本属性)有效地为未知的数据生成节点嵌入。本文的方法并非训练每个节点的单独嵌入,而是学习一个函数,该函数通过从节点的局部邻域采样和聚合特征来生成嵌入。该算法在三个 i n d u c t i v e inductive inductive(归纳式) 式节点分类 b e n c h m a r k s benchmarks benchmarks 上超过 b a s e l i n e s baselines baselines :能够根据引文和 R e d d i t Reddit Reddit 数据对演化信息图中未知的节点进行分类,实验表明使用一个 P P I PPI PPI p r o t e i n − p r o t e i n protein-protein proteinprotein i n t e r a c t i o n s interactions interactions)多图数据集,算法可以泛化到完全未见过的图上。

一、Introduction

在大规模图中,节点的低维向量embedding被证明了作为各种预测和图分析任务的特征输入是极为有用的。顶点embedding的基本思想是使用降维将节点图邻域的高维信息提取成密集的向量嵌入。这些节点嵌入可以提供给下游机器学习系统,并帮助完成节点分类、聚类和链接预测等任务[11, 28, 35]。

然而以往的工作一般是从单一的固定图中抽取顶点嵌入,现实中的应用需要能够快速地从不可见的节点或者全新的(子)图中生成 e m b e d d i n g embedding embedding,这种生成式能力对高吞吐量的机器学习系统很重要,特别是当数据处于一个不断演化的图中,不断加入新节点的情况下(如Reddit上的帖子、Youtube上的用户和视频)。

生成节点嵌入的归纳算法也有助于在具有相同特征形式的图之间进行泛化:比如我们可以在模型生物的蛋白质相互作用图上训练一个embedding生成器,然后利用这个生成器方便地为收集的新生物数据生成节点嵌入。

归纳式顶点嵌入问题比直推式难很多,因为要推广到未知节点需要将新观察的子图与算法已经优化过的节点嵌入“对齐”(aligning)。归纳式学习框架必须学会识别一个节点邻域的结构属性,它揭示了图中节点的局部角色和全局位置。

大部分现有的生成顶点嵌入的方法都是直推式的。这些方法中的大多数使用基于矩阵分解的目标直接优化每个节点的嵌入,而不会自然地推广到看不见的数据,因为它们是对单个固定图中的节点进行预测。这些方法可以修改为归纳式的(比如DeepWalk就可以),但是这些修改在计算上很昂贵,在做出新的预测之前需要额外的梯度下降。到目前为止,GCN仅被应用在固定图直推式的任务上。本文将GCN扩展成归纳式无监督学习,并提出一种GCN的推广算法(使用可训练的聚合函数,而不是简单的卷积)。

论文的工作:

  • 提出归纳式节点embedding算法GraphSAGE,该方法与基于矩阵分解的嵌入方法不同,作者利用节点特征(例如文本属性、节点概要信息、节点度)来学习一个可推广到未见节点的嵌入函数。通过在学习算法中引入节点特征,我们同时学习了每个节点的邻域拓扑结构以及节点特征在邻域中的分布情况。
  • 虽然本文关注特征丰富的图(例如,带有文本属性的引文数据,带有功能/分子标记的生物数据),但GraphSAGE也可以利用所有图中存在的结构特征(例如,节点度)。因此该算法也能应用在没有节点特征的图上。

在这里插入图片描述

  • 如上图的figure1所示,GraphSAGE不是为每个节点训练一个不同的嵌入向量,而是训练一组聚合器函数,学习从节点的局部邻域聚合的特征信息。给定义一个节点,每一个聚合函数从一个不同跳数或搜索深度上聚合信息。在测试或者推理的时候,我们使用训练好的系统通过应用学习聚合函数,生成整个不可见节点的嵌入。顺着之前生成节点嵌入的思路,作者设计了一个无监督的损失函数,允许GraphSAGE能做任务无关的监督式训练。同时也展示了GraphSAGE可以以一种完全监督学习的方式训练。
  • 作者在三个顶点分类的基准任务上评估了GraphSAGE模型,测试GraphSAGE在生成不可见节点的嵌入的能力。使用了两个动态文档图(一个是文献引用数据和一个是Reddit帖子数据,一个是预测论文分类,一个是预测帖子分类),并且还有一个多图广义化实验(蛋白质之间的交互作用,是预测蛋白质功能的任务)。使用这些基准任务,我们展示了我们的方法有能力生成不可见节点的表示,并且效果远超baseline模型。

二、Related work

GraphSAGE算法在概念上与以前的节点embedding方法、一般的图形学习监督方法以及最近将卷积神经网络应用于图形结构化数据的进展有关。

2.1 Factorization-based embedding approaches(节点embedding)

使用随机游走的统计方法和基于矩阵分解学习低维的embeddings

  • Grarep: Learning graph representations with global structural information. In KDD, 2015
  • node2vec: Scalable feature learning for networks. In KDD, 2016
  • Deepwalk: Online learning of social - representations. In KDD, 2014
  • Line: Large-scale information network embedding. In WWW, 2015
  • Structural deep network embedding. In KDD, 2016

上述embedding算法直接训练单个节点的节点embedding,本质上是transductive,而且需要大量的额外训练(如随机梯度下降)使他们能预测新的顶点。

此外,Yang et al.的Planetoid-I算法,是一个inductive的基于embedding的半监督学习算法。然而,Planetoid-I在推断的时候不使用任何图结构信息,而在训练的时候将图结构作为一种正则化的形式。

与上述方法不同,本文利用特征(feature)信息来训练可以对未见过的顶点生成embedding的模型。

2.2 Supervised learning over graphs

除了节点嵌入方法,还有大量关于图结构数据的监督学习的文献工作。这包括各种各样的基于内核(graph kernels)的方法,其中图的特征向量来自不同的图内核(参见Weisfeiler-lehman graph kernels和其中的引用)。

一些神经网络方法用于图结构上的监督学习,本文的方法在概念上受到了这些算法的启发。

  • Discriminative embeddings of latent variable models for structured data. In ICML, 2016
  • A new model for learning in graph domains
  • Gated graph sequence neural networks. In ICLR, 2015
  • The graph neural network model

然而,这些以前的方法是尝试对整个图(或子图)进行分类的,但是本文的工作的重点是为单个节点生成有用的表示。

2.3 Graph convolutional networks

近几年提出的几种用于图上学习的卷积神经网络结构:

  • Spectral networks and locally connected networks on graphs. In ICLR, 2014
  • Convolutional neural networks on graphs with fast localized spectral filtering. In NIPS, 2016
  • Convolutional networks on graphs for learning molecular fingerprints. In NIPS,2015
  • Semi-supervised classification with graph convolutional networks. In ICLR, 2016
  • Learning convolutional neural networks for graphs. In ICML, 2016

上述方法中的大多数不能扩展到大型图,或者设计用于全图分类(或者两者都是)。原始的GCN算法[17]是为在 t r a n s d u c t i v e transductive transductive 设置下的半监督学习而设计的,算法要求在训练过程中已知完整的图拉普拉斯算子。GraphSAGE可以看作是对 t r a n s d u c t i v e transductive transductive 的GCN框架对 i n d u c t i v e inductive inductive 设置的扩展。

三、Proposed method:GraphSAGE

3.1 Embedding generation algorithm

GraphSAGE的前向传播算法如下,算法3.1描述了如何使用聚合函数对节点的邻居信息进行聚合,从而生成节点embedding。

在每次迭代(或搜索深度),顶点从它们的局部邻居聚合信息,并且随着这个过程的迭代,顶点会从越来越远的地方获得信息。PinSAGE使用的前向传播算法和GraphSAGE一样,GraphSAGE是PinSAGE的理论基础。

下图算法1描述了在整个图上生成embedding的过程,其中:

  • G = ( V , E ) \mathcal{G}=(\mathcal{V}, \mathcal{E}) G=(V,E) K K K 是网络层数,也代表每个顶点能够聚合的邻接点的跳数(因为每增加一层,可以聚合更远的一层邻居信息)
  • x v , ∀ v ∈ V x_{v}, \forall v \in \mathcal{V} xv,vV表示节点 v v v的属性(特征向量),并且作为输入
  • { h u k − 1 , ∀ u ∈ N ( v ) } \left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\} {huk1,uN(v)}表示在 ( k − 1 ) (k-1) (k1)层中节点 v v v的邻居节点 u u u的embedding
  • h v k , ∀ v ∈ V \mathbf{h}_{v}^{k}, \forall v \in V hvk,vV表示在第 k k k层,节点 v v v的特征表示
  • N ( v ) \mathcal{N}(v) N(v)定义为邻居节点集合 { u ∈ v : ( u , V ) ∈ E } \{u \in v:(u, \mathcal{V}) \in \mathcal{E}\} {uv:(u,V)E}需要从中均匀采样出固定数量的节点做聚合,即GraphSAGE中每一层的节点邻居都是从上一层网络采样的,并不是所有邻居参与,并且采样后的邻居的size是固定的 h N ( v ) k \mathbf{h}_{\mathcal{N}(v)}^{k} hN(v)k表示在第 k k k层,节点 v v v的所有邻居节点的特征表示。

在这里插入图片描述
敲黑板
上图的4-5行是核心代码,介绍了卷积层操作:聚合与节点v相连的邻居(采样)k-1层的embedding,得到第k层邻居聚合特征 h N ( v ) k \mathbf{h}_{\mathcal{N}(v)}^{k} hN(v)k,与节点v第k-1层 embedding h V k \mathbf{h}_{\mathcal{V}}^{k} hVk拼接,并通过全连接层转换,然后进入激活函数计算,得到节点v在第k层的embedding h V k \mathbf{h}_{\mathcal{V}}^{k} hVk
第7行代码:通过除以矢量范数来标准化节点嵌入,以防止梯度爆炸。
在这里插入图片描述
如上图的栗子进行采样,要预测 0 号节点,因此首先随机选择 0 号节点的一阶邻居 2、4、5,然后随机选择 2 号节点的一阶邻居 8、9,对于4、5号节点也是类似。

为了将算法1扩展到mini-batch环境上,给定一组输入顶点,先采样采出需要的邻居集合(直到深度K),然后运行内部循环(算法1的第三行)(附录A包括了完整的mini-batch伪代码)。
在这里插入图片描述

(1)Relation to the Weisfeiler-Lehman Isomorphism Test

(和同构测试的相关性)
GraphSAGE算法从概念上受到测试图同构的经典算法的启发。对于算法1,如果满足条件:

  • K = ∣ V ∣ K=|\mathcal{V}| K=V
  • 设置权重矩阵为单位矩阵
  • 使用一个适当的哈希函数作为一个聚合器(没有非线性)

那么算法1是Weisfeiler-Lehman的实例(WL)同构测试,也称为“naive vertex refinement”[Weisfeiler-lehman graph kernels]。

如果对两个子图通过算法1的特征表示的集合 { z v , ∀ v ∈ V } \left\{z_v,\forall v \in \mathcal{V}\right\} {zv,vV}输出是相同的,那么ML测试就认为这两个子图是同构的。如果用可训练的神经网络聚合器替换了哈希函数,那么GraphSAGE就是WL测试的一个连续近似。当然,文中使用GraphSAGE是为了生成有用的节点表示—而不是测试图的同构。然而,GraphSAGE与经典的WL检验之间的联系为算法设计学习节点邻域的拓扑结构提供了理论基础。

(2)Neighborhood definition

(采样邻居顶点)
出于对计算效率的考虑,对每个顶点采样一定数量的邻居顶点作为待聚合信息的顶点。设需要的邻居数量,即采样数量为S,若顶点邻居数少于S,则采用有放回的抽样方法,直到采样出S个顶点。若顶点邻居数大于S,则采用无放回的抽样。(即采用有放回的重采样/负采样方法达到S)

若不考虑计算效率,完全可以对每个顶点利用其所有的邻居顶点进行信息聚合,这样是信息无损的。

文中在较大的数据集上实验。因此,统一采样一个固定大小的邻域集,以保持每个batch的计算占用空间是固定的(即 graphSAGE并不是使用全部的相邻节点,而是做了固定size的采样)。

这样固定size的采样,每个节点和采样后的邻居的个数都相同,可以把每个节点和它们的邻居拼成一个batch送到GPU中进行批训练。

  • 这里需要注意的是,每一层的node的表示都是由上一层生成的,跟本层的其他节点无关,这也是一种基于层的采样方式。
  • 在图中的“1层”,节点v聚合了“0层”的两个邻居的信息,v的邻居u也是聚合了“0层”的两个邻居的信息。到了“2层”,可以看到节点v通过“1层”的节点u,扩展到了“0层”的二阶邻居节点。因此,在聚合时,聚合K次,就可以扩展到K阶邻居。
  • 没有这种采样,单个batch的内存和预期运行时是不可预测的,在最坏的情况下是 O ( ∣ V ∣ ) O(|\mathcal{V}|) O(V) 。相比之下,GraphSAGE的每个batch空间和时间复杂度是固定的: O ( ∏ i = 1 K S i ) { }_{O}\left(\prod_{i=1}^{K} S_{i}\right) O(i=1KSi)其中 S i , i ∈ { 1 , . . . , K } S_i,i \in \left\{1,...,K\right\} Si,i{1,...,K} 是可以指定的参数。
  • 实验发现,K不必取很大的值,当K=2时,效果就很好了。至于邻居的个数,文中提到 S i ⋅ S 2 ≤ 500 S_{i} · S_{2} ≤ 500 SiS2500,即两次扩展的邻居数之际小于500,大约每次只需要扩展20来个邻居时获得较高的性能。
  • 论文里说固定长度的随机游走其实就是随机选择了固定数量的邻居。

为啥这样采样,因为:
(1)方便批处理:在给定一批要更新的节点后,要先取出它们的K阶邻居节点集合。
(2)减低时间复杂度:只采样固定数量的邻居节点而非所有的。

3.2 Learning the parameters of GraphSAGE

在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。

(1)基于图的无监督损失

基于图的损失函数倾向于使得相邻的顶点有相似的表示,但这会使相互远离的顶点的表示差异变大:
J G ( z u ) = − log ⁡ ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) log ⁡ ( σ ( − z u T z v n ) ) ( 1 ) J \mathcal{G}\left(\mathbf{z}_{u}\right)=-\log \left(\sigma\left(\mathbf{z}_{u}^{T} \mathbf{z}_{v}\right)\right)-Q \cdot \mathbb{E}_{v_{n} \sim P_{n}(v)} \log \left(\sigma\left(-\mathbf{z}_{u}^{T} \mathbf{z}_{v_{n}}\right)\right) \qquad (1) JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzvn))(1)其中:

  • Z u \mathbf{Z}_{u} Zu为节点 u u u 通过GraphSAGE生成的embedding
  • 节点 v v v 是节点 u u u 随机游走到达的邻居
  • σ \sigma σ是sigmoid函数
  • v n ∼ P n ( v ) v_n \sim P_{n}(v) vnPn(v)是负采样的概率分布(类似word2vec的负采样),节点 v n v_n vn是从节点u的负采样分布 P n P_n Pn采样的,Q为采样样本数。
  • embedding之间相似度通过向量点积计算得到
  • 负采样指我们还需要一批不是v邻居的节点作为负样本,那么上面式子就是指相邻节点的embedding的相似度尽量大的情况下保证不相邻节点的embedding的期望相似度尽可能小。

文中输入到损失函数的表示 z u \mathbf{z}_u zu​ 是从包含一个顶点局部邻居的特征生成出来的,而不像之前的那些方法(如DeepWalk),那些方法是对每个顶点训练一个独一无二的embedding,然后简单进行一个embedding lookup(查找)操作得到。

(2)基于图的有监督损失

无监督损失函数的设定来学习节点embedding 可以供下游多个任务使用。监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。

(3)参数学习

通过前向传播得到节点 u 的embedding z u z_u zu,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数 W k \mathbf{W}^{k} Wk 和聚合函数内的参数。

(4)新节点embedding的生成

W k W^k Wk 是所谓的dynamic embedding核心,因为保存下来了从节点原始的高维特征生成低维embedding的方式。现在,如果想得到一个点的embedding,只需要输入节点的特征向量,经过卷积(利用已经训练好的 W k \mathbf{W}^{k} Wk 以及特定聚合函数聚合neighbor的属性信息),就产生了节点的embedding。

3.3 Aggregator Architectures

在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即改变输入的顺序,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上
在这里插入图片描述

(1)均值聚合Mean aggregator

mean aggregator将目标顶点和邻居顶点的第k − 1层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第k层表示向量。
文中用下面的式子替换算法1(3.1的图)伪代码中的4行和5行得到GCN的inductive变形:
h v k ← σ ( W ⋅ MEAN ⁡ ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) ) ( 2 ) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right) \qquad (2) hvkσ(WMEAN({hvk1}{huk1,uN(v)}))(2)原始第4、5行是:
h v k ← σ ( W ⋅ MEAN ⁡ ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right. hvkσ(WMEAN({hvk1}{huk1,uN(v)}) h v k ← σ ( W k ⋅ CONCAT ⁡ ( h v k − 1 , h N ( v ) k ) ) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right) hvkσ(WkCONCAT(hvk1,hN(v)k))

  • 均值聚合近似等价在transducttive GCN框架[Semi-supervised classification with graph convolutional networks. In ICLR, 2016]中的卷积传播规则
  • 文中称这个修改后的基于均值的聚合器是convolutional的,这个卷积聚合器和文中的其他聚合器的重要不同在于它没有算法1中第5行的CONCAT操作——卷积聚合器没有将顶点前一层的表示 h v k − 1 \mathbf{h}^{k-1}_v hvk1和聚合的邻居向量 h N ( v ) k \mathbf{h}^k_{\mathcal{N}(v)} hN(v)k拼接起来
  • 拼接操作可以看作一个是GraphSAGE算法在不同的搜索深度或层之间的简单的skip connection[Identity mappings in deep residual networks]的形式,它使得模型获得了巨大的提升
  • 举个简单例子,比如一个节点的3个邻居的embedding分别为[1,2,3,4],[2,3,4,5],[3,4,5,6]按照每一维分别求均值就得到了聚合后的邻居embedding为[2,3,4,5]

(2)LSTM聚合

文中也测试了一个基于LSTM的复杂的聚合器[Long short-term memory]。和均值聚合器相比,LSTMs有更强的表达能力。但是,LSTM不是symmetric的,也就是说不具有排列不变性(permutation invariant),因为它们以一个序列的方式处理输入。因此,每次迭代时先随机打乱要聚合的邻接点的顺序,然后将邻居序列的embedding作为LSTM的输入

排列不变性(permutation invariance):指输入的顺序改变不会影响输出的值。

(3)池化聚合Pooling aggregator

pooling聚合器,它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的embedding向量进行一次非线性变换,之后进行一次pooling操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。

一个element-wise max pooling操作应用在邻居集合上来聚合信息:
h N ( v ) k = A G G R E G A T E k p o o l = max ⁡ ( { σ ( W p o o l h u k − 1 + b ) , ∀ u ∈ N ( v ) } ) ( 2 ) \mathbf{h}_{\mathcal{N}(v)}^{k}=\mathrm{AGGREGATE}_{k}^{p o o l}=\max \left(\left\{\sigma\left(\mathbf{W}_{p o o l} \mathbf{h}_{u}^{k-1}+\mathbf{b}\right), \forall u \in \mathcal{N}(v)\right\}\right) \qquad (2) hN(v)k=AGGREGATEkpool=max({σ(Wpoolhuk1+b),uN(v)})(2) h v k ← σ ( W k ⋅ CONCAT ⁡ ( h v k − 1 , h N ( v ) k ) ) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right) hvkσ(WkCONCAT(hvk1,hN(v)k))

  • max表示element-wise最大值操作,取每个特征的最大值
  • σ \sigma σ 是非线性激活函数
  • 所有相邻节点的向量共享权重,先经过一个非线性全连接层,然后做max-pooling
  • 按维度应用 max/mean pooling,可以捕获邻居集上在某一个维度的突出的/综合的表现。

池化聚合先让所有邻接点通过一个全连接层,然后做最大池化。

四、Experiments

4.1 Inductive learning on evolving graphs:Gitation and Reddit data

(一)实验背景

(1)实验目的
  • 比较GraphSAGE 相比baseline 算法的提升效果
  • 比较GraphSAGE的不同聚合函数
(2)数据集和任务

Citation 论文引用网络(节点分类)
Reddit 帖子论坛 (节点分类)
PPI 蛋白质网络 (graph分类)

(3)baselines
  • Random,随机分类器
  • Raw features,手工特征(非图特征)
  • deepwalk(图拓扑特征)
  • DeepWalk + features, deepwalk+手工特征

除此以外,还比较了GraphSAGE四个变种 ,并无监督生成embedding输入给LR和端到端有监督。因为,GraphSAGE的卷积变体是一种扩展形式,是Kipf et al. 半监督GCN的inductive版本,称这个变体为GraphSAGE-GCN。
以上baselines的分类器均采用LR。

在所有这些实验中,预测在训练期间看不到的节点,在PPI数据集的情况下,实验在完全看不见的图上进行了测试。

(4)实验参数设置
  • K=2,聚合两跳内邻居特征
  • S1=25,S2=10: 对一跳邻居抽样25个,二跳邻居抽样10个
  • RELU 激活单元
  • Adam 优化器(除了DeepWalk,DeepWalk使用梯度下降效果更好)
  • 文中所有的模型都是用TensorFlow实现
  • 对每个节点进行步长为5的50次随机游走
  • 负采样参考word2vec,按平滑degree进行,对每个节点采样20个
  • 保证公平性:所有版本都采用相同的minibatch迭代器、损失函数、邻居采样器
  • 实验测试了根据式1的损失函数训练的GraphSAGE的各种变体,还有在分类交叉熵损失上训练的可监督变体
  • 对于Reddit和citation数据集,使用”online”的方式来训练DeepWalk
  • 在多图情况下,不能使用DeepWalk,因为通过DeepWalk在不同不相交的图上运行后生成的embedding空间对它们彼此说可能是arbitrarily rotated的(见文中附录D)

(二)数据集介绍

前两个实验是在演化的信息图中对节点进行分类,这是一个与高吞吐量生产系统特别相关的任务,该系统经常遇到不可见的数据。

(1)数据集Citation data

第一个任务是在一个大的引文数据集中预测论文主题类别。文中使用来自汤姆森路透科学核心数据库(Thomson Reuters Web of Science Core
Collection)的无向的引文图数据集(对应于2000-2005年六个生物相关领域的所有论文)。这个数据集的节点标签对应于六个不同的领域的标签。该数据集共包含302,424个节点,平均度数为9.15。文中使用2000-2004年的数据集对所有算法进行训练,并使用2005年的数据进行测试(30%用于验证)。对于特征,本文使用节点的度。此外,按照Arora等人的sentence embedding方法处理论文摘要(使用GenSim用word2vec实现训练的300维单词向量)。

(2)数据集Reddit data

第二个任务预测不同的Reddit帖子(posts)属于哪个社区。Reddit是一个大型的在线论坛,用户可以在这里对不同主题社区的内容进行发布和评论。作者在Reddit上对2014年9月发布的帖子构建了一个图形数据集。本例中的节点标签是帖子所属的社区或“subreddit”。文中对50个大型社区进行了抽样,并构建了一个帖子-帖子的图,如果同一个用户评论了两个帖子,就将这两个帖子连接起来。

该数据集共包含232,965个帖子,平均度为492。文中将前20天的用于训练,其余的用于测试(30%用于验证)。对于特征,文中使用现成的300维GloVe CommonCrawl词向量对于每一篇帖子,将下面的内容连接起来:

  • 帖子标题的平均embedding
  • 所有帖子评论的平均embedding
  • 该帖子的得分
  • 该帖子的评论数量
(3)Generalizing across graphs: Protein-protein interactions

考虑跨图进行泛化的任务,这需要了解节点的角色,而不是社区结构。文中在各种蛋白质-蛋白质相互作用(PPI)图中对蛋白质角色进行分类,每个图对应一个不同的人体组织。并且使用从Molecular Signatures Database中收集的位置基因集、motif基因集和免疫学signatures作为特征,gene ontology作为标签(共121个)。图中平均包含2373个节点,平均度为28.8。文中将所有算法在20个图上训练,然后在两个测试图上预测F1 socres(另外两个图用于验证)。

(三)实验结果

在这里插入图片描述

  • 可以看到GraphSAGE的性能显著优于baseline方法
  • 三个数据集上的实验结果表明,一般是LSTM或pooling效果比较好,有监督都比无监督好(PS:每个数据集下有2个指标,Unsup是无监督,sup F1是有监督对应的F1值)
  • 无监督版本的GraphSAGE-pool对引文数据和Reddit数据的连接(concatenation)性能分别比DeepWalk embeddings和raw features的连接性能好13.8%和29.1%,而有监督版本的连接性能分别提高了19.7%和37.2%
  • 尽管LSTM是为有序数据而不是无序集设计的,但是基于LSTM的聚合器显示了强大的性能
  • 最后,可以看到无监督GraphSAGE的性能与完全监督的版本相比具有相当的竞争力,这表明文中的框架可以在不进行特定于任务的微调(task-specific fine-tuning)的情况下实现强大的性能

4.2 Runtime and parameter sensitivity

在这里插入图片描述
运行时间和参数敏感性:

  • 计算时间:GraphSAGE中LSTM训练速度最慢,但相比DeepWalk,GraphSAGE在预测时间减少100-500倍(因为对于未知节点,DeepWalk要重新进行随机游走以及通过SGD学习embedding)
  • 邻居采样数量:图B中邻居采样数量递增,F1也增大,但计算时间也变大。 为了平衡F1和计算时间,将S1设为25
  • 聚合K跳内信息:在GraphSAGE, K=2 相比K=1 有10-15%的提升;但将K设置超过2,效果上只有0-5%的提升,但是计算时间却变大了10-100倍

4.3 Summary comparison between the different aggregator architectures

不同聚合器之间的比较:

  • LSTM和pool的效果较好
  • 为了更定量地了解这些趋势,实验中将设置六种不同的实验,即(3个数据集)×(非监督vs.监督))
  • GraphSAGE-LSTM比GraphSAGE-pool慢得多(≈2×),这可能使基于pooling的聚合器在总体上略占优势
  • LSTM方法和pooling方法之间没有显著差异
  • 文中使用非参数Wilcoxon Signed-Rank检验来量化实验中不同聚合器之间的差异,在适用的情况下报告T-statistic和p-value。

五、Theoretical analysis

略。

六、Conclusion

GraphSAGE能够有效地泛化到未知节点,该方法也比很多baselines要强,但作者认为仍可以拓展和改进,例如扩展GraphSAGE以合并有向图或多模态图。作者认为未来工作的一个特别有趣的方向是探索非均匀邻域采样函数,甚至可能学习这些函数作为GraphSAGE优化的一部分。

最后小结下GraphSAGE 的优点和缺点:
优点:
(1)通过邻居采样的方式解决了GCN内存爆炸的问题,适用于大规模图的表示学习;
(2)将 transductive 转化为 inductive learning,而且支持增量特征;
(3)引入邻居采样,可有效防止训练过拟合,增强泛化能力;GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出;
(4)可以根据不同领域的图场景来自定义图聚合方式。

缺点:
(1)邻居采样引入随机过程,推理阶段同一节点 embedding 特征不稳定,且邻居采样会导致反向传播时梯度不稳定;
(2)邻居采样数目限制会导致部分节点的重要局部信息丢失;

Reference

(1)【Graph Neural Network】GraphSAGE: 算法原理,实现和应用,阿里浅梦
(2)https://blog.csdn.net/yyl424525/article/details/100532849?spm=1001.2014.3001.5501
(3)GraphSAGE代码讲解-知乎
(4)网络表示学习: 淘宝推荐系统&&GraphSAGE
(5)鱼佬GNN 系列(三):GraphSAGE
(6)GraphSAGE: GCN落地必读论文
(7)https://baidu-pgl.gz.bcebos.com/pgl-course/lesson_4.pdf
(8)https://wmathor.com/index.php/archives/1533/
(9)斯坦福论文主页:http://snap.stanford.edu/graphsage/
(9)Transductive Learning vs Inductive Learning
(10)https://www.zhihu.com/search?type=content&q=GraphSAGE%E7%BF%BB%E8%AF%91

这篇关于【论文笔记】GraphSAGE:Inductive Representation Learning on Large Graphs(NIPS)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/389803

相关文章

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个

查看提交历史 —— Git 学习笔记 11

查看提交历史 查看提交历史 不带任何选项的git log-p选项--stat 选项--pretty=oneline选项--pretty=format选项git log常用选项列表参考资料 在提交了若干更新,又或者克隆了某个项目之后,你也许想回顾下提交历史。 完成这个任务最简单而又有效的 工具是 git log 命令。 接下来的例子会用一个用于演示的 simplegit

记录每次更新到仓库 —— Git 学习笔记 10

记录每次更新到仓库 文章目录 文件的状态三个区域检查当前文件状态跟踪新文件取消跟踪(un-tracking)文件重新跟踪(re-tracking)文件暂存已修改文件忽略某些文件查看已暂存和未暂存的修改提交更新跳过暂存区删除文件移动文件参考资料 咱们接着很多天以前的 取得Git仓库 这篇文章继续说。 文件的状态 不管是通过哪种方法,现在我们已经有了一个仓库,并从这个仓

忽略某些文件 —— Git 学习笔记 05

忽略某些文件 忽略某些文件 通过.gitignore文件其他规则源如何选择规则源参考资料 对于某些文件,我们不希望把它们纳入 Git 的管理,也不希望它们总出现在未跟踪文件列表。通常它们都是些自动生成的文件,比如日志文件、编译过程中创建的临时文件等。 通过.gitignore文件 假设我们要忽略 lib.a 文件,那我们可以在 lib.a 所在目录下创建一个名为 .gi