本文主要是介绍Triplet Loss三元组损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
基础知识
三元组损失(Triplet Loss)是一种用于学习深度神经网络嵌入的损失函数,它的主要目标是确保在我们的嵌入空间中,来自相同类别的样本更接近彼此,而不同类别的样本更远离彼此。三元组损失(Triplet Loss)常在人脸识别、图像检索等需要计算相似度的任务中使用
三元组损失需要三个样本来计算损失,这三个样本被称为锚(Anchor)、正(Positive)和负(Negative)样本。其中,锚样本是我们关注的样本,正样本与锚样本具有相同的类别标签,负样本与锚样本具有不同的类别标签。
假设我们已经通过神经网络得到了这三个样本在嵌入空间的位置,分别是 A(锚样本),P(正样本)和 N(负样本)。则三元组损失函数的形式为:
L = max(d(A, P) - d(A, N) + margin, 0)
其中,d(A, P) 和 d(A, N) 分别是锚样本与正样本,锚样本与负样本在嵌入空间的距离,"margin"是一个预设定的阈值,用于控制正样本与负样本之间的差异,我们希望锚样本比与负样本的距离比至少比与正样本的距离大。
例如:
我们有三个样本锚样本A, 正样本P, 负样本N。它们分别被一个神经网络映射到一个三维空间,得到的嵌入向量是:
A = [1, 1, 1]P = [1.1, 1.1, 1.1]
N = [2, 2, 2]
我们可以看到,正样本P比锚样本A更接近,而负样本N则比正样本P和锚样本A更远,这就是我们希望的结果。但如果网络没有很好的训练,可能会得到违背这一原则的嵌入,例如负样本N离锚样本A更近,那么这就需要三元组损失来调整网络的权重,使得同类样本更接近,不同类样本更远离。
Triplet Loss三元组损失函数 在模型训练中,batchsize不能设置太小:
- 多样性:在一个Batch中,我们需要包含足够多的类别,以便从中选择出质量较好的三元组。如果Batch太小,可能只包含少量的类别,这将限制我们选择三元组的可能性。
- 稳定性:较大的batch size可以使网络的训练更稳定。每个batch的梯度计算都是对全局梯度的一个估计,batch size越大,这个估计的准确性就越高,训练过程也就越稳定。
代码讲解
Triplet Loss三元组损失函数如下:
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).Related Triplet Loss theory can be found in paper 'In Defense of the TripletLoss for Person Re-Identification'."""if norm_feat:dist_mat = cosine_dist(embedding, embedding)else:dist_mat = euclidean_dist(embedding, embedding)# For distributed training, gather all features from different process.# if comm.get_world_size() > 1:# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)# all_targets = concat_all_gather(targets)# else:# all_embedding = embedding# all_targets = targets# 获取相似度矩阵dist_mat的行数,即样本数量N = dist_mat.size(0)# 创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()if hard_mining:dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)else:dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)y = dist_an.new().resize_as_(dist_an).fill_(1)if margin > 0:loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)else:loss = F.soft_margin_loss(dist_an - dist_ap, y)# fmt: offif loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)# fmt: onreturn loss
对上面代码进行解析:
定义函数
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
定义了一个名为triplet_loss的函数,输入参数为embedding(嵌入特征)、targets(目标标签)、margin(用于增加正负样本之间间距的值)、norm_feat(决定是否对特征进行归一化)以及hard_mining(决定是否启动困难样本挖掘)。
数据归一化处理
if norm_feat:dist_mat = cosine_dist(embedding, embedding)else:dist_mat = euclidean_dist(embedding, embedding)
判断是否对特征进行归一化,若决定归一化,就用余弦距离度量相似度;若不归一化,则用欧氏距离度量相似度。
cosine_dist(embedding, embedding)是将embedding中的每一个向量与embedding中的每一个向量都计算一遍余弦距离。
- 举一个简单的例子:
假设你的embedding是一个(3, 2)的张量,内容如下:
[[a1, a2],[b1, b2],[c1, c2]]
其中,[a1, a2],[b1, b2]和[c1, c2]是这个embedding中的3个向量。
当你执行cosine_dist(embedding, embedding)时,实际上计算的是:
[[cosine_dist([a1, a2], [a1, a2]), cosine_dist([a1, a2], [b1, b2]), cosine_dist([a1, a2], [c1, c2])],[cosine_dist([b1, b2], [a1, a2]), cosine_dist([b1, b2], [b1, b2]), cosine_dist([b1, b2], [c1, c2])],[cosine_dist([c1, c2], [a1, a2]), cosine_dist([c1, c2], [b1, b2]), cosine_dist([c1, c2], [c1, c2])]]
这个结果是一个(3, 3)的矩阵,表示embedding中的每一个向量与embedding中的每一个向量之间的余弦距离。
当if norm_feat:这个条件语句为真时,即当我们想对embedding进行归一化处理时,就会使用这种方法计算embedding中所有向量之间的余弦距离。
矩阵is_pos和is_neg构建
N = dist_mat.size(0)is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
创建两个相同大小的矩阵is_pos和is_neg,分别存储样本之间是否属于相同类别(正样本对)及不同类别(负样本对)。
- 假设我们有4个样本,它们的类标签targets是[1, 2, 1, 2],矩阵的行和列分别代表样本的索引,而值则表示相对应的两个样本是否属于同一类别(is_pos)或不同类别(is_neg)。
targets.view(N, 1).expand(N, N),得到的结果是:
1 1 1 1
2 2 2 2
1 1 1 1
2 2 2 2
执行targets.view(N, 1).expand(N, N).t(),得到的结果是:
1 2 1 2
1 2 1 2
1 2 1 2
1 2 1 2
当我们用eq()去判断两个矩阵对应位置是否相等时,得到的结果(is_pos)是:
1 0 1 0
0 1 0 1
1 0 1 0
0 1 0 1
对应位置用ne()去判断是否不相等,得到的结果(is_neg)是:
0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0
样本挖掘
if hard_mining:dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
根据是否进行困难样本挖掘,采用不同的挖掘方法获取到每个样本对的距离。
# 对于每个锚点样本,找到最难正样本(最远的具有相同类别标签的样本)和最难负样本(最近的具有不同类别标签的样本)。
def hard_example_mining(dist_mat, is_pos, is_neg):"""For each anchor, find the hardest positive and negative sample.Args:dist_mat: pair wise distance between samples, shape [N, M]is_pos: positive index with shape [N, M]is_neg: negative index with shape [N, M]Returns:dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]p_inds: pytorch LongTensor, with shape [N];indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1n_inds: pytorch LongTensor, with shape [N];indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1NOTE: Only consider the case in which all labels have same num of samples,thus we can cope with all anchors in parallel."""assert len(dist_mat.size()) == 2# `dist_ap` means distance(anchor, positive)# both `dist_ap` and `relative_p_inds` with shape [N]# dist_ap表示锚点样本与正样本之间的距离。通过在距离矩阵和正样本矩阵做逐元素相乘后,取每行(每个锚点)的最大值。dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)# `dist_an` means distance(anchor, negative)# both `dist_an` and `relative_n_inds` with shape [N]# dist_an表示锚点样本与负样本之间的距离。首先,通过在距离矩阵和负样本矩阵做逐元素相乘后,再将正样本矩阵与大数(1e9)相乘并加到上述结果上,旨在将负样本对里的正样本对的距离设置地非常大。之后取每行的最小值,找出与锚点样本最近且类别不同的样本。dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)return dist_ap, dist_andef weighted_example_mining(dist_mat, is_pos, is_neg):"""For each anchor, find the weighted positive and negative sample.Args:dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]is_pos:is_neg:Returns:dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]"""assert len(dist_mat.size()) == 2is_pos = is_posis_neg = is_neg# 对于每个锚点样本,找到正样本和负样本的加权距离dist_ap = dist_mat * is_posdist_an = dist_mat * is_neg# 分别通过softmax函数计算正样本和负样本的权重,注意负样本在计算权重之前要取负数。weights_ap = softmax_weights(dist_ap, is_pos)weights_an = softmax_weights(-dist_an, is_neg)# 计算的是加权距离,将距离与对应的权重相乘,然后对结果进行累加求和,得到最后的加权距离。dist_ap = torch.sum(dist_ap * weights_ap, dim=1)dist_an = torch.sum(dist_an * weights_an, dim=1)return dist_ap, dist_an
loss计算
y = dist_an.new().resize_as_(dist_an).fill_(1)
创建一个和dist_an相同大小并内容全部为1的向量。
- y在F.margin_ranking_loss函数中起到了标记的作用,决定了两个输入之间期望的相对大小和顺序。当我们设置y为1时,表示我们期望dist_an(锚点到负样本的距离)大于dist_ap(锚点到正样本的距离)。这也符合我们在训练过程中的期望:即我们希望模型将锚点与其类别内(正样本)的距离保持小,将其与其他类别(负样本)的距离保持大。
- 如果y不设置为1,而是设置为-1,那么其含义将完全颠倒,此时,我们期望dist_an(锚点到负样本的距离)小于dist_ap(锚点到正样本的距离)。这显然违背了我们在进行特征学习时的初衷,无法良好地反映出同类间的聚合性和异类间的分离性。
if margin > 0:loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)else:loss = F.soft_margin_loss(dist_an - dist_ap, y)# fmt: offif loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)# fmt: on
计算最终的三元组损失:
- 如果margin值大于0,那就使用margin ranking loss。这将试图确保正样本对的距离比负样本对的距离小于margin;
- 如果margin值不大于0,那就使用soft margin loss,它是margin ranking loss的一个变体,其中margin被设置为0,并在损失函数中引入了一个logistic损失。在计算soft margin loss之后,如果得到的loss值为infinity,则将margin值手动设置为0.3,再次使用margin ranking loss计算损失。
F.margin_ranking_loss函数是用来实现三元组损失的一个实用方法,它接受两组数据和一个目标向量作为输入来计算定制的秩序损失。
dist_ap代表锚点和正样本之间的距离,最大距离;dist_an代表锚点和负样本之间的距离,最小距离。y是目标向量,经常被设置为1,表示我们希望dist_an(锚点和负样本之间的距离)比dist_ap(锚点和正样本之间的距离)大。margin是我们希望两者之间的最小差距。
这篇关于Triplet Loss三元组损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!