【论文阅读】Variational Graph Auto-Encoder

2024-01-07 21:12

本文主要是介绍【论文阅读】Variational Graph Auto-Encoder,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

0、基本信息

  • 会议:2016-NIPS
  • 作者:Thomas N. Kipf,Max Welling
  • 文章链接:Variational Graph Auto-Encoder
  • 代码链接:Variational Graph Auto-Encoder

1、介绍

本文提出一个变分图自编码器,一个基于变分自编码(VAE)的,用于在图结构数据上无监督学习的框架。其基本思路是:用已知的图(graph)经过编码(图卷积)学到节点向量表示的分布,在分布中采样得到节点的向量表示,然后进行解码(链接预测)重新构建图。

2、创新点

  • 将变分自编码器(VAE)迁移到图领域

3、准备工作

3.1、自编码器(AE)

自编码器由两个部分组成,分别是编码器和解码器,其中编码器通过神经网络,得到原始数据的低维向量表示;解码器也通过神经网络,将低维向量表示还原为原始数据。
下图是自编码器的一个例子:
![[AE.png]]
自编码器的训练目标是最小化重建误差,即使输入和输出保持尽量一致。

3.2、变分自编码器

如果将解码器看做一个生成模型,我们只要有低维向量表示,就可以用这个生成模型得到近似真实的样本。但是,这样的生成模型存在一个问题:低维向量表示必须是由真实样本通过编码器得到的,否则随机产生的低维向量表示通过生成模型几乎不可能得到近似真实的样本。
那么,如果能将低维向量表示约束在一个分布(比如正态分布)中,那么从该分布中随机采样,产生的低维向量表示通过生成模型不是就能产生近似真实的样本了吗?

变分自编码器就是这样的一种自编码器:变分自编码器通过编码器学到的不是样本的低维向量表示,而是低维向量表示的分布。假设这个分布服从正态分布,然后在低维向量表示的分布中采样得到低维向量表示,接下来经过解码器还原出原始样本。
变分自编码器将真实样本 X X X输入变分图自编码器,通过编码器学到每个样本对应的低维向量表示的均值 μ \mu μ和方差 σ 2 \sigma^2 σ2,然后再 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)中采样出变量的表征,再通过解码器(生成器)生成样本 X ^ \hat{X} X^。均值描述了概率分布的期望值,而标准差描述了概率分布的广度。

3.3、图自编码器(GAE)

图自编码器,即Graph Auto-Encoders,简写GAE。图自编码器也由两部分组成,编码器和解码器。

(1)编码器
GAE的编码器是一个简单的两层GCN模型:
Z = G C N ( A , X ) \mathrm{Z} = \mathrm{GCN}\mathrm(A,\mathrm{X}) Z=GCN(A,X)
更具体一些,两层的GCN在论文中定义如下:
G C N ( A , X ) = A ~ R e L U ( A ^ X W 0 ) W 1 \mathrm{GCN(A,X)}=\tilde{A}\mathrm{ReLU(\hat{A}\mathrm{X}\mathrm{W^0})W^1} GCN(A,X)=A~ReLU(A^XW0)W1
其中, A ~ = D − 1 2 A D − 1 2 \tilde{A}=D^{-\frac{1}{2}}AD^{-\frac{1}{2}} A~=D21AD21,即对称归一化邻接矩阵, W 1 W^1 W1 W 0 W^0 W0是GCN需要学习的参数。
通过编码器,我们可以得到结点的嵌入向量 Z Z Z.
(2)解码器
编码器得到节点表示的向量后,通过解码器通过向量的内积来重构邻接矩阵:
A ^ = σ ( Z Z T ) \hat{A} = \sigma{(ZZ^T)} A^=σ(ZZT)
在GAE中,我们需要优化编码器中的 W 0 W^0 W0 W 1 W^1 W1进而使得经解码器重构出的邻接矩阵 A ^ \hat{A} A^ 与原始的邻接矩阵A尽量相似。因为邻居矩阵决定了图的结构,经节点向量表示重构出的邻接矩阵与原始邻接矩阵越相似,说明节点的向量表示越符合图的结构。因此,GAE中的损失函数可以定义如下:
L = − 1 N ∑ ( y l o g ( y ^ ) + ( 1 − y ) l o g ( 1 − y ^ ) ) \mathcal{L}=-\frac{1}{N}\sum (ylog(\hat{y})+(1-y)log(1-\hat{y})) L=N1(ylog(y^)+(1y)log(1y^))
这里 y y y表示原始邻接矩阵A中的元素,其值为0或1; y ^ \hat{y} y^ 为重构的邻接矩阵 A ^ \hat{A} A^中的元素。从上述损失函数可以看出,损失函数的本质就是两个交叉熵损失函数之和。
当然,我们可以对原始论文中的GAE进行扩展,例如编码器可以使用其他的GNN模型。
2.4、KL散度(又叫相对熵)
如果我们对于同一个随机变量 x有两个单独的概率分布 P(x) 和 Q(x),我们可以使用 KL 散度(Kullback-Leibler (KL) divergence)来衡量这两个分布的差异。

