【代码解读】LLGC

2024-09-05 19:20
文章标签 代码 解读 llgc

本文主要是介绍【代码解读】LLGC,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

对象创建:

model = LLGC(description.size(1), label.max().item()+1, args.drop_out, args.use_bias).to(device)

模型使用:

output = model(train_features)

LLGC:

# Lorentzian MODEL
class LLGC(nn.Module):def __init__(self, nfeat, nclass, drop_out, use_bias):super(LLGC, self).__init__()self.drop_out = drop_outself.use_bias = use_biasself.nclass = nclassself.c = torch.tensor([1.0]).to("cuda")self.manifold = getattr(manifolds_LLGC, "Lorentzian")()#创建了manifolds中的一个lorentzian类的对象,赋值给self.manifoldself.W = LorentzLinear(self.manifold, nfeat, nclass, self.c, self.drop_out, self.use_bias)def forward(self, x, batch_size):x_loren = self.manifold.normalize_input(x, self.c)#normalize_input操作内部带有对数映射,self.c为曲率。x_loren为对数映射后的结果x_w = self.W(x_loren)x_tan = self.manifold.log_map_zero(x_w, self.c)return x_tan[:batch_size]
  1. 欧式空间中的点到流形的映射
  2. 计算
  3. 流形映射到欧式空间

1. 欧式空间中的点到流形的映射

目标是使用self.manifold.normalize_input将欧式空间的目标特征x映射到流形上,返回x_loren

x_loren = self.manifold.normalize_input(x, self.c)

创建一个全零张量。
将全零张量和输入张量拼接,增加一个额外的维度。
调用exp_map_zero

    def normalize_input(self, x, c):# print('=====normalize original input===========')num_nodes = x.size(0)zeros = torch.zeros(num_nodes, 1, dtype=x.dtype, device=x.device)x_tan = torch.cat((zeros, x), dim=1)return self.exp_map_zero(x_tan, c)

创建流形上的基点

    def exp_map_zero(self, dp, c, is_res_normalize=True, is_dp_normalize=True):zeros = torch.zeros_like(dp)zeros[:, 0] = c ** 0.5return self.exp_map_x(zeros, dp, c, is_res_normalize, is_dp_normalize)

exp_map_x 方法通过指数映射将一个点 p 和切向量 dp 映射回洛伦兹流形上。
首先规范化切向量 dp。
然后计算其洛伦兹范数。
接着通过指数映射公式将其映射到流形上,并可选地对结果进行规范化。

    def exp_map_x(self, p, dp, c, is_res_normalize=True, is_dp_normalize=True):if is_dp_normalize:dp = self.normalize_tangent(p, dp, c)dp_lnorm = self.l_inner(dp, dp, keep_dim=True)dp_lnorm = torch.sqrt(torch.clamp(dp_lnorm + self.eps[p.dtype], 1e-6))dp_lnorm_cut = torch.clamp(dp_lnorm, max=50)sqrt_c = c ** 0.5res = (torch.cosh(dp_lnorm_cut / sqrt_c) * p) + sqrt_c * (torch.sinh(dp_lnorm_cut / sqrt_c) * dp / dp_lnorm)if is_res_normalize:res = self.normalize(res, c)return res

normalize_tangent 方法的目的是规范化洛伦兹流形上切向量,使其满足洛伦兹内积 <p, p_tan>_L = 0

    def normalize_tangent(self, p, p_tan, c):"""Normalize tangent vectors to place the vectors satisfies <p, p_tan>_L=0:param p: the tangent spaces at p. size:[nodes, feature]:param p_tan: the tangent vector in tangent space at p"""d = p_tan.size(1) - 1p_tail = p.narrow(1, 1, d)p_tan_tail = p_tan.narrow(1, 1, d)ptpt = torch.sum(p_tail * p_tan_tail, dim=1, keepdim=True)p_head = torch.sqrt(c + torch.sum(torch.pow(p_tail, 2), dim=1, keepdim=True) + self.eps[p.dtype])return torch.cat((ptpt / p_head, p_tan_tail), dim=1)

计算内积

    def l_inner(self, x, y, keep_dim=False):# input shape [node, features]d = x.size(-1) - 1xy = x * yxy = torch.cat((-xy.narrow(1, 0, 1), xy.narrow(1, 1, d)), dim=1)return torch.sum(xy, dim=1, keepdim=keep_dim)

目的是将一个向量 p 规范化,以确保它位于双曲面上。
这个过程可以理解为确保该向量符合双曲空间的几何结构。

    def normalize(self, p, c):"""Normalize vector to confirm it is located on the hyperboloid:param p: [nodes, features(d + 1)]:param c: parameter of curvature"""d = p.size(-1) - 1narrowed = p.narrow(-1, 1, d)if self.max_norm:narrowed = torch.renorm(narrowed.view(-1, d), 2, 0, self.max_norm)first = c + torch.sum(torch.pow(narrowed, 2), dim=-1, keepdim=True)first = torch.sqrt(first)return torch.cat((first, narrowed), dim=1)

2. 计算

x_w = self.W(x_loren)

其中LorentzLinear的类定义如下:

class LorentzLinear(nn.Module):# Lorentz Hyperbolic Graph Neural Layerdef __init__(self, manifold, in_features, out_features, c, drop_out, use_bias):super(LorentzLinear, self).__init__()# print("LorentzLinear")self.manifold = manifoldself.in_features = in_featuresself.out_features = out_featuresself.c = cself.drop_out = drop_outself.use_bias = use_biasself.bias = nn.Parameter(torch.Tensor(out_features-1))   # -1 when use mine mat-vec multiplyself.weight = nn.Parameter(torch.Tensor(out_features - 1, in_features))  # -1, 0 when use mine mat-vec multiplyself.reset_parameters()def reset_parameters(self):init.xavier_uniform_(self.weight, gain=math.sqrt(2))init.constant_(self.bias, 0)def forward(self, x):drop_weight = F.dropout(self.weight, self.drop_out, training=self.training)mv = self.manifold.matvec_regular(drop_weight, x, self.bias, self.c, self.use_bias)return mv

