MoCo 算法阅读记录

2024-04-10 22:36
文章标签 算法 记录 阅读 moco

本文主要是介绍MoCo 算法阅读记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文地址:🐰

何凯明大神之作,通过无监督对比学习预训练Image Encoder的表征能力。后也被许多VLP算法作为ITC的底层算法来使用。

一方面由于源代码本身并不复杂,但是要求多GPU分布式训练,以及需要下载ImageNet这个大规模的数据集;另一方面 本次只是测试和阅读算法原理的实现,并不完整使用。因此,重写了一个低配版(流程不变,超参数没有严格要求设置,单GPU跑,数据集自己配置,几十张图片, no Shuffling BN)。

queue 即文中所构建的字典,起名为这个就是因为 C++ 中 的queue 容器,因为它是一种先进先出的数据结构。

目录

一、数据预处理

二、前向传播

网络结构

算法流程


一、数据预处理

对同一张图片进行数据增强操作,得到 query 和 key。

增强操作包括

transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),transforms.RandomGrayscale(p=0.2),transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),transforms.RandomHorizontalFlip(),normalize,

所以,dataloader中的每个输入样本是一个样本对儿。

通过下列方法实现

class TwoCropsTransform:"""Take two random crops of one image as the query and key."""def __init__(self, base_transform):self.base_transform = base_transformdef __call__(self, x):q = self.base_transform(x)k = self.base_transform(x)return [q, k]

二、前向传播

网络结构

代码中 encoder q 和 encoder k的网络结构用的都是ReNet 。ResNet最终的输出层包含了

(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=128, bias=True)

所以,输出的特征向量维度为 (N,C)。N为文中的Mini batch大小,代码中的超参数为batch size。C应该没有什么具体的含义,只是经验的设置为这一长度了(没找出来C的大小关乎什么)。

其输出还经过了L2归一化。 

算法流程

1、 q 送入 encoder q 得到输出,并经过L2归一化, (N,C)

2、 momentum 更新 key encoder。

3、 Shuffling BN(当然我重写的代码并没有实现这个,因为它需要多GPU,但这并不妨碍认识它的作用)

文中所述

大致意思由于ResNet使用了BN操作,因此由于Batch 数据之间的交互,使得模型利用它欺骗预设任务从而简单的找到一个低损失的解决方案,然而这个解决方案效果并不好,使得模型学习不到好的表征能力。

其提出的Shuffling BN

首先,把所有进程的Tensor的收集起来(如果分布式训练,一般每个GPU包含一个进程,所以收集的数据总量大小为 num GPUs * batch size),参考这里🤖

x_gather = concat_all_gather(x)

接下来制作打乱的索引,整个过程如下所示

    def _batch_shuffle_ddp(self, x):"""Batch shuffle, for making use of BatchNorm.*** Only support DistributedDataParallel (DDP) model. ***"""# gather from all gpusbatch_size_this = x.shape[0]x_gather = concat_all_gather(x)  # 将所有进程的数据收集起来batch_size_all = x_gather.shape[0]num_gpus = batch_size_all // batch_size_this# random shuffle indexidx_shuffle = torch.randperm(batch_size_all).cuda()  # torch.randperm 将[0,n)数随机排列# broadcast to all gpustorch.distributed.broadcast(idx_shuffle, src=0)  # 将这个信息广播到所有其他进程# index for restoringidx_unshuffle = torch.argsort(idx_shuffle)  # 按照值大小顺序返回下标# shuffled index for this gpugpu_idx = torch.distributed.get_rank()  # 返回当前的进程idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]  # idx_shuffle view 后 (num_gpus, batch size) 但是batch size中的索引是打乱顺序的return x_gather[idx_this], idx_unshuffle

最终返回 随机打乱顺序后挑选的当前进程的 batch size 大小的数据,也就是说进行 BN归一化后的数据已经不在 同一个原来的批 中了。

4、k 送入 encoder k 中,在经过L2 归一化, 和q一样。  (N,C)

5、Shuffling BN 对齐 q 和 k

如下面举例

# idx_shuffle
tensor([10, 16, 13,  2,  4,  0,  6, 21, 22, 31, 29,  3, 19, 17, 14, 30, 28, 12,24, 26,  8, 25, 11, 18,  5,  7, 27,  1, 15, 23, 20,  9])# idx_unshuffle
tensor([ 5, 27,  3, 11,  4, 24,  6, 25, 20, 31,  0, 22, 17,  2, 14, 28,  1, 13,23, 12, 30,  7,  8, 29, 18, 21, 19, 26, 16, 10, 15,  9])# q 的 idx_this
tensor([10, 16, 13,  2,  4,  0,  6, 21])# k 的 idx_this
tensor([ 5, 27,  3, 11,  4, 24,  6, 25])

