Understanding the difficulty of training deep feedforward neural networks (Xavier)

本文主要是介绍Understanding the difficulty of training deep feedforward neural networks (Xavier),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

转自:http://blog.csdn.net/shuzfan/article/details/51338178
“Xavier”初始化方法是一种很有效的神经网络初始化方法,方法来源于2010年的一篇论文

《Understanding the difficulty of training deep feedforward neural networks》

可惜直到近两年,这个方法才逐渐得到更多人的应用和认可。

为了使得网络中信息更好的流动,每一层输出的方差应该尽量相等。

基于这个目标,现在我们就去推导一下:每一层的权重应该满足哪种条件。

文章先假设的是线性激活函数,而且满足0点处导数为1,即 
这里写图片描述

现在我们先来分析一层卷积: 
这里写图片描述 
其中ni表示输入个数。

根据概率统计知识我们有下面的方差公式: 
这里写图片描述

特别的,当我们假设输入和权重都是0均值时(目前有了BN之后,这一点也较容易满足),上式可以简化为: 
这里写图片描述

进一步假设输入x和权重w独立同分布,则有: 
这里写图片描述

于是,为了保证输入与输出方差一致,则应该有: 
这里写图片描述

对于一个多层的网络,某一层的方差可以用累积的形式表达: 
这里写图片描述

特别的,反向传播计算梯度时同样具有类似的形式: 
这里写图片描述

综上,为了保证前向传播和反向传播时每一层的方差一致,应满足:

这里写图片描述

但是,实际当中输入与输出的个数往往不相等,于是为了均衡考量,最终我们的权重方差应满足

——————————————————————————————————————— 
这里写图片描述 
———————————————————————————————————————

学过概率统计的都知道 [a,b] 间的均匀分布的方差为: 
这里写图片描述

因此,Xavier初始化的实现就是下面的均匀分布:

—————————————————————————————————————————— 
这里写图片描述 
———————————————————————————————————————————

下面,我们来看一下caffe中具体是怎样实现的,代码位于include/caffe/filler.hpp文件中。

<code class="language-C++ hljs haskell has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-title" style="box-sizing: border-box;">template</span> <typename <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>>
<span class="hljs-class" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">XavierFiller</span> : public <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>> {public:explicit <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">XavierFiller</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">const</span> <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter</span>& <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">param</span>)</span>: <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">param</span>)</span> {}virtual void <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Fill</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Blob</span><<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>>* <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>)</span> {<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">CHECK</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">count</span>()</span>);int fan_in = blob->count<span class="hljs-container" style="box-sizing: border-box;">()</span> / blob->num<span class="hljs-container" style="box-sizing: border-box;">()</span>;int fan_out = blob->count<span class="hljs-container" style="box-sizing: border-box;">()</span> / blob->channels<span class="hljs-container" style="box-sizing: border-box;">()</span>;<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span> n = fan_in;  // default to fan_inif <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">variance_norm</span>()</span> ==<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter_VarianceNorm_AVERAGE</span>) {n = <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">fan_in</span> + <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">fan_out</span>)</span> / <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span><span class="hljs-container" style="box-sizing: border-box;">(2)</span>;} else if <span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">variance_norm</span>()</span> ==<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">FillerParameter_VarianceNorm_FAN_OUT</span>) {n = fan_out;}<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span> scale = sqrt<span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype(3)</span> / <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">n</span>)</span>;caffe_rng_uniform<<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Dtype</span>><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">blob</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">count</span>()</span>, -scale, scale,blob->mutable_cpu_data<span class="hljs-container" style="box-sizing: border-box;">()</span>);<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">CHECK_EQ</span><span class="hljs-container" style="box-sizing: border-box;">(<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">this</span>-><span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">filler_param_</span>.<span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">sparse</span>()</span>, -1)<< "<span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Sparsity</span> not supported by this <span class="hljs-type" style="box-sizing: border-box; color: rgb(102, 0, 102);">Filler</span>.";}
};</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li></ul>

由上面可以看出,caffe的Xavier实现有三种选择

(1) 默认情况,方差只考虑输入个数: 
这里写图片描述

(2) FillerParameter_VarianceNorm_FAN_OUT,方差只考虑输出个数: 
这里写图片描述

(3) FillerParameter_VarianceNorm_AVERAGE,方差同时考虑输入和输出个数: 
这里写图片描述

之所以默认只考虑输入,我个人觉得是因为前向信息的传播更重要一些

