MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析

本文主要是介绍MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

官方代码
作者开源的官方代码有一处错误,在代码剖析部分将指出。有人已经在github上提出了issue,作者一直没回应。我也是在阅读代码的时候发现了这个错误。

背景

VOS任务很少有使用自监督的,即在训练中不借助mask,只用frame image来训练。
作者巧妙的在STM的基础上,将value换成frame自身,使用过去帧重构当前帧作为代理任务(proxy),实现自监督的vos。效果还不错,在davis val上是64的J&F。

核心思想

在这里插入图片描述
仍然是采用STM的memory bank的思想。memory的特征和query的特征会使用transform,得到attention map。但不同的是,stm使用的是经过backbone得到的value,而MAST是直接使用raw frame或者mask。如果是训练阶段;使用raw frame,如果是test阶段,直接使用得到的mask。
在训练阶段,使用当前帧的特征作为query,和memory中的key,value是对应时刻的raw frame,直接使用qkv三元组重构出一个新的帧。这个输出又以当前帧为GT,用huber loss优化。整个过程没有使用到mask GT。在测试阶段,直接使用mask代替raw frame,则每次预测得到的都是重构出来的mask,作为当前帧的输出。

细节

颜色空间

作者认为,RGB颜色空间不适合作为输入,因为是重构作为代理任务。用huber loss是直接优化像素距离的。
比如说重构出来的输出的第i个像素,和raw frame的第i 个像素的matching 距离很小,但实际上他们可能是落在不同目标上。则说明,根据颜色匹配来优化网络,不适合推动模型学习语义特性。
作者也否决了随意丢弃一个channel的做法,因为RGB是关联的,可以通过其他两个通道推理得知另一个通道的像素。
作者使用LAB空间,在随机丢弃一个channel。lab空间解耦性较好。
在这里插入图片描述
作者统计了davis数据集的RGB数值和LAB数值的分布图。可以看出RGB是线性相关的。
输入的颜色值不是互相关联的,则网络将被push学习更好的表征,而并非仅仅依赖局部颜色信息。

loss

作为GT,raw frame使用RGB颜色空间。使用smooth l1 loss(huber loss)
在这里插入图片描述

 outputs = F.interpolate(outputs, (h, w), mode='bilinear')loss = F.smooth_l1_loss(outputs*20, tar_y*20, reduction='mean')

获取ROI区域

作者分析了STM的劣势,就是memory bank式的matching,需要的内存和计算量都很大,O(T*H*W*H*W)。
如果先获得了目标的大致位置,每一个pixel需要匹配的数目就会少很多(原始的是T*H*W)。
作者提出了一个两阶段的ROI localization。假设对query的第i个位置 q i q_i qi进行匹配。首先使用一个网格(应用空洞技巧),围绕在key的第i个位置上,得到网络上的特征,和 q i q_i qi做匹配(dot运算),得到的相似性系数直接加权相对坐标(和直推式vos的做法类似),这里是应用soft argmax,得到离第i个位置最相似的offset。
第二步就是围绕新的位置(i+offset),resample出一个小区域,作为需要匹配的对象。
在这里插入图片描述

其他细节

网络使用resnet18,修改stride,最低分辨率为1/4。训练也是先pretrain,在main train,接着dynamic train。

代码剖析

主要看看ROI那步。其他的步骤都很好读
作者先是在init里面设置了两种sampler。第一个是带dilate的,第二种是没有dilation的。前者用于long term的sampler,后者用于short term。

self.correlation_sampler_dilated = [SpatialCorrelationSampler(kernel_size=1,patch_size=self.memory_patch_P,stride=1,padding=0,dilation=1,dilation_patch=dirate) for dirate in range(2,6)]self.correlation_sampler = SpatialCorrelationSampler(kernel_size=1,patch_size=self.P,stride=1,padding=0,dilation=1)

在forward里面,大致有下面几个步骤:

  • 先对long term key进行第一步粗糙采样,得到ROI的位置,然后在截取主要特征作为matching对象得到系数。
  • 在对short term key同样操作
  • 用得到的offset,对raw frames,也截取对应的value。
  • 所有的attention map以及value都齐了,开始使用qkv公式得到输出。
 for searching_index in range(nsearch):  # long term: need dilation##### GET OFFSET HERE.  (b,h,w,2)samplerindex = dirates[searching_index]-2coarse_search_correlation = self.correlation_sampler_dilated[samplerindex](feats_t, feats_r[searching_index])  # b, p, p, h, wcoarse_search_correlation = coarse_search_correlation.reshape(b, self.memory_patch_N, h*w)coarse_search_correlation = F.softmax(coarse_search_correlation, dim=1)coarse_search_correlation = coarse_search_correlation.reshape(b,self.memory_patch_P,self.memory_patch_P,h,w,1)_y, _x = torch.meshgrid(torch.arange(-self.memory_patch_R,self.memory_patch_R+1),torch.arange(-self.memory_patch_R,self.memory_patch_R+1))grid = torch.stack([_x, _y], dim=-1).unsqueeze(-2).unsqueeze(-2)\.reshape(1,self.memory_patch_P,self.memory_patch_P,1,1,2).contiguous().float().to(coarse_search_correlation.device)# 每个query像素在mem bank中的一帧该以哪个位置为中心采样offset0 = (coarse_search_correlation * grid ).sum(1).sum(1) * dirates[searching_index]  # 1,h,w,2col_0 = deform_im2col(feats_r[searching_index], offset0, kernel_size=self.P)  # b,c*N,h*wcol_0 = col_0.reshape(b,c,N,h,w)##corr = (feats_t.unsqueeze(2) * col_0).sum(1)   # (b, N, h, w)corr = corr.reshape([b, self.P * self.P, h * w])corrs.append(corr)
 for ind in range(nsearch, nref):  # short termcorrs.append(self.correlation_sampler(feats_t, feats_r[ind]))_, _, _, h1, w1 = corrs[-1].size()corrs[ind] = corrs[ind].reshape([b, self.P*self.P, h1*w1])

