图神经网络框架DGL实现Graph Attention Network (GAT)笔记

2024-09-08 09:18

本文主要是介绍图神经网络框架DGL实现Graph Attention Network (GAT)笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

参考列表:

[1]深入理解图注意力机制
[2]DGL官方学习教程一 ——基础操作&消息传递
[3]Cora数据集介绍+python读取

一、DGL实现GAT分类机器学习论文

程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3])
在这里插入图片描述

1. 程序

Ubuntu:18.04
cuda:11.1
cudnn:8.0.4.30
pytorch:1.7.0
networkx:2.5

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = gself.fc = nn.Linear(in_dim, out_dim, bias=False)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):alpha = F.softmax(nodes.mailbox['e'], dim=1)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):z = self.fc(h) # eq. 1self.g.ndata['z'] = z self.g.apply_edges(self.edge_attention) # eq. 2self.g.update_all(self.message_func, self.reduce_func) # eq. 3 and 4return self.g.ndata.pop('h')class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':return torch.cat(head_outs, dim=1)else:return torch.mean(torch.stack(head_outs))class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return hfrom dgl import DGLGraph
from dgl.data import citation_graph as citegrhdef load_core_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, maskimport time 
import numpy as np
g, features, labels, mask = load_core_data()net = GAT(g, in_dim = features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(300):if epoch >= 3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >= 3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
2.笔记
2.1 初始化一个graph的两种方式

对于如下图数据结构:
0->1
1->2
3->1

多称之为小括号方式

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph((torch.tensor([0, 1, 3]), torch.tensor([1, 2, 1]))) # 小括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
或中括号方式:

import networkx as nx
import matplotlib.pyplot as plt
import dgl
import torch
%matplotlib inline
g = dgl.graph([torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([3, 1])]) # 中括号
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])  #使用nx绘制,设置节点大小及灰度值
plt.show()

在这里插入图片描述
note: 同一个graph,每次打印出来的各节点的位置是随机的。

2.2 DGL的update_all函数实际工作过程

利用如下例程说明:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import dglN = 100  # number of nodes
DAMP = 0.85  # damping factor阻尼因子
K = 10  # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) #图随机生成器,生成nx图
g = dgl.DGLGraph(g)                 #转换成DGL图
g.ndata['pv'] = torch.ones(N) / N  #初始化PageRank值
g.ndata['deg'] = g.in_degrees(g.nodes()).float()  #初始化节点特征
print(g.ndata['deg'])
#定义message函数,它将每个节点的PageRank值除以其out-degree,并将结果作为消息传递给它的邻居:
def pagerank_message_func(edges):return {'pv' : edges.src['pv'] / edges.src['deg']}
#定义reduce函数,它从mailbox中删除并聚合message,并计算其新的PageRank值:
def pagerank_reduce_func(nodes):print("-batch size--pv size-------------")print(nodes.batch_size(), nodes.mailbox['pv'].size())msgs = torch.sum(nodes.mailbox['pv'], dim=1)pv = (1 - DAMP) / N + DAMP * msgsreturn {'pv' : pv}
g.update_all(pagerank_message_func, pagerank_reduce_func)

打印g.ndata[‘deg’]信息(也即每个节点的入度信息)如下:

tensor([ 9., 7., 17., 10., 12., 13., 13., 9., 5., 14., 7., 12., 15., 6.,
15., 7., 13., 7., 11., 9., 9., 15., 9., 12., 10., 8., 10., 9.,
15., 7., 8., 10., 10., 8., 11., 13., 6., 10., 10., 11., 5., 13.,
6., 12., 12., 8., 6., 11., 9., 10., 12., 8., 11., 5., 7., 12.,
4., 7., 8., 13., 11., 14., 9., 10., 12., 10., 10., 9., 10., 13.,
7., 15., 15., 10., 6., 11., 4., 6., 5., 10., 9., 11., 19., 9.,
12., 13., 15., 12., 12., 11., 10., 8., 11., 9., 7., 7., 11., 3.,
10., 5.])

pagerank_reduce_func函数内的打印信息如下:

-batch size–pv size-------------
1 torch.Size([1, 3])
-batch size–pv size-------------
2 torch.Size([2, 4])
-batch size–pv size-------------
5 torch.Size([5, 5])
-batch size–pv size-------------
6 torch.Size([6, 6])
-batch size–pv size-------------
10 torch.Size([10, 7])
-batch size–pv size-------------
7 torch.Size([7, 8])
-batch size–pv size-------------
12 torch.Size([12, 9])
-batch size–pv size-------------
16 torch.Size([16, 10])
-batch size–pv size-------------
11 torch.Size([11, 11])
-batch size–pv size-------------
11 torch.Size([11, 12])
-batch size–pv size-------------
8 torch.Size([8, 13])
-batch size–pv size-------------
2 torch.Size([2, 14])
-batch size–pv size-------------
7 torch.Size([7, 15])
-batch size–pv size-------------
1 torch.Size([1, 17])
-batch size–pv size-------------
1 torch.Size([1, 19])

