[work] 优化方法总结:SGD,Momentum,AdaGrad,RMSProp,Adam

2023-11-05 02:38

本文主要是介绍[work] 优化方法总结:SGD,Momentum,AdaGrad,RMSProp,Adam,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. SGD

Batch Gradient Descent

在每一轮的训练过程中,Batch Gradient Descent算法用整个训练集的数据计算cost fuction的梯度,并用该梯度对模型参数进行更新:

 

Θ=Θ−α⋅▿ΘJ(Θ)Θ=Θ−α⋅▽ΘJ(Θ)

 

优点:

  • cost fuction若为凸函数,能够保证收敛到全局最优值;若为非凸函数,能够收敛到局部最优值

缺点:

  • 由于每轮迭代都需要在整个数据集上计算一次,所以批量梯度下降可能非常慢
  • 训练数较多时,需要较大内存
  • 批量梯度下降不允许在线更新模型,例如新增实例。

Stochastic Gradient Descent

和批梯度下降算法相反,Stochastic gradient descent 算法每读入一个数据,便立刻计算cost fuction的梯度来更新参数: 

Θ=Θ−α⋅▿ΘJ(Θ;x(i),y(i))Θ=Θ−α⋅▽ΘJ(Θ;x(i),y(i))

 

优点:

  • 算法收敛速度快(在Batch Gradient Descent算法中, 每轮会计算很多相似样本的梯度, 这部分是冗余的)
  • 可以在线更新
  • 有几率跳出一个比较差的局部最优而收敛到一个更好的局部最优甚至是全局最优

缺点:

  • 容易收敛到局部最优,并且容易被困在鞍点

Mini-batch Gradient Descent

mini-batch Gradient Descent的方法是在上述两个方法中取折衷, 每次从所有训练数据中取一个子集(mini-batch) 用于计算梯度: 

Θ=Θ−α⋅▿ΘJ(Θ;x(i:i+n),y(i:i+n))Θ=Θ−α⋅▽ΘJ(Θ;x(i:i+n),y(i:i+n))

 

Mini-batch Gradient Descent在每轮迭代中仅仅计算一个mini-batch的梯度,不仅计算效率高,而且收敛较为稳定。该方法是目前深度学训练中的主流方法

上述三个方法面临的主要挑战如下:

  • 选择适当的学习率αα 较为困难。太小的学习率会导致收敛缓慢,而学习速度太块会造成较大波动,妨碍收敛。
  • 目前可采用的方法是在训练过程中调整学习率大小,例如模拟退火算法:预先定义一个迭代次数m,每执行完m次训练便减小学习率,或者当cost function的值低于一个阈值时减小学习率。然而迭代次数和阈值必须事先定义,因此无法适应数据集的特点。
  • 上述方法中, 每个参数的 learning rate 都是相同的,这种做法是不合理的:如果训练数据是稀疏的,并且不同特征的出现频率差异较大,那么比较合理的做法是对于出现频率低的特征设置较大的学习速率,对于出现频率较大的特征数据设置较小的学习速率。
  • 近期的的研究表明,深层神经网络之所以比较难训练,并不是因为容易进入local minimum。相反,由于网络结构非常复杂,在绝大多数情况下即使是 local minimum 也可以得到非常好的结果。而之所以难训练是因为学习过程容易陷入到马鞍面中,即在坡面上,一部分点是上升的,一部分点是下降的。而这种情况比较容易出现在平坦区域,在这种区域中,所有方向的梯度值都几乎是 0。

2. Momentum

SGD方法的一个缺点是其更新方向完全依赖于当前batch计算出的梯度,因而十分不稳定。Momentum算法借用了物理中的动量概念,它模拟的是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力:

 

vt=γ⋅vt−1+α⋅▿ΘJ(Θ)vt=γ⋅vt−1+α⋅▽ΘJ(Θ)

 

Θ=Θ−vtΘ=Θ−vt

 

Momentum算法会观察历史梯度vt−1vt−1,若当前梯度的方向与历史梯度一致(表明当前样本不太可能为异常点),则会增强这个方向的梯度,若当前梯度与历史梯方向不一致,则梯度会衰减。一种形象的解释是:我们把一个球推下山,球在下坡时积聚动量,在途中变得越来越快,γ可视为空气阻力,若球的方向发生变化,则动量会衰减。

3. Nesterov Momentum

在小球向下滚动的过程中,我们希望小球能够提前知道在哪些地方坡面会上升,这样在遇到上升坡面之前,小球就开始减速。这方法就是Nesterov Momentum,其在凸优化中有较强的理论保证收敛。并且,在实践中Nesterov Momentum也比单纯的 Momentum 的效果好:

 

vt=γ⋅vt−1+α⋅▿ΘJ(Θ−γvt−1)vt=γ⋅vt−1+α⋅▽ΘJ(Θ−γvt−1)

 

Θ=Θ−vtΘ=Θ−vt

 

其核心思想是:注意到 momentum 方法,如果只看 γ * v 项,那么当前的 θ经过 momentum 的作用会变成 θ+γ * v。因此可以把 θ+γ * v这个位置看做是当前优化的一个”展望”位置。所以,可以在 θ+γ * v求导, 而不是原始的θ。 

这里写图片描述

 

4. Adagrad

上述方法中,对于每一个参数θiθi 的训练都使用了相同的学习率α。Adagrad算法能够在训练中自动的对learning rate进行调整,对于出现频率较低参数采用较大的α更新;相反,对于出现频率较高的参数采用较小的α更新。因此,Adagrad非常适合处理稀疏数据。