得到T帧的匹配系数的softmax值

  corr = torch.cat(corrs, 1)  # b,nref*N,HWcorr = F.softmax(corr, dim=1)corr = corr.unsqueeze(1)

得到value

im_col0 = [deform_im2col(qr[i], offset0, kernel_size=self.P)  for i in range(nsearch)]# b, 3*N, h*w
im_col1 = [F.unfold(r, kernel_size=self.P, padding=self.R) for r in qr[nsearch:]]
image_uf = im_col0 + im_col1  # memory value list.

得到预测结果

  out = (corr * image_uf).sum(2).reshape([b,qr[0].size(1),h,w])

采用使用的是spatial correlation sapmle,是计算光流的cost valume的重要操作。不知道啥是cost valume可以去知乎搜索一下。作者这里用他是计算 q i q_i qi和在key上以i为中心的网格中被选取的特征的相似度。

所谓的截取,就是已知 q i q_i qi应该在哪个位置截取,就使用grid sample取出来。

这篇关于MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/247398

相关文章

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一

在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码

《在MyBatis的XML映射文件中<trim>元素所有场景下的完整使用示例代码》在MyBatis的XML映射文件中,trim元素用于动态添加SQL语句的一部分,处理前缀、后缀及多余的逗号或连接符,示... 在MyBATis的XML映射文件中,<trim>元素用于动态地添加SQL语句的一部分,例如SET或W

使用C#代码计算数学表达式实例

《使用C#代码计算数学表达式实例》这段文字主要讲述了如何使用C#语言来计算数学表达式,该程序通过使用Dictionary保存变量,定义了运算符优先级,并实现了EvaluateExpression方法来... 目录C#代码计算数学表达式该方法很长,因此我将分段描述下面的代码片段显示了下一步以下代码显示该方法如

MySQL中的锁和MVCC机制解读

《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本

Redis过期键删除策略解读

《Redis过期键删除策略解读》Redis通过惰性删除策略和定期删除策略来管理过期键,惰性删除策略在键被访问时检查是否过期并删除,节省CPU开销但可能导致过期键滞留,定期删除策略定期扫描并删除过期键,... 目录1.Redis使用两种不同的策略来删除过期键,分别是惰性删除策略和定期删除策略1.1惰性删除策略

python多进程实现数据共享的示例代码

《python多进程实现数据共享的示例代码》本文介绍了Python中多进程实现数据共享的方法,包括使用multiprocessing模块和manager模块这两种方法,具有一定的参考价值,感兴趣的可以... 目录背景进程、进程创建进程间通信 进程间共享数据共享list实践背景 安卓ui自动化框架,使用的是

SpringBoot生成和操作PDF的代码详解

《SpringBoot生成和操作PDF的代码详解》本文主要介绍了在SpringBoot项目下,通过代码和操作步骤,详细的介绍了如何操作PDF,希望可以帮助到准备通过JAVA操作PDF的你,项目框架用的... 目录本文简介PDF文件简介代码实现PDF操作基于PDF模板生成,并下载完全基于代码生成,并保存合并P

SpringBoot基于MyBatis-Plus实现Lambda Query查询的示例代码

《SpringBoot基于MyBatis-Plus实现LambdaQuery查询的示例代码》MyBatis-Plus是MyBatis的增强工具,简化了数据库操作,并提高了开发效率,它提供了多种查询方... 目录引言基础环境配置依赖配置(Maven)application.yml 配置表结构设计demo_st

Redis与缓存解读

《Redis与缓存解读》文章介绍了Redis作为缓存层的优势和缺点,并分析了六种缓存更新策略,包括超时剔除、先删缓存再更新数据库、旁路缓存、先更新数据库再删缓存、先更新数据库再更新缓存、读写穿透和异步... 目录缓存缓存优缺点缓存更新策略超时剔除先删缓存再更新数据库旁路缓存(先更新数据库,再删缓存)先更新数