图神经网络框架DGL实现Graph Attention Network (GAT)笔记

2024-09-08 09:18

本文主要是介绍图神经网络框架DGL实现Graph Attention Network (GAT)笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考列表:

[1]深入理解图注意力机制
[2]DGL官方学习教程一 ——基础操作&消息传递
[3]Cora数据集介绍+python读取

一、DGL实现GAT分类机器学习论文

程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3])
在这里插入图片描述

1. 程序

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = gself.fc = nn.Linear(in_dim, out_dim, bias=False)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):alpha = F.softmax(nodes.mailbox['e'], dim=1)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):z = self.fc(h) # eq. 1self.g.ndata['z'] = z self.g.apply_edges(self.edge_attention) # eq. 2self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4return self.g.ndata.pop('h')class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':return torch.cat(head_outs, dim=1)else:return torch.mean(torch.stack(head_outs))class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return hfrom dgl import DGLGraph
from dgl.data import citation_graph as citegrhdef load_core_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, maskimport time 
import numpy as np
g, features, labels, mask = load_core_data()net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):if epoch >= 3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >= 3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.笔记
2.1 初始化一个graph的两种方式

对于如下图数据结构:
0->1
1->2
3->1

多称之为小括号方式

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
或中括号方式:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
note: 同一个graph,每次打印出来的各节点的位置是随机的。

2.2 DGL的update_all函数实际工作过程

利用如下例程说明:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dglN = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):return {'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):print("-batch size--pv size-------------")print(nodes.batch_size(), nodes.mailbox['pv'].size())msgs = torch.sum(nodes.mailbox['pv'], dim=1)pv = (1 - DAMP) / N + DAMP * msgsreturn {'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

打印g.ndata[‘deg’]信息(也即每个节点的入度信息)如下:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_func函数内的打印信息如下:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…

对比图的入度信息与pagerank_reduce_func函数内的打印信息,我们发现:入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…因此,得出:
1)函数update_all并不是将所有节点一起更新;
2)函数update_all将具有同等个数目标节点的节点放在一起更新,形成一个batch,这也是为什么reduce_func(nodes)中的入参中的入参type为dgl.udf.NodeBatch的原因。reduce_func(nodes)中的入参nodes的不同行代表与不同节点相关的数据。

这篇关于图神经网络框架DGL实现Graph Attention Network (GAT)笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147712

相关文章

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import

前端原生js实现拖拽排课效果实例

《前端原生js实现拖拽排课效果实例》:本文主要介绍如何实现一个简单的课程表拖拽功能,通过HTML、CSS和JavaScript的配合,我们实现了课程项的拖拽、放置和显示功能,文中通过实例代码介绍的... 目录1. 效果展示2. 效果分析2.1 关键点2.2 实现方法3. 代码实现3.1 html部分3.2

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

java父子线程之间实现共享传递数据

《java父子线程之间实现共享传递数据》本文介绍了Java中父子线程间共享传递数据的几种方法,包括ThreadLocal变量、并发集合和内存队列或消息队列,并提醒注意并发安全问题... 目录通过 ThreadLocal 变量共享数据通过并发集合共享数据通过内存队列或消息队列共享数据注意并发安全问题总结在 J