如何Fine-Tune微调SAM

2023-12-27 12:36
文章标签 微调 sam fine tune

本文主要是介绍如何Fine-Tune微调SAM,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

转眼已经到了2023年的末尾,年初ChatGPT爆火,随后SAM横空出世,给今年的科技圈带来了众多看点,在SAM刚刚发布的时候我们也做过相关的实践,感兴趣的话可以自行移步阅读:

《Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了》

《Segment Anything Model (SAM)——分割一切,具有预测提示输入的图像分割实践》

《SAM-FAST:Accelerating Generative AI with PyTorch: Segment Anything, Fast基于官方PyTorch团队开发原生SAM提速8倍》

Segment Anything Model(SAM)的发布让计算机视觉迎来了ChatGPT时刻。SAM经过超过110亿个分割掩码的训练,是预测性人工智能用例而非生成性人工智能的基础模型。虽然它在广泛的图像模式和问题空间上表现出了令人难以置信的灵活性,但它的发布没有“微调”功能。

本教程将概述使用掩码解码器微调SAM的一些关键步骤,特别是描述SAM的哪些函数用于预/后处理数据,使其处于良好的微调状态。

什么是(SAM)?

分割一切模型(SAM)是Meta AI开发的一个分段模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万张图像和数十亿个mask的庞大数据库上进行训练的,这使得它非常强大。顾名思义,SAM能够为各种图像生成准确的分割掩模。SAM的设计使其能够将人工提示考虑在内,使其对“循环中的人工”注释特别强大。这些提示可以是多模式的:它们可以是要分割的区域上的点、要分割的对象周围的边界框或关于应该分割的内容的文本提示。

该模型分为三个部分:图像编码器、提示编码器和掩码解码器。

显示Segment Anything(SA)模型的基础模型体系结构的图像

官方论文在这里,如下所示:

图像编码器为被分割的图像生成嵌入,而提示编码器为提示生成嵌入。图像编码器是模型中一个特别大的组件。这与基于嵌入预测分割掩码的轻量级掩码解码器形成对比。Meta AI已经将在Segment Anything 10 Billion Mask(SA-1B)数据集上训练的模型的权重和偏差作为模型检查点。

什么是模型微调?

公开可用的现有技术模型具有自定义架构,并且通常提供有预先训练的模型权重。如果这些架构是在没有权重的情况下提供的,那么用户将需要从头开始训练模型,用户将需要使用大量数据集来获得最先进的性能。

模型微调是采用预先训练好的模型(体系结构+权重)并向其显示特定用例的数据的过程。这通常是模型以前从未见过的数据,或者在其原始训练数据集中代表性不足的数据。

微调模型和从头开始之间的区别在于权重和偏差的起始值。如果我们从头开始训练,这些将根据一些策略随机初始化。在这样的启动配置中,模型将对手头的任务“一无所知”,并表现不佳。通过使用预先存在的权重和偏差作为起点,我们可以“微调”权重和偏差,以便我们的模型在自定义数据集上更好地工作。例如:学会识别猫的信息(边缘检测、计数爪子)将有助于识别狗。

为什么要微调模型?

微调模型的目的是在预先训练的模型以前没有看到的数据上获得更高的性能。例如,在从手机摄像头收集的大量数据上训练的图像分割模型将主要从水平角度看到图像。

如果我们试图将这个模型用于从垂直角度拍摄的卫星图像,它可能不会表现得那么好。如果我们试图分割屋顶,该模型可能不会产生最佳结果。预训练是有用的,因为模型通常已经学会了如何分割对象,所以我们想利用这个起点来建立一个可以准确分割屋顶的模型。此外,我们的自定义数据集可能没有数百万个示例,因此我们希望进行微调,而不是从头开始训练模型。

微调是可取的,这样我们就可以在特定的用例中获得更好的性能,而不必承担从头开始训练模型的计算成本。

如何微调分段任意模型?

背景与架构

我们在介绍部分概述了SAM体系结构。图像编码器具有具有许多参数的复杂结构。为了微调模型,我们有必要关注掩码解码器,它重量轻,因此更容易、更快、更高效地进行微调。

为了微调SAM,我们需要提取其架构的底层部分(图像和提示编码器、掩码解码器)。由于两个原因,我们无法使用SamPredictor.predict:

1、我们只想微调掩码解码器

2、这个函数调用SamPredictor.predict_tarch,它有@torch.no_grad()装饰器,它阻止我们计算梯度

