随机梯度下降SGD的理解和现象分析

2024-05-09 20:04

本文主要是介绍随机梯度下降SGD的理解和现象分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

提出问题:令人疑惑的损失值

在某次瞎炼丹的过程中,出现了如下令人疑惑的损失值变化图像:

嗯,看起来还挺工整,来看看前10轮打印的具体损失值变化:

| epoch 1 |  iter 5 / 10 | time 1[s] | loss 2.3137 | lr 0.0010
| epoch 1 |  iter 10 / 10 | time 1[s] | loss 2.2976 | lr 0.0010
| epoch 2 |  iter 5 / 10 | time 1[s] | loss 2.3135 | lr 0.0010
| epoch 2 |  iter 10 / 10 | time 1[s] | loss 2.2973 | lr 0.0010
| epoch 3 |  iter 5 / 10 | time 1[s] | loss 2.3132 | lr 0.0010
| epoch 3 |  iter 10 / 10 | time 1[s] | loss 2.2970 | lr 0.0010
| epoch 4 |  iter 5 / 10 | time 1[s] | loss 2.3129 | lr 0.0010
| epoch 4 |  iter 10 / 10 | time 1[s] | loss 2.2968 | lr 0.0010
| epoch 5 |  iter 5 / 10 | time 1[s] | loss 2.3127 | lr 0.0010
| epoch 5 |  iter 10 / 10 | time 1[s] | loss 2.2965 | lr 0.0010
| epoch 6 |  iter 5 / 10 | time 1[s] | loss 2.3124 | lr 0.0010
| epoch 6 |  iter 10 / 10 | time 1[s] | loss 2.2962 | lr 0.0010
| epoch 7 |  iter 5 / 10 | time 1[s] | loss 2.3122 | lr 0.0010
| epoch 7 |  iter 10 / 10 | time 1[s] | loss 2.2960 | lr 0.0010
| epoch 8 |  iter 5 / 10 | time 1[s] | loss 2.3119 | lr 0.0010
| epoch 8 |  iter 10 / 10 | time 1[s] | loss 2.2957 | lr 0.0010
| epoch 9 |  iter 5 / 10 | time 1[s] | loss 2.3116 | lr 0.0010
| epoch 9 |  iter 10 / 10 | time 1[s] | loss 2.2954 | lr 0.0010
| epoch 10 |  iter 5 / 10 | time 1[s] | loss 2.3114 | lr 0.0010
| epoch 10 |  iter 10 / 10 | time 1[s] | loss 2.2952 | lr 0.0010

可以明显看到两列出现递减的子序列:奇数序列和偶数序列。奇数序列的损失值为2.3137, 2.3135, 2.3132, 2.3129,...;奇数序列的损失值为2.2976, 2.2973, 2.2970, 2.2968,...。事出反常必有妖,那么究竟是什么样的东西导致如此的怪象?

在尝试找具体的原因之前,我们先把涉及的具体参数描述清楚。

模型就是一个很简单的序列模型,其网络结构如下:

layers = [MatMul(W1), Sigmoid(), MatMul(W2), Sigmoid(), MSE()]

网络结构就是两层重复结构,单层为一个矩阵乘法层MatMul加上一个激活函数Sigmoid,两层计算完后用均方误差MSE计算损失值,其中参数W1,W2的赋值如下:

rn = np.random.randn
W1 = (rn(10, 1000)).astype(np.float32)
W2 = (rn(1000, 10)).astype(np.float32)

数据和标签的赋值如下:

x = (rn(1000, 10)).astype(np.float32)
t = x**2

数据就是按照正态分布随机化初始1000个10维的向量,而标签就是原来的向量按元素乘方,而炼丹的目的就是观察模型如何学习二次函数的运算法则的。
相关训练的参数如下:

epochs = 100
batch_size = 100
eval_interval = 5
lr = 0.001