在机器学习中,P往往用来表示样本的真实分布,Q用来表示模型所预测的分布,那么KL散度就可以计算两个分布的差异,也就是Loss损失值。其公式定义如下:
K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) KL(p||q)=\sum_{i=1}^np(x_i)log(\frac{p(x_i)}{q(x_i)}) KL(p∣∣q)=i=1np(xi)log(q(xi)p(xi))
从KL散度公式中可以看到Q的分布越接近P(Q分布越拟合P),那么散度值越小,即损失值越小。

4、变分图自编码器(VGAE)

变分图自编码器也有两部分组成,分别是推理模型(编码器)和生成模型(解码器)。

4.1、推理模型

在GAE中,可训练的参数只有 W 0 W^0 W0 W 1 W^1 W1 ,训练结束后只要输入邻接矩阵 A A A和节点特征矩阵 X X X,就能得到节点的向量表征 Z Z Z

与GAE不同,在变分图自编码器VGAE中,节点表征 Z Z Z不是由一个确定的GCN得到,而是从一个多维高斯分布中采样得到。

多维高斯分布的均值和方差由两个GCN确定:
μ = G C N μ ( X , A ) \mu=\mathrm{GCN}_\mu(\mathrm{X,A}) μ=GCNμ(X,A)
以及
l o g σ = G C N σ ( X , A ) log \;\sigma = \mathrm{GCN}_\sigma(\mathrm{X,A}) logσ=GCNσ(X,A)
论文中,这两个不同的GCN都是两层,并且第一层的参数 W 0 W^0 W0是共享的。

有了均值和方差后,我们就能唯一地确定一个多维高斯分布,然后从中进行采样以得到节点的向量表示 Z Z Z,也就是说,向量表征的后验概率分布为:

q ( Z ∣ X , A ) = ∏ i = 1 N q ( z i ∣ X , A ) , q(\mathbf{Z}|\mathbf{X},\mathbf{A})=\prod_{i=1}^Nq(\mathbf{z}_i|\mathbf{X},\mathbf{A}), q(ZX,A)=i=1Nq(ziX,A),
其中,
q ( z i ∣ X , A ) = N ( z i ∣ μ i , diag ⁡ ( σ i 2 ) ) q(\mathbf{z}_i|\mathbf{X},\mathbf{A})=\mathcal{N}(\mathbf{z}_i|\boldsymbol{\mu}_i,\operatorname{diag}(\boldsymbol{\sigma}_i^2)) q(ziX,A)=N(ziμi,diag(σi2))
其中 和 μ i \mu_i μi σ i 2 \sigma^2_i σi2分别表示节点向量的均值和方差。也就是说,通过两个GCN我们得到了所有节点向量的均值和方差,然后再从中采样形成节点向量。具体来讲,编码器得到多维高斯分布的均值向量和协方差矩阵后,我们就可以通过采样来得到节点的向量表示,但是,采样操作无法提供梯度信息,这对神经网络来讲是没有意义的,因此作者做了重采样:

z = μ + ϵ σ \mathrm{z=\mu+\epsilon\sigma} z=μ+ϵσ
这里 ϵ \epsilon ϵ服从 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1),也就是标准高斯分布,因为 ϵ \epsilon ϵ服从标准高斯分布,所以 μ + ϵ σ \mu+\epsilon \sigma μ+ϵσ服从 N ( μ , σ 2 ) \mathcal{N}(\mu,\sigma^2) N(μ,σ2).

反复从训练集中抽取样本,并对每个样本重新拟合感兴趣的模型,以获得有关拟合模型的其他信息。

4.2、生成模型

生成模型,通过计算图中任意两个节点间存在边的概率来重构图,
p ( A ∣ Z ) = ∏ i = 1 N ∏ j = 1 N p ( A i j ∣ z i , z j ) , with p ( A i j = 1 ∣ z i , z j ) = σ ( z i ⊤ z j ) , p\left(\mathbf{A}|\mathbf{Z}\right)=\prod_{i=1}^N\prod_{j=1}^Np\left(A_{ij}|\mathbf{z}_i,\mathbf{z}_j\right),\quad\text{with}\quad p\left(A_{ij}=1|\mathbf{z}_i,\mathbf{z}_j\right)=\sigma(\mathbf{z}_i^\top\mathbf{z}_j), p(AZ)=i=1Nj=1Np(Aijzi,zj),withp(Aij=1∣zi,zj)=σ(zizj),