这篇关于Understanding the difficulty of training deep feedforward neural networks (Xavier)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1050220

相关文章

2014 Multi-University Training Contest 8小记

1002 计算几何 最大的速度才可能拥有无限的面积。 最大的速度的点 求凸包, 凸包上的点( 注意不是端点 ) 才拥有无限的面积 注意 :  凸包上如果有重点则不满足。 另外最大的速度为0也不行的。 int cmp(double x){if(fabs(x) < 1e-8) return 0 ;if(x > 0) return 1 ;return -1 ;}struct poin

2014 Multi-University Training Contest 7小记

1003   数学 , 先暴力再解方程。 在b进制下是个2 , 3 位数的 大概是10000进制以上 。这部分解方程 2-10000 直接暴力 typedef long long LL ;LL n ;int ok(int b){LL m = n ;int c ;while(m){c = m % b ;if(c == 3 || c == 4 || c == 5 ||

2014 Multi-University Training Contest 6小记

1003  贪心 对于111...10....000 这样的序列,  a 为1的个数,b为0的个数,易得当 x= a / (a + b) 时 f最小。 讲串分成若干段  1..10..0   ,  1..10..0 ,  要满足x非递减 。  对于 xi > xi+1  这样的合并 即可。 const int maxn = 100008 ;struct Node{int

MonoHuman: Animatable Human Neural Field from Monocular Video 翻译

MonoHuman:来自单目视频的可动画人类神经场 摘要。利用自由视图控制来动画化虚拟化身对于诸如虚拟现实和数字娱乐之类的各种应用来说是至关重要的。已有的研究试图利用神经辐射场(NeRF)的表征能力从单目视频中重建人体。最近的工作提出将变形网络移植到NeRF中,以进一步模拟人类神经场的动力学,从而动画化逼真的人类运动。然而,这种流水线要么依赖于姿态相关的表示,要么由于帧无关的优化而缺乏运动一致性

Post-Training有多重要?一文带你了解全部细节

1. 简介 随着LLM学界和工业界日新月异的发展,不仅预训练所用的算力和数据正在疯狂内卷,后训练(post-training)的对齐和微调方法也在不断更新。InstructGPT、WebGPT等较早发布的模型使用标准RLHF方法,其中的数据管理风格和规模似乎已经过时。近来,Meta、谷歌和英伟达等AI巨头纷纷发布开源模型,附带发布详尽的论文或报告,包括Llama 3.1、Nemotron 340

Understanding the GitHub Flow

这里看下Github的入门介绍    --链接 GitHub Flow is a lightweight, branch-based workflow that supports teams and projects where deployments are made regularly. This guide explains how and why GitHub Flow works

A Comprehensive Survey on Graph Neural Networks笔记

一、摘要-Abstract 1、传统的深度学习模型主要处理欧几里得数据(如图像、文本),而图神经网络的出现和发展是为了有效处理和学习非欧几里得域(即图结构数据)的信息。 2、将GNN划分为四类:recurrent GNNs(RecGNN), convolutional GNNs,(GCN), graph autoencoders(GAE), and spatial–temporal GNNs(S

Deep Ocr

1.圈出内容,文本那里要有内容.然后你保存,并'导出数据集'. 2.找出deep_ocr_recognition_training_workflow.hdev 文件.修改“DatasetFilename := 'Test.hdict'” 310行 write_deep_ocr (DeepOcrHandle, BestModelDeepOCRFilename) 3.推理test.hdev

OpenSNN推文:神经网络(Neural Network)相关论文最新推荐(九月份)(一)

基于卷积神经网络的活动识别分析系统及应用 论文链接:oalib简介:  活动识别技术在智能家居、运动评估和社交等领域得到广泛应用。本文设计了一种基于卷积神经网络的活动识别分析与应用系统,通过分析基于Android搭建的前端采所集的三向加速度传感器数据,对用户的当前活动进行识别。实验表明活动识别准确率满足了应用需求。本文基于识别的活动进行卡路里消耗计算,根据用户具体的活动、时间以及体重计算出相应活

Complex Networks Package for MatLab

http://www.levmuchnik.net/Content/Networks/ComplexNetworksPackage.html 翻译: 复杂网络的MATLAB工具包提供了一个高效、可扩展的框架,用于在MATLAB上的网络研究。 可以帮助描述经验网络的成千上万的节点,生成人工网络,运行鲁棒性实验,测试网络在不同的攻击下的可靠性,模拟任意复杂的传染病的传