机器学习周记(第三十一周:文献阅读-GGNN)2024.3.18~2024.3.24

2024-03-25 06:44

本文主要是介绍机器学习周记(第三十一周:文献阅读-GGNN)2024.3.18~2024.3.24,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

摘要

ABSTRACT

1 论文信息

1.1 论文标题

1.2 论文模型

1.2.1 数据处理

1.2.2 门控图神经网络

1.2.3 掩码操作

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

3 相关代码


摘要

  本周阅读了一篇利用图神经网络(GNN)与门控循环单元(GRU)进行配水网络(WDN)水质预测的论文。论文模型(GGNN)实现了扩展图邻接矩阵在有向图中加入双向信息流,从而增强了模型的双向学习能力。同时模型还利用掩码操作模拟了站点故障导致数据缺失的情况,根据正常站点数据也能对故障站点进行预测,并且还能解决模型过拟合或者欠拟合的问题。

ABSTRACT

  This week, We read a paper on water quality prediction in water distribution networks (WDNs) using Graph Neural Networks (GNN) and Gated Recurrent Units (GRU). The paper introduces a model called GGNN, which extends the graph adjacency matrix to incorporate bidirectional information flow in directed graphs, thus enhancing the model's bidirectional learning capability. Additionally, the model utilizes masking operations to simulate data missing due to station failures, enabling the prediction of faulty stations based on normal station data. Moreover, it addresses the issues of model overfitting or underfitting.

1 论文信息

1.1 论文标题

Real-time water quality prediction in water distribution networks using graph neural networks with sparse monitoring data

1.2 论文模型

  论文模型(GGNN)旨在利用门控图神经网络(GGNN)处理网络拓扑结构、流向以及水质监测站的历史氯浓度测量数据来预测配水网络(WDN)中的实时水质。该模型由两个主要部分组成:(1)对供水网络信息进行数据处理,输入到图神经网络中;(2)利用收集到的数据构建模型。

Fig.1 基于GGNN的实时水质预测方法示意图

1.2.1 数据处理

  GGNN模型需要两类数据:传感器监测站的WDN拓扑结构和历史水质监测数据。假设一个WDNn个节点和m条管道组成,配备N_{s}个传感器站监测水质。网络拓扑由图G=(V,E)表示,其中V表示由水库、储罐和连接点组成的节点集,E表示由管道、阀门和泵组成的边集。网络的流向信息和空间拓扑细节通常可以从EPANET等水力模型中获得。利用这些数据构建有向图的邻接矩阵A \in \mathbb{R}^{n\times n},其中每个元素A_{ij}表示水是否从节点i流向节点j (A_{ij}=1)或不流向 (A_{ij}=0)。论文仅在边的权重相等时考虑水流方向。更进一步还可以同时考虑流量的动态变化和加权边。

  通过在WDN中实现的监控和数据采集(SCADA)系统,可以获得各监测站的历史水质数据。该数据采集过程包括在指定的时间窗口内采集水质测量数据,记为T_{c},也表示采集历史数据的周期时间。然后将采集到的数据作为数据集中被监测节点的节点属性,对于未被监测节点,将空值替换为0,得到节点属性X\in \mathbb{R}^{n \times N_{c}}N_{c}表示数据采集周期T_{c}内获得的水质测量次数,对应于指定时间窗口内的时间步数。它是预测下一时刻水质所需数据大小的指标。

1.2.2 门控图神经网络

  为了解决WDN的非欧氏图域带来的挑战,将GGNN架构用于水质预测。GGNN是一种图神经网络,用于处理复杂的图结构数据,如WDN拓扑。它扩展了通常定义在欧氏域上的传统神经网络,使其能够直接处理非欧氏图数据。GGNN模型根据相邻节点和边之间传递的消息为每个节点v\in V计算状态向量h_{v}。状态向量h_{v}表示节点学习到的特征表示,编码了关于图的局部和全局信息。它可以被认为是节点的隐藏状态,从其邻域和整个图中捕获相关信息。最终,状态向量可用于水质预测。GGNN的整体工作流程如Fig.2所示。

