BatchNorm介绍:卷积神经网络中的BN

2024-02-11 20:36

本文主要是介绍BatchNorm介绍:卷积神经网络中的BN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、BN介绍

1.原理

在机器学习中让输入的数据之间相关性越少越好,最好输入的每个样本都是均值为0方差为1。在输入神经网络之前可以对数据进行处理让数据消除共线性,但是这样的话输入层的激活层看到的是一个分布良好的数据,但是较深的激活层看到的的分布就没那么完美了,分布将变化的很严重。这样会使得训练神经网络变得更加困难。所以添加BatchNorm层,在训练的时候BN层使用batch来估计数据的均值和方差,然后用均值和方差来标准化这个batch的数据,并且随着不同的batch经过网络,均值和方差都在做累计平均。在测试的时候就直接作为标准化的依据。

这样的方法也有可能导致降低神经网络的表示能力,因为某些层的全局最优的特征可能不是均值为0或者方差为1的。所以BN层也是能够进行学习每个特征维度的缩放gamma和平移beta的来避免这样的情况。

2.BN层前向传播

def batchnorm_forward(x, gamma, beta, bn_param):"""先进行标准化再进行平移缩放running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: (N, D) 输入的数据- gamma: (D,) 每个特征维度数据的缩放- beta: (D,) 每个特征维度数据的偏移- bn_param: 字典,有如下键值:- mode: 'train'/'test' 必须指定- eps: 一个常量为了维持数值稳定,保证不会除0- momentum: 动量- running_mean: (D,) 积累的均值- running_var: (D,) 积累的方差Returns:- out: (N,D)- cache: 反向传播时需要的数据"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))out, cache = None, Noneif mode == 'train':sample_mean = np.mean(x, axis=0)sample_var = np.var(x, axis=0)# 先标准化x_hat = (x - sample_mean)/(np.sqrt(sample_var + eps))# 再做缩放偏移out = gamma * x_hat + betacache = (gamma, x, sample_mean, sample_var, eps, x_hat)running_mean = momentum * running_mean + (1-momuntum)*sample_meanrunning_var = momentum * running_var + (1-momentum)*sample_varelif mode == 'test':# 先标准化#x_hat = (x - running_mean)/(np.sqrt(running_var+eps))# 再做缩放偏移#out = gamma * x_hat + beta# 或者是下面的骚写法scale = gamma/(np.sqrt(running_var + eps))out = x*scale + (beta - running_mean*scale)else:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)bn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

3.BN层反向传播

def batchnorm_barckward(out, cache):"""反向传播的简单写法,易于理解Inputs:- dout: (N,D) dloss/dout- cache: (gamma, x, sample_mean, sample_var, eps, x_hat)Returns:- dx: (N,D)- dgamma: (D,) 每个维度的缩放和平移参数不同- dbeta: (D,)"""dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_1 = gamma * dout # dloss/dx_hat = dloss/dout * gamma (N, D)dx_2_b = np.sum((x - u_b) * dx_1, axis=0)dx_2_a = ((sigma_squared_b + eps)**-0.5)*dx_1dx_3_b = (-0.5) * ((sigma_squared_b + eps)**-1.5)*dx_2_bdx_4_b = dx_3_b * 1dx_5_b = np.ones_like(x)/N * dx_4_bdx_6_b = 2*(x-u_b)*dx_5_bdx_7_a = dx_6_b*1 + dx_2_a*1dx_7_b = dx_6_b*1 * dx_2_a*1dx_8_b = -1*np.sum(dx_7_b, axis=0)dx_9_b = np.ones_like(x)/N * dx_8_bdx_10 = dx_9_b + dx_7_adgamma = np.sum(x_hat * dout, axis=0)dbeta = np.sum(dout, axis=0)dx = dx_10return dx, dgamma, dbeta

下面是直接使用公式来计算:

def batchnorm_backward_alt(dout, cache):dx, dgamma, dbeta = None, None, None# unpack cachegamma, x, u_b, sigma_squared_b, eps, x_hat = cacheN = x.shape[0]dx_hat = dout * gammadvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmeandgamma = np.sum(x_hat * dout, axis = 0)dbeta = np.sum(dout , axis = 0)return dx, dgamma, dbeta

4.BN有什么作用

  1. 对于不好的权重初始化有更高的鲁棒性,仍然能得到较好的效果。
  2. 能更好的避免过拟合。
  3. 解决梯度消失/爆炸问题,BN防止了前向传播的时候数值过大或者过小,这样就能让反向传播时梯度处于一个较好的区间内。

二、卷积神经网络中的BN

1.前向传播

def spatial_batchnorm_forward(x, gamma, beta, bn_param):"""利用普通神经网络的BN来实现卷积神经网络的BNInputs:- x: (N, C, H, W)- gamma: (C,)缩放系数- beta: (C,)平移系数- bn_param: 包含如下键的字典- mode: 'train'/'test'必须的键- eps: 数值稳定需要的一个较小的值- momentum: 一个常量,用来处理running mean和var的。如果momentum=0 那么之前不利用之前的均值和方差。momentum=1表示不利用现在的均值和方差,一般设置momentum=0.9- running_mean: (C,)- running_var: (C,)Returns:- out: (N, C, H, W)- cache: 反向传播需要的数据,这里直接使用了普通神经网络的cache"""N, C, H, W = x.shape# transpose之后(N, W, H, C) channel在这里就可以看成是特征temp_out, cache = batchnorm_forward(x.transpose(0, 3, 2, 1).reshape((N*H*W, C)), gamma, beta, bn_param)# 再恢复shapeout = temp_output.reshape(N, W, H, C).transpose(0, 3, 2, 1)return out, cache

2.反向传播

def spatial_batchnorm_backward(dout, cache):"""利用普通神经网络的BN反向传播实现卷积神经网络中的BN反向传播Inputs:- dout: (N, C, H, W) 反向传播回来的导数- cache: 前向传播时的中间数据Returns:- dx: (N, C, H, W)- dgamma: (C,) 缩放系数的导数- dbeta: (C,) 偏移系数的导数"""dx, dgamma, dbeta = None, None, NoneN, C, H, W = dout.shape# 利用普通神经网络的BN进行计算 (N*H*W, C)channel看成是特征维度dx_temp, dgamma, dbeta = batchnorm_backward_alt(dout.transpose(0, 3, 2, 1).reshape((N*H*W, C)), cache)# 将shape恢复dx = dx_temp.reshape(N, W, H, C).transpose(0, 3, 2, 1)return dx, dgamma, dbeta

这篇关于BatchNorm介绍:卷积神经网络中的BN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/700751

相关文章

Python进阶之Excel基本操作介绍

《Python进阶之Excel基本操作介绍》在现实中,很多工作都需要与数据打交道,Excel作为常用的数据处理工具,一直备受人们的青睐,本文主要为大家介绍了一些Python中Excel的基本操作,希望... 目录概述写入使用 xlwt使用 XlsxWriter读取修改概述在现实中,很多工作都需要与数据打交

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

C++——stack、queue的实现及deque的介绍

目录 1.stack与queue的实现 1.1stack的实现  1.2 queue的实现 2.重温vector、list、stack、queue的介绍 2.1 STL标准库中stack和queue的底层结构  3.deque的简单介绍 3.1为什么选择deque作为stack和queue的底层默认容器  3.2 STL中对stack与queue的模拟实现 ①stack模拟实现

Mysql BLOB类型介绍

BLOB类型的字段用于存储二进制数据 在MySQL中,BLOB类型,包括:TinyBlob、Blob、MediumBlob、LongBlob,这几个类型之间的唯一区别是在存储的大小不同。 TinyBlob 最大 255 Blob 最大 65K MediumBlob 最大 16M LongBlob 最大 4G

FreeRTOS-基本介绍和移植STM32

FreeRTOS-基本介绍和STM32移植 一、裸机开发和操作系统开发介绍二、任务调度和任务状态介绍2.1 任务调度2.1.1 抢占式调度2.1.2 时间片调度 2.2 任务状态 三、FreeRTOS源码和移植STM323.1 FreeRTOS源码3.2 FreeRTOS移植STM323.2.1 代码移植3.2.2 时钟中断配置 一、裸机开发和操作系统开发介绍 裸机:前后台系