MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析

本文主要是介绍MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

官方代码
作者开源的官方代码有一处错误,在代码剖析部分将指出。有人已经在github上提出了issue,作者一直没回应。我也是在阅读代码的时候发现了这个错误。

背景

VOS任务很少有使用自监督的,即在训练中不借助mask,只用frame image来训练。
作者巧妙的在STM的基础上,将value换成frame自身,使用过去帧重构当前帧作为代理任务(proxy),实现自监督的vos。效果还不错,在davis val上是64的J&F。

核心思想

在这里插入图片描述
仍然是采用STM的memory bank的思想。memory的特征和query的特征会使用transform,得到attention map。但不同的是,stm使用的是经过backbone得到的value,而MAST是直接使用raw frame或者mask。如果是训练阶段;使用raw frame,如果是test阶段,直接使用得到的mask。
在训练阶段,使用当前帧的特征作为query,和memory中的key,value是对应时刻的raw frame,直接使用qkv三元组重构出一个新的帧。这个输出又以当前帧为GT,用huber loss优化。整个过程没有使用到mask GT。在测试阶段,直接使用mask代替raw frame,则每次预测得到的都是重构出来的mask,作为当前帧的输出。

细节

颜色空间

作者认为,RGB颜色空间不适合作为输入,因为是重构作为代理任务。用huber loss是直接优化像素距离的。
比如说重构出来的输出的第i个像素,和raw frame的第i 个像素的matching 距离很小,但实际上他们可能是落在不同目标上。则说明,根据颜色匹配来优化网络,不适合推动模型学习语义特性。
作者也否决了随意丢弃一个channel的做法,因为RGB是关联的,可以通过其他两个通道推理得知另一个通道的像素。
作者使用LAB空间,在随机丢弃一个channel。lab空间解耦性较好。
在这里插入图片描述
作者统计了davis数据集的RGB数值和LAB数值的分布图。可以看出RGB是线性相关的。
输入的颜色值不是互相关联的,则网络将被push学习更好的表征,而并非仅仅依赖局部颜色信息。

loss

作为GT,raw frame使用RGB颜色空间。使用smooth l1 loss(huber loss)
在这里插入图片描述

 outputs = F.interpolate(outputs, (h, w), mode='bilinear')loss = F.smooth_l1_loss(outputs*20, tar_y*20, reduction='mean')

获取ROI区域

作者分析了STM的劣势,就是memory bank式的matching,需要的内存和计算量都很大,O(T*H*W*H*W)。
如果先获得了目标的大致位置,每一个pixel需要匹配的数目就会少很多(原始的是T*H*W)。
作者提出了一个两阶段的ROI localization。假设对query的第i个位置 q i q_i qi进行匹配。首先使用一个网格(应用空洞技巧),围绕在key的第i个位置上,得到网络上的特征,和 q i q_i qi做匹配(dot运算),得到的相似性系数直接加权相对坐标(和直推式vos的做法类似),这里是应用soft argmax,得到离第i个位置最相似的offset。
第二步就是围绕新的位置(i+offset),resample出一个小区域,作为需要匹配的对象。
在这里插入图片描述

其他细节

网络使用resnet18,修改stride,最低分辨率为1/4。训练也是先pretrain,在main train,接着dynamic train。

代码剖析

主要看看ROI那步。其他的步骤都很好读
作者先是在init里面设置了两种sampler。第一个是带dilate的,第二种是没有dilation的。前者用于long term的sampler,后者用于short term。

self.correlation_sampler_dilated = [SpatialCorrelationSampler(kernel_size=1,patch_size=self.memory_patch_P,stride=1,padding=0,dilation=1,dilation_patch=dirate) for dirate in range(2,6)]self.correlation_sampler = SpatialCorrelationSampler(kernel_size=1,patch_size=self.P,stride=1,padding=0,dilation=1)

在forward里面,大致有下面几个步骤:

  • 先对long term key进行第一步粗糙采样,得到ROI的位置,然后在截取主要特征作为matching对象得到系数。
  • 在对short term key同样操作
  • 用得到的offset,对raw frames,也截取对应的value。
  • 所有的attention map以及value都齐了,开始使用qkv公式得到输出。
 for searching_index in range(nsearch):  # long term: need dilation##### GET OFFSET HERE.  (b,h,w,2)samplerindex = dirates[searching_index]-2coarse_search_correlation = self.correlation_sampler_dilated[samplerindex](feats_t, feats_r[searching_index])  # b, p, p, h, wcoarse_search_correlation = coarse_search_correlation.reshape(b, self.memory_patch_N, h*w)coarse_search_correlation = F.softmax(coarse_search_correlation, dim=1)coarse_search_correlation = coarse_search_correlation.reshape(b,self.memory_patch_P,self.memory_patch_P,h,w,1)_y, _x = torch.meshgrid(torch.arange(-self.memory_patch_R,self.memory_patch_R+1),torch.arange(-self.memory_patch_R,self.memory_patch_R+1))grid = torch.stack([_x, _y], dim=-1).unsqueeze(-2).unsqueeze(-2)\.reshape(1,self.memory_patch_P,self.memory_patch_P,1,1,2).contiguous().float().to(coarse_search_correlation.device)# 每个query像素在mem bank中的一帧该以哪个位置为中心采样offset0 = (coarse_search_correlation * grid ).sum(1).sum(1) * dirates[searching_index]  # 1,h,w,2col_0 = deform_im2col(feats_r[searching_index], offset0, kernel_size=self.P)  # b,c*N,h*wcol_0 = col_0.reshape(b,c,N,h,w)##corr = (feats_t.unsqueeze(2) * col_0).sum(1)   # (b, N, h, w)corr = corr.reshape([b, self.P * self.P, h * w])corrs.append(corr)
 for ind in range(nsearch, nref):  # short termcorrs.append(self.correlation_sampler(feats_t, feats_r[ind]))_, _, _, h1, w1 = corrs[-1].size()corrs[ind] = corrs[ind].reshape([b, self.P*self.P, h1*w1])

