使用DGL完成节点分类任务

2024-02-20 09:30

本文主要是介绍使用DGL完成节点分类任务,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

更多图神经网络和深度学习内容请关注:
在这里插入图片描述

节点分类任务概述

节点分类(node classification)任务是在图数据处理中最流行任务之一,一个模型需要预测每个节点属于哪个类别。

在图神经网络出现之前,用于结点分类任务的方法可归为两大类:

  • 仅使用连通性(如DeepWalk或node2vec)
  • 简单地结合连通性和节点自身的特征

相比之下,GNNs是一个通过结合局部邻域(广义上的邻居,包含结点自身)的连通性及其特征来获得节点表征的方法。

Kipf等人将节点分类问题描述为一个半监督的节点分类任务。图神经网络只需要一小部分已标记的节点,即可准确地预测其他节点的类别。

本文将展示如何在Cora数据集中(即以论文为节点,以论文引用为边的引文网络)使用少量标签构建半监督节点分类任务的GNN模型。其具体任务为预测给定论文的类别。每个论文节点均包含一个单词计数向量(word count vector)作为它的特征,这些特征进行了归一化(使其总和为1),参考论文第5.2节。

使用DGL完成节点分类

导入相对应的包

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
Using backend: pytorch

加载数据集

dataset = dgl.data.CoraGraphDataset()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.

DGL数据集对象可以包含一个或多个图。一般情况下,整图分类任务数据集包含多个图,边预测节点分类数据集只包含一个图,如节点分类任务中的Cora数据集只包含一个图。

g = dataset[0]

DGL图将节点特征和边特征分别存储在两个类似字典的属性ndataedata中,在Cora数据集中,图包含以下节点特征(其他数据集也类似):

  • train_mask:布尔张量,表示节点是否在训练集中。
  • val_mask:布尔张量,表示节点是否在验证集中。
  • test_mask:布尔张量,表示节点是否在测试集中。
  • label:节点类别。
  • feat:节点特征。
print("Node feature")
print(g.ndata)print("Edge feature")
print(g.edata)
Node feature
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]])}
Edge feature
{}

定义图卷积网络(GCN)

本文将构建一个两层图卷积网络(GCN)。其中每一层都通过聚合邻居信息来计算新的节点表示,若需要构建多层GCN网络,我们可简单地堆叠dgl.nn.GraphConv模块,这些都模块继承于torch.nn.Module。(假设DGL使用的后端框架为PyTorch)

from dgl.nn import GraphConvclass GCN(nn.Module):def __init__(self, in_feats, h_feats, num_class):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_class)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h#设置参数
in_feats = g.ndata["feat"].shape[1]
h_feats = 16
num_class = (torch.max(g.ndata["label"]) + 1).item() #或者 num_class = dataset.num_classes
# 创建模型
model = GCN(in_feats, h_feats, num_class)

DGL提供了许多流行的邻居聚合模块的实现,我们可以使用一行代码即可轻松调用它们。

训练GCN模型

GCN模型训练过程类似其他PyTorch神经网络训练过程。

def train(g, model, learning_rate=0.01, num_epoch=100):optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)best_val_acc = 0best_test_acc = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for epoch in range(num_epoch):result = model(g, features)pred = result.argmax(1)loss = F.cross_entropy(result[train_mask], labels[train_mask])train_acc = (pred[train_mask]==labels[train_mask]).float().mean()val_acc  = (pred[val_mask]==labels[val_mask]).float().mean()        test_acc  = (pred[test_mask]==labels[test_mask]).float().mean()if best_val_acc < val_acc:best_val_acc, best_test_acc = val_acc, test_accoptimizer.zero_grad()loss.backward()optimizer.step()if epoch % 5 == 0:print('In epoch {}, loss: {}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(epoch, loss, val_acc, best_val_acc, test_acc, best_test_acc))if __name__ == "__main__":train(g, model, num_epoch=200, learning_rate=0.002)
In epoch 0, loss: 1.0601081612549024e-06, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 5, loss: 9.979492006095825e-07, val acc: 0.760 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 10, loss: 9.494142432231456e-07, val acc: 0.762 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 15, loss: 9.017308570946625e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 20, loss: 8.557504429518303e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 25, loss: 8.157304023370671e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 30, loss: 7.71452903336467e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 35, loss: 7.322842634494009e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 40, loss: 6.948185955479858e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 45, loss: 6.624618436035234e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 50, loss: 6.292536340879451e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 55, loss: 6.028573125149705e-07, val acc: 0.764 (best 0.764), test acc: 0.764 (best 0.764)
In epoch 60, loss: 5.807185630146705e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 65, loss: 5.534708407139988e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 70, loss: 5.381440359997214e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 75, loss: 5.117477144267468e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 80, loss: 4.913119937555166e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 85, loss: 4.759851037761109e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 90, loss: 4.5640075541086844e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 95, loss: 4.368164354673354e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 100, loss: 4.2319251747358066e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 105, loss: 4.07865627494175e-07, val acc: 0.764 (best 0.764), test acc: 0.766 (best 0.764)
In epoch 110, loss: 3.993507107225014e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 115, loss: 3.840238207430957e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 120, loss: 3.755089039714221e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 125, loss: 3.6358795796331833e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 130, loss: 3.5081561122751737e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 135, loss: 3.414492084630183e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 140, loss: 3.363402356626466e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 145, loss: 3.218648316760664e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 150, loss: 3.159043444611598e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 155, loss: 3.0568645570383524e-07, val acc: 0.762 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 160, loss: 2.988745109178126e-07, val acc: 0.764 (best 0.764), test acc: 0.765 (best 0.764)
In epoch 165, loss: 2.895080797316041e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 170, loss: 2.792901625525701e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 175, loss: 2.733296753376635e-07, val acc: 0.766 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 180, loss: 2.673692165444663e-07, val acc: 0.764 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 185, loss: 2.614087861729786e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 190, loss: 2.53745326972421e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)
In epoch 195, loss: 2.486363541720493e-07, val acc: 0.762 (best 0.766), test acc: 0.765 (best 0.765)

完整代码为

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from dgl.nn import GraphConvclass GCN(nn.Module):"""GCN network"""def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model, num_epoch = 100, learning_rate =  0.001):"""train function"""optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)best_val_accurate = 0best_test_accurate = 0features = g.ndata["feat"]labels = g.ndata["label"]train_mask = g.ndata["train_mask"]test_mask = g.ndata["test_mask"]val_mask = g.ndata["val_mask"]for e in range(num_epoch):#forwardresult = model(g, features)#predictionpred = result.argmax(dim=1)#Lossloss = F.cross_entropy(result[train_mask], labels[train_mask])#compute accuratetrain_accurate = (pred[train_mask]==labels[train_mask]).float().mean()test_accurate = (pred[test_mask]==labels[test_mask]).float().mean()val_accurate = (pred[val_mask]==labels[val_mask]).float().mean()if best_val_accurate < val_accurate:best_val_accurate, best_test_accurate = val_accurate, test_accurate#backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_accurate, best_val_accurate, test_accurate, best_test_accurate))def main():dataset = CoraGraphDataset()g = dataset[0]in_feats = g.ndata["feat"].shape[1]h_feats = 16num_classes = dataset.num_classesmodel = GCN(in_feats, h_feats, num_classes)train(g, model)if __name__ == "__main__":main()
  NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.946, val acc: 0.104 (best 0.104), test acc: 0.114 (best 0.114)
In epoch 5, loss: 1.942, val acc: 0.276 (best 0.276), test acc: 0.314 (best 0.314)
In epoch 10, loss: 1.936, val acc: 0.452 (best 0.452), test acc: 0.452 (best 0.452)
In epoch 15, loss: 1.929, val acc: 0.546 (best 0.546), test acc: 0.549 (best 0.549)
In epoch 20, loss: 1.921, val acc: 0.612 (best 0.612), test acc: 0.631 (best 0.631)
In epoch 25, loss: 1.913, val acc: 0.640 (best 0.640), test acc: 0.647 (best 0.647)
In epoch 30, loss: 1.904, val acc: 0.654 (best 0.654), test acc: 0.670 (best 0.670)
In epoch 35, loss: 1.895, val acc: 0.684 (best 0.684), test acc: 0.692 (best 0.692)
In epoch 40, loss: 1.886, val acc: 0.690 (best 0.692), test acc: 0.695 (best 0.693)
In epoch 45, loss: 1.876, val acc: 0.700 (best 0.700), test acc: 0.694 (best 0.694)
In epoch 50, loss: 1.866, val acc: 0.706 (best 0.708), test acc: 0.701 (best 0.699)
In epoch 55, loss: 1.855, val acc: 0.710 (best 0.710), test acc: 0.698 (best 0.698)
In epoch 60, loss: 1.844, val acc: 0.708 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 65, loss: 1.833, val acc: 0.704 (best 0.712), test acc: 0.702 (best 0.699)
In epoch 70, loss: 1.821, val acc: 0.702 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 75, loss: 1.809, val acc: 0.704 (best 0.712), test acc: 0.705 (best 0.699)
In epoch 80, loss: 1.796, val acc: 0.706 (best 0.712), test acc: 0.704 (best 0.699)
In epoch 85, loss: 1.783, val acc: 0.702 (best 0.712), test acc: 0.706 (best 0.699)
In epoch 90, loss: 1.769, val acc: 0.694 (best 0.712), test acc: 0.703 (best 0.699)
In epoch 95, loss: 1.755, val acc: 0.692 (best 0.712), test acc: 0.706 (best 0.699)

使用GPU进行训练

在GPU上进行训练需要使用to方法将模型和图都放到GPU上,PyTorch训练其他神经网络模型类似。

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

参考

翻译整理自Node Classification with DGL

这篇关于使用DGL完成节点分类任务的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/727761

相关文章

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

使用opencv优化图片(画面变清晰)

文章目录 需求影响照片清晰度的因素 实现降噪测试代码 锐化空间锐化Unsharp Masking频率域锐化对比测试 对比度增强常用算法对比测试 需求 对图像进行优化,使其看起来更清晰,同时保持尺寸不变,通常涉及到图像处理技术如锐化、降噪、对比度增强等 影响照片清晰度的因素 影响照片清晰度的因素有很多,主要可以从以下几个方面来分析 1. 拍摄设备 相机传感器:相机传

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]