DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)

本文主要是介绍DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PS:要转载请注明出处,本人版权所有。

PS: 这个只是基于《我自己》的理解,

如果和你的原则及想法相冲突,请谅解,勿喷。

环境说明
  • Windows 10
  • VSCode
  • Python 3.8.10
  • Pytorch 1.8.1
  • Cuda 10.2

前言


  如果有计算机背景的相关童鞋,都应该知道数值计算中的上溢和下溢的问题。关于计算机中的数值表示,在我的《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》 (https://blog.csdn.net/u011728480/article/details/100277582) 一文中有比较好的介绍。计算机中的数值表示,相对于实数数轴来说是离散且有限的,意思就是计算机中的能表示的数有最大值和最小值以及最小单位,特别是浮点数表示,有兴趣的可以看看上文。

  其实很好理解,深度学习里面具有大量的乘法加法,一不小心你就会遇见上溢和下溢的问题,因此我们一不小心就会遇见NAN和INF的问题(NAN和INF详见上文提到的文章)。此外,由于一些特殊的情况,可能会导致我们的参数的偏导数接近于0,让我们的模型收敛的非常的慢。因此我们可能需要从模型的初始化以及相关的模型构造方面来好好的讨论一下我们在训练过程中可能出现的问题。

  一般来说,我们训练的时候都非常的关注我们的损失函数,如果损失函数值异常,会导致相关的偏导数出现接近于0或者接近于无限大,那么就会直接导致模型训练及其困难。此外,我们的权重参数也会参与网络计算,按照上述的描述,权重参数的初始值也可能导致损失函数的值异常。因此大佬们也引入了另外一种常见的初始化方式Xavier,比较具有普适性。下面我们简单的验证一下我们训练过程中出现梯度接近于0和接近于无限大的情况,这里也就是说的梯度消失和梯度爆炸问题。同时也简单说明参数初始化相关的问题。





梯度消失(gradient vanishing)


  在深度学习中有一个激活层叫做Sigmoid层,其定义如下是: S i g m o i d ( x ) = 1 / ( 1 + exp ⁡ ( − x ) ) Sigmoid(x)=1/(1+\exp(-x)) Sigmoid(x)=1/(1+exp(x)),如果我们的模型里面接入了这种激活函数,很容易构造出梯度消失的情况,下面我们看一下其导数和函数值相对于X的相关关系。

  代码如下:

import torch
import numpy as np
import matplotlib.pyplot as pltfig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):ax.set_xlabel('x')ax.set_ylabel('sigmoid(x)')ax.set_title('sigmoid/gradient-sigmoid')ax.set_xlim(xlim_min, xlim_max)ax.set_ylim(ylim_min, ylim_max)ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))line0.set_data(xdata[0], ydata[0])line1.set_data(xdata[1], ydata[1])plt.show()def sigmoid_test():x = np.arange(-10.0, 10.0, 0.1)x = torch.tensor(x, dtype=torch.float, requires_grad=True)sig_fun = torch.nn.Sigmoid()y = sig_fun(x)y.backward(torch.ones_like(y))xdata[0] = x.detach().numpy()xdata[1] = x.detach().numpy()ydata[0] = y.detach().numpy()ydata[1] = x.grad.detach().numpy()init_and_show(-10.0, 10.0, 0, 1)def multi_mat_dot():M = np.random.normal(size=(4, 4))print('⼀个矩阵\n', M)for i in range(10000):M = np.dot(M, np.random.normal(size=(4, 4)))print('乘以100个矩阵后\n', M)if __name__ == '__main__':sigmoid_test()

  结果图如下

rep_img

  我们可以从图中看到,当x小于-5和大于+5的时候,其导数的值接近于0,导致bp的时候,参数更新小,模型收敛的特别的慢。





梯度爆炸(gradient exploding)


  现在我们假设我们有一个模型,其有N个线性层构成,定义输入为X,标签为Y,模型为 M ( X ) = X ∗ W 1 . . . . W n − 2 ∗ W n − 1 ∗ W n M(X) = X*W_1 .... W_{n-2}*W_{n-1}*W_n M(X)=XW1....Wn2Wn1Wn,损失函数为 L ( X ) = M ( X ) − Y = X ∗ W 1 . . . . W n − 2 ∗ W n − 1 ∗ W n − Y L(X) = M(X) - Y = X*W_1 .... W_{n-2}*W_{n-1}*W_n - Y L(X)=M(X)Y=XW1....Wn2Wn1WnY,求W1关于损失函数的偏导数 d L ( X ) d W 1 = X ∗ W 2 . . . . W n − 2 ∗ W n − 1 ∗ W n \frac{dL(X)}{dW_1} = X*W_2 .... W_{n-2}*W_{n-1}*W_n dW1dL(X)=XW2....Wn2Wn1Wn。从这里我们可以看到W2到Wn与输入的X的乘积构成了W1的偏导数。

  下面我们简单的构造一个矩阵,然后让他计算100次乘法。代码如下:

