MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析

本文主要是介绍MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

官方代码
作者开源的官方代码有一处错误,在代码剖析部分将指出。有人已经在github上提出了issue,作者一直没回应。我也是在阅读代码的时候发现了这个错误。

背景

VOS任务很少有使用自监督的,即在训练中不借助mask,只用frame image来训练。
作者巧妙的在STM的基础上,将value换成frame自身,使用过去帧重构当前帧作为代理任务(proxy),实现自监督的vos。效果还不错,在davis val上是64的J&F。

核心思想

在这里插入图片描述
仍然是采用STM的memory bank的思想。memory的特征和query的特征会使用transform,得到attention map。但不同的是,stm使用的是经过backbone得到的value,而MAST是直接使用raw frame或者mask。如果是训练阶段;使用raw frame,如果是test阶段,直接使用得到的mask。
在训练阶段,使用当前帧的特征作为query,和memory中的key,value是对应时刻的raw frame,直接使用qkv三元组重构出一个新的帧。这个输出又以当前帧为GT,用huber loss优化。整个过程没有使用到mask GT。在测试阶段,直接使用mask代替raw frame,则每次预测得到的都是重构出来的mask,作为当前帧的输出。

细节

颜色空间

作者认为,RGB颜色空间不适合作为输入,因为是重构作为代理任务。用huber loss是直接优化像素距离的。
比如说重构出来的输出的第i个像素,和raw frame的第i 个像素的matching 距离很小,但实际上他们可能是落在不同目标上。则说明,根据颜色匹配来优化网络,不适合推动模型学习语义特性。
作者也否决了随意丢弃一个channel的做法,因为RGB是关联的,可以通过其他两个通道推理得知另一个通道的像素。
作者使用LAB空间,在随机丢弃一个channel。lab空间解耦性较好。
在这里插入图片描述
作者统计了davis数据集的RGB数值和LAB数值的分布图。可以看出RGB是线性相关的。
输入的颜色值不是互相关联的,则网络将被push学习更好的表征,而并非仅仅依赖局部颜色信息。

loss

作为GT,raw frame使用RGB颜色空间。使用smooth l1 loss(huber loss)
在这里插入图片描述

 outputs = F.interpolate(outputs, (h, w), mode='bilinear')loss = F.smooth_l1_loss(outputs*20, tar_y*20, reduction='mean')

获取ROI区域

作者分析了STM的劣势,就是memory bank式的matching,需要的内存和计算量都很大,O(T*H*W*H*W)。
如果先获得了目标的大致位置,每一个pixel需要匹配的数目就会少很多(原始的是T*H*W)。
作者提出了一个两阶段的ROI localization。假设对query的第i个位置 q i q_i qi进行匹配。首先使用一个网格(应用空洞技巧),围绕在key的第i个位置上,得到网络上的特征,和 q i q_i qi做匹配(dot运算),得到的相似性系数直接加权相对坐标(和直推式vos的做法类似),这里是应用soft argmax,得到离第i个位置最相似的offset。
第二步就是围绕新的位置(i+offset),resample出一个小区域,作为需要匹配的对象。
在这里插入图片描述

其他细节

网络使用resnet18,修改stride,最低分辨率为1/4。训练也是先pretrain,在main train,接着dynamic train。

代码剖析

主要看看ROI那步。其他的步骤都很好读
作者先是在init里面设置了两种sampler。第一个是带dilate的,第二种是没有dilation的。前者用于long term的sampler,后者用于short term。

self.correlation_sampler_dilated = [SpatialCorrelationSampler(kernel_size=1,patch_size=self.memory_patch_P,stride=1,padding=0,dilation=1,dilation_patch=dirate) for dirate in range(2,6)]self.correlation_sampler = SpatialCorrelationSampler(kernel_size=1,patch_size=self.P,stride=1,padding=0,dilation=1)

在forward里面,大致有下面几个步骤:

  • 先对long term key进行第一步粗糙采样,得到ROI的位置,然后在截取主要特征作为matching对象得到系数。
  • 在对short term key同样操作
  • 用得到的offset,对raw frames,也截取对应的value。
  • 所有的attention map以及value都齐了,开始使用qkv公式得到输出。
 for searching_index in range(nsearch):  # long term: need dilation##### GET OFFSET HERE.  (b,h,w,2)samplerindex = dirates[searching_index]-2coarse_search_correlation = self.correlation_sampler_dilated[samplerindex](feats_t, feats_r[searching_index])  # b, p, p, h, wcoarse_search_correlation = coarse_search_correlation.reshape(b, self.memory_patch_N, h*w)coarse_search_correlation = F.softmax(coarse_search_correlation, dim=1)coarse_search_correlation = coarse_search_correlation.reshape(b,self.memory_patch_P,self.memory_patch_P,h,w,1)_y, _x = torch.meshgrid(torch.arange(-self.memory_patch_R,self.memory_patch_R+1),torch.arange(-self.memory_patch_R,self.memory_patch_R+1))grid = torch.stack([_x, _y], dim=-1).unsqueeze(-2).unsqueeze(-2)\.reshape(1,self.memory_patch_P,self.memory_patch_P,1,1,2).contiguous().float().to(coarse_search_correlation.device)# 每个query像素在mem bank中的一帧该以哪个位置为中心采样offset0 = (coarse_search_correlation * grid ).sum(1).sum(1) * dirates[searching_index]  # 1,h,w,2col_0 = deform_im2col(feats_r[searching_index], offset0, kernel_size=self.P)  # b,c*N,h*wcol_0 = col_0.reshape(b,c,N,h,w)##corr = (feats_t.unsqueeze(2) * col_0).sum(1)   # (b, N, h, w)corr = corr.reshape([b, self.P * self.P, h * w])corrs.append(corr)
 for ind in range(nsearch, nref):  # short termcorrs.append(self.correlation_sampler(feats_t, feats_r[ind]))_, _, _, h1, w1 = corrs[-1].size()corrs[ind] = corrs[ind].reshape([b, self.P*self.P, h1*w1])

得到T帧的匹配系数的softmax值

  corr = torch.cat(corrs, 1)  # b,nref*N,HWcorr = F.softmax(corr, dim=1)corr = corr.unsqueeze(1)

得到value

im_col0 = [deform_im2col(qr[i], offset0, kernel_size=self.P)  for i in range(nsearch)]# b, 3*N, h*w
im_col1 = [F.unfold(r, kernel_size=self.P, padding=self.R) for r in qr[nsearch:]]
image_uf = im_col0 + im_col1  # memory value list.

得到预测结果

  out = (corr * image_uf).sum(2).reshape([b,qr[0].size(1),h,w])

采用使用的是spatial correlation sapmle,是计算光流的cost valume的重要操作。不知道啥是cost valume可以去知乎搜索一下。作者这里用他是计算 q i q_i qi和在key上以i为中心的网格中被选取的特征的相似度。

所谓的截取,就是已知 q i q_i qi应该在哪个位置截取,就使用grid sample取出来。

这篇关于MAST: A Memory-Augmented Self-Supervised Tracker论文解读和代码剖析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/247398

相关文章

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

MCU7.keil中build产生的hex文件解读

1.hex文件大致解读 闲来无事,查看了MCU6.用keil新建项目的hex文件 用FlexHex打开 给我的第一印象是:经过软件的解释之后,发现这些数据排列地十分整齐 :02000F0080FE71:03000000020003F8:0C000300787FE4F6D8FD75810702000F3D:00000001FF 把解释后的数据当作十六进制来观察 1.每一行数据

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1