DGL入坑

2023-10-20 17:40
文章标签 入坑 dgl

本文主要是介绍DGL入坑,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DGL教学

文章目录

  • DGL教学
    • 1. 数据集
    • 2. 图特征
    • 3. Graph Loader and Training
    • 4. 自定义图神经网络

DGL官方文档:https://docs.dgl.ai/index.html

1. 数据集

from dgl.data import DGLDatasetclass MyDataset(DGLDataset):def __init__(self,url=None,raw_dir=None,save_dir=None,force_reload=False,verbose=False):super(MyDataset, self).__init__(name='dataset_name',url=url,raw_dir=raw_dir,save_dir=save_dir,force_reload=force_reload,verbose=verbose)def process(self):# 将原始数据处理为图、标签和数据集划分的掩码passdef __getitem__(self, idx):# 通过idx得到与之对应的一个样本return self.reactant_graphs[i], self.prod_graphs[i], self.labels[i]def __len__(self):# 数据样本的数量return len(self.reactant_graphs)def save(self):# 将处理后的数据保存至 `self.save_path`print('saving dataset to ' + self.path + '.bin')save_info(self.path + '_info.pkl', {'labels': self.labels})dgl.save_graphs(self.path + '_reactant_graphs.bin', self.reactant_graphs)dgl.save_graphs(self.path + '_product_graphs.bin', self.prod_graphsdef load(self):# 从 `self.save_path` 导入处理后的数据print('loading dataset from ' + self.path + '.bin')self.reactant_graphs = dgl.load_graphs(self.path + '_reactant_graphs.bin')[0]self.prod_graphs = dgl.load_graphs(self.path + '_product_graphs.bin')[0]self.labels = load_info(self.path + '_info.pkl')['labels']def has_cache(self):# 检查在 `self.save_path` 中是否存有处理后的数据pass

读取数据到这个类中,数据处理流程如下:对应模板中的process, save, load
在这里插入图片描述

2. 图特征

DGL使用自身的定义的数据结构,这部分应该在上述的process函数中处理,将读入的图转换为DGL图结构

import dgl
graph = dgl.graph((src, dst), num_nodes=n_node) #其中一种定义方式

常用的接口

graph.adj()
graph.ndata['']
graph.edata['']

其中ndataedata对应图的点特征和边特征,可以多个

3. Graph Loader and Training

该步骤将步骤1的自定义数据集类放入图的迭代器中

from dgl.dataloading import GraphDataLoadertrain_dataloader = GraphDataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

然后开始训练

for i in range(args.epoch):model.train()for batch in train_dataloader:data, label = batchy = model(data)loss = ...optimizer.zero_grad()loss.backward()optimizer.step()

4. 自定义图神经网络

常用的图神经网络

from dgl.nn import GraphConv, GATConv, SAGEConv, SGConv, TAGConv

初始化对应好输入和输出就行,源码在githubdgl-master\python\dgl\nn\pytorch\conv上(pytorch)

forward输入为DGL的graph

自定义图神经网络,要搞懂两个函数,graph内置函数

  • update_all
  • apply_edges

第一个是对所有的点进行操作,第二个是对所有的边进行操作,这两个函数有两个输入,分布是message passing函数reduce函数

  • message passing函数

    • 有如下已经定义好的
      在这里插入图片描述

    ​ 举例展示其操作:

    copy_e(‘x’, ‘y’) : 就是将每个节点v,与之关联的边(指向v的边)的特征xgraph.edata['v'])放到点v的点特征y上(graph.ndata['y']

    ​ 这步操作完后,graph.ndata['y']的维度可以写作(为了方便理解):

    n × n e × h n\times n_e \times h n×ne×h n n n表示节点数, n e n_e ne表示每个节点关联的边数(入度,每个节点不同), h h h表示特征维度

    u_add_v(‘x’, ‘x’, ‘y’) :就是将每个节点v,其特征xgraph.ndata['x'])与其邻居节点(指向自己)的特征x相加,放到点v的特征y上(graph.ndata['y']

    ​ 这步操作完后,graph.ndata['y']的维度可以写作(为了方便理解):

    n × n i × h n\times n_i \times h n×ni×h n n n表示节点数, n i n_i ni表示每个节点的邻居数(每个节点不同), h h h表示特征维度

    ​ 其他操作类似,注意:这里u表示是源节点,v表示是目标节点

  • Reduce函数

    就是将上述操作完的数据进行聚合,有如下:
    在这里插入图片描述

    举例:

    sum(‘y’, ‘m’) : 就是将每个节点或者每条边的y特征相加放到m

  • 自定义

    • message passing
    def message(self, edges):f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)msg_h = self.edge_mlp(f)msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']return {'msg_x': msg_x, 'msg_h': msg_h}
    
    • reduce
    def reducer(self, node):msg = torch.sum(node.mailbox['a'], dim=1) * torch.max(node.mailbox['a'], dim=1)[0]return {'m': msg}
    