我们设gt,igt,i为第t轮第i个参数的梯度,即gt,i=▿ΘJ(Θi)gt,i=▽ΘJ(Θi)。因此,SGD中参数更新的过程可写为:

 

Θt+1,i=Θt,i−α⋅gt,iΘt+1,i=Θt,i−α⋅gt,i

 

Adagrad在每轮训练中对每个参数θiθi的学习率进行更新,参数更新公式如下:

 

Θt+1,i=Θt,i−αGt,ii+ϵ√⋅gt,iΘt+1,i=Θt,i−αGt,ii+ϵ⋅gt,i

 

其中,Gt∈ℝd×dGt∈Rd×d为对角矩阵,每个对角线位置i,ii,i为对应参数θiθi从第1轮到第t轮梯度的平方和。ϵ是平滑项,用于避免分母为0,一般取值1e−8。Adagrad的缺点是在训练的中后期,分母上梯度平方的累加将会越来越大,从而梯度趋近于0,使得训练提前结束。

5. RMSprop

RMSprop是Geoff Hinton提出的一种自适应学习率方法。Adagrad会累加之前所有的梯度平方,而RMSprop仅仅是计算对应的平均值,因此可缓解Adagrad算法学习率下降较快的问题。 

E[g2]t=0.9E[g2]t−1+0.1g2tE[g2]t=0.9E[g2]t−1+0.1gt2

 

Θt+1=Θt−αE[g2]t+ϵ√⋅gtΘt+1=Θt−αE[g2]t+ϵ⋅gt

 

6. Adam

Adam(Adaptive Moment Estimation)是另一种自适应学习率的方法。它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。公式如下: 

mt=β1mt−1+(1−β1)gtmt=β1mt−1+(1−β1)gt 
vt=β1vt−1+(1−β1)g2tvt=β1vt−1+(1−β1)gt2 
m̂ t=mt1−βt1m^t=mt1−β1t 
v̂ t=vt1−βt2v^t=vt1−β2t 
Θt+1=Θt−αv̂ t√+ϵm̂ tΘt+1=Θt−αv^t+ϵm^t

 

其中,mtmt,vtvt分别是对梯度的一阶矩估计和二阶矩估计,可以看作对期望E[gt]E[gt],E[g2t]E[gt2]的近似;mt^mt^,vt^vt^是对mtmt,vtvt的校正,这样可以近似为对期望的无偏估计。 Adam算法的提出者建议β1β1 的默认值为0.9,β2β2的默认值为.999,ϵϵ默认为10−810−8。 另外,在数据比较稀疏的时候,adaptive的方法能得到更好的效果,例如Adagrad,RMSprop, Adam 等。Adam 方法也会比 RMSprop方法收敛的结果要好一些, 所以在实际应用中 ,Adam为最常用的方法,可以比较快地得到一个预估结果。

最后两张动图从直观上展现了算法的优化过程。第一张图为不同算法在损失平面等高线上随时间的变化情况,第二张图为不同算法在鞍点处的行为比较。 

这里写图片描述

 

 

这里写图片描述

 

7. 参考资料

  • An overview of gradient descent optimization algorithms

这篇关于[work] 优化方法总结:SGD,Momentum,AdaGrad,RMSProp,Adam的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/346844

相关文章

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

Java判断多个时间段是否重合的方法小结

《Java判断多个时间段是否重合的方法小结》这篇文章主要为大家详细介绍了Java中判断多个时间段是否重合的方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录判断多个时间段是否有间隔判断时间段集合是否与某时间段重合判断多个时间段是否有间隔实体类内容public class D

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

IDEA编译报错“java: 常量字符串过长”的原因及解决方法

《IDEA编译报错“java:常量字符串过长”的原因及解决方法》今天在开发过程中,由于尝试将一个文件的Base64字符串设置为常量,结果导致IDEA编译的时候出现了如下报错java:常量字符串过长,... 目录一、问题描述二、问题原因2.1 理论角度2.2 源码角度三、解决方案解决方案①:StringBui

Linux使用nload监控网络流量的方法

《Linux使用nload监控网络流量的方法》Linux中的nload命令是一个用于实时监控网络流量的工具,它提供了传入和传出流量的可视化表示,帮助用户一目了然地了解网络活动,本文给大家介绍了Linu... 目录简介安装示例用法基础用法指定网络接口限制显示特定流量类型指定刷新率设置流量速率的显示单位监控多个

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

C#中读取XML文件的四种常用方法

《C#中读取XML文件的四种常用方法》Xml是Internet环境中跨平台的,依赖于内容的技术,是当前处理结构化文档信息的有力工具,下面我们就来看看C#中读取XML文件的方法都有哪些吧... 目录XML简介格式C#读取XML文件方法使用XmlDocument使用XmlTextReader/XmlTextWr

C++初始化数组的几种常见方法(简单易懂)

《C++初始化数组的几种常见方法(简单易懂)》本文介绍了C++中数组的初始化方法,包括一维数组和二维数组的初始化,以及用new动态初始化数组,在C++11及以上版本中,还提供了使用std::array... 目录1、初始化一维数组1.1、使用列表初始化(推荐方式)1.2、初始化部分列表1.3、使用std::

oracle DBMS_SQL.PARSE的使用方法和示例

《oracleDBMS_SQL.PARSE的使用方法和示例》DBMS_SQL是Oracle数据库中的一个强大包,用于动态构建和执行SQL语句,DBMS_SQL.PARSE过程解析SQL语句或PL/S... 目录语法示例注意事项DBMS_SQL 是 oracle 数据库中的一个强大包,它允许动态地构建和执行