Triplet Loss三元组损失函数

2023-12-22 05:01

本文主要是介绍Triplet Loss三元组损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基础知识

三元组损失(Triplet Loss)是一种用于学习深度神经网络嵌入的损失函数,它的主要目标是确保在我们的嵌入空间中,来自相同类别的样本更接近彼此,而不同类别的样本更远离彼此。三元组损失(Triplet Loss)常在人脸识别、图像检索等需要计算相似度的任务中使用

三元组损失需要三个样本来计算损失,这三个样本被称为锚(Anchor)、正(Positive)和负(Negative)样本。其中,锚样本是我们关注的样本,正样本与锚样本具有相同的类别标签,负样本与锚样本具有不同的类别标签。
假设我们已经通过神经网络得到了这三个样本在嵌入空间的位置,分别是 A(锚样本),P(正样本)和 N(负样本)。则三元组损失函数的形式为:
L = max(d(A, P) - d(A, N) + margin, 0)
其中,d(A, P) 和 d(A, N) 分别是锚样本与正样本,锚样本与负样本在嵌入空间的距离,"margin"是一个预设定的阈值,用于控制正样本与负样本之间的差异,我们希望锚样本比与负样本的距离比至少比与正样本的距离大。
例如:
我们有三个样本锚样本A, 正样本P, 负样本N。它们分别被一个神经网络映射到一个三维空间,得到的嵌入向量是:
A = [1, 1, 1]P = [1.1, 1.1, 1.1]
N = [2, 2, 2]
我们可以看到,正样本P比锚样本A更接近,而负样本N则比正样本P和锚样本A更远,这就是我们希望的结果。但如果网络没有很好的训练,可能会得到违背这一原则的嵌入,例如负样本N离锚样本A更近,那么这就需要三元组损失来调整网络的权重,使得同类样本更接近,不同类样本更远离。

Triplet Loss三元组损失函数 在模型训练中,batchsize不能设置太小:

  • 多样性:在一个Batch中,我们需要包含足够多的类别,以便从中选择出质量较好的三元组。如果Batch太小,可能只包含少量的类别,这将限制我们选择三元组的可能性。
  • 稳定性:较大的batch size可以使网络的训练更稳定。每个batch的梯度计算都是对全局梯度的一个估计,batch size越大,这个估计的准确性就越高,训练过程也就越稳定。

代码讲解

Triplet Loss三元组损失函数如下:

