使用DGL完成节点分类任务

2024-02-20 09:30

本文主要是介绍使用DGL完成节点分类任务,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

节点分类任务概述

节点分类(node classification)任务是在图数据处理中最流行任务之一,一个模型需要预测每个节点属于哪个类别。

在图神经网络出现之前,用于结点分类任务的方法可归为两大类:

  • 仅使用连通性(如DeepWalk或node2vec)
  • 简单地结合连通性和节点自身的特征

相比之下,GNNs是一个通过结合局部邻域(广义上的邻居,包含结点自身)的连通性及其特征来获得节点表征的方法。

Kipf等人将节点分类问题描述为一个半监督的节点分类任务。图神经网络只需要一小部分已标记的节点,即可准确地预测其他节点的类别。

本文将展示如何在Cora数据集中(即以论文为节点,以论文引用为边的引文网络)使用少量标签构建半监督节点分类任务的GNN模型。其具体任务为预测给定论文的类别。每个论文节点均包含一个单词计数向量(word count vector)作为它的特征,这些特征进行了归一化(使其总和为1),参考论文第5.2节。

使用DGL完成节点分类

导入相对应的包

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
Using backend: pytorch

加载数据集

dataset = dgl.data.CoraGraphDataset()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.

DGL数据集对象可以包含一个或多个图。一般情况下,整图分类任务数据集包含多个图,边预测节点分类数据集只包含一个图,如节点分类任务中的Cora数据集只包含一个图。

g = dataset[0]

DGL图将节点特征和边特征分别存储在两个类似字典的属性ndataedata中,在Cora数据集中,图包含以下节点特征(其他数据集也类似):

  • train_mask:布尔张量,表示节点是否在训练集中。
  • val_mask:布尔张量,表示节点是否在验证集中。
  • test_mask:布尔张量,表示节点是否在测试集中。
  • label:节点类别。
  • feat:节点特征。
print("Node feature")
print(g.ndata)print("Edge feature")
print(g.edata)
Node feature
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])}
Edge feature
{}

定义图卷积网络(GCN)

本文将构建一个两层图卷积网络(GCN)。其中每一层都通过聚合邻居信息来计算新的节点表示,若需要构建多层GCN网络,我们可简单地堆叠dgl.nn.GraphConv模块,这些都模块继承于torch.nn.Module。(假设DGL使用的后端框架为PyTorch)

from dgl.nn import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_class):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_class)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h#设置参数
in_feats = g.ndata["feat"].shape[1]
h_feats = 16
num_class = (torch.max(g.ndata["label"]) + 1).item() #或者 num_class = dataset.num_classes
# 创建模型
model = GCN(in_feats, h_feats, num_class)

DGL提供了许多流行的邻居聚合模块的实现,我们可以使用一行代码即可轻松调用它们。

训练GCN模型

GCN模型训练过程类似其他PyTorch神经网络训练过程。

def train(g, model, learning_rate=0.01, num_epoch=100):optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for epoch in range(num_epoch):result = model(g, features)pred = result.argmax(1)loss = F.cross_entropy(result[train_mask], labels[train_mask])train_acc = (pred[train_mask]==labels[train_mask]).float().mean()val_acc  = (pred[val_mask]==labels[val_mask]).float().mean()        test_acc  = (pred[test_mask]==labels[test_mask]).float().mean()if best_val_acc < val_acc:best_val_acc, best_test_acc = val_acc, test_accoptimizer.zero_grad()loss.backward()optimizer.step()if epoch % 5 == 0:print('In epoch {}, loss: {}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(epoch, loss, val_acc, best_val_acc, test_acc, best_test_acc))if __name__ == "__main__":train(g, model, num_epoch=200, learning_rate=0.002)
In epoch 0, loss: 1.0601081612549024e-06, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 5, loss: 9.979492006095825e-07, val acc: 0.760 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 10, loss: 9.494142432231456e-07, val acc: 0.762 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 15, loss: 9.017308570946625e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 20, loss: 8.557504429518303e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 25, loss: 8.157304023370671e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 30, loss: 7.71452903336467e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 35, loss: 7.322842634494009e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 40, loss: 6.948185955479858e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 45, loss: 6.624618436035234e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 50, loss: 6.292536340879451e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 55, loss: 6.028573125149705e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 60, loss: 5.807185630146705e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 65, loss: 5.534708407139988e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 70, loss: 5.381440359997214e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 75, loss: 5.117477144267468e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 80, loss: 4.913119937555166e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 85, loss: 4.759851037761109e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 90, loss: 4.5640075541086844e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 95, loss: 4.368164354673354e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 100, loss: 4.2319251747358066e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 105, loss: 4.07865627494175e-07, val acc: 0.764 (best 0.764), test acc: 0.766 (best 0.764)
In epoch 110, loss: 3.993507107225014e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 115, loss: 3.840238207430957e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 120, loss: 3.755089039714221e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 125, loss: 3.6358795796331833e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 130, loss: 3.5081561122751737e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 135, loss: 3.414492084630183e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 140, loss: 3.363402356626466e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 145, loss: 3.218648316760664e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 150, loss: 3.159043444611598e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 155, loss: 3.0568645570383524e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 160, loss: 2.988745109178126e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 165, loss: 2.895080797316041e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 170, loss: 2.792901625525701e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 175, loss: 2.733296753376635e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 180, loss: 2.673692165444663e-07, val acc: 0.764 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 185, loss: 2.614087861729786e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 190, loss: 2.53745326972421e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 195, loss: 2.486363541720493e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)