训练一共进行100轮,每一轮的每一批数据有100个,对于1000个数据,那么单个轮次可以分10个批次。每个批次都会计算当前批次100个数据的平均损失值,5个批次评估一次平均损失值,然后打印出来。也就是单个轮次可以看到2次打印出来的评估数据。

显然,第1次评估的平均损失值是用前一半的数据计算出来的,而第2次的则是后一半的数据进行运算。那么可以简单猜测:造成如此令人困惑的损失值变化图像,很可能原因就在数据分批上。

本质思考:推导数学公式解释

我们先把模型抽象为数学上的函数 F F F,其具体形式如下:
L o s s = F ( x , t , w ) Loss = F(x,t,w) Loss=F(x,t,w)
其中, x x x为数据, t t t为标签, w w w为权重, L o s s Loss Loss为损失值。
考虑到数据分批,对数据分成 m m m批的情况,实际上存在 m m m个子函数,如下:
L 1 = F 1 ( x 1 , t 1 , w ) L 2 = F 2 ( x 2 , t 2 , w ) L 3 = F 3 ( x 3 , t 3 , w ) . . . L m = F m ( x m , t m , w ) \begin{matrix} L_{1} = F_{1} (x_{1},t_{1},w)\\L_{2} = F_{2} (x_{2},t_{2},w) \\L_{3} = F_{3} (x_{3},t_{3},w) \\... \\L_{m} = F_{m} (x_{m},t_{m},w) \end{matrix} L1=F1(x1,t1,w)L2=F2(x2,t2,w)L3=F3(x3,t3,w)...Lm=Fm(xm,tm,w)
如果将 w ( i , j ) w_{(i,j)} w(i,j)表示为第 i 轮 i轮 i j j j批的权重值,那么很显然对第 i i i轮的训练批次来说,存在如下关系:
w i , 0 = w i − 1 , m w i , 1 = w i , 0 + k ∂ F 1 ∂ w ∣ w = w i , 0 w i , 2 = w i , 1 + k ∂ F 2 ∂ w ∣ w = w i , 1 w i , 3 = w i , 2 + k ∂ F 3 ∂ w ∣ w = w i , 2 . . . w i , m = w i , m − 1 + k ∂ F m ∂ w ∣ w = w i , m − 1 \begin{matrix} w_{i,0}=w_{i-1,m}\\w_{i,1} = w_{i,0}+k\frac{\partial F_{1}}{\partial w}|_{w=w_{i,0}} \\w_{i,2} = w_{i,1}+k\frac{\partial F_{2}}{\partial w}|_{w=w_{i,1}} \\w_{i,3} = w_{i,2}+k\frac{\partial F_{3}}{\partial w}|_{w=w_{i,2}} \\... \\w_{i,m} = w_{i,m-1}+k\frac{\partial F_{m}}{\partial w}|_{w=w_{i,m-1}} \end{matrix} wi,0=wi1,mwi,1=wi,0+kwF1w=wi,0wi,2=wi,1+kwF2w=wi,1wi,3=wi,2+kwF3w=wi,2...wi,m=wi,m1+kwFmw=wi,m1
其中 k k k为学习率的相反数,且一般情况下取值都较小(如取 k = − 0.001 k=-0.001 k=0.001)。考虑到 k k k取值较小,所以有如下近似公式:
w i , 0 = w i − 1 , m w i , 1 = w i , 0 + k ∂ F 1 ∂ w ∣ w = w i , 0 w i , 2 ≈ w i , 1 + k ∂ F 2 ∂ w ∣ w = w i , 0 w i , 3 ≈ w i , 2 + k ∂ F 3 ∂ w ∣ w = w i , 0 . . . w i , m ≈ w i , m − 1 + k ∂ F m ∂ w ∣ w = w i , 0 \begin{matrix} w_{i,0}=w_{i-1,m}\\w_{i,1} = w_{i,0}+k\frac{\partial F_{1}}{\partial w}|_{w=w_{i,0}} \\w_{i,2} \approx w_{i,1}+k\frac{\partial F_{2}}{\partial w}|_{w=w_{i,0}} \\w_{i,3} \approx w_{i,2}+k\frac{\partial F_{3}}{\partial w}|_{w=w_{i,0}} \\... \\w_{i,m} \approx w_{i,m-1}+k\frac{\partial F_{m}}{\partial w}|_{w=w_{i,0}} \end{matrix} wi,0=wi1,mwi,1=wi,0+kwF1w=wi,0wi,2wi,1+kwF2w=wi,0wi,3wi,2+kwF3w=wi,0...wi,mwi,m1+kwFmw=wi,0
从而进一步得到如下具体的近似公式:
w i , j ≈ w i − 1 , j + ∑ t = 1 m k ∂ F t ∂ w ∣ w = w i − 1 , j w_{i,j} \approx w_{i-1,j}+\sum_{t=1}^{m} k\frac{\partial F_{t}}{\partial w}|_{w=w_{i-1,j}} wi,jwi1,j+t=1mkwFtw=wi1,j
为了直观得到结论,采用如下表示:
v t = k ∂ F t ∂ w ∣ w = w i − 1 , j v_{t} = k\frac{\partial F_{t}}{\partial w}|_{w=w_{i-1,j}} vt=kwFtw=wi1,j
那么之前的表达式就可以简写为:
w i , j ≈ w i − 1 , j + ∑ t = 1 m v t w_{i,j} \approx w_{i-1,j}+\sum_{t=1}^{m} v_{t} wi,jwi1,j+t=1mvt
对于 w i , j w_{i,j} wi,j来说, v j v_{j} vj才是其让损失值下降最快的方向,其他的向量代表其他批的数据,往往得到的方向与该方向比较随机,最后得到的和可能趋于0或者其他损失值下降不太快的方向。

