GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS

2024-06-08 22:12

本文主要是介绍GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:提升分类模型acc(一):BatchSize&LARS

在使用大的bs训练情况下,会对精度有一定程度的损失,本文探讨了训练的bs大小对精度的影响,同时探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。

论文链接:https://arxiv.org/abs/1708.03888

论文代码: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
知乎专栏: https://zhuanlan.zhihu.com/p/406882110

1 引言

如何提升业务分类模型的性能,一直是个难题,毕竟没有99.999%的性能都会带来一定程度的风险,所以很多时候我们只能通过控制阈值来调整准召以达到想要的效果。本系列主要探究哪些模型trick和数据的方法可以大幅度让你的分类性能更上一层楼,不过要注意一点的是,tirck不一定是适用于不同的数据场景的,但是数据处理方法是普适的。本篇文章主要是对于大的bs下训练分类模型的情况,如果bs比较小的可以忽略,直接看最后的结论就好了(这个系列以后的文章讲述的方法是通用的,无论bs大小都可以用)。

2 实验配置

  • 模型:ResNet50

  • 数据:ImageNet1k

  • 环境:8xV100

3 BatchSize对精度的影响

我这里设计了4组对照实验,256, 1024, 2048和4096的batchsize,开了FP16也只能跑到了4096了。采用的是分布式训练,所以单张卡的bs就是bs = total_bs / ngpus_per_node。这里我没有使用跨卡bn,对于bs 64单卡来说理论上已经很大了,bn的作用是约束数据分布,64的bs已经可以表达一个分布的subset了,再大的bs还是同分布的,意义不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略这个问题。但是对于检测的任务,跨卡bn还是有价值的,毕竟输入的分辨率大,单卡的bs比较小,一般4,8,16,这时候统计更大的bn会对模型收敛更好。

很明显可以看出来,当bs增加到4k的时候,acc下降了将近0.8%个点,1k的时候,下降了0.2%个点,所以,通常我们用大的bs训练的时候,是没办法达到最优的精度的。个人建议,使用1k的bs和0.4的学习率最优。

4 LARS(Layer-wise Adaptive Rate Scaling)

4.1. 理论分析

由于bs的增加,在同样的epoch的情况下,会使网络的weights更新迭代的次数变少,所以需要对LR随着bs的增加而线性增加,但是这样会导致上面我们看到的问题,过大的lr会导致最终的收敛不稳定,精度有所下降。

LARS代码如下:

class LARC(object):def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):self.optim = optimizerself.trust_coefficient = trust_coefficientself.eps = epsself.clip = clipdef step(self):with torch.no_grad():weight_decays = []for group in self.optim.param_groups:# absorb weight decay control from optimizerweight_decay = group['weight_decay'] if 'weight_decay' in group else 0weight_decays.append(weight_decay)group['weight_decay'] = 0for p in group['params']:if p.grad is None:continueparam_norm = torch.norm(p.data)grad_norm = torch.norm(p.grad.data)if param_norm != 0 and grad_norm != 0:# calculate adaptive lr + weight decayadaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)# clip learning rate for LARCif self.clip:# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`adaptive_lr = min(adaptive_lr / group['lr'], 1)p.grad.data += weight_decay * p.datap.grad.data *= adaptive_lrself.optim.step()# return weight decay control to optimizerfor i, group in enumerate(self.optim.param_groups):group['weight_decay'] = weight_decays[i]

这里有一个超参数,trust_coefficient,也就是公式里面所提到的, 这个参数对精度的影响比较大,实验部分我们会给出结论。

4.2. 实验结论

可以很明显发现,使用了LARS,设置turst_confidence为1e-3的情况下,有着明显的掉点,设置为2e-2的时候,在1k和4k的情况下,有着明显的提升,但是2k的情况下有所下降。

LARS一定程度上可以提升精度,但是强依赖超参,还是需要细致的调参训练。

5 结论

  • 8卡进行分布式训练,使用1k的bs可以很好的平衡acc&speed。

  • LARS一定程度上可以提升精度,但是需要调参,做业务可以不用考虑,刷点的话要好好训练。

6 结束语

本文是提升分类模型acc系列的第一篇,后续会讲解一些通用的trick和数据处理的方法,敬请关注。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

这篇关于GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1043448

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep