本文主要是介绍ModuleNotFoundError: No module named ‘torch_scatter‘,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
大概率是cuda的版本问题,看了很多解决方案,都不舒服。
直到看到这篇
新建一个名字叫torch_scatter.py的脚本,然后就可以调用本地的脚本了。
#torch_scatter.py
import torch
from typing import Optionaldef scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,out: Optional[torch.Tensor] = None,dim_size: Optional[int] = None) -> torch.Tensor:index = broadcast(index, src, dim)if out is None:size = list(src.size())if dim_size is not None:size[dim] = dim_sizeelif index.numel() == 0:size[dim] = 0else:size[dim] = int(index.max()) + 1out = torch.zeros(size, dtype=src.dtype, device=src.device)return out.scatter_add_(dim, index, src)else:return out.scatter_add_(dim, index, src)def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,out: Optional[torch.Tensor] = None,dim_size: Optional[int] = None) -> torch.Tensor:return scatter_sum(src, index, dim, out, dim_size)def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):if dim < 0:dim = other.dim() + dimif src.dim() == 1:for _ in range(0, dim):src = src.unsqueeze(0)for _ in range(src.dim(), other.dim()):src = src.unsqueeze(-1)src = src.expand(other.size())return src
这篇关于ModuleNotFoundError: No module named ‘torch_scatter‘的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!