因此,要想让第 j j j批的数据对应的损失值稳定下降,还得靠一轮一轮的循环才行,靠同一轮的其他批次是不太合理的(只有一部分情况才能如此)

合理外推:实验数据验证想法

如果看懂了前面的数学推导,那么很自然就能想到:对于批次 m m m较大的情况下,损失函数图像会呈现整体趋势下降的条带,如下图:

其中训练参数改动如下:

x = (rn(2000, 10)).astype(np.float32)
t = x**2
epochs = 200

你说啥?数学推导没看懂?那也没关系,其实到最后只是为了说明一个事情:你把训练数据分成很多个批次去炼丹,对于具体的某个批次的损失值下降,主要是依赖该批次的下一轮迭代,而不是同一轮的其他批次。

如果你感觉条带形状的损失值碍眼,感觉损失值起起伏伏的,很多计算资源都浪费了,那么用一招就能“瞒天过海”:把损失值的评估计算改为一整轮的平均损失,比如有 m m m批数据,那么统计损失值时使用这 m m m个批次的损失值总平均值即可,效果绝对立竿见影:

其中训练参数改动如下:

x = (rn(2000, 10)).astype(np.float32)
t = x**2
epochs = 200
batch_size = 100
eval_interval = 20

这参数里面,一共有2000个数据,100个数据为1批,共20批数据,然后20批数据评估一次整体平均损失值,训练200轮。

这篇关于随机梯度下降SGD的理解和现象分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/974360

相关文章

关于 ZigBee WiFi Bluetooth 的分析

ZigBee: Zigbee是IEEE 802.15.4协议的代名词。根据这个协议规定的技术是一种短距离、低功耗的无线通信技术。其特点是近距离、低复杂度、自组织、低功耗、低数据速率、低成本。主要适合用于自动控制和远程控制领域,可以嵌入各种设备。简而言之,ZigBee就是一种便宜的,低功耗的近距离无线组网通讯技术。             ZigBee的主要优势:

【Python3-API】情感倾向分析示例代码

