详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

2024-05-09 04:44

本文主要是介绍详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Unionimport torch
import torch.nn.functional as F
from torch import Tensorfrom torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (Adj,NoneType,OptTensor,PairTensor,SparseTensor,
)
from torch_geometric.utils import softmaxif typing.TYPE_CHECKING:from typing import overload
else:from torch.jit import _overload_method as overload[docs]class TransformerConv(MessagePassing):r"""The graph transformer operator from the `"Masked Label Prediction:Unified Message Passing Model for Semi-Supervised Classification"<https://arxiv.org/abs/2009.03509>`_ paper... math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},where the attention coefficients :math:`\alpha_{i,j}` are computed viamulti-head dot product attention:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}{\sqrt{d}} \right)Args:in_channels (int or tuple): Size of each input sample, or :obj:`-1` toderive the size from the first input(s) to the forward method.A tuple corresponds to the sizes of source and targetdimensionalities.out_channels (int): Size of each output sample.heads (int, optional): Number of multi-head-attentions.(default: :obj:`1`)concat (bool, optional): If set to :obj:`False`, the multi-headattentions are averaged instead of concatenated.(default: :obj:`True`)beta (bool, optional): If set, will combine aggregation andskip information via.. math::\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)dropout (float, optional): Dropout probability of the normalizedattention coefficients which exposes each node to a stochasticallysampled neighborhood during training. (default: :obj:`0`)edge_dim (int, optional): Edge feature dimensionality (in casethere are any). Edge features are added to the keys afterlinear transformation, that is, prior to computing theattention dot product. They are also added to final valuesafter the same linear transformation. The model is:.. math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}\right),where the attention coefficients :math:`\alpha_{i,j}` are nowcomputed via:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}{\sqrt{d}} \right)(default :obj:`None`)bias (bool, optional): If set to :obj:`False`, the layer will not learnan additive bias. (default: :obj:`True`)root_weight (bool, optional): If set to :obj:`False`, the layer willnot add the transformed root node features to the output and theoption  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)**kwargs (optional): Additional arguments of:class:`torch_geometric.nn.conv.MessagePassing`."""_alpha: OptTensordef __init__(self,in_channels: Union[int, Tuple[int, int]],out_channels: int,heads: int = 1,concat: bool = True,beta: bool = False,dropout: float = 0.,edge_dim: Optional[int] = None,bias: bool = True,root_weight: bool = True,**kwargs,):kwargs.setdefault('aggr', 'add')super().__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.beta = beta and root_weightself.root_weight = root_weightself.concat = concatself.dropout = dropoutself.edge_dim = edge_dimself._alpha = Noneif isinstance(in_channels, int):in_channels = (in_channels, in_channels)self.lin_key = Linear(in_channels[0], heads * out_channels)self.lin_query = Linear(in_channels[1], heads * out_channels)self.lin_value = Linear(in_channels[0], heads * out_channels)if edge_dim is not None:self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)else:self.lin_edge = self.register_parameter('lin_edge', None)if concat:self.lin_skip = Linear(in_channels[1], heads * out_channels,bias=bias)if self.beta:self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)else:self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)if self.beta:self.lin_beta = Linear(3 * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)self.reset_parameters()[docs]    def reset_parameters(self):super().reset_parameters()self.lin_key.reset_parameters()self.lin_query.reset_parameters()self.lin_value.reset_parameters()if self.edge_dim:self.lin_edge.reset_parameters()self.lin_skip.reset_parameters()if self.beta:self.lin_beta.reset_parameters()@overloaddef forward(self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: NoneType = None,) -> Tensor:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Tensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: SparseTensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, SparseTensor]:pass[docs]    def forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: Optional[bool] = None,) -> Union[Tensor,Tuple[Tensor, Tuple[Tensor, Tensor]],Tuple[Tensor, SparseTensor],]:r"""Runs the forward pass of the module.Args:x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input nodefeatures.edge_index (torch.Tensor or SparseTensor): The edge indices.edge_attr (torch.Tensor, optional): The edge features.(default: :obj:`None`)return_attention_weights (bool, optional): If set to :obj:`True`,will additionally return the tuple:obj:`(edge_index, attention_weights)`, holding the computedattention weights for each edge. (default: :obj:`None`)"""H, C = self.heads, self.out_channelsif isinstance(x, Tensor):x = (x, x)query = self.lin_query(x[1]).view(-1, H, C)key = self.lin_key(x[0]).view(-1, H, C)value = self.lin_value(x[0]).view(-1, H, C)# propagate_type: (query: Tensor, key:Tensor, value: Tensor,#                  edge_attr: OptTensor)out = self.propagate(edge_index, query=query, key=key, value=value,edge_attr=edge_attr)alpha = self._alphaself._alpha = Noneif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)if self.root_weight:x_r = self.lin_skip(x[1])if self.lin_beta is not None:beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))beta = beta.sigmoid()out = beta * x_r + (1 - beta) * outelse:out = out + x_rif isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return outdef message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,edge_attr: OptTensor, index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:if self.lin_edge is not None:assert edge_attr is not Noneedge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels)key_j = key_j + edge_attralpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)out = value_jif edge_attr is not None:out = out + edge_attrout = out * alpha.view(-1, self.heads, 1)return outdef __repr__(self) -> str:return (f'{self.__class__.__name__}({self.in_channels}, 'f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)# 使用模型进行前向传播
output = conv_layer(x)print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

