图神经网络框架DGL实现Graph Attention Network (GAT)笔记

2024-09-08 09:18

本文主要是介绍图神经网络框架DGL实现Graph Attention Network (GAT)笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考列表:

[1]深入理解图注意力机制
[2]DGL官方学习教程一 ——基础操作&消息传递
[3]Cora数据集介绍+python读取

一、DGL实现GAT分类机器学习论文

程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3])
在这里插入图片描述

1. 程序

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = gself.fc = nn.Linear(in_dim, out_dim, bias=False)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):alpha = F.softmax(nodes.mailbox['e'], dim=1)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):z = self.fc(h) # eq. 1self.g.ndata['z'] = z self.g.apply_edges(self.edge_attention) # eq. 2self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4return self.g.ndata.pop('h')class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':return torch.cat(head_outs, dim=1)else:return torch.mean(torch.stack(head_outs))class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return hfrom dgl import DGLGraph
from dgl.data import citation_graph as citegrhdef load_core_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, maskimport time 
import numpy as np
g, features, labels, mask = load_core_data()net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):if epoch >= 3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >= 3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.笔记
2.1 初始化一个graph的两种方式

对于如下图数据结构:
0->1
1->2
3->1

多称之为小括号方式

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
或中括号方式:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
note: 同一个graph,每次打印出来的各节点的位置是随机的。

2.2 DGL的update_all函数实际工作过程

利用如下例程说明:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dglN = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):return {'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):print("-batch size--pv size-------------")print(nodes.batch_size(), nodes.mailbox['pv'].size())msgs = torch.sum(nodes.mailbox['pv'], dim=1)pv = (1 - DAMP) / N + DAMP * msgsreturn {'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

打印g.ndata[‘deg’]信息(也即每个节点的入度信息)如下:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_func函数内的打印信息如下:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…

对比图的入度信息与pagerank_reduce_func函数内的打印信息,我们发现:入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…因此,得出:
1)函数update_all并不是将所有节点一起更新;
2)函数update_all将具有同等个数目标节点的节点放在一起更新,形成一个batch,这也是为什么reduce_func(nodes)中的入参中的入参type为dgl.udf.NodeBatch的原因。reduce_func(nodes)中的入参nodes的不同行代表与不同节点相关的数据。

这篇关于图神经网络框架DGL实现Graph Attention Network (GAT)笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147712

相关文章

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过

Python中将嵌套列表扁平化的多种实现方法

《Python中将嵌套列表扁平化的多种实现方法》在Python编程中,我们常常会遇到需要将嵌套列表(即列表中包含列表)转换为一个一维的扁平列表的需求,本文将给大家介绍了多种实现这一目标的方法,需要的朋... 目录python中将嵌套列表扁平化的方法技术背景实现步骤1. 使用嵌套列表推导式2. 使用itert

Python使用pip工具实现包自动更新的多种方法

《Python使用pip工具实现包自动更新的多种方法》本文深入探讨了使用Python的pip工具实现包自动更新的各种方法和技术,我们将从基础概念开始,逐步介绍手动更新方法、自动化脚本编写、结合CI/C... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核