Python3-urllib3-API情感倾向分析示例代码 AccessToken获取可以参考:http://ai.baidu.com/forum/topic/show/497663(Python3-urllib3示例)Python安装什么的。大家百度经验即可 -----------------------------------------------------下面开始代码-----

JAVA中死锁例子分析和上下文切换的实战

本篇是整理《java多线程编程核心技术》和《java并发编程的艺术》中对于死锁和上下文切换的查看命令的总结。 一、死锁 1.死锁的图解 死锁 死锁是两个甚至多个线程被永久阻塞时的一种运行局面。死锁的原因:由于两个甚至多个线程相互等待对方已被锁定的资源。 2.死锁的例子 public class DeadThreadLockTest implements Runnable{priv

DoNet:浅淡对delegate的理解

1 前言 C#的相关文档,MSDN上其实已经很详细了,关于delegate的使用可以参 考MSDN上的文档https://msdn.microsoft.com/zh-cn/library/900fyy8e.aspx 2 官方示例 委托类型的声明与方法签名相似, 有一个返回值和任意数目任意类型的参数: public delegate void TestDelegate(string mes

u-boot 源码分析(1) 启动过程分析

u-boot 源码分析(1) 启动过程分析 文章目录 u-boot 源码分析(1) 启动过程分析前言配置源码结构apiarchboardcommoncmddriversfs Kbuild启动过程第一阶段start.Scrt0_64.S 第二阶段board_f.cboard_r.c 总结参考 前言 对于uboot,我一直是云里雾里的一个状态,这部分让我感到自己很菜,不用纵向深

STM32 标准库V3.5启动文件startup_stm32f10xxx.s分析

layout: post tags: [STM32] comments: true 文章目录 layout: post tags: [STM32] comments: true前言分析startup_stm32f10x_md.s汇编指令EQUAREASPACEDCDPROCLDR常见的转移指令 源码分析 前言 这里以stm32f103为硬件平台,搭建了stm32f1系列

Pytorch中的梯度下降及优化

在PyTorch中使用Mini-batch这种方法进行训练 Mini-batch的梯度下降法 对整个训练集进行梯度下降法的时候,我们必须处理整个训练数据集,然后才能进行一步梯度下降,即每一步梯度下降法需要对整个训练集进行一次处理,如果训练数据集很大的时候处理速度会很慢,而且也不可能一次的载入到内存或者显存中 所以我们会把大数据集分成小数据集,一部分一部分的训练,这个训练子集即称为Mini-b

【Ruby】Ruby框架的使用(Ruby on Rails项目结构分析)

在阅读的过程中有任何问题,欢迎一起交流 邮箱:1494713801@qq.com    QQ:1494713801   目录结构 Ruby on Rails框架项目目录结构如下图所示: app:是web应用的主程序目录,包含了mvc组件的代码(controllers、models、views)config:runtime rules,routes,database等的配置db

理解 Python 中的 `super()` 与 `__init__()` 方法

在 Python 的面向对象编程中,super() 函数和 __init__() 方法是两个非常重要的概念。它们在类的继承和初始化过程中扮演着关键的角色。本文将深入探讨这两个概念的工作原理,并通过示例代码来展示它们的使用。 基本原理 __init__() 方法 __init__() 是一个特殊的方法,也称为类的构造器。当你创建一个类的新实例时,Python 会自动调用这个方法。它通常用于初始

24长三角A题思路+分析选题

需要资料的宝子们可以进企鹅获取 A题 问题1:西湖游船上掉落华为 mate 60 pro 手机 1. 手机掉落范围分析 物品特征:华为 mate 60 pro 手机的尺寸、重量、形状等特性。静水假设:西湖水面平静,不考虑水流影响。掉落位置:游客在游船上的具体位置,如船头、船尾、船边等。手机落水后行为:根据手机的重量和形状,它可能会沉入水底或漂浮在水面。 基于上述信息,可以初步判断