def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).Related Triplet Loss theory can be found in paper 'In Defense of the TripletLoss for Person Re-Identification'."""if norm_feat:dist_mat = cosine_dist(embedding, embedding)else:dist_mat = euclidean_dist(embedding, embedding)# For distributed training, gather all features from different process.# if comm.get_world_size() > 1:#     all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)#     all_targets = concat_all_gather(targets)# else:#     all_embedding = embedding#     all_targets = targets# 获取相似度矩阵dist_mat的行数,即样本数量N = dist_mat.size(0)# 创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()if hard_mining:dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)else:dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)y = dist_an.new().resize_as_(dist_an).fill_(1)if margin > 0:loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)else:loss = F.soft_margin_loss(dist_an - dist_ap, y)# fmt: offif loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)# fmt: onreturn loss

对上面代码进行解析:

定义函数

def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):

定义了一个名为triplet_loss的函数,输入参数为embedding(嵌入特征)、targets(目标标签)、margin(用于增加正负样本之间间距的值)、norm_feat(决定是否对特征进行归一化)以及hard_mining(决定是否启动困难样本挖掘)。

数据归一化处理

    if norm_feat:dist_mat = cosine_dist(embedding, embedding)else:dist_mat = euclidean_dist(embedding, embedding)

判断是否对特征进行归一化,若决定归一化,就用余弦距离度量相似度;若不归一化,则用欧氏距离度量相似度。

cosine_dist(embedding, embedding)是将embedding中的每一个向量与embedding中的每一个向量都计算一遍余弦距离。

  • 举一个简单的例子:
假设你的embedding是一个(3, 2)的张量,内容如下:
[[a1, a2],[b1, b2],[c1, c2]]
其中,[a1, a2],[b1, b2]和[c1, c2]是这个embedding中的3个向量。
当你执行cosine_dist(embedding, embedding)时,实际上计算的是:
[[cosine_dist([a1, a2], [a1, a2]), cosine_dist([a1, a2], [b1, b2]), cosine_dist([a1, a2], [c1, c2])],[cosine_dist([b1, b2], [a1, a2]), cosine_dist([b1, b2], [b1, b2]), cosine_dist([b1, b2], [c1, c2])],[cosine_dist([c1, c2], [a1, a2]), cosine_dist([c1, c2], [b1, b2]), cosine_dist([c1, c2], [c1, c2])]]
这个结果是一个(3, 3)的矩阵,表示embedding中的每一个向量与embedding中的每一个向量之间的余弦距离。
当if norm_feat:这个条件语句为真时,即当我们想对embedding进行归一化处理时,就会使用这种方法计算embedding中所有向量之间的余弦距离。

矩阵is_pos和is_neg构建

    N = dist_mat.size(0)is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()

创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)。

  • 假设我们有4个样本,它们的类标签targets是[1, 2, 1, 2],矩阵的行和列分别代表样本的索引,而值则表示相对应的两个样本是否属于同一类别(is_pos)或不同类别(is_neg)。
targets.view(N, 1).expand(N, N),得到的结果是:
1 1 1 1
2 2 2 2
1 1 1 1
2 2 2 2
执行targets.view(N, 1).expand(N, N).t(),得到的结果是:
1 2 1 2
1 2 1 2
1 2 1 2
1 2 1 2
当我们用eq()去判断两个矩阵对应位置是否相等时,得到的结果(is_pos)是:
1 0 1 0
0 1 0 1
1 0 1 0
0 1 0 1
对应位置用ne()去判断是否不相等,得到的结果(is_neg)是:
0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0

样本挖掘

if hard_mining:dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)

根据是否进行困难样本挖掘,采用不同的挖掘方法获取到每个样本对的距离。

# 对于每个锚点样本,找到最难正样本(最远的具有相同类别标签的样本)和最难负样本(最近的具有不同类别标签的样本)。
def hard_example_mining(dist_mat, is_pos, is_neg):"""For each anchor, find the hardest positive and negative sample.Args:dist_mat: pair wise distance between samples, shape [N, M]is_pos: positive index with shape [N, M]is_neg: negative index with shape [N, M]Returns:dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]p_inds: pytorch LongTensor, with shape [N];indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1n_inds: pytorch LongTensor, with shape [N];indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1NOTE: Only consider the case in which all labels have same num of samples,thus we can cope with all anchors in parallel."""assert len(dist_mat.size()) == 2# `dist_ap` means distance(anchor, positive)# both `dist_ap` and `relative_p_inds` with shape [N]# dist_ap表示锚点样本与正样本之间的距离。通过在距离矩阵和正样本矩阵做逐元素相乘后,取每行(每个锚点)的最大值。dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)# `dist_an` means distance(anchor, negative)# both `dist_an` and `relative_n_inds` with shape [N]# dist_an表示锚点样本与负样本之间的距离。首先,通过在距离矩阵和负样本矩阵做逐元素相乘后,再将正样本矩阵与大数(1e9)相乘并加到上述结果上,旨在将负样本对里的正样本对的距离设置地非常大。之后取每行的最小值,找出与锚点样本最近且类别不同的样本。dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)return dist_ap, dist_andef weighted_example_mining(dist_mat, is_pos, is_neg):"""For each anchor, find the weighted positive and negative sample.Args:dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]is_pos:is_neg:Returns:dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]"""assert len(dist_mat.size()) == 2is_pos = is_posis_neg = is_neg# 对于每个锚点样本,找到正样本和负样本的加权距离dist_ap = dist_mat * is_posdist_an = dist_mat * is_neg# 分别通过softmax函数计算正样本和负样本的权重,注意负样本在计算权重之前要取负数。weights_ap = softmax_weights(dist_ap, is_pos)weights_an = softmax_weights(-dist_an, is_neg)# 计算的是加权距离,将距离与对应的权重相乘,然后对结果进行累加求和,得到最后的加权距离。dist_ap = torch.sum(dist_ap * weights_ap, dim=1)dist_an = torch.sum(dist_an * weights_an, dim=1)return dist_ap, dist_an

loss计算

y = dist_an.new().resize_as_(dist_an).fill_(1)

创建一个和dist_an相同大小并内容全部为1的向量。

  • y在F.margin_ranking_loss函数中起到了标记的作用,决定了两个输入之间期望的相对大小和顺序。当我们设置y为1时,表示我们期望dist_an(锚点到负样本的距离)大于dist_ap(锚点到正样本的距离)。这也符合我们在训练过程中的期望:即我们希望模型将锚点与其类别内(正样本)的距离保持小,将其与其他类别(负样本)的距离保持大。
  • 如果y不设置为1,而是设置为-1,那么其含义将完全颠倒,此时,我们期望dist_an(锚点到负样本的距离)小于dist_ap(锚点到正样本的距离)。这显然违背了我们在进行特征学习时的初衷,无法良好地反映出同类间的聚合性和异类间的分离性。
    if margin > 0:loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)else:loss = F.soft_margin_loss(dist_an - dist_ap, y)# fmt: offif loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)# fmt: on

计算最终的三元组损失:

  • 如果margin值大于0,那就使用margin ranking loss。这将试图确保正样本对的距离比负样本对的距离小于margin;
  • 如果margin值不大于0,那就使用soft margin loss,它是margin ranking loss的一个变体,其中margin被设置为0,并在损失函数中引入了一个logistic损失。在计算soft margin loss之后,如果得到的loss值为infinity,则将margin值手动设置为0.3,再次使用margin ranking loss计算损失。

F.margin_ranking_loss函数是用来实现三元组损失的一个实用方法,它接受两组数据和一个目标向量作为输入来计算定制的秩序损失。

dist_ap代表锚点和正样本之间的距离,最大距离;dist_an代表锚点和负样本之间的距离,最小距离。y是目标向量,经常被设置为1,表示我们希望dist_an(锚点和负样本之间的距离)比dist_ap(锚点和正样本之间的距离)大。margin是我们希望两者之间的最小差距。

这篇关于Triplet Loss三元组损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/522682

相关文章

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

函数式编程思想

我们经常会用到各种各样的编程思想,例如面向过程、面向对象。不过笔者在该博客简单介绍一下函数式编程思想. 如果对函数式编程思想进行概括,就是f(x) = na(x) , y=uf(x)…至于其他的编程思想,可能是y=a(x)+b(x)+c(x)…,也有可能是y=f(x)=f(x)/a + f(x)/b+f(x)/c… 面向过程的指令式编程 面向过程,简单理解就是y=a(x)+b(x)+c(x)

利用matlab bar函数绘制较为复杂的柱状图,并在图中进行适当标注

示例代码和结果如下:小疑问:如何自动选择合适的坐标位置对柱状图的数值大小进行标注?😂 clear; close all;x = 1:3;aa=[28.6321521955954 26.2453660695847 21.69102348512086.93747104431360 6.25442246899816 3.342835958564245.51365061796319 4.87

OpenCV结构分析与形状描述符(11)椭圆拟合函数fitEllipse()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C++11 算法描述 围绕一组2D点拟合一个椭圆。 该函数计算出一个椭圆,该椭圆在最小二乘意义上最好地拟合一组2D点。它返回一个内切椭圆的旋转矩形。使用了由[90]描述的第一个算法。开发者应该注意,由于数据点靠近包含的 Mat 元素的边界,返回的椭圆/旋转矩形数据

Unity3D 运动之Move函数和translate

CharacterController.Move 移动 function Move (motion : Vector3) : CollisionFlags Description描述 A more complex move function taking absolute movement deltas. 一个更加复杂的运动函数,每次都绝对运动。 Attempts to

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

✨机器学习笔记(二)—— 线性回归、代价函数、梯度下降

1️⃣线性回归(linear regression) f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b​(x)=wx+b 🎈A linear regression model predicting house prices: 如图是机器学习通过监督学习运用线性回归模型来预测房价的例子,当房屋大小为1250 f e e t 2 feet^

JavaSE(十三)——函数式编程(Lambda表达式、方法引用、Stream流)

函数式编程 函数式编程 是 Java 8 引入的一个重要特性,它允许开发者以函数作为一等公民(first-class citizens)的方式编程,即函数可以作为参数传递给其他函数,也可以作为返回值。 这极大地提高了代码的可读性、可维护性和复用性。函数式编程的核心概念包括高阶函数、Lambda 表达式、函数式接口、流(Streams)和 Optional 类等。 函数式编程的核心是Lambda

PHP APC缓存函数使用教程

APC,全称是Alternative PHP Cache,官方翻译叫”可选PHP缓存”。它为我们提供了缓存和优化PHP的中间代码的框架。 APC的缓存分两部分:系统缓存和用户数据缓存。(Linux APC扩展安装) 系统缓存 它是指APC把PHP文件源码的编译结果缓存起来,然后在每次调用时先对比时间标记。如果未过期,则使用缓存的中间代码运行。默认缓存 3600s(一小时)。但是这样仍会浪费大量C