CenterLoss | 减小类间距离

2024-03-07 23:30

本文主要是介绍CenterLoss | 减小类间距离,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.centerloss原理

centerloss中心损失它仅仅用来减少类内的差异,而不能有效增大类间的差异性。下图中,图(a)表示softmax loss学习到的特征描述 。图(b)表示softmax loss + center loss 学习到的特征描述,他能把同一类的样本之间的距离拉近一些,使其相似性变大,尽量的往样本中心靠拢,但可以看出他没有把不同类样本之间的样本距离拉大。

centerloss的主要思路为:让每一类特征尽可能的在输出特征空间内聚集在一起。更直白的描述就是每一类的特征在特征空间中尽可能的聚集在某一个中心点附近。正常情况下,如果我们先验的知道了所有样本的GT中心点,那这个任务就好解决了,然而事实是我们无法预先获取类中心特征空间的分布。因此我们只能从训练的过程中动态的获取类中心特征,并对整体的训练过程产生约束。需要注意的是在训练的过程中,受限于GPU的显存等问题,我们不可能直接获取所有样本的特征中心,因此整个过程是基于batch进行的,而且当网络还未收敛的情况下,网络得到的特征中心也是不正确的。基于这两点,特征中心的确定势必是一个基于batch的动态过程。

2.中心点是如何维护的

接下来就详细讲一下这个动态过程,首先提出一个问题:中心点明明是不确定的,那如何让特征去聚集在这个不确定的特征中心点呢?

这要从centerloss的更新机制说起,从下面的两组公式可以看出,center中心点的更新方向是特征值和中心点的二范数,简单来说最终通过这种更新方式会使得某一类特征值对应的中心点被更新成与所有该类样本特征值的二范数和最小的位置,而这个位置我们可以广义的理解为所以特征的中心点位置。因此整体的centerloss是在边学习边找中心点的,最终中心点的确定和整体分类任务的收敛是同步进行的。

用知乎上比较概括性的话来讲就是:
center loss的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离。

最终通过将centerloss和softmaxloss进行加权求和,实现整体的分类任务的学习。

centerloss的计算代码:

def forward(self, output_features, y_truth):"""损失计算:param output_features: conv层输出的特征,  [b,c,h,w]:param y_truth:  标签值  [b,]:return:"""batch_size = y_truth.size(0)output_features = output_features.view(batch_size, -1)assert output_features.size(-1) == self.feat_dimfactor = self.scale / batch_size# return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers))centers_batch = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]diff = output_features - centers_batchloss = self.lamda * 0.5 * factor * (diff.pow(2).sum())#########return loss

center的更新代码:

# 改段代码需要注意的是backward返回值需要与对应的forward的输入参数一一对应。
class CenterlossFunc(Function):@staticmethoddef forward(ctx, feature, label, centers, batch_size):ctx.save_for_backward(feature, label, centers, batch_size)centers_batch = centers.index_select(0, label.long())return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size@staticmethoddef backward(ctx, grad_output):feature, label, centers, batch_size = ctx.saved_tensorscenters_batch = centers.index_select(0, label.long())diff = centers_batch - feature# init every iterationcounts = centers.new_ones(centers.size(0))ones = centers.new_ones(label.size(0))grad_centers = centers.new_zeros(centers.size())counts = counts.scatter_add_(0, label.long(), ones)grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)grad_centers = grad_centers/counts.view(-1, 1)return - grad_output * diff / batch_size, None, grad_centers / batch_size, None

pytorch代码
https://www.cnblogs.com/dxscode/p/12059548.html
https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py

这篇关于CenterLoss | 减小类间距离的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/785197

相关文章

线性代数|机器学习-P35距离矩阵和普鲁克问题

文章目录 1. 距离矩阵2. 正交普鲁克问题3. 实例说明 1. 距离矩阵 假设有三个点 x 1 , x 2 , x 3 x_1,x_2,x_3 x1​,x2​,x3​,三个点距离如下: ∣ ∣ x 1 − x 2 ∣ ∣ 2 = 1 , ∣ ∣ x 2 − x 3 ∣ ∣ 2 = 1 , ∣ ∣ x 1 − x 3 ∣ ∣ 2 = 6 \begin{equation} ||x

模拟退火求n个点到某点距离和最短

/*找出一个点使得这个店到n个点的最长距离最短,即求最小覆盖圆的半径用一个点往各个方向扩展,如果结果更优,则继续以当前步长扩展,否则缩小步长*/#include<stdio.h>#include<math.h>#include<string.h>const double pi = acos(-1.0);struct point {double x,y;}p[1010];int

黑神话:悟空》增加草地绘制距离MOD使游戏场景看起来更加广阔与自然,增强了游戏的沉浸式体验

《黑神话:悟空》增加草地绘制距离MOD为玩家提供了一种全新的视觉体验,通过扩展游戏中草地的绘制距离,增加了场景的深度和真实感。该MOD通过增加草地的绘制距离,使游戏场景看起来更加广阔与自然,增强了游戏的沉浸式体验。 增加草地绘制距离MOD安装 1、在%userprofile%AppDataLocalb1SavedConfigWindows目录下找到Engine.ini文件。 2、使用记事本编辑

SimD:基于相似度距离的小目标检测标签分配

摘要 https://arxiv.org/pdf/2407.02394 由于物体尺寸有限且信息不足,小物体检测正成为计算机视觉领域最具挑战性的任务之一。标签分配策略是影响物体检测精度的关键因素。尽管已经存在一些针对小物体的有效标签分配策略,但大多数策略都集中在降低对边界框的敏感性以增加正样本数量上,并且需要设置一些固定的超参数。然而,更多的正样本并不一定会带来更好的检测结果,事实上,过多的正样本

Matlab)实现HSV非等间隔量化--相似判断:欧式距离--输出图片-

%************************************************************************** %                                 图像检索——提取颜色特征 %HSV空间颜色直方图(将RGB空间转化为HS

C/C++两点坐标求距离以及C++保留两位小数输出,秒了

目录 1. 前言 2. 正文 2.1 问题 2.2 解决办法 2.2.1 思路 2.2.2 代码实现 3. 备注 1. 前言 依旧是带来一个练手的题目,目的就一个,方法千千万,通向终点的方式有很多种,没有谁与谁,我们都是为了成为更好的自己。 2. 正文 2.1 问题 题目描述: 输入两点坐标(X1,Y1),(X2,Y2),计算并输出两点间的距离。 输入格式:

mysql5.6根据经纬度查询距离二

在MySQL 5.6中,您可以使用Haversine公式来根据经纬度查询距离。以下是一个示例SQL查询,它计算出所有点与给定点(经度lon和纬度lat)的距离,并按距离排序: SELECT id, (2 * 6378.137 * ASIN(SQRT(POW( SIN( PI( ) * ( $lng- `long` ) / 360 ), 2 ) + COS( PI( ) * $lat / 180

像素间的关系(邻接、连通、区域、边界、距离定义)

文章目录 像素的相邻像素4邻域D邻域8邻域 邻接、连通、区域和边界邻接类型连通区域边界 距离测度欧氏距离城市街区距离(city-block distance)棋盘距离(chessboard distance) 参考 像素的相邻像素 4邻域 坐标 ( x , y ) (x,y) (x,y)处的像素 p p p有2个水平的相邻像素和2个垂直的相邻像素,它们的坐标是: ( x

【go语言计算两个经纬度距离】根据经纬度计算两点之间距离

一、需求分析: 输入两个经纬度,计算它们之间的距离 lat1,lng1 := 32.060255,118.796877lat2,lng2 := 39.904211,116.407395 二、计算公式 //C = sin(LatA*Pi/180)*sin(LatB*Pi/180) + cos(LatA*Pi/180)*cos(LatB*Pi/180)*cos((MLonA-MLonB)

【python 走进NLP】文本相似度各种距离计算

计算文本相似度有什么用? 1、反垃圾文本的捞取 “诚聘淘宝兼职”、“诚聘打字员”…这样的小广告满天飞,作为网站或者APP的运营者,不可能手动将所有的广告文本放入屏蔽名单里,挑几个典型广告文本,与它满足一定相似度就进行屏蔽。 2、推荐系统 在微博和各大BBS上,每一篇文章/帖子的下面都有一个推荐阅读,那就是根据一定算法计算出来的相似文章。 3、冗余过滤 我们每天接触过量的信息,信息之间存在大量