GAT学习:PyG实现multi-head GAT(二)

2024-02-01 08:18
文章标签 实现 学习 head multi pyg gat

本文主要是介绍GAT学习:PyG实现multi-head GAT(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyG实现GAT网络

  • 预备知识
  • 代码分析
    • GAT

接上篇学习笔记GAT学习:PyG实现GAT(图注意力神经网络)网络(一)为了使得Attention的效果更好,所以加入multi-head attention。画个图说明multi-head attention的工作原理。
在这里插入图片描述
其实就相当于并联了head_num个attention后,将每个attention层的输出特征拼接起来,然后再输入一个attenion层得到输出结果。

预备知识

关于GAT的原理等知识,参考我的上篇博客:PyG实现GAT(图注意力神经网络)网络(一)

代码分析

import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as Fclass GAL(MessagePassing):def __init__(self,in_features,out_featrues):super(GAL,self).__init__(aggr='add')self.a = torch.nn.Parameter(torch.zeros(size=(2*out_featrues, 1)))torch.nn.init.xavier_uniform_(self.a.data, gain=1.414)  # 初始化# 定义leakyrelu激活函数self.leakyrelu = torch.nn.LeakyReLU()self.linear=torch.nn.Linear(in_features,out_featrues)def forward(self,x,edge_index):x=self.linear(x)N=x.size()[0]row,col=edge_indexa_input = torch.cat([x[row], x[col]], dim=1)# [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)temp=torch.mm(a_input,self.a).squeeze()e = self.leakyrelu(temp)#e_all为同一个节点与其全部邻居的计算的分数的和,用于计算归一化softmaxe_all=torch.zeros(x.size()[0])count = 0for i in col:e_all[i]+=e[count]count=count+1for i in range(len(e)):e[i]=math.exp(e[i])/math.exp(e_all[col[i]])return self.propagate(edge_index,x=x,norm=e)def message(self, x_j, norm):return norm.view(-1, 1) * x_jclass GAT(torch.nn.Module):def __init__(self, in_features, hid_features, out_features, n_heads):"""n_heads 表示有几个GAL层,最后进行拼接在一起,类似self-attention从不同的子空间进行抽取特征。"""super(GAT, self).__init__()# 定义multi-head的图注意力层self.attentions = [GAL(in_features, hid_features) for _ inrange(n_heads)]# 输出层,也通过图注意力层来实现,可实现分类、预测等功能self.out_att = GAL(hid_features * n_heads, out_features)def forward(self, x, edge_index):# 将每个head得到的x特征进行拼接x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)print('x.size after cat',x.size())x = F.elu(self.out_att(x,edge_index))  # 输出并激活print('x.size after elu',x.size())return F.log_softmax(x, dim=1)  # log_softmax速度变快,保持数值稳定class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.gat = GAT(dataset.num_node_features,16,7,4)def forward(self, data):x, edge_index = data.x, data.edge_indexx = F.dropout(x, training=self.training)x = self.gat(x, edge_index)print('X_GAT',x.size())return F.log_softmax(x, dim=1)ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
x=dataset[0].x
edge_index=dataset[0].edge_index
model=Net()
data=dataset[0]
out=Net()(data)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(2):optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>>Accuracy:0.1930

GAT

class GAT(torch.nn.Module):def __init__(self, in_features, hid_features, out_features, n_heads):"""n_heads 表示有几个GAL层,最后进行拼接在一起,类似self-attention从不同的子空间进行抽取特征。"""super(GAT, self).__init__()# 定义multi-head的图注意力层self.attentions = [GAL(in_features, hid_features) for _ inrange(n_heads)]# 输出层,也通过图注意力层来实现,可实现分类、预测等功能self.out_att = GAL(hid_features * n_heads, out_features)def forward(self, x, edge_index):# 将每个head得到的x特征进行拼接x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)print('x.size after cat',x.size())x = F.elu(self.out_att(x,edge_index))  # 输出并激活print('x.size after elu',x.size())return F.log_softmax(x, dim=1)  # log_softmax速度变快,保持数值稳定
>>>x.size after cat torch.Size([2708, 64])
x.size after elu torch.Size([2708, 7])
x.size after cat torch.Size([2708, 64])
x.size after elu torch.Size([2708, 7])
x.size after cat torch.Size([2708, 64])
x.size after elu torch.Size([2708, 7])
x.size after cat torch.Size([2708, 64])
x.size after elu torch.Size([2708, 7])

这篇关于GAT学习:PyG实现multi-head GAT(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/666653

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount