GCN,GraphSAGE 到底在训练什么呢?

2023-12-05 00:44
文章标签 训练 到底 gcn graphsage

本文主要是介绍GCN,GraphSAGE 到底在训练什么呢?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import osos.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]val_mask = g.ndata["val_mask"]test_mask = g.ndata["test_mask"]for e in range(100):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that you should only compute the losses of the nodes in the training set.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print(f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})")if __name__ == "__main__" :dataset = dgl.data.CoraGraphDataset()# print(f"Number of categories: {dataset.num_classes}")g = dataset[0]g = g.to('cuda')model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :XXXXXXXXXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。# 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])# 通过布尔索引选择元素
selected_tensor = tensor[mask]print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))# 输出为:
# <class 'torch.Tensor'>

这篇关于GCN,GraphSAGE 到底在训练什么呢?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/455502

相关文章

YOLO v3 训练速度慢的问题

一天一夜出了两个模型,仅仅迭代了200次   原因:编译之前没有将Makefile 文件里的GPU设置为1,编译的是CPU版本,必须训练慢   解决方案: make clean  vim Makefile make   再次训练 速度快了,5分钟迭代了500次

将一维机械振动信号构造为训练集和测试集(Python)

从如下链接中下载轴承数据集。 https://www.sciencedirect.com/science/article/pii/S2352340918314124 import numpy as npimport scipy.io as sioimport matplotlib.pyplot as pltimport statistics as statsimport pandas

PAT-1039 到底买不买(20)(字符串的使用)

题目描述 小红想买些珠子做一串自己喜欢的珠串。卖珠子的摊主有很多串五颜六色的珠串,但是不肯把任何一串拆散了卖。于是小红要你帮忙判断一下,某串珠子里是否包含了全部自己想要的珠子?如果是,那么告诉她有多少多余的珠子;如果不是,那么告诉她缺了多少珠子。为方便起见,我们用[0-9]、[a-z]、[A-Z]范围内的字符来表示颜色。例如,YrR8RrY是小红想做的珠串;那么ppRYYGrrYBR2258可以

6月21日训练 (东北林业大学)(个人题解)

前言:   这次训练是大一大二一起参加的训练,总体来说难度是有的,我和队友在比赛时间内就写出了四道题,之后陆陆续续又补了了三道题,还有一道题看了学长题解后感觉有点超出我的能力范围了,就留给以后的自己吧。话不多说,上正文。 正文:   Problem:A 幸运数字: #include <bits/stdc++.h>using namespace std;int sum,ans;in

国产AI算力训练大模型技术实践

&nbsp;&nbsp; ChatGPT引领AI大模型热潮,国内外模型如雨后春笋,掀起新一轮科技浪潮。然而,国内大模型研发推广亦面临不小挑战。面对机遇与挑战,我们需保持清醒,持续推进技术创新与应用落地。 为应对挑战,我们需从战略高度全面规划大模型的研发与运营,利用我们的制度优势,集中资源攻坚克难。通过加强顶层设计,统一规划,并加大政策与资源的扶持,我们必将推动中国人工智能实现从追赶者到

预训练是什么?

预训练是什么? 图像领域的预训练 在介绍图像领域的预训练之前,我们首先介绍下卷积神经网络(CNN),CNN 一般用于图片分类任务,并且CNN 由多个层级结构组成,不同层学到的图像特征也不同,越浅的层学到的特征越通用(横竖撇捺),越深的层学到的特征和具体任务的关联性越强(人脸-人脸轮廓、汽车-汽车轮廓) 由此,当领导给我们一个任务:阿猫、阿狗、阿虎的图片各十张,然后让我们设计一个深度神经网

http:与https:到底有哪些区别?

http和https使用的是完全不同的连接方式,用的端口也不一样,前者是80,后者是443。http的连接很简单,是无状态的,... HTTPS协议是由SSL+HTTP协议构建的可进行加密传输、身份认证的网络协议要比http协议安全

「Bionano系列」下机数据的BNX文件到底说了什么

最近我拿到了一批Bionano数据,用关键字 “Bionano+组装” 进行检索时,并没有发现任何的教程,所以这应是中文网络世界里第一个Bionano数据分析系列 Bionano技术简单来说,就是给分子加上荧光标记,然后拍照,所以最原始的下机数据就是TIFF格式,但是用户拿到的一般都是经AutoDetect/IrysView 转换过的BNX格式。这篇文章主要就是讲讲BNX格式的具体含义。

本地离线模型搭建指南-LLaMA-Factory训练框架及工具

搭建一个本地中文大语言模型(LLM)涉及多个关键步骤,从选择模型底座,到运行机器和框架,再到具体的架构实现和训练方式。以下是一个详细的指南,帮助你从零开始构建和运行一个中文大语言模型。 本地离线模型搭建指南将按照以下四个部分展开 中文大语言模型底座选择依据本地运行显卡选择RAG架构实现LLaMA-Factory训练框架及工具 4 训练架构及工具 4.1 为什么要使用LLaMA-Factor

ChatGPT原理和训练【 ChatGPT是由OpenAI开发】

本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(注明:作者:王文峰哦) ChatGPT原理和训练【 ChatGPT是由OpenAI开发】 学习教程(传送门)1.概述2