GCN学习:Pytorch-Geometric教程(二)

2024-02-01 08:18

本文主要是介绍GCN学习:Pytorch-Geometric教程(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyG教程二

  • 数据转换
  • GCN网络

数据转换

PyTorch Geometric带有自己的变换,该变换期望将Data对象作为输入并返回一个新的变换后的Data对象。 可以使用torch_geometric.transforms.Compose将变换链接在一起,并在将处理后的数据集保存到磁盘之前(pre_transform)或访问数据集中的图形之前(transform)应用变换。
让我们看一个示例,其中我们对ShapeNet数据集(包含17,000个3D形状点clouds和来自16个形状类别的每个点标签)应用变换。

from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
print(dataset[0])
>>> Data(pos=[2518, 3], y=[2518])

通过transform将其从点云转化成图。

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='ShapeNet', categories=['Airplane'],pre_transform=T.KNNGraph(k=6))
print(dataset[0])
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

在将数据保存到磁盘之前,我们使用pre_transform进行了转换(从而缩短了加载时间)。 请注意,下一次初始化数据集时,即使我们不传递任何变换,也将已经包含图的边。

GCN网络

我们建立如下一个GCN网络
在这里插入图片描述先获取数据集:

from torch_geometric.datasets import Planetoid
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
print(dataset)
print(dataset.num_node_features)
print(dataset.num_classes)
>>>Cora()
>>>1433
>>>7

Cora是一个机器学习论文数据集,其中共有7个类别(num_classes:基于案例、遗传算法、 神经网络、概率方法、强化学习 、规则学习、理论。整个数据集中共有2708篇论文,在词干堵塞和去除词尾后,只剩下1433个独特的单词(num_node_features),文档频率小于10的所有单词都被删除。
这里我们并不需要使用dataloader和transform。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.module):def __init__(self):super(Net, self).__init__()#GCNConv的两个参数为input channel size和Output channel size#conv1将每个顶点的1433个特征压缩到16个特征值#conv2根据之前得到的16个特征值将其再压缩为7self.conv1=GCNConv(dataset.num_node_features, 16)self.conv2=GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index=data.x, data.edge_indexx=self.conv1(x, edge_index)x=F.relu(x)#dropout用于降低过拟合情况x=F.dropout((x, training = self.training))x=self.conv2(x, edge_index)#dim=0对一列所有元素的进行softmax运算#dim=1对一行所有元素的进行softmax运算return F.log_softmax(x,dim=1)
class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data)#在训练集上计算loss,out为图在gcn网络中的计算结果,data.y即7类的概率大小loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()
#计算准确率
model.eval()
#选取7种类别中概率最大的类别为预测的节点类别
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>> Accuracy: 0.8150

这篇关于GCN学习:Pytorch-Geometric教程(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/666646

相关文章

Windows环境下解决Matplotlib中文字体显示问题的详细教程

《Windows环境下解决Matplotlib中文字体显示问题的详细教程》本文详细介绍了在Windows下解决Matplotlib中文显示问题的方法,包括安装字体、更新缓存、配置文件设置及编码調整,并... 目录引言问题分析解决方案详解1. 检查系统已安装字体2. 手动添加中文字体(以SimHei为例)步骤

Java JDK1.8 安装和环境配置教程详解

《JavaJDK1.8安装和环境配置教程详解》文章简要介绍了JDK1.8的安装流程,包括官网下载对应系统版本、安装时选择非系统盘路径、配置JAVA_HOME、CLASSPATH和Path环境变量,... 目录1.下载JDK2.安装JDK3.配置环境变量4.检验JDK官网下载地址:Java Downloads

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

Java Web实现类似Excel表格锁定功能实战教程

《JavaWeb实现类似Excel表格锁定功能实战教程》本文将详细介绍通过创建特定div元素并利用CSS布局和JavaScript事件监听来实现类似Excel的锁定行和列效果的方法,感兴趣的朋友跟随... 目录1. 模拟Excel表格锁定功能2. 创建3个div元素实现表格锁定2.1 div元素布局设计2.

SpringBoot连接Redis集群教程

《SpringBoot连接Redis集群教程》:本文主要介绍SpringBoot连接Redis集群教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1. 依赖2. 修改配置文件3. 创建RedisClusterConfig4. 测试总结1. 依赖 <de

Nexus安装和启动的实现教程

《Nexus安装和启动的实现教程》:本文主要介绍Nexus安装和启动的实现教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、Nexus下载二、Nexus安装和启动三、关闭Nexus总结一、Nexus下载官方下载链接:DownloadWindows系统根

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

CnPlugin是PL/SQL Developer工具插件使用教程

《CnPlugin是PL/SQLDeveloper工具插件使用教程》:本文主要介绍CnPlugin是PL/SQLDeveloper工具插件使用教程,具有很好的参考价值,希望对大家有所帮助,如有错... 目录PL/SQL Developer工具插件使用安装拷贝文件配置总结PL/SQL Developer工具插

Java中的登录技术保姆级详细教程

《Java中的登录技术保姆级详细教程》:本文主要介绍Java中登录技术保姆级详细教程的相关资料,在Java中我们可以使用各种技术和框架来实现这些功能,文中通过代码介绍的非常详细,需要的朋友可以参考... 目录1.登录思路2.登录标记1.会话技术2.会话跟踪1.Cookie技术2.Session技术3.令牌技