也就是说,解码器通过计算任意两个节点向量表示的相似性来重建图结构。

4.3、学习

损失函数定义如下:
L = E q ( Z ∣ X , A ) [ log ⁡ p ( A ∣ Z ) ] − K L [ q ( Z ∣ X , A ) ∣ ∣ p ( Z ) ] \mathcal{L}=\mathbb{E}_{q(\mathbf{Z}\mid\mathbf{X},\mathbf{A})}\big[\log p\left(\mathbf{A}\mid\mathbf{Z}\right)\big]-\mathrm{KL}\big[q(\mathbf{Z}\mid\mathbf{X},\mathbf{A})||p(\mathbf{Z})\big] L=Eq(ZX,A)[logp(AZ)]KL[q(ZX,A)∣∣p(Z)]

  • 第一部分与GAE中类似,为交叉熵函数,也就是经分布 q q q得到的向量重构出的图与原图的差异,这种差异越小越好;
  • 第二部分表示利用GCN得到的分布 q ( ⋅ ) q(·) q()与标准高斯分布 p ( Z ) p(Z) p(Z)间的KL散度,也就是要求分布 q q q尽量与标准高斯分布相似。
    其中, p ( Z ) p(Z) p(Z)为为标准高斯分布 p ( Z ) = ∏ i p ( z i ) = ∏ i N ( z i ∣ 0 , I ) . \begin{aligned}p(\mathbf{Z})=\prod_ip(\mathbf{z_i})=\prod_i\mathcal{N}(\mathbf{z}_i|0,\mathbf{I}).\end{aligned} p(Z)=ip(zi)=iN(zi∣0,I).

损失函数展开形式下:
∑ i = 1 n { − p ( x ) l o g ( p ( x ) ) − ( 1 − p ( x ) ) l o g ( 1 − p ( x ) ) + 1 / 2 ( μ ( i ) 2 + σ ( i ) 2 − l o g σ ( i ) 2 − 1 ) } \sum_{i=1}^n\{-p(x)log(p(x))-(1-p(x))log(1-p(x)) +1/2(\mu_{(i)}^{2}+\sigma_{(i)}^{2}-log\sigma_{(i)}^{2}-1)\} i=1n{p(x)log(p(x))(1p(x))log(1p(x))+1/2(μ(i)2+σ(i)2logσ(i)21)}

5、模型实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import args
class VGAE(nn.Module):def __init__(self,adj) -> None:super(VGAE,self).__init__()self.adj = adjself.share_gcn = GraphConvLayer(args.infeat,args.hidfeat1)self.gcn_mean  = GraphConvLayer(args.hidfeat1,args.hidfeat2,activation = lambda x: x)self.gcn_logstd = GraphConvLayer(args.hidfeat1,args.hidfeat2,activation = lambda x:x)def encode(self,x,adj):hidden = self.share_gcn(x,adj)self.mean = self.gcn_mean(hidden,adj)self.logstd = self.gcn_logstd(hidden,adj)gaussian_noise = torch.randn(x.shape[0],args.hidfeat2)sampled_z = gaussian_noise * torch.exp(self.logstd) + self.meanreturn sampled_z@staticmethoddef decode(z):'''静态方法'''A_pred = F.sigmoid(torch.matmul(z,z.T))return A_preddef forward(self,x):Z = self.encode(x,self.adj)A_pred = self.decode(Z)return A_pred
class GraphConvLayer(nn.Module):    def __init__(self, infeat,outfeat,activation = F.relu) -> None:super(GraphConvLayer,self).__init__()self.activation = activationself.layer = nn.Linear(infeat,outfeat,bias=False)self.init_param()passdef init_param(self):self.layer.reset_parameters()def forward(self,x,adj):AX = torch.mm(adj,x)AXW = self.layer(AX)return self.activation(AXW)

参考链接

  • 参考链接1
  • 参考链接2
  • 参考链接3

这篇关于【论文阅读】Variational Graph Auto-Encoder的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/581242

相关文章

ssh在本地虚拟机中的应用——解决虚拟机中编写和阅读代码不方便问题的一个小技巧

