GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?

2023-12-13 00:52

本文主要是介绍GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 一个端到端的 同构图(Cora数据集)节点分类代码:

import argparseimport dgl
import dgl.nn as dglnnimport torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDatasetclass SAGE(nn.Module):def __init__(self, in_size, hid_size, out_size):super().__init__()self.layers = nn.ModuleList()# two-layer GraphSAGE-meanself.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))self.dropout = nn.Dropout(0.5)def forward(self, graph, x):h = self.dropout(x)for l, layer in enumerate(self.layers):h = layer(graph, h)if l != len(self.layers) - 1:h = F.relu(h)h = self.dropout(h)return hdef evaluate(g, features, labels, mask, model):model.eval()with torch.no_grad():logits = model(g, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)def train(g, features, labels, masks, model):# define train/val samples, loss function and optimizertrain_mask, val_mask = masksloss_fcn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)# training loopfor epoch in range(200):model.train()logits = model(g, features)loss = loss_fcn(logits[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()acc = evaluate(g, features, labels, val_mask, model)print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(epoch, loss.item(), acc))if __name__ == "__main__":parser = argparse.ArgumentParser(description="GraphSAGE")parser.add_argument("--dataset",type=str,default="cora",help="Dataset name ('cora', 'citeseer', 'pubmed')",)parser.add_argument("--dt",type=str,default="float",help="data type(float, bfloat16)",)args = parser.parse_args()print(f"Training with DGL built-in GraphSage module")# load and preprocess datasettransform = (AddSelfLoop())  # by default, it will first remove self-loops to prevent duplicationif args.dataset == "cora":data = CoraGraphDataset(transform=transform)elif args.dataset == "citeseer":data = CiteseerGraphDataset(transform=transform)elif args.dataset == "pubmed":data = PubmedGraphDataset(transform=transform)else:raise ValueError("Unknown dataset: {}".format(args.dataset))g = data[0]device = torch.device("cuda" if torch.cuda.is_available() else "cpu")g = g.int().to(device)features = g.ndata["feat"]labels = g.ndata["label"]masks = g.ndata["train_mask"], g.ndata["val_mask"]# create GraphSAGE modelin_size = features.shape[1]out_size = data.num_classesmodel = SAGE(in_size, 16, out_size).to(device)# convert model and graph to bfloat16 if neededif args.dt == "bfloat16":g = dgl.to_bfloat16(g)features = features.to(dtype=torch.bfloat16)model = model.to(dtype=torch.bfloat16)# model trainingprint("Training...")train(g, features, labels, masks, model)# test the modelprint("Testing...")acc = evaluate(g, features, labels, g.ndata["test_mask"], model)print("Test accuracy {:.4f}".format(acc))

2. GraphSAGE的实现 : SAGEConv 类:

我们先来介绍一下DGL对GraphSAGE这个模型的实现:SAGEConv() 在三方库的下述位置:

这篇关于GraphSAGE 到底在训练什么? 图上的Mini-Batch 是怎么训练的 ?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/486576

相关文章

MySql死锁怎么排查的方法实现

《MySql死锁怎么排查的方法实现》本文主要介绍了MySql死锁怎么排查的方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录前言一、死锁排查方法1. 查看死锁日志方法 1:启用死锁日志输出方法 2:检查 mysql 错误

Rsnapshot怎么用? 基于Rsync的强大Linux备份工具使用指南

《Rsnapshot怎么用?基于Rsync的强大Linux备份工具使用指南》Rsnapshot不仅可以备份本地文件,还能通过SSH备份远程文件,接下来详细介绍如何安装、配置和使用Rsnaps... Rsnapshot 是一款开源的文件系统快照工具。它结合了 Rsync 和 SSH 的能力,可以帮助你在 li

电脑密码怎么设置? 一文读懂电脑密码的详细指南

《电脑密码怎么设置?一文读懂电脑密码的详细指南》为了保护个人隐私和数据安全,设置电脑密码显得尤为重要,那么,如何在电脑上设置密码呢?详细请看下文介绍... 设置电脑密码是保护个人隐私、数据安全以及系统安全的重要措施,下面以Windows 11系统为例,跟大家分享一下设置电脑密码的具体办php法。Windo

怎么关闭Ubuntu无人值守升级? Ubuntu禁止自动更新的技巧

《怎么关闭Ubuntu无人值守升级?Ubuntu禁止自动更新的技巧》UbuntuLinux系统禁止自动更新的时候,提示“无人值守升级在关机期间,请不要关闭计算机进程”,该怎么解决这个问题?详细请看... 本教程教你如何处理无人值守的升级,即 Ubuntu linux 的自动系统更新。来源:https://

Ubuntu系统怎么安装Warp? 新一代AI 终端神器安装使用方法

《Ubuntu系统怎么安装Warp?新一代AI终端神器安装使用方法》Warp是一款使用Rust开发的现代化AI终端工具,该怎么再Ubuntu系统中安装使用呢?下面我们就来看看详细教程... Warp Terminal 是一款使用 Rust 开发的现代化「AI 终端」工具。最初它只支持 MACOS,但在 20

LinuxMint怎么安装? Linux Mint22下载安装图文教程

《LinuxMint怎么安装?LinuxMint22下载安装图文教程》LinuxMint22发布以后,有很多新功能,很多朋友想要下载并安装,该怎么操作呢?下面我们就来看看详细安装指南... linux Mint 是一款基于 Ubuntu 的流行发行版,凭借其现代、精致、易于使用的特性,深受小伙伴们所喜爱。对

macOS怎么轻松更换App图标? Mac电脑图标更换指南

《macOS怎么轻松更换App图标?Mac电脑图标更换指南》想要给你的Mac电脑按照自己的喜好来更换App图标?其实非常简单,只需要两步就能搞定,下面我来详细讲解一下... 虽然 MACOS 的个性化定制选项已经「缩水」,不如早期版本那么丰富,www.chinasem.cn但我们仍然可以按照自己的喜好来更换

Ubuntu 怎么启用 Universe 和 Multiverse 软件源?

《Ubuntu怎么启用Universe和Multiverse软件源?》在Ubuntu中,软件源是用于获取和安装软件的服务器,通过设置和管理软件源,您可以确保系统能够从可靠的来源获取最新的软件... Ubuntu 是一款广受认可且声誉良好的开源操作系统,允许用户通过其庞大的软件包来定制和增强计算体验。这些软件

Ubuntu 24.04 LTS怎么关闭 Ubuntu Pro 更新提示弹窗?

《Ubuntu24.04LTS怎么关闭UbuntuPro更新提示弹窗?》Ubuntu每次开机都会弹窗提示安全更新,设置里最多只能取消自动下载,自动更新,但无法做到直接让自动更新的弹窗不出现,... 如果你正在使用 Ubuntu 24.04 LTS,可能会注意到——在使用「软件更新器」或运行 APT 命令时,

TP-LINK/水星和hasivo交换机怎么选? 三款网管交换机系统功能对比

《TP-LINK/水星和hasivo交换机怎么选?三款网管交换机系统功能对比》今天选了三款都是”8+1″的2.5G网管交换机,分别是TP-LINK水星和hasivo交换机,该怎么选呢?这些交换机功... TP-LINK、水星和hasivo这三台交换机都是”8+1″的2.5G网管交换机,我手里的China编程has