【DGL】节点分类(GCN、SAGE、自定义)

2024-02-20 09:30

本文主要是介绍【DGL】节点分类(GCN、SAGE、自定义),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

    • 使用dgl进行节点分类(GCN)
      • 数据集
      • 搭建网络
      • 训练
    • 使用dgl进行节点分类(SAGE)
      • 实现SAGE
      • 引入边权
      • 更多自定义操作

使用dgl进行节点分类(GCN)

数据集

dataset = dgl.data.CoraGraphDataset()
print("Number of categories:", dataset.num_classes)
g = dataset[0]

数据集信息:
Cora dataset,引用网络图,其中,节点表示论文,边表示论文的引用。任务是预测给定论文的类别。

NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done loading data from cached files.
Number of categories: 7

其中,含有一个graph:

Graph(num_nodes=2708, num_edges=10556,ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(1433,), dtype=torch.float32)}edata_schemes={})

train_mask: A boolean tensor indicating whether the node is in the training set.
val_mask: A boolean tensor indicating whether the node is in the validation set.
test_mask: A boolean tensor indicating whether the node is in the test set.
label: The ground truth node category.
feat: The node features.

搭建网络

根据Graph Convolutional Network (GCN)搭建两层的图卷积神经网络。每一层通过聚合邻居节点的信息来计算新的节点表示。
在这里插入图片描述

class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hmodel = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)

数学上表示成1 h i ( l + 1 ) = σ ( b ( l ) + ∑ j ∈ N ( i ) 1 c j i h j ( l ) W ( l ) ) h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)}) hi(l+1)=σ(b(l)+jN(i)cji1hj(l)W(l))

模型结构:

GCN((conv1): GraphConv(in=1433, out=16, normalization=both, activation=None)(conv2): GraphConv(in=16, out=7, normalization=both, activation=None)
)

训练