这里主要关注的点是 这步是为了使 k对齐打乱顺序的q。q之前是打乱了顺序从而改变了每个batch的内容,相当于从所有的batch中随机挑选了 batch size的q,从而保证去除BN的影响。

而 k 不需要 再打乱了, 只需要从原有的batch size 数据分布中挑选出与q对应的数据即可。所以才在 shuffle BN q的过程中记录了indx unshuffle。

这里的对应关系举例,比如 index shuffle 中的 0 现在位于原来没打乱状态的索引 5处, 类似的 1 -->27, 2-->3, 以此类推。

注:不要被上面单进程的(即idx this)不对齐所迷惑,上面的只是分进程处理的,分布式训练最终会把所有进程的数据拼接起来一起处理,所以所有进程的数据对齐就行。

6、计算损失,即文中公式1

其中 用到的计算方法举例如下,分别用爱因斯坦求和公式实现,参考这里🤖

a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 2]])
b = torch.tensor([[2, 2, 2], [2, 2, 2], [1, 1, 1]])
print(a)
print(b)
c = torch.einsum("nc, nc->n", [a, b])  # (3)
d = c.unsqueeze(-1)  # (3,1)
print(c)#=== 输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[2, 2, 2],[2, 2, 2],[1, 1, 1]])
tensor([12,  6,  7])
tensor([[12],[ 6],[ 7]])
a = torch.tensor([[1, 2, 3], [1, 1, 1], [2, 2, 3]])  # (3,3)
a1 = torch.tensor([[1, 2], [1, 1], [2, 2]])  # (3,2)
c = torch.einsum("nc,ck->nk", [a, a1])
print(a)
print(a1)
print(c)# ===输出
tensor([[1, 2, 3],[1, 1, 1],[2, 2, 3]])
tensor([[1, 2],[1, 1],[2, 2]])
tensor([[ 9, 10],[ 4,  5],[10, 12]])

这里的self.queue 即文中的字典 queue,初始化为

self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)

K为字典的长度,默认设置65536。这里为什么设置为这个可能是由于ImageNet数据集比较大,所以设置的字典比较长,具体的长度设置好像没有做固定的要求,

来源于github官网。但代码中有要求,K必须是batch size 的倍数,这个为了确保字典的更新,方便执行入栈和弹出操作。这个字典像是C++的 queue容器的FIFO数据结构,即先进先出

self.K % batch_size == 0
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)  #  (8,1)  对应元素相乘并第一维加和# negative logits: NxKl_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])  # (8,65536)  矩阵相乘# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)  # (8,65537)# apply temperaturelogits /= self.Tlabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()  # (8,)loss = criterion(output, target)

这里看标签都是0,即第一个也就是0维数据为正样本。因为在拼接cat的时候正样本是在前面的。

7、更新字典

按mini batch 更新。具体地,如果 训练次数*mini batch size 小于字典长度,则字典queue每次都会填充新的key。若训练次数*mini batch size 大于 字典长度,则之前的被替换掉。

ptr = (ptr + batch_size) % self.K  # move pointer  8

这篇关于MoCo 算法阅读记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/892339

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python MySQL如何通过Binlog获取变更记录恢复数据

《PythonMySQL如何通过Binlog获取变更记录恢复数据》本文介绍了如何使用Python和pymysqlreplication库通过MySQL的二进制日志(Binlog)获取数据库的变更记录... 目录python mysql通过Binlog获取变更记录恢复数据1.安装pymysqlreplicat

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

Servlet中配置和使用过滤器的步骤记录

《Servlet中配置和使用过滤器的步骤记录》:本文主要介绍在Servlet中配置和使用过滤器的方法,包括创建过滤器类、配置过滤器以及在Web应用中使用过滤器等步骤,文中通过代码介绍的非常详细,需... 目录创建过滤器类配置过滤器使用过滤器总结在Servlet中配置和使用过滤器主要包括创建过滤器类、配置过滤

正则表达式高级应用与性能优化记录

《正则表达式高级应用与性能优化记录》本文介绍了正则表达式的高级应用和性能优化技巧,包括文本拆分、合并、XML/HTML解析、数据分析、以及性能优化方法,通过这些技巧,可以更高效地利用正则表达式进行复杂... 目录第6章:正则表达式的高级应用6.1 模式匹配与文本处理6.1.1 文本拆分6.1.2 文本合并6

python与QT联合的详细步骤记录

《python与QT联合的详细步骤记录》:本文主要介绍python与QT联合的详细步骤,文章还展示了如何在Python中调用QT的.ui文件来实现GUI界面,并介绍了多窗口的应用,文中通过代码介绍... 目录一、文章简介二、安装pyqt5三、GUI页面设计四、python的使用python文件创建pytho

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