Fig.2 GGNN总体架构示意图

  首先,通过扩展邻接矩阵A \in \mathbb{R}^{n \times n},在有向图中加入双向信息流来作为输入。主要通过将邻接矩阵A与其转置连接起来,形成一个扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]来实现的,这样可以同时考虑输入边和输出边。\widehat{A} \in \mathbb{R}^{n \times 2n}捕获了节点之间的复杂关系和消息传播方向,从而增强了GGNN的双向学习能力。

  然后,通过标准线性组合修正线性单元(rectified linear unit, ReLU)激活函数将节点v的节点属性x_{v}从原始空间\mathbb{R}^{N_{c}}映射到新空间\mathbb{R}^{M}的原始隐藏状态h_{v}^{(0)}。这种映射过程有效地扩大了节点属性的大小,使GGNN能够捕获节点属性之间潜在的重要非线性关系。隐藏状态的大小用M表示,是一个决定模型容量的超参数。然而,至关重要的是要与M取得平衡,以防止过拟合并控制训练期间的计算复杂性。

  GGNN以扩展的邻接矩阵\widehat{A}=\left [ A,A^{T} \right ]和映射的节点属性h^{(0)}为输入,在固定的k步上递归计算节点状态以产生最终的状态矩阵h^{(K)}\in \mathbb{R}^{n \times M}。在聚合阶段,利用扩展邻接矩阵\widehat{A}计算聚合向量a_{v}a_{v}表示节点v和相邻节点状态的聚合,聚合向量的计算公式如下:

a_{v}^{(k)}=\widehat{A}^{T}_{v:}\left [ h_{1}^{(k-1)^{T}},...,h_{n}^{(k-1)^{T}} \right ]^{T}+b                                                                              (1)

其中,上标k表示时间步长,\widehat{A}_{v:}\in \mathbb{R}^{n \times 2}是块\widehat{A}中对应节点v的两列,b是偏移向量。在聚合阶段之后,传播阶段采用门控循环单元(gated recurrent units, GRU)机制更新节点状态。GRU传播方程描述如下:

r_{v}^{(k)}=\sigma (W_{r} \cdot a_{v}^{(k)}+U_{r}\cdot h_{v}^{(k-1)})                                                                                       (2)

z_{v}^{(k)}=\sigma (W_{z} \cdot a_{v}^{(k)}+U_{z}\cdot h_{v}^{(k-1)})                                                                                       (3)

\widetilde{h}_{v}^{(k)}=\tanh (W \cdot a_{v}^{(k)}+U\cdot (r_{v}^{(k)}\bigodot h_{v}^{(k-1)}))                                                                   (4)

h_{v}^{(k)}=(1-z_{v}^{(k)})\bigodot h_{v}^{(k-1)}+z_{v}^{(k)}\bigodot \widetilde{h}_{v}^{(k)}                                                                         (5)

其中rz是重置门和更新门;W_{r},W_{z},WU_{r},U_{z},U是每层的权重和偏差;\sigma (\cdot)sigmoid激活函数;\bigodot是元素点积运算。

  GGNN中的聚合和传播步骤允许模型迭代更新和细化节点状态,合并来自节点先前的特征及其邻近节点的特征信息。这个迭代过程捕获了图结构内的动态和交互规则,使GGNN能够学习和表示节点之间的复杂关系和依赖关系。传播步长K(也即GNN层数)决定了GGNN中信息传播的深度。当K=1时,每个节点只能从其近邻节点学习。随着K的增加,GGNN可以从距离K步的节点捕获信息,包括它们的间接连接。K的选择影响模型的学习能力和效率。较高的K值会导致训练较慢以及增加内存需求,而较低的K值会限制每个节点可以学习的依赖关系的数量。因此,K的选择应该在模型性能和计算效率之间取得平衡。

  在使用GRU模块更新节点状态后,使用线性层将更新后的状态h^{(K)}转换为表示每个节点预测状态的\widehat{Y}\in \mathbb{R}^{n}。在本研究中,节点属性为历史水质浓度数据,其预测状态表示模型对每个节点下一时间步水质浓度的预测。这种转换允许模型根据其更新的表示和从邻近节点传播的信息在每个节点生成对水质的预测。