这篇关于详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/972422

相关文章

Java操作PDF文件实现签订电子合同详细教程

《Java操作PDF文件实现签订电子合同详细教程》:本文主要介绍如何在PDF中加入电子签章与电子签名的过程,包括编写Word文件、生成PDF、为PDF格式做表单、为表单赋值、生成文档以及上传到OB... 目录前言:先看效果:1.编写word文件1.2然后生成PDF格式进行保存1.3我这里是将文件保存到本地后

windows系统下shutdown重启关机命令超详细教程

《windows系统下shutdown重启关机命令超详细教程》shutdown命令是一个强大的工具,允许你通过命令行快速完成关机、重启或注销操作,本文将为你详细解析shutdown命令的使用方法,并提... 目录一、shutdown 命令简介二、shutdown 命令的基本用法三、远程关机与重启四、实际应用

使用SpringBoot创建一个RESTful API的详细步骤

《使用SpringBoot创建一个RESTfulAPI的详细步骤》使用Java的SpringBoot创建RESTfulAPI可以满足多种开发场景,它提供了快速开发、易于配置、可扩展、可维护的优点,尤... 目录一、创建 Spring Boot 项目二、创建控制器类(Controller Class)三、运行

springboot整合gateway的详细过程

《springboot整合gateway的详细过程》本文介绍了如何配置和使用SpringCloudGateway构建一个API网关,通过实例代码介绍了springboot整合gateway的过程,需要... 目录1. 添加依赖2. 配置网关路由3. 启用Eureka客户端(可选)4. 创建主应用类5. 自定

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

最新版IDEA配置 Tomcat的详细过程

《最新版IDEA配置Tomcat的详细过程》本文介绍如何在IDEA中配置Tomcat服务器,并创建Web项目,首先检查Tomcat是否安装完成,然后在IDEA中创建Web项目并添加Web结构,接着,... 目录配置tomcat第一步,先给项目添加Web结构查看端口号配置tomcat    先检查自己的to

使用Nginx来共享文件的详细教程

《使用Nginx来共享文件的详细教程》有时我们想共享电脑上的某些文件,一个比较方便的做法是,开一个HTTP服务,指向文件所在的目录,这次我们用nginx来实现这个需求,本文将通过代码示例一步步教你使用... 在本教程中,我们将向您展示如何使用开源 Web 服务器 Nginx 设置文件共享服务器步骤 0 —

SpringBoot集成SOL链的详细过程

《SpringBoot集成SOL链的详细过程》Solanaj是一个用于与Solana区块链交互的Java库,它为Java开发者提供了一套功能丰富的API,使得在Java环境中可以轻松构建与Solana... 目录一、什么是solanaj?二、Pom依赖三、主要类3.1 RpcClient3.2 Public

手把手教你idea中创建一个javaweb(webapp)项目详细图文教程

《手把手教你idea中创建一个javaweb(webapp)项目详细图文教程》:本文主要介绍如何使用IntelliJIDEA创建一个Maven项目,并配置Tomcat服务器进行运行,过程包括创建... 1.启动idea2.创建项目模板点击项目-新建项目-选择maven,显示如下页面输入项目名称,选择

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt