BatchNormalization和Layer Normalization解析

2024-06-17 20:12

本文主要是介绍BatchNormalization和Layer Normalization解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Batch Normalization

是google团队2015年提出的,能够加速网络的收敛并提升准确率

1.Batch Normalization原理

图像预处理过程中通常会对图像进行标准化处理,能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应的feature map的数据要满足分布规律)。而我们BN的目的就是使feature map满足均值为0,方差为1的分布规律。

对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,其中x^1就代表我们的R通道所对应的特征矩阵,依次类推。标准化处理也就是分别对R通道,G通道,B通道进行处理。

让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后再进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的BN,也就是计算一个Batch数据的feature map然后进行标准化(batch越大越接近整个数据集的分布,效果越好)。

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^1代表batch的所有的feature的channel1的数据。然后分别计算x^1和x^2的均值和方差。然后再根据标准差计算公式分别计算每个channel 的值(\varepsilon是很小的常量,放置分母为0的情况)。在训练过程中要去不断地计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后再我们的验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理。

\gamma是用来调整数值分布的方差大小,默认为1,\beta是用来调节数值均值的位置,默认值为0。这两个参数实在反向传播过程中学习到的。

2.使用Pytorch进行实验

在训练过程中,均值和方差是同通过计算当前批次数据得到的记录为\mu _{now},\delta_{now} ^{2},而我们的验证以及预测过程中使用的均值方差是一个统计量为\mu _{statistic},\delta _{statistic}^{2}。具体更新策略如下,其中momentum默认取0.1:

\mu _{statistic+1} = 0.9*\mu _{statistic}+0.1*\mu _{now}\\ \delta _{statistic+1}^{2} = 0.9*\delta _{statistic}^{2}+0.1*\delta _{now}^{2}

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\beta=0,\gamma=1。

import numpy as np
import torch.nn as nn
import torchdef bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):# [batch,channel, height, weight]feature_t = feature[:, i, :, :]mean_t = feature_t.mean()#总体标准差std_t1 = feature_t.std()#样本标准差std_t2 = feature_t.std(ddof = 1)#bn process#这里记得加上eps和pytorch保持一致feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2+ 1e-5)#更新计算均值mean[i]  = mean[i]*0.9 + mean_t * 0.1var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1print(feature)#随机生成一个batch为2,channel为2,height=width=2的特征向量
#[batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
#初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
#print(feature1.numpy())#注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)bn = nn.BatchNorm2d(2, eps =  1e-5)
output = bn(feature1)
print(output)

 

3.使用BN时需要注意的问题

(1)训练时要将training采纳数设置为True,在验证时将training参数设置为False。在Pytorch中了可以通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层和激活层之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,及时使用了偏置bias求出的结果也是一样的。

 


Layer Normalization

Layer Normalization针对自然语言处理提出的,为什么不用BN呢,因为在RNN这类时序网络中,时序的长度并不是一个定值(网络深度不一定相同),比如每句话的长短都不一定相同,所以很难去使用BN,所以作者提出了Layer Normalization(图像处理领域BN比LN更有效),但现在很多人将自然语言领域的模型用来处理图像,比如Vision Transformer,此时会涉及到LN。

直接看Pytorch 官方给出的关于LayerNorm 的介绍。不同的是,BN是对一个batch数据的每个channel进行Norm处理,一个for循环,但LN是对单个数据的制定维度进行Norm处理与batch无关而且BN中训练时是需要累计moving_mean和moving_var两个变量的(所以BN中有4个参数moving_mean,moving_var,\beta ,\gamma),但LN不需要累计只有\beta ,\gamma两个参数。

在Pytorch的LayerNorm类中有个normalized_shape参数,可以指定要Norm的维度(注意,函数说明中the last certain number of dimensions,指定的维度必须是从最后一维开始)。比如我们的数据shape是[4,2,3],那么normalized_shape可以是[3](最后一维进行Norm处理),也可以是[2,3](Norm最后两个维度),也可以是整个维度[4,2,3],但不能是[2]或者[4,2],否则会报错。

y = \frac{x-E[X]}{\sqrt{Var[x]+\varepsilon}}*\gamma +\beta

import torch
import torch.nn as nndef layer_norm_process(feature:torch.Tensor, beta=0.,gamma = 1.,eps=1e-5):var_mean = torch.var_mean(feature, dim = -1, unbiased = False)#均值mean = var_mean[1]#方差var = var_mean[0]#layer norm processfeature  = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)feature = feature*gamma+betareturn featuredef main():t = torch.randn(4, 2, 3)print(t)#仅在最后一个维度上做norm处理norm = nn.LayerNorm(normalized_shape= t.shape[-1], eps = 1e-5)#官方layer norm处理t1 = norm(t)#自己实现的layer norm处理t2 = layer_norm_process(t, eps = 1e-5)print("t1:\n",t1)print("t2:\n",t2)if __name__ == '__main__':main()
tensor([[[ 0.8512,  0.4201, -0.3457],[ 0.4701, -0.0647,  0.0733]],[[-0.9950, -0.4634,  0.0540],[ 0.4096,  0.4037, -0.0914]],[[-2.3165,  1.3059,  0.3183],[-0.9716,  0.4956,  0.4524]],[[-0.6209, -0.5958,  0.3212],[-0.8762,  0.3176, -0.5427]]])
t1:tensor([[[ 1.0963,  0.2254, -1.3218],[ 1.3697, -0.9893, -0.3804]],[[-1.2302,  0.0110,  1.2192],[ 0.7198,  0.6942, -1.4140]],[[-1.3642,  1.0050,  0.3591],[-1.4137,  0.7385,  0.6752]],[[-0.7355, -0.6783,  1.4138],[-1.0123,  1.3614, -0.3490]]], grad_fn=<NativeLayerNormBackward0>)
t2:tensor([[[ 1.0963,  0.2254, -1.3218],[ 1.3697, -0.9893, -0.3804]],[[-1.2302,  0.0110,  1.2192],[ 0.7198,  0.6942, -1.4140]],[[-1.3642,  1.0050,  0.3591],[-1.4137,  0.7385,  0.6752]],[[-0.7355, -0.6783,  1.4138],[-1.0123,  1.3614, -0.3490]]])

这篇关于BatchNormalization和Layer Normalization解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1070404

相关文章

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

Springboot @Autowired和@Resource的区别解析

《Springboot@Autowired和@Resource的区别解析》@Resource是JDK提供的注解,只是Spring在实现上提供了这个注解的功能支持,本文给大家介绍Springboot@... 目录【一】定义【1】@Autowired【2】@Resource【二】区别【1】包含的属性不同【2】@

SpringCloud动态配置注解@RefreshScope与@Component的深度解析

《SpringCloud动态配置注解@RefreshScope与@Component的深度解析》在现代微服务架构中,动态配置管理是一个关键需求,本文将为大家介绍SpringCloud中相关的注解@Re... 目录引言1. @RefreshScope 的作用与原理1.1 什么是 @RefreshScope1.

Java并发编程必备之Synchronized关键字深入解析

《Java并发编程必备之Synchronized关键字深入解析》本文我们深入探索了Java中的Synchronized关键字,包括其互斥性和可重入性的特性,文章详细介绍了Synchronized的三种... 目录一、前言二、Synchronized关键字2.1 Synchronized的特性1. 互斥2.

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、

Spring MVC使用视图解析的问题解读

《SpringMVC使用视图解析的问题解读》:本文主要介绍SpringMVC使用视图解析的问题解读,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring MVC使用视图解析1. 会使用视图解析的情况2. 不会使用视图解析的情况总结Spring MVC使用视图

利用Python和C++解析gltf文件的示例详解

《利用Python和C++解析gltf文件的示例详解》gltf,全称是GLTransmissionFormat,是一种开放的3D文件格式,Python和C++是两个非常强大的工具,下面我们就来看看如何... 目录什么是gltf文件选择语言的原因安装必要的库解析gltf文件的步骤1. 读取gltf文件2. 提

Java中的runnable 和 callable 区别解析

《Java中的runnable和callable区别解析》Runnable接口用于定义不需要返回结果的任务,而Callable接口可以返回结果并抛出异常,通常与Future结合使用,Runnab... 目录1. Runnable接口1.1 Runnable的定义1.2 Runnable的特点1.3 使用Ru