完整代码为

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.nn import GraphConvclass GCN(nn.Module):"""GCN network"""def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model, num_epoch = 100, learning_rate =  0.001):"""train function"""optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)best_val_accurate = 0best_test_accurate = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for e in range(num_epoch):#forwardresult = model(g, features)#predictionpred = result.argmax(dim=1)#Lossloss = F.cross_entropy(result[train_mask], labels[train_mask])#compute accuratetrain_accurate = (pred[train_mask]==labels[train_mask]).float().mean()test_accurate = (pred[test_mask]==labels[test_mask]).float().mean()val_accurate = (pred[val_mask]==labels[val_mask]).float().mean()if best_val_accurate < val_accurate:best_val_accurate, best_test_accurate = val_accurate, test_accurate#backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_accurate, best_val_accurate, test_accurate, best_test_accurate))def main():dataset = CoraGraphDataset()g = dataset[0]in_feats = g.ndata["feat"].shape[1]h_feats = 16num_classes = dataset.num_classesmodel = GCN(in_feats, h_feats, num_classes)train(g, model)if __name__ == "__main__":main()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.946, val acc: 0.104 (best 0.104), test acc: 0.114 (best 0.114)
In epoch 5, loss: 1.942, val acc: 0.276 (best 0.276), test acc: 0.314 (best 0.314)
In epoch 10, loss: 1.936, val acc: 0.452 (best 0.452), test acc: 0.452 (best 0.452)
In epoch 15, loss: 1.929, val acc: 0.546 (best 0.546), test acc: 0.549 (best 0.549)
In epoch 20, loss: 1.921, val acc: 0.612 (best 0.612), test acc: 0.631 (best 0.631)
In epoch 25, loss: 1.913, val acc: 0.640 (best 0.640), test acc: 0.647 (best 0.647)
In epoch 30, loss: 1.904, val acc: 0.654 (best 0.654), test acc: 0.670 (best 0.670)
In epoch 35, loss: 1.895, val acc: 0.684 (best 0.684), test acc: 0.692 (best 0.692)
In epoch 40, loss: 1.886, val acc: 0.690 (best 0.692), test acc: 0.695 (best 0.693)
In epoch 45, loss: 1.876, val acc: 0.700 (best 0.700), test acc: 0.694 (best 0.694)
In epoch 50, loss: 1.866, val acc: 0.706 (best 0.708), test acc: 0.701 (best 0.699)
In epoch 55, loss: 1.855, val acc: 0.710 (best 0.710), test acc: 0.698 (best 0.698)
In epoch 60, loss: 1.844, val acc: 0.708 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 65, loss: 1.833, val acc: 0.704 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 70, loss: 1.821, val acc: 0.702 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 75, loss: 1.809, val acc: 0.704 (best 0.712), test acc: 0.705 (best 0.699)
In epoch 80, loss: 1.796, val acc: 0.706 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 85, loss: 1.783, val acc: 0.702 (best 0.712), test acc: 0.706 (best 0.699)
In epoch 90, loss: 1.769, val acc: 0.694 (best 0.712), test acc: 0.703 (best 0.699)
In epoch 95, loss: 1.755, val acc: 0.692 (best 0.712), test acc: 0.706 (best 0.699)

使用GPU进行训练

在GPU上进行训练需要使用to方法将模型和图都放到GPU上,PyTorch训练其他神经网络模型类似。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

参考

翻译整理自Node Classification with DGL

这篇关于使用DGL完成节点分类任务的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/727761

相关文章

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态

prometheus如何使用pushgateway监控网路丢包

《prometheus如何使用pushgateway监控网路丢包》:本文主要介绍prometheus如何使用pushgateway监控网路丢包问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录监控网路丢包脚本数据图表总结监控网路丢包脚本[root@gtcq-gt-monitor-prome

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期

java使用protobuf-maven-plugin的插件编译proto文件详解

《java使用protobuf-maven-plugin的插件编译proto文件详解》:本文主要介绍java使用protobuf-maven-plugin的插件编译proto文件,具有很好的参考价... 目录protobuf文件作为数据传输和存储的协议主要介绍在Java使用maven编译proto文件的插件

SpringBoot线程池配置使用示例详解

《SpringBoot线程池配置使用示例详解》SpringBoot集成@Async注解,支持线程池参数配置(核心数、队列容量、拒绝策略等)及生命周期管理,结合监控与任务装饰器,提升异步处理效率与系统... 目录一、核心特性二、添加依赖三、参数详解四、配置线程池五、应用实践代码说明拒绝策略(Rejected

C++ Log4cpp跨平台日志库的使用小结

《C++Log4cpp跨平台日志库的使用小结》Log4cpp是c++类库,本文详细介绍了C++日志库log4cpp的使用方法,及设置日志输出格式和优先级,具有一定的参考价值,感兴趣的可以了解一下... 目录一、介绍1. log4cpp的日志方式2.设置日志输出的格式3. 设置日志的输出优先级二、Window

Ubuntu如何分配​​未使用的空间

《Ubuntu如何分配​​未使用的空间》Ubuntu磁盘空间不足,实际未分配空间8.2G因LVM卷组名称格式差异(双破折号误写)导致无法扩展,确认正确卷组名后,使用lvextend和resize2fs... 目录1:原因2:操作3:报错5:解决问题:确认卷组名称​6:再次操作7:验证扩展是否成功8:问题已解