pytorch 计算点集内或矩阵内两两元素之间的距离矩阵

2024-02-07 14:08

本文主要是介绍pytorch 计算点集内或矩阵内两两元素之间的距离矩阵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言:时间紧可直接看4,5两条。


1. 这一功能在python内或numpy内有现成的工具包

from scipy.spatial import distance
# 以下两种方式视情况选择
scipy.spatial.distance.pdist()
scipy.spatial.distance.cdist()

在神经网络的训练过程中,应用以上工具包需要把torch.tensor转变成numpy格式再计算,存在两个缺点:一是耗时,格式变来变去,而且从GPU迁移到CPU再返回到GPU;二是会造成梯度丢失。

2. pytorch中自带pdist()函数,但是这个函数输出结果为距离向量,而不是距离矩阵。距离向量是距离矩阵中上三角的元素。

import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensord=F.pdist(a, p=2)
print(d)
"""
输出:tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
"""

3. 自定义pdist()函数计算欧氏距离,如下所示,但是该函数只能用来计算欧式距离(L2范数),而且对角线上的元素不是0,而是一个极小的数1e-4。

import torch
import torch.tensor as tensor
"""
自定义的距离矩阵函数
"""
def pdists(A, squared = False, eps = 1e-8):prod = torch.mm(A, A.t())norm = prod.diag().unsqueeze(1).expand_as(prod)res = (norm + norm.t() - 2 * prod).clamp(min = 0)if squared:return reselse:res = res.clamp(min = eps).sqrt()return res"""应用示例"""
a = tensor([[1., 2., 3.],[4., 5., 6.],[7., 8., 9.],[10., 11., 12.]])c=pdists(a, squared = False)
print(c)
"""打印结果
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],[5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],[1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],[1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
"""

4. pytorch中的torch.norm(input[:, None] - input, dim=2, p=p)函数可以实现该功能

    在torch.nn.functional.pdist的文档介绍中有这么一句话:

 简单翻译:计算输入中每​​对行向量之间的p范数距离。 这与torch.norm(input[:, None] - input, dim=2, p=p)的对角线以外的上部三角形部分相同。 如果行是连续的,此功能将更快。

这句话暗示:torch.norm函数可用于计算距离矩阵,而且可以选择L1、L2范数或者其他范数。
应用示例:

import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
"""
tensor([[0.0000, 1.7321, 3.4641, 5.1962],[1.7321, 0.0000, 1.7321, 3.4641],[3.4641, 1.7321, 0.0000, 1.7321],[5.1962, 3.4641, 1.7321, 0.0000]])
"""

对应的,可以把torch.norm封装成新的pdist函数:

import torch
import torch.tensor as tensor
"""函数封装"""
def pdist(a,dim=2, p=2):dist_matrix = torch.norm(a[:, None]-a, dim, p)return dist_matrix 

5. 自定义余弦距离矩阵

import torch
def cosinematrix(A):prod = torch.mm(A, A.t())#分子norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母cos = prod.div(torch.mm(norm.t(),norm))return cos# 使用
d_matrix=cosinematrix(inputs)

文章参考:pytorch不用for循环计算一个矩阵各行之间的L1 、L2范数距离和余弦距离_小鱼的代码世界-CSDN博客_pytorch计算距离矩阵

这篇关于pytorch 计算点集内或矩阵内两两元素之间的距离矩阵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/687969

相关文章

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

day-51 合并零之间的节点

思路 直接遍历链表即可,遇到val=0跳过,val非零则加在一起,最后返回即可 解题过程 返回链表可以有头结点,方便插入,返回head.next Code /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*

poj 1113 凸包+简单几何计算

题意: 给N个平面上的点,现在要在离点外L米处建城墙,使得城墙把所有点都包含进去且城墙的长度最短。 解析: 韬哥出的某次训练赛上A出的第一道计算几何,算是大水题吧。 用convexhull算法把凸包求出来,然后加加减减就A了。 计算见下图: 好久没玩画图了啊好开心。 代码: #include <iostream>#include <cstdio>#inclu

uva 1342 欧拉定理(计算几何模板)

题意: 给几个点,把这几个点用直线连起来,求这些直线把平面分成了几个。 解析: 欧拉定理: 顶点数 + 面数 - 边数= 2。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#inc

uva 11178 计算集合模板题

题意: 求三角形行三个角三等分点射线交出的内三角形坐标。 代码: #include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <cstring>#include <cmath>#include <stack>#include <vector>#include <

XTU 1237 计算几何

题面: Magic Triangle Problem Description: Huangriq is a respectful acmer in ACM team of XTU because he brought the best place in regional contest in history of XTU. Huangriq works in a big compa

hdu 4565 推倒公式+矩阵快速幂

题意 求下式的值: Sn=⌈ (a+b√)n⌉%m S_n = \lceil\ (a + \sqrt{b}) ^ n \rceil\% m 其中: 0<a,m<215 0< a, m < 2^{15} 0<b,n<231 0 < b, n < 2^{31} (a−1)2<b<a2 (a-1)^2< b < a^2 解析 令: An=(a+b√)n A_n = (a +

【每日一题】LeetCode 2181.合并零之间的节点(链表、模拟)

【每日一题】LeetCode 2181.合并零之间的节点(链表、模拟) 题目描述 给定一个链表,链表中的每个节点代表一个整数。链表中的整数由 0 分隔开,表示不同的区间。链表的开始和结束节点的值都为 0。任务是将每两个相邻的 0 之间的所有节点合并成一个节点,新节点的值为原区间内所有节点值的和。合并后,需要移除所有的 0,并返回修改后的链表头节点。 思路分析 初始化:创建一个虚拟头节点

hdu 6198 dfs枚举找规律+矩阵乘法

number number number Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Problem Description We define a sequence  F : ⋅   F0=0,F1=1 ; ⋅   Fn=Fn

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显