因此,我们需要检查SamPredictor.prpredict函数,并在我们想要微调的部分(掩码解码器)启用梯度计算的情况下调用适当的函数。这样做也是了解更多SAM如何工作的好方法。

创建自定义数据集

我们需要完成三件事来微调我们的模型:

1、要在其上绘制分割的图像

2、分割实况掩码

3、提示输入到模型中

我们选择了印章验证数据集因为它有SAM在训练中可能没有看到的数据(即,在文档上盖章)。我们可以通过使用预先训练的权重运行推理来验证它在该数据集上的表现良好,但并不完美。地面实况面具也非常精确,这将使我们能够计算出准确的损失。最后,这个数据集包含分割掩码周围的边界框,我们可以将其用作SAM的提示。下面显示了一个示例图像。这些边界框与人工注释器在生成分段时要经过的工作流程非常一致。

输入数据预处理

我们需要对从numpy数组到pytorch张量的扫描进行预处理。要做到这一点,我们可以遵循SamPredictor.set_image和预处理图像的SamPredictor.set_arch_image内部发生的情况。首先,我们可以使用utils.transform.ResizeLongestSide来调整图像的大小,因为这是预测器内部使用的转换器。然后,我们可以将图像转换为pytorch张量,并使用SAM预处理方法完成预处理。

训练设置

我们下载vit_b模型的模型检查点,并将其加载到:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

我们可以使用默认值设置Adam优化器,并指定要调整的参数是掩码解码器的参数:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters()) 

同时,我们可以设置我们的损失函数,例如均方误差

loss_fn = torch.nn.MSELoss()

训练循环

在主训练循环中,我们将迭代我们的数据项,生成掩码,并将其与我们的地面实况掩码进行比较,以便我们可以基于损失函数优化模型参数。

在这个例子中,我们使用GPU进行训练,因为它比使用CPU快得多。在适当的张量上使用.to(设备)是很重要的,以确保CPU上没有某些张量,GPU上没有其他张量。

我们希望通过将编码器封装在torch.no.grad()上下文管理器中来嵌入图像,因为否则我们将出现内存问题,同时我们不希望微调图像编码器。

with torch.no_grad():image_embedding = sam_model.image_encoder(input_image)

我们还可以在no.grad上下文管理器中生成提示嵌入。我们使用边界框坐标,转换为pytorch张量。

with torch.no_grad():sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None,boxes=box_torch,masks=None,)

最后,我们可以生成掩码。请注意,这里我们处于单掩码生成模式(与正常输出的3个掩码形成对比)。

low_res_masks, iou_predictions = sam_model.mask_decoder(image_embeddings=image_embedding,image_pe=sam_model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=False,
)

这里的最后一步是将掩码升级回原始图像大小,因为它们的分辨率较低。我们可以使用Sam.postprocess_masks来实现这一点。我们还希望从预测的掩码中生成二进制掩码,以便将其与我们的基本事实进行比较。为了不破坏反向传播,使用torch泛函是很重要的。

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)from torch.nn.functional import threshold, normalizebinary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

最后,我们可以计算损失并运行优化步骤:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

通过在多个epoch和批次上重复这一过程,我们可以微调SAM解码器。

保存检查点并从中启动模型

一旦我们完成了训练并对性能提升感到满意,我们就可以使用以下方法保存调整模型的状态dict:

torch.save(model.state_dict(), PATH)

然后,当我们想对与我们用来微调模型的数据相似的数据执行推理时,我们可以加载这个状态dict。

针对下游应用的微调

虽然SAM目前不提供开箱即用的微调,但我们正在构建一个与Encord平台集成的自定义微调调谐器。如本文所示,为了实现这一点,我们对解码器进行了微调。这在web应用程序中是一个开箱即用的一键过程,可以自动设置超参数。

微调前

微调后

我们可以看到,这个掩码比原来的掩码更紧。这是对印章验证数据集中的一小部分图像进行微调的结果,然后在一个以前看不见的例子上运行调整后的模型。通过进一步的训练和更多的例子,我们可以获得更好的结果。

结论

现在已经学会了如何微调分段任意模型(SAM)。如果您想开箱即用地微调SAM,您可能也有兴趣了解我们最近在Encord中发布了Segment Anything模型,允许您在不编写任何代码的情况下微调模型。

参考

How To Fine-Tune Segment Anything

这篇关于如何Fine-Tune微调SAM的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/543150

