pytorch 计算点集内或矩阵内两两元素之间的距离矩阵

2024-02-07 14:08

本文主要是介绍pytorch 计算点集内或矩阵内两两元素之间的距离矩阵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言:时间紧可直接看4,5两条。


1. 这一功能在python内或numpy内有现成的工具包

from scipy.spatial import distance
# 以下两种方式视情况选择
scipy.spatial.distance.pdist()
scipy.spatial.distance.cdist()

在神经网络的训练过程中,应用以上工具包需要把torch.tensor转变成numpy格式再计算,存在两个缺点:一是耗时,格式变来变去,而且从GPU迁移到CPU再返回到GPU;二是会造成梯度丢失。

2. pytorch中自带pdist()函数,但是这个函数输出结果为距离向量,而不是距离矩阵。距离向量是距离矩阵中上三角的元素。

import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensord=F.pdist(a, p=2)
print(d)
"""
输出:tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
"""

3. 自定义pdist()函数计算欧氏距离,如下所示,但是该函数只能用来计算欧式距离(L2范数),而且对角线上的元素不是0,而是一个极小的数1e-4。

import torch
import torch.tensor as tensor
"""
自定义的距离矩阵函数
"""
def pdists(A, squared = False, eps = 1e-8):prod = torch.mm(A, A.t())norm = prod.diag().unsqueeze(1).expand_as(prod)res = (norm + norm.t() - 2 * prod).clamp(min = 0)if squared:return reselse:res = res.clamp(min = eps).sqrt()return res"""应用示例"""
a = tensor([[1., 2., 3.],[4., 5., 6.],[7., 8., 9.],[10., 11., 12.]])c=pdists(a, squared = False)
print(c)
"""打印结果
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],[5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],[1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],[1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
"""

4. pytorch中的torch.norm(input[:, None] - input, dim=2, p=p)函数可以实现该功能

    在torch.nn.functional.pdist的文档介绍中有这么一句话:

 简单翻译:计算输入中每​​对行向量之间的p范数距离。 这与torch.norm(input[:, None] - input, dim=2, p=p)的对角线以外的上部三角形部分相同。 如果行是连续的,此功能将更快。

这句话暗示:torch.norm函数可用于计算距离矩阵,而且可以选择L1、L2范数或者其他范数。
应用示例:

import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
"""
tensor([[0.0000, 1.7321, 3.4641, 5.1962],[1.7321, 0.0000, 1.7321, 3.4641],[3.4641, 1.7321, 0.0000, 1.7321],[5.1962, 3.4641, 1.7321, 0.0000]])
"""

对应的,可以把torch.norm封装成新的pdist函数:

import torch
import torch.tensor as tensor
"""函数封装"""
def pdist(a,dim=2, p=2):dist_matrix = torch.norm(a[:, None]-a, dim, p)return dist_matrix 

5. 自定义余弦距离矩阵

import torch
def cosinematrix(A):prod = torch.mm(A, A.t())#分子norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母cos = prod.div(torch.mm(norm.t(),norm))return cos# 使用
d_matrix=cosinematrix(inputs)

文章参考:pytorch不用for循环计算一个矩阵各行之间的L1 、L2范数距离和余弦距离_小鱼的代码世界-CSDN博客_pytorch计算距离矩阵

这篇关于pytorch 计算点集内或矩阵内两两元素之间的距离矩阵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/687969

相关文章

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Python如何计算两个不同类型列表的相似度

《Python如何计算两个不同类型列表的相似度》在编程中,经常需要比较两个列表的相似度,尤其是当这两个列表包含不同类型的元素时,下面小编就来讲讲如何使用Python计算两个不同类型列表的相似度吧... 目录摘要引言数字类型相似度欧几里得距离曼哈顿距离字符串类型相似度Levenshtein距离Jaccard相

java父子线程之间实现共享传递数据

《java父子线程之间实现共享传递数据》本文介绍了Java中父子线程间共享传递数据的几种方法,包括ThreadLocal变量、并发集合和内存队列或消息队列,并提醒注意并发安全问题... 目录通过 ThreadLocal 变量共享数据通过并发集合共享数据通过内存队列或消息队列共享数据注意并发安全问题总结在 J

Java文件与Base64之间的转化方式

《Java文件与Base64之间的转化方式》这篇文章介绍了如何使用Java将文件(如图片、视频)转换为Base64编码,以及如何将Base64编码转换回文件,通过提供具体的工具类实现,作者希望帮助读者... 目录Java文件与Base64之间的转化1、文件转Base64工具类2、Base64转文件工具类3、

CSS3中使用flex和grid实现等高元素布局的示例代码

《CSS3中使用flex和grid实现等高元素布局的示例代码》:本文主要介绍了使用CSS3中的Flexbox和Grid布局实现等高元素布局的方法,通过简单的两列实现、每行放置3列以及全部代码的展示,展示了这两种布局方式的实现细节和效果,详细内容请阅读本文,希望能对你有所帮助... 过往的实现方法是使用浮动加

在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码

《在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码》在MyBatis的XML映射文件中,trim元素用于动态添加SQL语句的一部分,处理前缀、后缀及多余的逗号或连接符,示... 在MyBATis的XML映射文件中,<trim>元素用于动态地添加SQL语句的一部分,例如SET或W

使用C#代码计算数学表达式实例

《使用C#代码计算数学表达式实例》这段文字主要讲述了如何使用C#语言来计算数学表达式,该程序通过使用Dictionary保存变量,定义了运算符优先级,并实现了EvaluateExpression方法来... 目录C#代码计算数学表达式该方法很长,因此我将分段描述下面的代码片段显示了下一步以下代码显示该方法如

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

day-51 合并零之间的节点

思路 直接遍历链表即可,遇到val=0跳过,val非零则加在一起,最后返回即可 解题过程 返回链表可以有头结点,方便插入,返回head.next Code /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*