import torch
import numpy as np
import matplotlib.pyplot as pltfig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):ax.set_xlabel('x')ax.set_ylabel('sigmoid(x)')ax.set_title('sigmoid/gradient-sigmoid')ax.set_xlim(xlim_min, xlim_max)ax.set_ylim(ylim_min, ylim_max)ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))line0.set_data(xdata[0], ydata[0])line1.set_data(xdata[1], ydata[1])plt.show()def sigmoid_test():x = np.arange(-10.0, 10.0, 0.1)x = torch.tensor(x, dtype=torch.float, requires_grad=True)sig_fun = torch.nn.Sigmoid()y = sig_fun(x)y.backward(torch.ones_like(y))xdata[0] = x.detach().numpy()xdata[1] = x.detach().numpy()ydata[0] = y.detach().numpy()ydata[1] = x.grad.detach().numpy()init_and_show(-10.0, 10.0, 0, 1)def multi_mat_dot():M = np.random.normal(size=(4, 4))print('⼀个矩阵\n', M)for i in range(100):M = np.dot(M, np.random.normal(size=(4, 4)))print('乘以100个矩阵后\n', M)if __name__ == '__main__':multi_mat_dot()

  他计算100次乘法后结果如下:

rep_img

  我们可以看到,经过100次乘法后,其值已经非常大(小)了指数都是到了25了。这个时候算出来的损失非常大的,这个时候梯度也非常大,很容易导致训练异常。





参数初始化之Xavier


  文首我们提到,我们之前的参数初始化都是基于期望为0,方差为一个指定值初始化的,这里面的指定值是随个人定义的,这个可能会给我们的训练过程带来困扰。

  但是我们可以从以下的角度来看待这个事情,我们的权重参数W是一个期望为0,方差为 δ 2 \delta^2 δ2的特定分布。我们的输入特征X是一个期望为0,方差为 λ 2 \lambda^2 λ2的特定分布(注意这里不仅仅是正态分布)。我们假设我们的模型是线性模型,那么其输出为: O i = ∑ j = 1 n W i j X j O_i = \sum\limits_{j=1}^{n}W_{ij}X_{j} Oi=j=1nWijXj O i O_i Oi是代表第i层的输出。这个时候,我们求出 O i O_i Oi的期望是: E ( O i ) = ∑ j = 1 n E ( W i j X j ) = ∑ j = 1 n E ( W i j ) E ( X j ) = 0 E(O_i) = \sum\limits_{j=1}^{n}E(W_{ij}X_{j}) = \sum\limits_{j=1}^{n}E(W_{ij})E(X_{j}) = 0 E(Oi)=j=1nE(WijXj)=j=1nE(Wij)E(Xj)=0,其方差为: V a r i a n c e ( O i ) = E ( O i 2 ) − ( E ( O i ) ) 2 = ∑ j = 1 n E ( W i j 2 X j 2 ) − 0 = ∑ j = 1 n E ( W i j 2 ) E ( X j 2 ) = n ∗ δ 2 ∗ λ 2 Variance(O_i) = E(O_i^2) - (E(O_i))^2 = \sum\limits_{j=1}^{n}E(W_{ij}^2X_{j}^2) - 0 = \sum\limits_{j=1}^{n}E(W_{ij}^2)E(X_{j}^2) = n*\delta^2*\lambda^2 Variance(Oi)=E(Oi2)(E(Oi))2=j=1nE(Wij2Xj2)0=j=1nE(Wij2)E(Xj2)=nδ2λ2。我们现在假设如果要 O i O_i Oi的方差等于X的方差,那么 n ∗ δ 2 = 1 n*\delta^2 = 1 nδ2=1才能够满足要求。现在我们考虑BP的时候,也需要 n o u t ∗ δ 2 = 1 n_{out}*\delta^2 = 1 noutδ2=1才能够保证方差不会变,至少从数值稳定性来说,我们应该保证方差尽量稳定,不应该放大。我们同时考虑n和 n o u t n_{out} nout,那么我们可以认为当 1 / 2 ∗ ( n + n o u t ) ∗ δ 2 = 1 1/2*(n+n_{out})*\delta^2 = 1 1/2(n+nout)δ2=1时,我们保证了输出O的方差在约定范围内,尽量保证了其数值的稳定性,这就是Xavier方法的核心内容。

  初始化方法有很多,但是Xavier方法有较大的普适性。对于某些模型,特定的初始化方法有奇效。





后记


  到本文结束,其实我们可以训练一些简单的模型了,但是本文所介绍的3个概念会一直伴随着我们以后的学习过程,如果训练出现了INF,NAN这些特殊的值,基本我们就需要往这方面去想和解决问题。

参考文献

  • https://github.com/d2l-ai/d2l-zh/releases (V1.0.0)
  • https://github.com/d2l-ai/d2l-zh/releases (V2.0.0 alpha1)
  • https://blog.csdn.net/u011728480/article/details/100277582 《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》



