三种梯度下降算法的区别(BGD, SGD, MBGD)

2023-10-21 06:30

本文主要是介绍三种梯度下降算法的区别(BGD, SGD, MBGD),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

我们在训练网络的时候经常会设置 batch_size,这个 batch_size 究竟是做什么用的,一万张图的数据集,应该设置为多大呢,设置为 1、10、100 或者是 10000 究竟有什么区别呢?

# 手写数字识别网络训练方法
network.fit(train_images,train_labels,epochs=5,batch_size=128)

批量梯度下降(Batch Gradient Descent,BGD)

梯度下降算法一般用来最小化损失函数:把原始的数据网络喂给网络,网络会进行一定的计算,会求得一个损失函数,代表着网络的计算结果与实际的差距,梯度下降算法用来调整参数,使得训练出的结果与实际更好的拟合,这是梯度下降的含义。

批量梯度下降是梯度下降最原始的形式,它的思想是使用所有的训练数据一起进行梯度的更新,梯度下降算法需要对损失函数求导数,可以想象,如果训练数据集比较大,所有的数据需要一起读入进来,一起在网络中去训练,一起求和,会是一个庞大的矩阵,这个计算量将非常巨大。当然,这也是有优点的,那就是因为考虑到所有训练集的情况,因此网络一定在向最优(极值)的方向在优化。

随机梯度下降(Stochastic Gradient Descent,SGD)

与批量梯度下降不同,随机梯度下降的思想是每次拿出训练集中的一个,进行拟合训练,进行迭代去训练。训练的过程就是先拿出一个训练数据,网络修改参数去拟合它并修改参数,然后拿出下一个训练数据,用刚刚修改好的网络再去拟合和修改参数,如此迭代,直到每个数据都输入过网络,再从头再来一遍,直到参数比较稳定,优点就是每次拟合都只用了一个训练数据,每一轮更新迭代速度特别快,缺点是每次进行拟合的时候,只考虑了一个训练数据,优化的方向不一定是网络在训练集整体最优的方向,经常会抖动或收敛到局部最优。

小批量梯度下降(Mini-Batch Gradient Descent,MBGD)

小批量梯度下降采用的还是计算机中最常用的折中的解决办法,每次输入网络进行训练的既不是训练数据集全体,也不是训练数据集中的某一个,而是其中的一部分,比如每次输入 20 个。可以想象,这既不会造成数据量过大计算缓慢,也不会因为某一个训练样本的某些噪声特点引起网络的剧烈抖动或向非最优的方向优化。

对比一下这三种梯度下降算法的计算方式:批量梯度下降是大矩阵的运算,可以考虑采用矩阵计算优化的方式进行并行计算,对内存等硬件性能要求较高;随机梯度下降每次迭代都依赖于前一次的计算结果,因此无法并行计算,对硬件要求较低;而小批量梯度下降,每一个次迭代中,都是一个较小的矩阵,对硬件的要求也不高,同时矩阵运算可以采用并行计算,多次迭代之间采用串行计算,整体来说会节省时间。

看下面一张图,可以较好的体现出三种剃度下降算法优化网络的迭代过程,会有一个更加直观的印象。

对比图

总结

梯度下降算法的调优,训练数据集很小,直接采用批量梯度下降;每次只能拿到一个训练数据,或者是在线实时传输过来的训练数据,采用随机梯度下降;其他情况或一般情况采用批量梯度下降算法更好。

  • 本文首发自: RAIS

这篇关于三种梯度下降算法的区别(BGD, SGD, MBGD)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/252560

相关文章

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

go 指针接收者和值接收者的区别小结

《go指针接收者和值接收者的区别小结》在Go语言中,值接收者和指针接收者是方法定义中的两种接收者类型,本文主要介绍了go指针接收者和值接收者的区别小结,文中通过示例代码介绍的非常详细,需要的朋友们下... 目录go 指针接收者和值接收者的区别易错点辨析go 指针接收者和值接收者的区别指针接收者和值接收者的

售价599元起! 华为路由器X1/Pro发布 配置与区别一览

《售价599元起!华为路由器X1/Pro发布配置与区别一览》华为路由器X1/Pro发布,有朋友留言问华为路由X1和X1Pro怎么选择,关于这个问题,本期图文将对这二款路由器做了期参数对比,大家看... 华为路由 X1 系列已经正式发布并开启预售,将在 4 月 25 日 10:08 正式开售,两款产品分别为华

如何将Python彻底卸载的三种方法

《如何将Python彻底卸载的三种方法》通常我们在一些软件的使用上有碰壁,第一反应就是卸载重装,所以有小伙伴就问我Python怎么卸载才能彻底卸载干净,今天这篇文章,小编就来教大家如何彻底卸载Pyth... 目录软件卸载①方法:②方法:③方法:清理相关文件夹软件卸载①方法:首先,在安装python时,下

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

Redis实现延迟任务的三种方法详解

《Redis实现延迟任务的三种方法详解》延迟任务(DelayedTask)是指在未来的某个时间点,执行相应的任务,本文为大家整理了三种常见的实现方法,感兴趣的小伙伴可以参考一下... 目录1.前言2.Redis如何实现延迟任务3.代码实现3.1. 过期键通知事件实现3.2. 使用ZSet实现延迟任务3.3

Java图片压缩三种高效压缩方案详细解析

《Java图片压缩三种高效压缩方案详细解析》图片压缩通常涉及减少图片的尺寸缩放、调整图片的质量(针对JPEG、PNG等)、使用特定的算法来减少图片的数据量等,:本文主要介绍Java图片压缩三种高效... 目录一、基于OpenCV的智能尺寸压缩技术亮点:适用场景:二、JPEG质量参数压缩关键技术:压缩效果对比

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

kotlin中const 和val的区别及使用场景分析

《kotlin中const和val的区别及使用场景分析》在Kotlin中,const和val都是用来声明常量的,但它们的使用场景和功能有所不同,下面给大家介绍kotlin中const和val的区别,... 目录kotlin中const 和val的区别1. val:2. const:二 代码示例1 Java

CSS Padding 和 Margin 区别全解析

《CSSPadding和Margin区别全解析》CSS中的padding和margin是两个非常基础且重要的属性,它们用于控制元素周围的空白区域,本文将详细介绍padding和... 目录css Padding 和 Margin 全解析1. Padding: 内边距2. Margin: 外边距3. Padd