图神经网络框架DGL实现Graph Attention Network (GAT)笔记

2024-09-08 09:18

本文主要是介绍图神经网络框架DGL实现Graph Attention Network (GAT)笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考列表:

[1]深入理解图注意力机制
[2]DGL官方学习教程一 ——基础操作&消息传递
[3]Cora数据集介绍+python读取

一、DGL实现GAT分类机器学习论文

程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3])
在这里插入图片描述

1. 程序

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = gself.fc = nn.Linear(in_dim, out_dim, bias=False)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):alpha = F.softmax(nodes.mailbox['e'], dim=1)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):z = self.fc(h) # eq. 1self.g.ndata['z'] = z self.g.apply_edges(self.edge_attention) # eq. 2self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4return self.g.ndata.pop('h')class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':return torch.cat(head_outs, dim=1)else:return torch.mean(torch.stack(head_outs))class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return hfrom dgl import DGLGraph
from dgl.data import citation_graph as citegrhdef load_core_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, maskimport time 
import numpy as np
g, features, labels, mask = load_core_data()net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):if epoch >= 3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >= 3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.笔记
2.1 初始化一个graph的两种方式

对于如下图数据结构:
0->1
1->2
3->1

多称之为小括号方式

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
或中括号方式:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
note: 同一个graph,每次打印出来的各节点的位置是随机的。

2.2 DGL的update_all函数实际工作过程

利用如下例程说明:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dglN = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):return {'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):print("-batch size--pv size-------------")print(nodes.batch_size(), nodes.mailbox['pv'].size())msgs = torch.sum(nodes.mailbox['pv'], dim=1)pv = (1 - DAMP) / N + DAMP * msgsreturn {'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

打印g.ndata[‘deg’]信息(也即每个节点的入度信息)如下:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_func函数内的打印信息如下:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…

对比图的入度信息与pagerank_reduce_func函数内的打印信息,我们发现:入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…因此,得出:
1)函数update_all并不是将所有节点一起更新;
2)函数update_all将具有同等个数目标节点的节点放在一起更新,形成一个batch,这也是为什么reduce_func(nodes)中的入参中的入参type为dgl.udf.NodeBatch的原因。reduce_func(nodes)中的入参nodes的不同行代表与不同节点相关的数据。

这篇关于图神经网络框架DGL实现Graph Attention Network (GAT)笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147712

相关文章

使用Python实现一个优雅的异步定时器

《使用Python实现一个优雅的异步定时器》在Python中实现定时器功能是一个常见需求,尤其是在需要周期性执行任务的场景下,本文给大家介绍了基于asyncio和threading模块,可扩展的异步定... 目录需求背景代码1. 单例事件循环的实现2. 事件循环的运行与关闭3. 定时器核心逻辑4. 启动与停

基于Python实现读取嵌套压缩包下文件的方法

《基于Python实现读取嵌套压缩包下文件的方法》工作中遇到的问题,需要用Python实现嵌套压缩包下文件读取,本文给大家介绍了详细的解决方法,并有相关的代码示例供大家参考,需要的朋友可以参考下... 目录思路完整代码代码优化思路打开外层zip压缩包并遍历文件:使用with zipfile.ZipFil

Python实现word文档内容智能提取以及合成

《Python实现word文档内容智能提取以及合成》这篇文章主要为大家详细介绍了如何使用Python实现从10个左右的docx文档中抽取内容,再调整语言风格后生成新的文档,感兴趣的小伙伴可以了解一下... 目录核心思路技术路径实现步骤阶段一:准备工作阶段二:内容提取 (python 脚本)阶段三:语言风格调

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

基于Java实现回调监听工具类

《基于Java实现回调监听工具类》这篇文章主要为大家详细介绍了如何基于Java实现一个回调监听工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录监听接口类 Listenable实际用法打印结果首先,会用到 函数式接口 Consumer, 通过这个可以解耦回调方法,下面先写一个

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Qt中QGroupBox控件的实现

《Qt中QGroupBox控件的实现》QGroupBox是Qt框架中一个非常有用的控件,它主要用于组织和管理一组相关的控件,本文主要介绍了Qt中QGroupBox控件的实现,具有一定的参考价值,感兴趣... 目录引言一、基本属性二、常用方法2.1 构造函数 2.2 设置标题2.3 设置复选框模式2.4 是否

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法

《springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法》:本文主要介绍springboot整合阿里云百炼DeepSeek实现sse流式打印,本文给大家介绍的非常详细,对大... 目录1.开通阿里云百炼,获取到key2.新建SpringBoot项目3.工具类4.启动类5.测试类6.测

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你