CLIP算法的Loss详解 和 交叉熵CrossEntropy实现

2024-01-08 07:40

本文主要是介绍CLIP算法的Loss详解 和 交叉熵CrossEntropy实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CLIP:Contrastive Language–Image Pre-training(可对比语言-图像预训练算法)是OpenAI提出的多模态预训练的算法,在各种各样的**样本对(图像、文本)**上训练的神经网络。

具体参考:CLIP、OpenCLIP

image-20220601180224080

其中,流程:

image-20220601180639145

loss_iloss_t的具体源码如下,参考 model.py:

    def forward(self, image, text):image_features = self.encode_image(image)text_features = self.encode_text(text)# normalized featuresimage_features = image_features / image_features.norm(dim=1, keepdim=True)text_features = text_features / text_features.norm(dim=1, keepdim=True)# cosine similarity as logitslogit_scale = self.logit_scale.exp()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()# shape = [global_batch_size, global_batch_size]return logits_per_image, logits_per_text

其中,labels是torch.arange(batch_size, device=device).long(),参考train.py,具体如下

        with torch.no_grad():for i, batch in enumerate(dataloader):images, texts = batchimages = images.to(device=device, non_blocking=True)texts = texts.to(device=device, non_blocking=True)with autocast():image_features, text_features, logit_scale = model(images, texts)# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly# however, system RAM is easily exceeded and compute time becomes problematicall_image_features.append(image_features.cpu())all_text_features.append(text_features.cpu())logit_scale = logit_scale.mean()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()batch_size = images.shape[0]labels = torch.arange(batch_size, device=device).long()total_loss = (F.cross_entropy(logits_per_image, labels) +F.cross_entropy(logits_per_text, labels)) / 2

交叉熵函数:y就是label,x_softmax[i][y[i]],表示在x_softmax中筛选第i个sample的第y[i]个值,作为log的输入,全部log负向求和,再求均值。

  • y所对应的就是CLIP的np.arange(n),也就是依次是第0个位置~第n-1个位置,计算log。
# 定义softmax函数
def softmax(x):return np.exp(x) / np.sum(np.exp(x))# 利用numpy计算
def cross_entropy_np(x, y):x_softmax = [softmax(x[i]) for i in range(len(x))]x_log = [np.log(x_softmax[i][y[i]]) for i in range(len(y))]loss = - np.sum(x_log) / len(y)return loss# 测试逻辑
x = [[1.9269, 1.4873, 0.9007, -2.1055]]
y = [[2]]
v1 = cross_entropy_np(x, y)
print(f"v1: {v1}")x = torch.unsqueeze(torch.Tensor(x), dim=0)
x = x.transpose(1, 2)  # CrossEntropy输入期望: Class放在第2维,Batch放在第1维y = torch.Tensor(y)
y = y.to(torch.long)  # label的类型为longv2 = F.cross_entropy(x, y, reduction="none")
print(f"v2: {v2}")

输出:

v1: 1.729491540989093
v2: tensor([[1.7295]])

参考:

  • arxiv文章下载很慢怎么办?
  • CLIP-对比图文多模态预训练的读后感
  • CrossEntropy的numpy实现和Pytorch调用

这篇关于CLIP算法的Loss详解 和 交叉熵CrossEntropy实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582742

相关文章

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串

Spring Boot spring-boot-maven-plugin 参数配置详解(最新推荐)

《SpringBootspring-boot-maven-plugin参数配置详解(最新推荐)》文章介绍了SpringBootMaven插件的5个核心目标(repackage、run、start... 目录一 spring-boot-maven-plugin 插件的5个Goals二 应用场景1 重新打包应用

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

Linux系统性能检测命令详解

《Linux系统性能检测命令详解》本文介绍了Linux系统常用的监控命令(如top、vmstat、iostat、htop等)及其参数功能,涵盖进程状态、内存使用、磁盘I/O、系统负载等多维度资源监控,... 目录toppsuptimevmstatIOStatiotopslabtophtopdstatnmon