得到T帧的匹配系数的softmax值

  corr = torch.cat(corrs, 1)  # b,nref*N,HWcorr = F.softmax(corr, dim=1)corr = corr.unsqueeze(1)

得到value

im_col0 = [deform_im2col(qr[i], offset0, kernel_size=self.P)  for i in range(nsearch)]# b, 3*N, h*w
im_col1 = [F.unfold(r, kernel_size=self.P, padding=self.R) for r in qr[nsearch:]]
image_uf = im_col0 + im_col1  # memory value list.

得到预测结果

  out = (corr * image_uf).sum(2).reshape([b,qr[0].size(1),h,w])

采用使用的是spatial correlation sapmle,是计算光流的cost valume的重要操作。不知道啥是cost valume可以去知乎搜索一下。作者这里用他是计算 q i q_i qi和在key上以i为中心的网格中被选取的特征的相似度。

所谓的截取,就是已知 q i q_i qi应该在哪个位置截取,就使用grid sample取出来。

这篇关于MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/247398

相关文章

什么是 Java 的 CyclicBarrier(代码示例)

《什么是Java的CyclicBarrier(代码示例)》CyclicBarrier是多线程协同的利器,适合需要多次同步的场景,本文通过代码示例讲解什么是Java的CyclicBarrier,感... 你的回答(口语化,面试场景)面试官:什么是 Java 的 CyclicBarrier?你:好的,我来举个例

基于Canvas的Html5多时区动态时钟实战代码

《基于Canvas的Html5多时区动态时钟实战代码》:本文主要介绍了如何使用Canvas在HTML5上实现一个多时区动态时钟的web展示,通过Canvas的API,可以绘制出6个不同城市的时钟,并且这些时钟可以动态转动,每个时钟上都会标注出对应的24小时制时间,详细内容请阅读本文,希望能对你有所帮助...

HTML5 data-*自定义数据属性的示例代码

《HTML5data-*自定义数据属性的示例代码》HTML5的自定义数据属性(data-*)提供了一种标准化的方法在HTML元素上存储额外信息,可以通过JavaScript访问、修改和在CSS中使用... 目录引言基本概念使用自定义数据属性1. 在 html 中定义2. 通过 JavaScript 访问3.

Linux系统之authconfig命令的使用解读

《Linux系统之authconfig命令的使用解读》authconfig是一个用于配置Linux系统身份验证和账户管理设置的命令行工具,主要用于RedHat系列的Linux发行版,它提供了一系列选项... 目录linux authconfig命令的使用基本语法常用选项示例总结Linux authconfi

Flutter监听当前页面可见与隐藏状态的代码详解

《Flutter监听当前页面可见与隐藏状态的代码详解》文章介绍了如何在Flutter中使用路由观察者来监听应用进入前台或后台状态以及页面的显示和隐藏,并通过代码示例讲解的非常详细,需要的朋友可以参考下... flutter 可以监听 app 进入前台还是后台状态,也可以监听当http://www.cppcn

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

解读docker运行时-itd参数是什么意思

《解读docker运行时-itd参数是什么意思》在Docker中,-itd参数组合用于在后台运行一个交互式容器,同时保持标准输入和分配伪终端,这种方式适合需要在后台运行容器并保持交互能力的场景... 目录docker运行时-itd参数是什么意思1. -i(或 --interactive)2. -t(或 --

Java中有什么工具可以进行代码反编译详解

《Java中有什么工具可以进行代码反编译详解》:本文主要介绍Java中有什么工具可以进行代码反编译的相关资,料,包括JD-GUI、CFR、Procyon、Fernflower、Javap、Byte... 目录1.JD-GUI2.CFR3.Procyon Decompiler4.Fernflower5.Jav

解读为什么@Autowired在属性上被警告,在setter方法上不被警告问题

《解读为什么@Autowired在属性上被警告,在setter方法上不被警告问题》在Spring开发中,@Autowired注解常用于实现依赖注入,它可以应用于类的属性、构造器或setter方法上,然... 目录1. 为什么 @Autowired 在属性上被警告?1.1 隐式依赖注入1.2 IDE 的警告:

javaScript在表单提交时获取表单数据的示例代码

《javaScript在表单提交时获取表单数据的示例代码》本文介绍了五种在JavaScript中获取表单数据的方法:使用FormData对象、手动提取表单数据、使用querySelector获取单个字... 方法 1:使用 FormData 对象FormData 是一个方便的内置对象,用于获取表单中的键值