1.2.3 掩码操作

  虽然之前的研究主要采用掩码操作(Maskng Operation来模拟传感器故障,特别是在不利条件下测试模型的鲁棒性,但本文方法在训练阶段利用掩码操作来增强模型对未监测节点的预测能力。在训练过程中,结合掩码操作对解决两个重大挑战至关重要。首先,现有研究通常假设传感器节点的输入,并根据模拟的网络中所有节点的值来计算损失,这在现实世界中是不切实际的,因为获取非传感器节点的测量数据很困难。论文使用模拟模型的合成数据,这样数据虽然完整,但作者并没有使用所有网络节点的所有数据进行训练。相反,只使用了一小部分节点数据。其次,如果模型仅基于传感器节点的输入进行训练,并基于这些节点计算损失,可能会导致过拟合,阻碍模型预测未监测节点的水质的能力。为了克服这些挑战,在训练过程中引入了掩码操作。随机选择指定比例(例如20%)的传感器节点,并通过在每个训练批次中将其输入替换为零进行掩盖。这个屏蔽操作有两个目的。首先,在训练过程中模拟非传感器节点数据的不可用性,使模型能够在观测到的传感器数据之外进行泛化,并学习预测无监测节点的值;其次,它作为正则化技术,防止模型仅依赖有限的传感器输入。通过鼓励模型捕捉传感器节点和非监测节点之间的关系,提高模型的泛化能力,降低过拟合的可能性。需要研究掩码节点的比例,因为它可以平衡模型性能和过拟合。更高的比率会减少可用的信息,增加欠拟合的风险。较低的速率可以提供更多的信息,但可能会导致过拟合。因此,掩码率也是一个十分重要的超参数。

2 相关知识

2.1 图神经网络(GNN)

2.2 图卷积神经网络(GCN)

  需要注意的是,常规任务情境下不会需要节点的信息传播太远。经过6~7个hops,基本上就可以使节点的信息传播到整个网络,这也使得聚合不那么有意义。实验结果也表明,2~3层的网络应该是比较好的,当GCN达到7层时,效果已经变得较差,但是通过在隐藏层间加上残差连接(Residual Connections)可以使效果变好。

3 相关代码

GCN模型定义与图结构数据定义:

import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkxclass GCN(nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features, 4)self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = nn.Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index)  # 输入特征与邻接矩阵h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()out = self.classifier(h)return out, hdef visualize_graph(G, color):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2")plt.show()def visualize_embedding(h, color, epoch=None, loss=None):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])h = h.detach().cpu().numpy()plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")if epoch is not None and loss is not None:plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)plt.show()dataset = KarateClub()
print(f'Dataset: {dataset}')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]
# x:[34, 34](M*F,M:样本数,F:特征维度)
# edge_index:[2, 156](两个数组,第一个为source,第二个为target,156条边)
# y:[34](标签)
# train_mask:[34](指定节点是否有标签,通过此数组可以选择哪些节点计算损失,元素类型为bool)
print(data)
print(dataset.edge_index)G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

数据集KarateClub的图结构:

这篇关于机器学习周记(第三十一周:文献阅读-GGNN)2024.3.18~2024.3.24的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/844148

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

线性代数|机器学习-P36在图中找聚类

文章目录 1. 常见图结构2. 谱聚类 感觉后面几节课的内容跨越太大,需要补充太多的知识点,教授讲得内容跨越较大,一般一节课的内容是书本上的一章节内容,所以看视频比较吃力,需要先预习课本内容后才能够很好的理解教授讲解的知识点。 1. 常见图结构 假设我们有如下图结构: Adjacency Matrix:行和列表示的是节点的位置,A[i,j]表示的第 i 个节点和第 j 个