Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战

本文主要是介绍Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.GAT简介

GAT(Graph Attention Network)模型是一种用于图数据的深度学习模型,由Veličković等人在2018年提出。它通过自适应地在图中计算节点之间的注意力来学习节点之间的关系,并在节点表示中捕捉全局和局部信息。

GAT模型的核心思想是通过注意力机制,对图中的节点进行加权聚合。与传统的图卷积网络(GCN)模型不同,GAT不仅考虑节点本身的特征信息,还考虑了节点与其邻居节点之间的关系。每个节点在聚合邻居节点的特征时,会分配不同的注意力权重,以捕捉不同邻居节点对该节点的贡献程度。

GAT模型具有以下特点和优势:

  1. 自适应学习的注意力机制:GAT模型能够根据数据自动学习节点之间的注意力权重,从而捕捉到不同节点之间的重要性和关系。
  2. 并行计算效率高:由于注意力权重是节点间独立计算的,可以高效地并行计算,适用于大规模图数据。
  3. 稀疏性:GAT模型引入了注意力系数,可以将注意力集中在有用的邻居节点上,减小计算量和存储需求。
  4. 灵活性:GAT模型可以根据任务需求设计不同的注意力权重计算方式,适应不同的图学习任务。

2.代码实战

模型架构分为两部分:GAT主体部分,GAT的注意力计算部分

注意力机制:首先输入参数为(节点的特征表示hi,邻接矩阵),注意这个hi可以来源于上一层,也可以是原始的;先计算每个节点到中心节点的权值,也可以称为权重或者系数,然后对所有的权值进行归一化,最后对每个邻居节点与对应的权值相乘,然后相加就得到了中心节点的最终表示,注意求权值的时候是要考虑中心节点本身的;

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass GATLayer(nn.Module):def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GATLayer, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.dropout = dropoutself.alpha = alphaself.concat = concatself.W = nn.Linear(in_features, out_features)self.a = nn.Linear(2*out_features, 1)def forward(self, h, adj):Wh = self.W(h)  # W*hN = h.size()[0]  # Number of nodesa_input = torch.cat([Wh.repeat(1, N).view(N*N, -1), Wh.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)e = F.leaky_relu(self.a(a_input).squeeze(2), negative_slope=self.alpha)zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)attention = F.softmax(attention, dim=1)attention = F.dropout(attention, p=self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh)if self.concat:return F.elu(h_prime)else:return h_primeclass GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):super(GAT, self).__init__()self.dropout = dropoutself.hidden = nn.ModuleList([GATLayer(nfeat, nhid, dropout, alpha, concat=True) for _ in range(nheads)])self.out_att = GATLayer(nhid*nheads, nclass, dropout, alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.hidden], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.sigmoid(self.out_att(x, adj))return F.log_softmax(x, dim=1)# 创建示例数据和邻接矩阵
adj = torch.tensor([[0, 1, 1, 0],[1, 0, 1, 1],[1, 1, 0, 1],[0, 1, 1, 0]])  # 邻接矩阵
features = torch.randn(4, 5)  # 特征矩阵# 创建GAT模型
model = GAT(nfeat=5, nhid=8, nclass=2, dropout=0.6, alpha=0.2, nheads=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
for epoch in range(100):optimizer.zero_grad()output = model(features, adj)# 假设这里有标签数据yy = torch.LongTensor([0, 1, 0, 1])  # 标签loss = criterion(output, y)loss.backward()optimizer.step()# 测试模型
output = model(features, adj)
_, predictions = output.max(dim=1)
correct = (predictions == y).sum().item()
accuracy = correct / len(y)
print("准确率:", accuracy)

这篇关于Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1027103

相关文章

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

python实现svg图片转换为png和gif

《python实现svg图片转换为png和gif》这篇文章主要为大家详细介绍了python如何实现将svg图片格式转换为png和gif,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录python实现svg图片转换为png和gifpython实现图片格式之间的相互转换延展:基于Py

Python中的getopt模块用法小结

《Python中的getopt模块用法小结》getopt.getopt()函数是Python中用于解析命令行参数的标准库函数,该函数可以从命令行中提取选项和参数,并对它们进行处理,本文详细介绍了Pyt... 目录getopt模块介绍getopt.getopt函数的介绍getopt模块的常用用法getopt模

Python利用ElementTree实现快速解析XML文件

《Python利用ElementTree实现快速解析XML文件》ElementTree是Python标准库的一部分,而且是Python标准库中用于解析和操作XML数据的模块,下面小编就来和大家详细讲讲... 目录一、XML文件解析到底有多重要二、ElementTree快速入门1. 加载XML的两种方式2.

Python如何精准判断某个进程是否在运行

《Python如何精准判断某个进程是否在运行》这篇文章主要为大家详细介绍了Python如何精准判断某个进程是否在运行,本文为大家整理了3种方法并进行了对比,有需要的小伙伴可以跟随小编一起学习一下... 目录一、为什么需要判断进程是否存在二、方法1:用psutil库(推荐)三、方法2:用os.system调用

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)

《使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)》PPT是一种高效的信息展示工具,广泛应用于教育、商务和设计等多个领域,PPT文档中常常包含丰富的图片内容,这些图片不仅提升了... 目录一、引言二、环境与工具三、python 提取PPT背景图片3.1 提取幻灯片背景图片3.2 提取

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)