def train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata['feat']labels = g.ndata['label']train_mask = g.ndata['train_mask']val_mask = g.ndata['val_mask']test_mask = g.ndata['test_mask']for e in range(100):logits = model(g, features)pred = logits.argmax(1)loss = F.cross_entropy(logits[train_mask], labels[train_mask])train_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()if(best_val_acc < val_acc):best_val_acc = val_accbest_test_acc = test_accoptimizer.zero_grad()loss.backward()optimizer.step()if(e%5==0):print("In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(e, loss, val_acc, best_val_acc, test_acc, best_test_acc))train(g, model)
In epoch 0, loss: 1.946, val acc: 0.240 (best 0.240), test acc: 0.254 (best 0.254)
In epoch 5, loss: 1.903, val acc: 0.642 (best 0.642), test acc: 0.639 (best 0.639)
In epoch 10, loss: 1.837, val acc: 0.696 (best 0.700), test acc: 0.711 (best 0.715)
In epoch 15, loss: 1.746, val acc: 0.674 (best 0.700), test acc: 0.685 (best 0.715)
In epoch 20, loss: 1.628, val acc: 0.694 (best 0.700), test acc: 0.710 (best 0.715)
In epoch 25, loss: 1.484, val acc: 0.690 (best 0.700), test acc: 0.715 (best 0.715)
In epoch 30, loss: 1.321, val acc: 0.710 (best 0.710), test acc: 0.732 (best 0.732)
In epoch 35, loss: 1.144, val acc: 0.714 (best 0.720), test acc: 0.738 (best 0.737)
In epoch 40, loss: 0.966, val acc: 0.730 (best 0.730), test acc: 0.742 (best 0.742)
In epoch 45, loss: 0.797, val acc: 0.742 (best 0.742), test acc: 0.745 (best 0.745)
In epoch 50, loss: 0.647, val acc: 0.756 (best 0.756), test acc: 0.756 (best 0.756)
In epoch 55, loss: 0.520, val acc: 0.762 (best 0.762), test acc: 0.759 (best 0.759)
In epoch 60, loss: 0.416, val acc: 0.768 (best 0.768), test acc: 0.767 (best 0.765)
In epoch 65, loss: 0.334, val acc: 0.762 (best 0.768), test acc: 0.771 (best 0.765)
In epoch 70, loss: 0.270, val acc: 0.758 (best 0.768), test acc: 0.774 (best 0.765)
In epoch 75, loss: 0.220, val acc: 0.760 (best 0.768), test acc: 0.777 (best 0.765)
In epoch 80, loss: 0.182, val acc: 0.764 (best 0.768), test acc: 0.779 (best 0.765)
In epoch 85, loss: 0.151, val acc: 0.764 (best 0.768), test acc: 0.780 (best 0.765)
In epoch 90, loss: 0.128, val acc: 0.764 (best 0.768), test acc: 0.782 (best 0.765)
In epoch 95, loss: 0.109, val acc: 0.766 (best 0.768), test acc: 0.779 (best 0.765)Process finished with exit code 0

使用dgl进行节点分类(SAGE)

dgl遵循消息传递网络范式2。GraphSAGE convolution (Hamilton et al., 2017)具有以下形式:

h N ( v ) k ← A v e r a g e { h u k − 1 , ∀ u ∈ N ( v ) } h v k ← R e L U ( W k ⋅ C O N C A T ( h v k − 1 , h N ( v ) k ) ) h_\mathcal{N(v)}^k \gets Average\{ h_u ^{k-1} , \forall u \in \mathcal{N}(v) \} \\ h_v^k \gets ReLU(W^k \cdot CONCAT(h_v^{k-1}, h^k _{\mathcal{N}(v)})) hN(v)kAverage{huk1,uN(v)}hvkReLU(WkCONCAT(hvk1,hN(v)k))

实现SAGE

在dgl中有内置的SAGEConv。下面来自己实现:

class SAGEConv(nn.Module):def __init__(self, in_feat, out_feat):super(SAGEConv, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = nn.Linear(in_feat*2, out_feat) # Wdef forward(self, g, h):with g.local_scope():#在这个区域内对g的修改不会同步到原始的图上g.ndata['h'] = hg.update_all(    #对所有的节点和边采用下面的message函数和reduce函数message_func=fn.copy_u("h", "m"), #message函数:将节点特征'h'作为消息传递给邻居,命名为'm'reduce_func=fn.mean("m", "h_N"),  #reduce函数:将接收到的'm'信息取平均,保存至节点特征'h_N')h_N = g.ndata["h_N"]h_total = torch.cat([h, h_N], dim=1)return self.linear(h_total)

依此搭建新的网络:

class Model(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(Model, self).__init__()self.conv1 = SAGEConv(in_feats, h_feats)self.conv2 = SAGEConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hmodel = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)

效果和GCN差不多吧

引入边权

class WeightedSAGEConv(nn.Module):def __init__(self, in_feat, out_feat):super(WeightedSAGEConv, self).__init__()# A linear submodule for projecting the input and neighbor feature to the output.self.linear = nn.Linear(in_feat * 2, out_feat)def forward(self, g, h, w):with g.local_scope():g.ndata["h"] = hg.edata["w"] = wg.update_all(message_func=fn.u_mul_e("h", "w", "m"), #节点特征'h' 与 邻居间的边特征'w' 的乘积作为消息传递给邻居,记作'm'reduce_func=fn.mean("m", "h_N"), #将接收到的'm'信息取平均,保存至节点特征'h_N')h_N = g.ndata["h_N"]h_total = torch.cat([h, h_N], dim=1)return self.linear(h_total)class Model(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(Model, self).__init__()self.conv1 = WeightedSAGEConv(in_feats, h_feats)self.conv2 = WeightedSAGEConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))#数据中没有边特征,在这里手动添加h = F.relu(h)h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))return hmodel = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)

更多自定义操作

见dgl.function

内置函数 dgl.function.u_add_v('hu','hv',' he')等价于:

def message_func(edges):#返回值为字典形式return {'he': edges.src['hu'] + edges.dst['hv']}

  1. https://docs.dgl.ai/generated/dgl.nn.pytorch.conv.GraphConv.html#dgl.nn.pytorch.conv.GraphConv ↩︎

  2. Neural Message Passing for Quantum Chemistry ↩︎

这篇关于【DGL】节点分类(GCN、SAGE、自定义)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/727760

相关文章

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

day-51 合并零之间的节点

思路 直接遍历链表即可,遇到val=0跳过,val非零则加在一起,最后返回即可 解题过程 返回链表可以有头结点,方便插入,返回head.next Code /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*

自定义类型:结构体(续)

目录 一. 结构体的内存对齐 1.1 为什么存在内存对齐? 1.2 修改默认对齐数 二. 结构体传参 三. 结构体实现位段 一. 结构体的内存对齐 在前面的文章里我们已经讲过一部分的内存对齐的知识,并举出了两个例子,我们再举出两个例子继续说明: struct S3{double a;int b;char c;};int mian(){printf("%zd\n",s

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

【每日一题】LeetCode 2181.合并零之间的节点(链表、模拟)

【每日一题】LeetCode 2181.合并零之间的节点(链表、模拟) 题目描述 给定一个链表,链表中的每个节点代表一个整数。链表中的整数由 0 分隔开,表示不同的区间。链表的开始和结束节点的值都为 0。任务是将每两个相邻的 0 之间的所有节点合并成一个节点,新节点的值为原区间内所有节点值的和。合并后,需要移除所有的 0,并返回修改后的链表头节点。 思路分析 初始化:创建一个虚拟头节点

Oracle type (自定义类型的使用)

oracle - type   type定义: oracle中自定义数据类型 oracle中有基本的数据类型,如number,varchar2,date,numeric,float....但有时候我们需要特殊的格式, 如将name定义为(firstname,lastname)的形式,我们想把这个作为一个表的一列看待,这时候就要我们自己定义一个数据类型 格式 :create or repla

JS和jQuery获取节点的兄弟,父级,子级元素

原文转自http://blog.csdn.net/duanshuyong/article/details/7562423 先说一下JS的获取方法,其要比JQUERY的方法麻烦很多,后面以JQUERY的方法作对比。 JS的方法会比JQUERY麻烦很多,主要则是因为FF浏览器,FF浏览器会把你的换行也当最DOM元素。 <div id="test"><div></div><div></div

HTML5自定义属性对象Dataset

原文转自HTML5自定义属性对象Dataset简介 一、html5 自定义属性介绍 之前翻译的“你必须知道的28个HTML5特征、窍门和技术”一文中对于HTML5中自定义合法属性data-已经做过些介绍,就是在HTML5中我们可以使用data-前缀设置我们需要的自定义属性,来进行一些数据的存放,例如我们要在一个文字按钮上存放相对应的id: <a href="javascript:" d