GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?

2023-12-13 00:52

本文主要是介绍GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 一个端到端的 同构图(Cora数据集)节点分类代码:

import argparseimport dgl
import dgl.nn as dglnnimport torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDatasetclass SAGE(nn.Module):def __init__(self, in_size, hid_size, out_size):super().__init__()self.layers = nn.ModuleList()# two-layer GraphSAGE-meanself.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))self.dropout = nn.Dropout(0.5)def forward(self, graph, x):h = self.dropout(x)for l, layer in enumerate(self.layers):h = layer(graph, h)if l != len(self.layers) - 1:h = F.relu(h)h = self.dropout(h)return hdef evaluate(g, features, labels, mask, model):model.eval()with torch.no_grad():logits = model(g, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)def train(g, features, labels, masks, model):# define train/val samples, loss function and optimizertrain_mask, val_mask = masksloss_fcn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)# training loopfor epoch in range(200):model.train()logits = model(g, features)loss = loss_fcn(logits[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()acc = evaluate(g, features, labels, val_mask, model)print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(epoch, loss.item(), acc))if __name__ == "__main__":parser = argparse.ArgumentParser(description="GraphSAGE")parser.add_argument("--dataset",type=str,default="cora",help="Dataset name ('cora', 'citeseer', 'pubmed')",)parser.add_argument("--dt",type=str,default="float",help="data type(float, bfloat16)",)args = parser.parse_args()print(f"Training with DGL built-in GraphSage module")# load and preprocess datasettransform = (AddSelfLoop())  # by default, it will first remove self-loops to prevent duplicationif args.dataset == "cora":data = CoraGraphDataset(transform=transform)elif args.dataset == "citeseer":data = CiteseerGraphDataset(transform=transform)elif args.dataset == "pubmed":data = PubmedGraphDataset(transform=transform)else:raise ValueError("Unknown dataset: {}".format(args.dataset))g = data[0]device = torch.device("cuda" if torch.cuda.is_available() else "cpu")g = g.int().to(device)features = g.ndata["feat"]labels = g.ndata["label"]masks = g.ndata["train_mask"], g.ndata["val_mask"]# create GraphSAGE modelin_size = features.shape[1]out_size = data.num_classesmodel = SAGE(in_size, 16, out_size).to(device)# convert model and graph to bfloat16 if neededif args.dt == "bfloat16":g = dgl.to_bfloat16(g)features = features.to(dtype=torch.bfloat16)model = model.to(dtype=torch.bfloat16)# model trainingprint("Training...")train(g, features, labels, masks, model)# test the modelprint("Testing...")acc = evaluate(g, features, labels, g.ndata["test_mask"], model)print("Test accuracy {:.4f}".format(acc))

2. GraphSAGE的实现 : SAGEConv 类:

我们先来介绍一下DGL对GraphSAGE这个模型的实现:SAGEConv() 在三方库的下述位置:

这篇关于GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/486576

相关文章

Ubuntu 怎么启用 Universe 和 Multiverse 软件源?

《Ubuntu怎么启用Universe和Multiverse软件源?》在Ubuntu中,软件源是用于获取和安装软件的服务器,通过设置和管理软件源,您可以确保系统能够从可靠的来源获取最新的软件... Ubuntu 是一款广受认可且声誉良好的开源操作系统,允许用户通过其庞大的软件包来定制和增强计算体验。这些软件

Ubuntu 24.04 LTS怎么关闭 Ubuntu Pro 更新提示弹窗?

《Ubuntu24.04LTS怎么关闭UbuntuPro更新提示弹窗?》Ubuntu每次开机都会弹窗提示安全更新,设置里最多只能取消自动下载,自动更新,但无法做到直接让自动更新的弹窗不出现,... 如果你正在使用 Ubuntu 24.04 LTS,可能会注意到——在使用「软件更新器」或运行 APT 命令时,

TP-LINK/水星和hasivo交换机怎么选? 三款网管交换机系统功能对比

《TP-LINK/水星和hasivo交换机怎么选?三款网管交换机系统功能对比》今天选了三款都是”8+1″的2.5G网管交换机,分别是TP-LINK水星和hasivo交换机,该怎么选呢?这些交换机功... TP-LINK、水星和hasivo这三台交换机都是”8+1″的2.5G网管交换机,我手里的China编程has

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

W外链微信推广短连接怎么做?

制作微信推广链接的难点分析 一、内容创作难度 制作微信推广链接时,首先需要创作有吸引力的内容。这不仅要求内容本身有趣、有价值,还要能够激起人们的分享欲望。对于许多企业和个人来说,尤其是那些缺乏创意和写作能力的人来说,这是制作微信推广链接的一大难点。 二、精准定位难度 微信用户群体庞大,不同用户的需求和兴趣各异。因此,制作推广链接时需要精准定位目标受众,以便更有效地吸引他们点击并分享链接

电脑桌面文件删除了怎么找回来?别急,快速恢复攻略在此

在日常使用电脑的过程中,我们经常会遇到这样的情况:一不小心,桌面上的某个重要文件被删除了。这时,大多数人可能会感到惊慌失措,不知所措。 其实,不必过于担心,因为有很多方法可以帮助我们找回被删除的桌面文件。下面,就让我们一起来了解一下这些恢复桌面文件的方法吧。 一、使用撤销操作 如果我们刚刚删除了桌面上的文件,并且还没有进行其他操作,那么可以尝试使用撤销操作来恢复文件。在键盘上同时按下“C

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操

怎么让1台电脑共享给7人同时流畅设计

在当今的创意设计与数字内容生产领域,图形工作站以其强大的计算能力、专业的图形处理能力和稳定的系统性能,成为了众多设计师、动画师、视频编辑师等创意工作者的必备工具。 设计团队面临资源有限,比如只有一台高性能电脑时,如何高效地让七人同时流畅地进行设计工作,便成为了一个亟待解决的问题。 一、硬件升级与配置 1.高性能处理器(CPU):选择多核、高线程的处理器,例如Intel的至强系列或AMD的Ry

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

速盾高防cdn是怎么解决网站攻击的?

速盾高防CDN是一种基于云计算技术的网络安全解决方案,可以有效地保护网站免受各种网络攻击的威胁。它通过在全球多个节点部署服务器,将网站内容缓存到这些服务器上,并通过智能路由技术将用户的请求引导到最近的服务器上,以提供更快的访问速度和更好的网络性能。 速盾高防CDN主要采用以下几种方式来解决网站攻击: 分布式拒绝服务攻击(DDoS)防护:DDoS攻击是一种常见的网络攻击手段,攻击者通过向目标网