相关文章

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU

可选择的反思指令微调

论文:https://arxiv.org/pdf/2402.10110代码:GitHub - tianyi-lab/Reflection_Tuning: [ACL'24] Selective Reflection-Tuning: Student-Selected Data Recycling for LLM Instruction-Tuning机构:马里兰大学, Adobe Research领

【LVI-SAM】激光雷达点云处理特征提取LIO-SAM 之FeatureExtraction实现细节

激光雷达点云处理特征提取LIO-SAM 之FeatureExtraction实现细节 1. 特征提取实现过程总结1.0 特征提取过程小结1.1 类 `FeatureExtraction` 的整体结构与作用1.2 详细特征提取的过程1. 平滑度计算(`calculateSmoothness()`)2. 标记遮挡点(`markOccludedPoints()`)3. 特征提取(`extractF

文本分类场景下微调BERT

How to Fine-Tune BERT for Text Classification 论文《How to Fine-Tune BERT for Text Classification?》是2019年发表的一篇论文。这篇文章做了一些实验来分析了如何在文本分类场景下微调BERT,是后面网上讨论如何微调BERT时经常提到的论文。 结论与思路 先来看一下论文的实验结论: BERT模型上面的

从零开始构建大语言模型并进行微调:全面指南

要从0开始搭建并训练一个大语言模型(LLM),涉及到多个步骤和资源,包括理论理解、工具使用、数据准备、模型训练与微调。以下是一个从基础到应用的指南,帮助你理解并逐步实现这一目标。 1. 理解基础概念 在开始搭建大语言模型之前,了解以下基本概念至关重要: 生成式AI:通过大语言模型生成自然语言文本,例如GPT、BERT等。机器学习:通过数据训练模型,使其具备从数据中学习规律的能力。深度学习:机

什么是GPT-3的自回归架构?为什么GPT-3无需梯度更新和微调

文章目录 知识回顾GPT-3的自回归架构何为自回归架构为什么架构会影响任务表现自回归架构的局限性与双向模型的对比小结 为何无需梯度更新和微调为什么不需要怎么做到不需要 🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发,目前开始人工智能领域相关知识的学习 🦅个人主页:@逐梦苍穹 📕所属专栏:人工智能 🌻gitee地址:x

R-Adapter:零样本模型微调新突破,提升鲁棒性与泛化能力 | ECCV 2024

大规模图像-文本预训练模型实现了零样本分类,并在不同数据分布下提供了一致的准确性。然而,这些模型在下游任务中通常需要微调优化,这会降低对于超出分布范围的数据的泛化能力,并需要大量的计算资源。论文提出新颖的Robust Adapter(R-Adapter),可以在微调零样本模型用于下游任务的同时解决这两个问题。该方法将轻量级模块集成到预训练模型中,并采用新颖的自我集成技术以提高超出分布范围的鲁棒性

Segment Anything Model(SAM)中的Adapter是什么?

在META团队发布的Segment Anything Model (SAM) 中,Adapter 是一种用于提升模型在特定任务或领域上的性能的机制。具体来说,SAM 是一个通用的分割模型,能够处理多种不同类型的图像分割任务,而 Adapter 的引入是为了更好地让模型适应不同的任务需求。 Adapter 的主要功能是: 模块化设计:Adapter 是一种小规模的、可插拔的网络模块,可以在不改

欺诈文本分类检测(十一):LLamaFactory多卡微调

1. 引言 前文训练时都做了一定的编码工作,其实有一些框架可以支持我们零代码微调,LLama-Factory就是其中一个。这是一个专门针对大语言模型的微调和训练平台,有如下特性: 支持常见的模型种类:LLaMA、Mixtral-MoE、Qwen、Baichuan、ChatGLM等等。支持单GPU和多GPU训练。支持全参微调、Lora微调、QLora微调。 …… 还有很多优秀的特性,详细参考

大模型微调训练营毕业总结

我目前在一家零售公司从事大数据架构方面的工作。 之所以选择参加AI大模型微调训练营,主要是考虑到当前无论是大数据这条技术赛道,还是个人职业发展都处在平台期,短期内看不到突破点。所以想看看在大模型这个技术领域有没有可能有所突破。大数据经过多年的发展,在理论和技术层面都已经到达了一个比较成熟的高度,用户使用也属于普惠期。不仅仅有支持度良好的商业化的产品,而且开源系统也能满足用户的基本使用。另外经过这