GCN学习:Pytorch-Geometric教程(二)

2024-02-01 08:18

本文主要是介绍GCN学习:Pytorch-Geometric教程(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyG教程二

  • 数据转换
  • GCN网络

数据转换

PyTorch Geometric带有自己的变换,该变换期望将Data对象作为输入并返回一个新的变换后的Data对象。 可以使用torch_geometric.transforms.Compose将变换链接在一起,并在将处理后的数据集保存到磁盘之前(pre_transform)或访问数据集中的图形之前(transform)应用变换。
让我们看一个示例,其中我们对ShapeNet数据集(包含17,000个3D形状点clouds和来自16个形状类别的每个点标签)应用变换。

from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])
print(dataset[0])
>>> Data(pos=[2518, 3], y=[2518])

通过transform将其从点云转化成图。

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='ShapeNet', categories=['Airplane'],pre_transform=T.KNNGraph(k=6))
print(dataset[0])
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

在将数据保存到磁盘之前,我们使用pre_transform进行了转换(从而缩短了加载时间)。 请注意,下一次初始化数据集时,即使我们不传递任何变换,也将已经包含图的边。

GCN网络

我们建立如下一个GCN网络
在这里插入图片描述先获取数据集:

from torch_geometric.datasets import Planetoid
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
print(dataset)
print(dataset.num_node_features)
print(dataset.num_classes)
>>>Cora()
>>>1433
>>>7

Cora是一个机器学习论文数据集,其中共有7个类别(num_classes:基于案例、遗传算法、 神经网络、概率方法、强化学习 、规则学习、理论。整个数据集中共有2708篇论文,在词干堵塞和去除词尾后,只剩下1433个独特的单词(num_node_features),文档频率小于10的所有单词都被删除。
这里我们并不需要使用dataloader和transform。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.module):def __init__(self):super(Net, self).__init__()#GCNConv的两个参数为input channel size和Output channel size#conv1将每个顶点的1433个特征压缩到16个特征值#conv2根据之前得到的16个特征值将其再压缩为7self.conv1=GCNConv(dataset.num_node_features, 16)self.conv2=GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index=data.x, data.edge_indexx=self.conv1(x, edge_index)x=F.relu(x)#dropout用于降低过拟合情况x=F.dropout((x, training = self.training))x=self.conv2(x, edge_index)#dim=0对一列所有元素的进行softmax运算#dim=1对一行所有元素的进行softmax运算return F.log_softmax(x,dim=1)
class GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)model.train()
for epoch in range(200):optimizer.zero_grad()out = model(data)#在训练集上计算loss,out为图在gcn网络中的计算结果,data.y即7类的概率大小loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()
#计算准确率
model.eval()
#选取7种类别中概率最大的类别为预测的节点类别
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>> Accuracy: 0.8150

这篇关于GCN学习:Pytorch-Geometric教程(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/666646

相关文章

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

springboot简单集成Security配置的教程

《springboot简单集成Security配置的教程》:本文主要介绍springboot简单集成Security配置的教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录集成Security安全框架引入依赖编写配置类WebSecurityConfig(自定义资源权限规则

MySQL Workbench 安装教程(保姆级)

《MySQLWorkbench安装教程(保姆级)》MySQLWorkbench是一款强大的数据库设计和管理工具,本文主要介绍了MySQLWorkbench安装教程,文中通过图文介绍的非常详细,对大... 目录前言:详细步骤:一、检查安装的数据库版本二、在官网下载对应的mysql Workbench版本,要是

通过Docker Compose部署MySQL的详细教程

《通过DockerCompose部署MySQL的详细教程》DockerCompose作为Docker官方的容器编排工具,为MySQL数据库部署带来了显著优势,下面小编就来为大家详细介绍一... 目录一、docker Compose 部署 mysql 的优势二、环境准备与基础配置2.1 项目目录结构2.2 基

Linux安装MySQL的教程

《Linux安装MySQL的教程》:本文主要介绍Linux安装MySQL的教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux安装mysql1.Mysql官网2.我的存放路径3.解压mysql文件到当前目录4.重命名一下5.创建mysql用户组和用户并修

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

最新Spring Security实战教程之Spring Security安全框架指南

《最新SpringSecurity实战教程之SpringSecurity安全框架指南》SpringSecurity是Spring生态系统中的核心组件,提供认证、授权和防护机制,以保护应用免受各种安... 目录前言什么是Spring Security?同类框架对比Spring Security典型应用场景传统

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

Centos环境下Tomcat虚拟主机配置详细教程

《Centos环境下Tomcat虚拟主机配置详细教程》这篇文章主要讲的是在CentOS系统上,如何一步步配置Tomcat的虚拟主机,内容很简单,从目录准备到配置文件修改,再到重启和测试,手把手带你搞定... 目录1. 准备虚拟主机的目录和内容创建目录添加测试文件2. 修改 Tomcat 的 server.X