详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

2024-05-09 04:44

本文主要是介绍详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Unionimport torch
import torch.nn.functional as F
from torch import Tensorfrom torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (Adj,NoneType,OptTensor,PairTensor,SparseTensor,
)
from torch_geometric.utils import softmaxif typing.TYPE_CHECKING:from typing import overload
else:from torch.jit import _overload_method as overload[docs]class TransformerConv(MessagePassing):r"""The graph transformer operator from the `"Masked Label Prediction:Unified Message Passing Model for Semi-Supervised Classification"<https://arxiv.org/abs/2009.03509>`_ paper... math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},where the attention coefficients :math:`\alpha_{i,j}` are computed viamulti-head dot product attention:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}{\sqrt{d}} \right)Args:in_channels (int or tuple): Size of each input sample, or :obj:`-1` toderive the size from the first input(s) to the forward method.A tuple corresponds to the sizes of source and targetdimensionalities.out_channels (int): Size of each output sample.heads (int, optional): Number of multi-head-attentions.(default: :obj:`1`)concat (bool, optional): If set to :obj:`False`, the multi-headattentions are averaged instead of concatenated.(default: :obj:`True`)beta (bool, optional): If set, will combine aggregation andskip information via.. math::\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)dropout (float, optional): Dropout probability of the normalizedattention coefficients which exposes each node to a stochasticallysampled neighborhood during training. (default: :obj:`0`)edge_dim (int, optional): Edge feature dimensionality (in casethere are any). Edge features are added to the keys afterlinear transformation, that is, prior to computing theattention dot product. They are also added to final valuesafter the same linear transformation. The model is:.. math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}\right),where the attention coefficients :math:`\alpha_{i,j}` are nowcomputed via:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}{\sqrt{d}} \right)(default :obj:`None`)bias (bool, optional): If set to :obj:`False`, the layer will not learnan additive bias. (default: :obj:`True`)root_weight (bool, optional): If set to :obj:`False`, the layer willnot add the transformed root node features to the output and theoption  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)**kwargs (optional): Additional arguments of:class:`torch_geometric.nn.conv.MessagePassing`."""_alpha: OptTensordef __init__(self,in_channels: Union[int, Tuple[int, int]],out_channels: int,heads: int = 1,concat: bool = True,beta: bool = False,dropout: float = 0.,edge_dim: Optional[int] = None,bias: bool = True,root_weight: bool = True,**kwargs,):kwargs.setdefault('aggr', 'add')super().__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.beta = beta and root_weightself.root_weight = root_weightself.concat = concatself.dropout = dropoutself.edge_dim = edge_dimself._alpha = Noneif isinstance(in_channels, int):in_channels = (in_channels, in_channels)self.lin_key = Linear(in_channels[0], heads * out_channels)self.lin_query = Linear(in_channels[1], heads * out_channels)self.lin_value = Linear(in_channels[0], heads * out_channels)if edge_dim is not None:self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)else:self.lin_edge = self.register_parameter('lin_edge', None)if concat:self.lin_skip = Linear(in_channels[1], heads * out_channels,bias=bias)if self.beta:self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)else:self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)if self.beta:self.lin_beta = Linear(3 * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)self.reset_parameters()[docs]    def reset_parameters(self):super().reset_parameters()self.lin_key.reset_parameters()self.lin_query.reset_parameters()self.lin_value.reset_parameters()if self.edge_dim:self.lin_edge.reset_parameters()self.lin_skip.reset_parameters()if self.beta:self.lin_beta.reset_parameters()@overloaddef forward(self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: NoneType = None,) -> Tensor:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Tensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: SparseTensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, SparseTensor]:pass[docs]    def forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: Optional[bool] = None,) -> Union[Tensor,Tuple[Tensor, Tuple[Tensor, Tensor]],Tuple[Tensor, SparseTensor],]:r"""Runs the forward pass of the module.Args:x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input nodefeatures.edge_index (torch.Tensor or SparseTensor): The edge indices.edge_attr (torch.Tensor, optional): The edge features.(default: :obj:`None`)return_attention_weights (bool, optional): If set to :obj:`True`,will additionally return the tuple:obj:`(edge_index, attention_weights)`, holding the computedattention weights for each edge. (default: :obj:`None`)"""H, C = self.heads, self.out_channelsif isinstance(x, Tensor):x = (x, x)query = self.lin_query(x[1]).view(-1, H, C)key = self.lin_key(x[0]).view(-1, H, C)value = self.lin_value(x[0]).view(-1, H, C)# propagate_type: (query: Tensor, key:Tensor, value: Tensor,#                  edge_attr: OptTensor)out = self.propagate(edge_index, query=query, key=key, value=value,edge_attr=edge_attr)alpha = self._alphaself._alpha = Noneif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)if self.root_weight:x_r = self.lin_skip(x[1])if self.lin_beta is not None:beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))beta = beta.sigmoid()out = beta * x_r + (1 - beta) * outelse:out = out + x_rif isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return outdef message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,edge_attr: OptTensor, index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:if self.lin_edge is not None:assert edge_attr is not Noneedge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels)key_j = key_j + edge_attralpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)out = value_jif edge_attr is not None:out = out + edge_attrout = out * alpha.view(-1, self.heads, 1)return outdef __repr__(self) -> str:return (f'{self.__class__.__name__}({self.in_channels}, 'f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)# 使用模型进行前向传播
output = conv_layer(x)print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

这篇关于详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/972422

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

Spring AI集成DeepSeek的详细步骤

《SpringAI集成DeepSeek的详细步骤》DeepSeek作为一款卓越的国产AI模型,越来越多的公司考虑在自己的应用中集成,对于Java应用来说,我们可以借助SpringAI集成DeepSe... 目录DeepSeek 介绍Spring AI 是什么?1、环境准备2、构建项目2.1、pom依赖2.2

Goland debug失效详细解决步骤(合集)

《Golanddebug失效详细解决步骤(合集)》今天用Goland开发时,打断点,以debug方式运行,发现程序并没有断住,程序跳过了断点,直接运行结束,网上搜寻了大量文章,最后得以解决,特此在这... 目录Bug:Goland debug失效详细解决步骤【合集】情况一:Go或Goland架构不对情况二:

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Android里面的Service种类以及启动方式

《Android里面的Service种类以及启动方式》Android中的Service分为前台服务和后台服务,前台服务需要亮身份牌并显示通知,后台服务则有启动方式选择,包括startService和b... 目录一句话总结:一、Service 的两种类型:1. 前台服务(必须亮身份牌)2. 后台服务(偷偷干

Spring Boot整合log4j2日志配置的详细教程

《SpringBoot整合log4j2日志配置的详细教程》:本文主要介绍SpringBoot项目中整合Log4j2日志框架的步骤和配置,包括常用日志框架的比较、配置参数介绍、Log4j2配置详解... 目录前言一、常用日志框架二、配置参数介绍1. 日志级别2. 输出形式3. 日志格式3.1 PatternL

Springboot 中使用Sentinel的详细步骤

《Springboot中使用Sentinel的详细步骤》文章介绍了如何在SpringBoot中使用Sentinel进行限流和熔断降级,首先添加依赖,配置Sentinel控制台地址,定义受保护的资源,... 目录步骤 1: 添加 Sentinel 依赖步骤 2: 配置 Sentinel步骤 3: 定义受保护的

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意