这篇关于DGL入坑的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/248643

相关文章

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

SpringCloud Alibaba 入坑(二)Nacos 配置中心

超详细的Java知识点路线图 文章目录 前言nacos配置中心介绍使用步骤总结 前言 SpringCloud Alibaba 入坑(一)Nacos 服务注册与发现 上篇文章介绍了nacos作为服务注册中心的用法,本文将介绍下nacos作为配置中心的用法。 nacos配置中心介绍 入坑Spring Cloud Alibaba后发现的nacos确实

GNN-第三方库:DGL【图神经网络框架,支持对异构图的处理,开源相关异构图神经网络的代码,在GCMC、RGCN等业内知名的模型实现上也取得了很好的效果】

一、DGL库的实现与性能 实现GNN并不容易,因为它需要在不规则数据上实现较高的GPU吞吐量。 1、DGL库简介 DGL库的逻辑层使用了顶点域的处理方式,使代码更容易理解。同时,又在底层的内存和运行效率方面做了大量的工作,使得框架可以发挥出更好的性能。 2、DGL库特点 GCMC:DGL的内存优化支持在一个GPU上对MovieLens10M数据集进行训练(原实现需要从CPU中动态加载数据

入坑爬坑必备!vot2016 配置(matlab,python)

环境ubantu18.4 +matlab2017b  A.下载预备 a.vot-toolkit  :https://github.com/votchallenge/vot-toolkit b.trax包:https://github.com/votchallenge/trax       在vot-toolkit下新建native文件夹 把trax放入 c.vot2016 :https:

前端新手小白的Vue3入坑指南

昨天有同学说想暑假在家学一学Vue3,问我有没有什么好的文档,我给他找了一些,然后顺带着,自己也写一篇吧,希望可以给新手小白们一些指引,Vue3欢迎你。 目录 1 项目安装 1.1 初始化项目 1.2 安装初始化依赖 1.3 启动项目  2  一定会用的第三方库 2.1 js-tool-big-box 2.2 less或者sass预处理器 2.3 axios请求库 2.4 UI

PyTorch 入坑十:模型泛化误差与偏差(Bias)、方差(Variance)

问题 阅读正文之前尝试回答以下问题,如果能准确回答,这篇文章不适合你;如果不是,可参考下文。 为什么会有偏差和方差?偏差、方差、噪声是什么?泛化误差、偏差和方差的关系?用图形解释偏差和方差。偏差、方差窘境。偏差、方差与过拟合、欠拟合的关系?偏差、方差与模型复杂度的关系?偏差、方差与bagging、boosting的关系?偏差、方差和K折交叉验证的关系?如何解决偏差、方差问题? 本文主要参考知

linux安装dgl

1.DGL官网、选择与自己cuda、python版本匹配的dgl的whl文件CUDA11.8、python10并下载 2.用pip install运行 pip install /home/u2023170749/download/dgl-2.2.0+cu118-cp310-cp310-manylinux1_x86_64.whl

termux入坑

先安装些东西 pkg install proot python git curl wget vim 获取root权限 termux-chroot 换清华源 termux-change-repo 修改vim的配置文件 termux-chrootvim /usr/share/vim/vimrc 扩充按键 mkdir -p ~/.termux && echo "extra

PWN入坑指南

CTF的PWN题想必是很多小伙伴心里的痛,大多小伙伴不知道PWN该如何入门,不知道该如何系统性学习 0x01开篇介绍 PWN 是一个黑客语法的俚语词 ,是指攻破设备或者系统 。发音类似"砰",对黑客而言,这就是成功实施黑客攻击的声音--砰的一声,被"黑"的电脑或手机就被你操纵了 。 斗哥认为解决PWN题就是利用简单逆向工程后得到代码(源码、字节码、汇编等),分析与研究代码最终发现

入坑,使用第三方SDK开发mavenJspWeb项目

最近公司给了我一个奇怪的任务:用国外一个特定的SDK新建maven项目用开发一个运行在手机上的jsp网站,刚听到这个任务的时候我是懵逼的,除了jsp会一点外,各种没接触过。 刚开始,我的电脑没有这个环境,可视化编辑器都没有,为了能快速完成,我选择用myeclipse8.5,装上去,发现缺少m2eclipse插件插件,后来才知道myeclipse10是有自带的。 安装maven,这个比较简单就不