Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战

本文主要是介绍Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.GAT简介

GAT(Graph Attention Network)模型是一种用于图数据的深度学习模型,由Veličković等人在2018年提出。它通过自适应地在图中计算节点之间的注意力来学习节点之间的关系,并在节点表示中捕捉全局和局部信息。

GAT模型的核心思想是通过注意力机制,对图中的节点进行加权聚合。与传统的图卷积网络(GCN)模型不同,GAT不仅考虑节点本身的特征信息,还考虑了节点与其邻居节点之间的关系。每个节点在聚合邻居节点的特征时,会分配不同的注意力权重,以捕捉不同邻居节点对该节点的贡献程度。

GAT模型具有以下特点和优势:

  1. 自适应学习的注意力机制:GAT模型能够根据数据自动学习节点之间的注意力权重,从而捕捉到不同节点之间的重要性和关系。
  2. 并行计算效率高:由于注意力权重是节点间独立计算的,可以高效地并行计算,适用于大规模图数据。
  3. 稀疏性:GAT模型引入了注意力系数,可以将注意力集中在有用的邻居节点上,减小计算量和存储需求。
  4. 灵活性:GAT模型可以根据任务需求设计不同的注意力权重计算方式,适应不同的图学习任务。

2.代码实战

模型架构分为两部分:GAT主体部分,GAT的注意力计算部分

注意力机制:首先输入参数为(节点的特征表示hi,邻接矩阵),注意这个hi可以来源于上一层,也可以是原始的;先计算每个节点到中心节点的权值,也可以称为权重或者系数,然后对所有的权值进行归一化,最后对每个邻居节点与对应的权值相乘,然后相加就得到了中心节点的最终表示,注意求权值的时候是要考虑中心节点本身的;

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fclass GATLayer(nn.Module):def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GATLayer, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.dropout = dropoutself.alpha = alphaself.concat = concatself.W = nn.Linear(in_features, out_features)self.a = nn.Linear(2*out_features, 1)def forward(self, h, adj):Wh = self.W(h)  # W*hN = h.size()[0]  # Number of nodesa_input = torch.cat([Wh.repeat(1, N).view(N*N, -1), Wh.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)e = F.leaky_relu(self.a(a_input).squeeze(2), negative_slope=self.alpha)zero_vec = -9e15*torch.ones_like(e)attention = torch.where(adj > 0, e, zero_vec)attention = F.softmax(attention, dim=1)attention = F.dropout(attention, p=self.dropout, training=self.training)h_prime = torch.matmul(attention, Wh)if self.concat:return F.elu(h_prime)else:return h_primeclass GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):super(GAT, self).__init__()self.dropout = dropoutself.hidden = nn.ModuleList([GATLayer(nfeat, nhid, dropout, alpha, concat=True) for _ in range(nheads)])self.out_att = GATLayer(nhid*nheads, nclass, dropout, alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.hidden], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.sigmoid(self.out_att(x, adj))return F.log_softmax(x, dim=1)# 创建示例数据和邻接矩阵
adj = torch.tensor([[0, 1, 1, 0],[1, 0, 1, 1],[1, 1, 0, 1],[0, 1, 1, 0]])  # 邻接矩阵
features = torch.randn(4, 5)  # 特征矩阵# 创建GAT模型
model = GAT(nfeat=5, nhid=8, nclass=2, dropout=0.6, alpha=0.2, nheads=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练模型
for epoch in range(100):optimizer.zero_grad()output = model(features, adj)# 假设这里有标签数据yy = torch.LongTensor([0, 1, 0, 1])  # 标签loss = criterion(output, y)loss.backward()optimizer.step()# 测试模型
output = model(features, adj)
_, predictions = output.max(dim=1)
correct = (predictions == y).sum().item()
accuracy = correct / len(y)
print("准确率:", accuracy)

这篇关于Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1027103

相关文章

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

Python基于wxPython和FFmpeg开发一个视频标签工具

《Python基于wxPython和FFmpeg开发一个视频标签工具》在当今数字媒体时代,视频内容的管理和标记变得越来越重要,无论是研究人员需要对实验视频进行时间点标记,还是个人用户希望对家庭视频进行... 目录引言1. 应用概述2. 技术栈分析2.1 核心库和模块2.2 wxpython作为GUI选择的优

Mysql表的简单操作(基本技能)

《Mysql表的简单操作(基本技能)》在数据库中,表的操作主要包括表的创建、查看、修改、删除等,了解如何操作这些表是数据库管理和开发的基本技能,本文给大家介绍Mysql表的简单操作,感兴趣的朋友一起看... 目录3.1 创建表 3.2 查看表结构3.3 修改表3.4 实践案例:修改表在数据库中,表的操作主要

Pandas使用SQLite3实战

《Pandas使用SQLite3实战》本文主要介绍了Pandas使用SQLite3实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1 环境准备2 从 SQLite3VlfrWQzgt 读取数据到 DataFrame基础用法:读

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

Python+PyQt5实现多屏幕协同播放功能

《Python+PyQt5实现多屏幕协同播放功能》在现代会议展示、数字广告、展览展示等场景中,多屏幕协同播放已成为刚需,下面我们就来看看如何利用Python和PyQt5开发一套功能强大的跨屏播控系统吧... 目录一、项目概述:突破传统播放限制二、核心技术解析2.1 多屏管理机制2.2 播放引擎设计2.3 专

Python中随机休眠技术原理与应用详解

《Python中随机休眠技术原理与应用详解》在编程中,让程序暂停执行特定时间是常见需求,当需要引入不确定性时,随机休眠就成为关键技巧,下面我们就来看看Python中随机休眠技术的具体实现与应用吧... 目录引言一、实现原理与基础方法1.1 核心函数解析1.2 基础实现模板1.3 整数版实现二、典型应用场景2