dropout的输入可以是特征,也可以是权值矩阵。
总归返回的是,以概率p随机给元素置零之后的输入。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


对输入执行对数映射。
分割映射结果。
矩阵乘法(将 x_tail 和权重矩阵 m 进行矩阵乘法。注意,这里对 m 执行了转置操作,以确保维度匹配)。
拼接结果,恢复到原有的维度。
首先执行 normalize_tangent_zero,将数据归一化到洛伦兹流形的切空间。再通过 执行指数映射,将数据映射回洛伦兹流形。
检查 mx 中的元素是否为零,用零替换掉 mx 中满足某个条件的部分。

    def matvec_regular(self, m, x, b, c, use_bias):d = x.size(1) - 1x_tan = self.log_map_zero(x, c)x_head = x_tan.narrow(1, 0, 1)x_tail = x_tan.narrow(1, 1, d)mx = x_tail @ m.transpose(-1, -2)if use_bias:mx_b = mx + belse:mx_b = mxmx = torch.cat((x_head, mx_b), dim=1)mx = self.normalize_tangent_zero(mx, c)mx = self.exp_map_zero(mx, c)cond = (mx==0).prod(-1, keepdim=True, dtype=torch.uint8)res = torch.zeros(1, dtype=mx.dtype, device=mx.device)res = torch.where(cond, res, mx)return res
    def log_map_zero(self, y, c, is_tan_normalize=True):zeros = torch.zeros_like(y)zeros[:, 0] = c ** 0.5return self.log_map_x(zeros, y, c, is_tan_normalize)

对数映射的作用是将洛伦兹流形上的点投影到某个点 x 的切空间中(即欧几里得空间)。
通过内积调整 y,得到一个新的向量 tmp_vector。
计算 tmp_vector 的范数。
计算切向量 y_tan。
如果 is_tan_normalize 为真,则对计算得到的切向量 y_tan 进行归一化处理,确保它满足洛伦兹切空间的约束。

    def log_map_x(self, x, y, c, is_tan_normalize=True):"""Logarithmic map at x: project hyperboloid vectors to a tangent space at x:param x: vector on hyperboloid:param y: vector to project a tangent space at x:param normalize: whether normalize the y_tangent:return: y_tangent"""xy_distance = self.induced_distance(x, y, c)tmp_vector = y + self.l_inner(x, y, keep_dim=True) / c * xtmp_norm = torch.sqrt(self.l_inner(tmp_vector, tmp_vector) + self.eps[x.dtype])y_tan = xy_distance.unsqueeze(-1) / tmp_norm.unsqueeze(-1) * tmp_vectorif is_tan_normalize:y_tan = self.normalize_tangent(x, y_tan, c)return y_tan

这里通过 induced_distance 方法计算向量 x 和 y 在洛伦兹流形上的距离,这实际上是两点在洛伦兹空间的测地线距离。

    def induced_distance(self, x, y, c):xy_inner = self.l_inner(x, y)sqrt_c = c ** 0.5return sqrt_c * arcosh(-xy_inner / c + self.eps[x.dtype])

3. 流形映射到欧式空间

x_tan = self.manifold.log_map_zero(x_w, self.c)

这篇关于【代码解读】LLGC的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1139842

相关文章

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

Java设计模式---迭代器模式(Iterator)解读

《Java设计模式---迭代器模式(Iterator)解读》:本文主要介绍Java设计模式---迭代器模式(Iterator),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录1、迭代器(Iterator)1.1、结构1.2、常用方法1.3、本质1、解耦集合与遍历逻辑2、统一

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

MySQL之InnoDB存储页的独立表空间解读

《MySQL之InnoDB存储页的独立表空间解读》:本文主要介绍MySQL之InnoDB存储页的独立表空间,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、独立表空间【1】表空间大小【2】区【3】组【4】段【5】区的类型【6】XDES Entry区结构【

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Java实现自定义table宽高的示例代码

《Java实现自定义table宽高的示例代码》在桌面应用、管理系统乃至报表工具中,表格(JTable)作为最常用的数据展示组件,不仅承载对数据的增删改查,还需要配合布局与视觉需求,而JavaSwing... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

Go语言代码格式化的技巧分享

《Go语言代码格式化的技巧分享》在Go语言的开发过程中,代码格式化是一个看似细微却至关重要的环节,良好的代码格式化不仅能提升代码的可读性,还能促进团队协作,减少因代码风格差异引发的问题,Go在代码格式... 目录一、Go 语言代码格式化的重要性二、Go 语言代码格式化工具:gofmt 与 go fmt(一)

MySQL主从复制与读写分离的用法解读

《MySQL主从复制与读写分离的用法解读》:本文主要介绍MySQL主从复制与读写分离的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、主从复制mysql主从复制原理实验案例二、读写分离实验案例安装并配置mycat 软件设置mycat读写分离验证mycat读

Python的端到端测试框架SeleniumBase使用解读

《Python的端到端测试框架SeleniumBase使用解读》:本文主要介绍Python的端到端测试框架SeleniumBase使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录SeleniumBase详细介绍及用法指南什么是 SeleniumBase?SeleniumBase