打赏、订阅、收藏、丢香蕉、硬币,请关注公众号(攻城狮的搬砖之路)
qrc_img

PS: 请尊重原创,不喜勿喷。

PS: 要转载请注明出处,本人版权所有。

PS: 有问题请留言,看到后我会第一时间回复。

这篇关于DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1065761

相关文章

RedHat运维-Linux文本操作基础-AWK进阶

你不用整理,跟着敲一遍,有个印象,然后把它保存到本地,以后要用再去看,如果有了新东西,你自个再添加。这是我参考牛客上的shell编程专项题,只不过换成了问答的方式而已。不用背,就算是我自己亲自敲,我现在好多也记不住。 1. 输出nowcoder.txt文件第5行的内容 2. 输出nowcoder.txt文件第6行的内容 3. 输出nowcoder.txt文件第7行的内容 4. 输出nowcode

Vim使用基础篇

本文内容大部分来自 vimtutor,自带的教程的总结。在终端输入vimtutor 即可进入教程。 先总结一下,然后再分别介绍正常模式,插入模式,和可视模式三种模式下的命令。 目录 看完以后的汇总 1.正常模式(Normal模式) 1.移动光标 2.删除 3.【:】输入符 4.撤销 5.替换 6.重复命令【. ; ,】 7.复制粘贴 8.缩进 2.插入模式 INSERT

零基础STM32单片机编程入门(一)初识STM32单片机

文章目录 一.概要二.单片机型号命名规则三.STM32F103系统架构四.STM32F103C8T6单片机启动流程五.STM32F103C8T6单片机主要外设资源六.编程过程中芯片数据手册的作用1.单片机外设资源情况2.STM32单片机内部框图3.STM32单片机管脚图4.STM32单片机每个管脚可配功能5.单片机功耗数据6.FALSH编程时间,擦写次数7.I/O高低电平电压表格8.外设接口

ABAP怎么把传入的参数刷新到内表里面呢?

1.在执行相关的功能操作之前,优先执行这一段代码,把输入的数据更新入内表里面 DATA: lo_guid TYPE REF TO cl_gui_alv_grid.CALL FUNCTION 'GET_GLOBALS_FROM_SLVC_FULLSCR'IMPORTINGe_grid = lo_guid.CALL METHOD lo_guid->check_changed_data.CALL M

ps基础入门

1.基础      1.1新建文件      1.2创建指定形状      1.4移动工具          1.41移动画布中的任意元素          1.42移动画布          1.43修改画布大小          1.44修改图像大小      1.5框选工具      1.6矩形工具      1.7图层          1.71图层颜色修改          1

Linux系统稳定性的奥秘:探究其背后的机制与哲学

在计算机操作系统的世界里,Linux以其卓越的稳定性和可靠性著称,成为服务器、嵌入式系统乃至个人电脑用户的首选。那么,是什么造就了Linux如此之高的稳定性呢?本文将深入解析Linux系统稳定性的几个关键因素,揭示其背后的技术哲学与实践。 1. 开源协作的力量Linux是一个开源项目,意味着任何人都可以查看、修改和贡献其源代码。这种开放性吸引了全球成千上万的开发者参与到内核的维护与优化中,形成了

Java面试八股之JVM参数-XX:+UseCompressedOops的作用

JVM参数-XX:+UseCompressedOops的作用 JVM参数-XX:+UseCompressedOops的作用是启用对象指针压缩(Ordinary Object Pointers compression)。这一特性主要应用于64位的Java虚拟机中,目的是为了减少内存使用。在传统的64位系统中,对象引用(即指针)通常占用8字节(64位),而大部分应用程序实际上并不需要如此大的地址空间

[FPGA][基础模块]跨时钟域传播脉冲信号

clk_a 周期为10ns clk_b 周期为34ns 代码: module pulse(input clk_a,input clk_b,input signal_a,output reg signal_b);reg [4:0] signal_a_widen_maker = 0;reg signal_a_widen;always @(posedge clk_a)if(signal_a)

00 - React 基础

1. React 基础 安装react指令 可参考: 官网官网使用教程 如: npx create-react-app 项目名如:npx create-react-app react-redux-pro JSX JSX 是一种 JavaScript 的语法扩展,类似于 XML 或 HTML,允许我们在 JavaScript 代码中编写 HTML。 const element =

如何设置windows计划任务

如何设置windows计划任务 前言:在工作过程中写了一个python脚本,用于调用jira接口查询bug单数量,想要在本地定时任务执行,每天发送到钉钉群提醒,写下操作步骤用于记录。 1. 准备 Python 脚本 确保你的 Python 脚本已经保存到一个文件,比如 jira_reminder.py。 2. 创建批处理文件 为了方便任务计划程序运行 Python 脚本,创建一个批处理文