入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…

对比图的入度信息与pagerank_reduce_func函数内的打印信息,我们发现:入度为3的节点只有一个,入度为4的节点有两个,入度为5的节点五个,…因此,得出:
1)函数update_all并不是将所有节点一起更新;
2)函数update_all将具有同等个数目标节点的节点放在一起更新,形成一个batch,这也是为什么reduce_func(nodes)中的入参中的入参type为dgl.udf.NodeBatch的原因。reduce_func(nodes)中的入参nodes的不同行代表与不同节点相关的数据。

这篇关于图神经网络框架DGL实现Graph Attention Network (GAT)笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147712

相关文章

Python xmltodict实现简化XML数据处理

《Pythonxmltodict实现简化XML数据处理》Python社区为提供了xmltodict库,它专为简化XML与Python数据结构的转换而设计,本文主要来为大家介绍一下如何使用xmltod... 目录一、引言二、XMLtodict介绍设计理念适用场景三、功能参数与属性1、parse函数2、unpa

C#实现获得某个枚举的所有名称

《C#实现获得某个枚举的所有名称》这篇文章主要为大家详细介绍了C#如何实现获得某个枚举的所有名称,文中的示例代码讲解详细,具有一定的借鉴价值,有需要的小伙伴可以参考一下... C#中获得某个枚举的所有名称using System;using System.Collections.Generic;usi

Go语言实现将中文转化为拼音功能

《Go语言实现将中文转化为拼音功能》这篇文章主要为大家详细介绍了Go语言中如何实现将中文转化为拼音功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 有这么一个需求:新用户入职 创建一系列账号比较麻烦,打算通过接口传入姓名进行初始化。想把姓名转化成拼音。因为有些账号即需要中文也需要英

C# 读写ini文件操作实现

《C#读写ini文件操作实现》本文主要介绍了C#读写ini文件操作实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录一、INI文件结构二、读取INI文件中的数据在C#应用程序中,常将INI文件作为配置文件,用于存储应用程序的

C#实现获取电脑中的端口号和硬件信息

《C#实现获取电脑中的端口号和硬件信息》这篇文章主要为大家详细介绍了C#实现获取电脑中的端口号和硬件信息的相关方法,文中的示例代码讲解详细,有需要的小伙伴可以参考一下... 我们经常在使用一个串口软件的时候,发现软件中的端口号并不是普通的COM1,而是带有硬件信息的。那么如果我们使用C#编写软件时候,如

Python使用qrcode库实现生成二维码的操作指南

《Python使用qrcode库实现生成二维码的操作指南》二维码是一种广泛使用的二维条码,因其高效的数据存储能力和易于扫描的特点,广泛应用于支付、身份验证、营销推广等领域,Pythonqrcode库是... 目录一、安装 python qrcode 库二、基本使用方法1. 生成简单二维码2. 生成带 Log

Go语言使用Buffer实现高性能处理字节和字符

《Go语言使用Buffer实现高性能处理字节和字符》在Go中,bytes.Buffer是一个非常高效的类型,用于处理字节数据的读写操作,本文将详细介绍一下如何使用Buffer实现高性能处理字节和... 目录1. bytes.Buffer 的基本用法1.1. 创建和初始化 Buffer1.2. 使用 Writ

基于WinForm+Halcon实现图像缩放与交互功能

《基于WinForm+Halcon实现图像缩放与交互功能》本文主要讲述在WinForm中结合Halcon实现图像缩放、平移及实时显示灰度值等交互功能,包括初始化窗口的不同方式,以及通过特定事件添加相应... 目录前言初始化窗口添加图像缩放功能添加图像平移功能添加实时显示灰度值功能示例代码总结最后前言本文将

Redis延迟队列的实现示例

《Redis延迟队列的实现示例》Redis延迟队列是一种使用Redis实现的消息队列,本文主要介绍了Redis延迟队列的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录一、什么是 Redis 延迟队列二、实现原理三、Java 代码示例四、注意事项五、使用 Redi

C#实现WinForm控件焦点的获取与失去

《C#实现WinForm控件焦点的获取与失去》在一个数据输入表单中,当用户从一个文本框切换到另一个文本框时,需要准确地判断焦点的转移,以便进行数据验证、提示信息显示等操作,本文将探讨Winform控件... 目录前言获取焦点改变TabIndex属性值调用Focus方法失去焦点总结最后前言在一个数据输入表单