虚拟机中编程小技巧分享——ssh的使用 事情的起因是这样的:前几天一位工程师过来我这边,他看到我在主机和虚拟机运行了两个vscode环境,不经意间提了句:“这么艰苦的环境写代码啊”。 后来我一想:确实。 我长时间以来都是直接在虚拟机里写的代码,但是毕竟是虚拟机嘛,有时候编辑器没那么流畅,在文件比较多的时候跳转很麻烦,容易卡住。因此,我当晚简单思考了一下,想到了一个可行的解决方法——即用ssh

康奈尔大学之论文审稿模型Reviewer2及我司七月对其的实现(含PeerRead)

前言 自从我司于23年7月开始涉足论文审稿领域之后「截止到24年6月份,我司的七月论文审稿GPT已经迭代到了第五版,详见此文的8.1 七月论文审稿GPT(从第1版到第5版)」,在业界的影响力越来越大,所以身边朋友如发现业界有相似的工作,一般都会第一时间发给我,比如本部分要介绍的康奈尔大学的reviewer2 当然,我自己也会各种看类似工作的论文,毕竟同行之间的工作一定会互相借鉴的,我们会学他们

芯片后端之 PT 使用 report_timing 产生报告如何阅读

今天,就PT常用的命令,做一个介绍,希望对大家以后的工作,起到帮助作用。 在PrimeTime中,使用report_timing -delay max命令生成此报告。switch -delay max表示定时报告用于设置(这是默认值)。 首先,我们整体看一下通过report_timing 运行之后,报告产生的整体样式。 pt_shell> report_timing -from start_

【论文精读】分类扩散模型:重振密度比估计(Revitalizing Density Ratio Estimation)

文章目录 一、文章概览(一)问题的提出(二)文章工作 二、理论背景(一)密度比估计DRE(二)去噪扩散模型 三、方法(一)推导分类和去噪之间的关系(二)组合训练方法(三)一步精确的似然计算 四、实验(一)使用两种损失对于实现最佳分类器的重要性(二)去噪结果、图像质量和负对数似然 论文:Classification Diffusion Models: Revitalizing

【python】python葡萄酒国家分布情况数据分析pyecharts可视化(源码+数据集+论文)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C++/Python语言 👉公众号👈:测试开发自动化【获取源码+商业合作】 👉荣__誉👈:阿里云博客专家博主、51CTO技术博主 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、测试技术。 python葡萄酒国家分布情况数据分析pyecharts可视化(源码+数据集+论文)【独一无二】 目录 python葡

论文阅读--Efficient Hybrid Zoom using Camera Fusion on Mobile Phones

这是谷歌影像团队 2023 年发表在 Siggraph Asia 上的一篇文章,主要介绍的是利用多摄融合的思路进行变焦。 单反相机因为卓越的硬件性能,可以非常方便的实现光学变焦。不过目前的智能手机,受制于物理空间的限制,还不能做到像单反一样的光学变焦。目前主流的智能手机,都是采用多摄的设计,一般来说一个主摄搭配一个长焦,为了实现主摄与长焦之间的变焦,目前都是采用数字变焦的方式,数字变焦相比于光学

【LLM之KG】CoK论文阅读笔记

研究背景 大规模语言模型(LLMs)在许多自然语言处理(NLP)任务中取得了显著进展,特别是在零样本/少样本学习(In-Context Learning, ICL)方面。ICL不需要更新模型参数,只需利用几个标注示例就可以生成预测。然而,现有的ICL和链式思维(Chain-of-Thought, CoT)方法在复杂推理任务上仍存在生成的推理链常常伴随错误的问题,导致不真实和不可靠的推理结果。

【python】python基于akshare企业财务数据对比分析可视化(源码+数据集+论文)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C++/Python语言 👉公众号👈:测试开发自动化【获取源码+商业合作】 👉荣__誉👈:阿里云博客专家博主、51CTO技术博主 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、测试技术。 系列文章目录 目录 系列文章目录一、设计要求二、设计思路三、可视化分析 一、设计要求 选取中铁和贵州茅

AIGC-Animate Anyone阿里的图像到视频 角色合成的框架-论文解读

Animate Anyone: Consistent and Controllable Image-to-Video Synthesis for Character Animation 论文:https://arxiv.org/pdf/2311.17117 网页:https://humanaigc.github.io/animate-anyone/ MOTIVATION 角色动画的

【python】python股票量化交易策略分析可视化(源码+数据集+论文)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C++/Python语言 👉公众号👈:测试开发自动化【获取源码+商业合作】 👉荣__誉👈:阿里云博客专家博主、51CTO技术博主 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、测试技术。 【python】python股票量化交易策略分析可视化(源码+数据集+论文)【独一无二】 目录 【python】pyt