解决PyG 报错 from torch_geometric.nn.pool.topk_pool import topk, filter_adj

2023-11-27 10:15

本文主要是介绍解决PyG 报错 from torch_geometric.nn.pool.topk_pool import topk, filter_adj,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

问题:

使用Pytorch 的 PyG 搭建 图神经网络 报错

can not import topk, filter_adj from torch_geometric.nn.pool.topk_pool 

解决

版本问题 语法变化
topk => SelectTopk
filter_adj => FilterEdges

from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK

发现替换后不可以
于是进去看SelectTopK\FilterEdges 源码
发现里面有 topk, filter_adj 方法 但是直接 import 也不能用
于是手动写函数出来再 layers.py 里即可运行

def topk(x: Tensor,ratio: Optional[Union[float, int]],batch: Tensor,min_score: Optional[float] = None,tol: float = 1e-7,
) -> Tensor:if min_score is not None:# Make sure that we do not drop all nodes in a graph.scores_max = scatter(x, batch, reduce='max')[batch] - tolscores_min = scores_max.clamp(max=min_score)perm = (x > scores_min).nonzero().view(-1)return permif ratio is not None:num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')if ratio >= 1:k = num_nodes.new_full((num_nodes.size(0),), int(ratio))else:k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)x, x_perm = torch.sort(x.view(-1), descending=True)batch = batch[x_perm]batch, batch_perm = torch.sort(batch, descending=False, stable=True)arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)ptr = cumsum(num_nodes)batched_arange = arange - ptr[batch]mask = batched_arange < k[batch]return x_perm[batch_perm[mask]]def filter_adj(edge_index: Tensor,edge_attr: Optional[Tensor],node_index: Tensor,cluster_index: Optional[Tensor] = None,num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:num_nodes = maybe_num_nodes(edge_index, num_nodes)if cluster_index is None:cluster_index = torch.arange(node_index.size(0),device=node_index.device)mask = node_index.new_full((num_nodes,), -1)mask[node_index] = cluster_indexrow, col = edge_index[0], edge_index[1]row, col = mask[row], mask[col]mask = (row >= 0) & (col >= 0)row, col = row[mask], col[mask]if edge_attr is not None:edge_attr = edge_attr[mask]return torch.stack([row, col], dim=0), edge_attr

参考官方文档

https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/topk_pool.html

这篇关于解决PyG 报错 from torch_geometric.nn.pool.topk_pool import topk, filter_adj的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/427496

相关文章

Spring事务中@Transactional注解不生效的原因分析与解决

《Spring事务中@Transactional注解不生效的原因分析与解决》在Spring框架中,@Transactional注解是管理数据库事务的核心方式,本文将深入分析事务自调用的底层原理,解释为... 目录1. 引言2. 事务自调用问题重现2.1 示例代码2.2 问题现象3. 为什么事务自调用会失效3

mysql出现ERROR 2003 (HY000): Can‘t connect to MySQL server on ‘localhost‘ (10061)的解决方法

《mysql出现ERROR2003(HY000):Can‘tconnecttoMySQLserveron‘localhost‘(10061)的解决方法》本文主要介绍了mysql出现... 目录前言:第一步:第二步:第三步:总结:前言:当你想通过命令窗口想打开mysql时候发现提http://www.cpp

SpringBoot启动报错的11个高频问题排查与解决终极指南

《SpringBoot启动报错的11个高频问题排查与解决终极指南》这篇文章主要为大家详细介绍了SpringBoot启动报错的11个高频问题的排查与解决,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一... 目录1. 依赖冲突:NoSuchMethodError 的终极解法2. Bean注入失败:No qu

springboot报错Invalid bound statement (not found)的解决

《springboot报错Invalidboundstatement(notfound)的解决》本文主要介绍了springboot报错Invalidboundstatement(not... 目录一. 问题描述二.解决问题三. 添加配置项 四.其他的解决方案4.1 Mapper 接口与 XML 文件不匹配

Python中ModuleNotFoundError: No module named ‘timm’的错误解决

《Python中ModuleNotFoundError:Nomodulenamed‘timm’的错误解决》本文主要介绍了Python中ModuleNotFoundError:Nomodulen... 目录一、引言二、错误原因分析三、解决办法1.安装timm模块2. 检查python环境3. 解决安装路径问题

如何解决mysql出现Incorrect string value for column ‘表项‘ at row 1错误问题

《如何解决mysql出现Incorrectstringvalueforcolumn‘表项‘atrow1错误问题》:本文主要介绍如何解决mysql出现Incorrectstringv... 目录mysql出现Incorrect string value for column ‘表项‘ at row 1错误报错

如何解决Spring MVC中响应乱码问题

《如何解决SpringMVC中响应乱码问题》:本文主要介绍如何解决SpringMVC中响应乱码问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring MVC最新响应中乱码解决方式以前的解决办法这是比较通用的一种方法总结Spring MVC最新响应中乱码解

Java报NoClassDefFoundError异常的原因及解决

《Java报NoClassDefFoundError异常的原因及解决》在Java开发过程中,java.lang.NoClassDefFoundError是一个令人头疼的运行时错误,本文将深入探讨这一问... 目录一、问题分析二、报错原因三、解决思路四、常见场景及原因五、深入解决思路六、预http://www

java常见报错及解决方案总结

《java常见报错及解决方案总结》:本文主要介绍Java编程中常见错误类型及示例,包括语法错误、空指针异常、数组下标越界、类型转换异常、文件未找到异常、除以零异常、非法线程操作异常、方法未定义异常... 目录1. 语法错误 (Syntax Errors)示例 1:解决方案:2. 空指针异常 (NullPoi

pip无法安装osgeo失败的问题解决

《pip无法安装osgeo失败的问题解决》本文主要介绍了pip无法安装osgeo失败的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 进入官方提供的扩展包下载网站寻找版本适配的whl文件注意:要选择cp(python版本)和你py