CS231n作业笔记2.4:Batchnorm的实现与使用

2023-12-15 23:38

本文主要是介绍CS231n作业笔记2.4:Batchnorm的实现与使用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CS231n简介

详见 CS231n课程笔记1:Introduction。
本文都是作者自己的思考,正确性未经过验证,欢迎指教。

作业笔记

Batchnorm的思想简单易懂,实现起来也很轻松,但是却具有很多优良的性质,具体请参考课程笔记。下图简要介绍了一下Batchnorm需要完成的工作以及优点(详情请见CS231n课程笔记5.3:Batch Normalization):
batchnorm
需要注意的有:

  1. 最后一步对归一化后的数据进行平移与缩放,且此参数可学习。
  2. 上诉参数对于x的每一维都具有相应的参数,故假设X.shape = [N,D],那么gamma.shape = [D,]

1. 前向传播

这里即实现上图所诉功能,需要注意的有:

  1. out_media的命名是为了后向传播的时候处理方便
  2. 使用var而不是std,这既符合图中公式,又方便了后向传播
  3. 计算out_media的时候,减法以及除法都做了broadcasting,对应于反向传播的时候sum
  4. 同理,计算out的时候,加法以及乘法也都做了broadcasting

注:broadcasting的部分请参考python、numpy、scipy、matplotlib的一些小技巧。

  if mode == 'train':mean = np.mean(x,axis = 0)var = np.var(x,axis = 0)running_mean = running_mean * momentum + (1-momentum) * meanrunning_var = running_var * momentum + (1-momentum) * varout_media = (x-mean)/np.sqrt(var + eps)out = (out_media + beta) * gammacache = (out_media,x,mean,var,beta,gamma,eps)elif mode == 'test':out = (x-running_mean)/np.sqrt(running_var+eps)out = (out + beta) * gammacache = (out,x,running_mean,running_var,beta,gamma,eps)

2. 后向传播

对前面所诉的前向传播过程做BP(详情参考CS231n课程笔记4.1:反向传播BP),值得注意的有:

  1. dgamma以及dbeta求值的时候由于前向传播那里使用了broadcasting,这里需要做求和。
  2. dvar也会向dmean传播,所以先分解dvar。
  3. 直接求解dvar过于复杂,使用dstd过渡。
  4. 每次求解的时候不要忘记乘以全局梯度。
  5. 对于dvar的分解,分别使用(x-mean)^2以及(x-mean)过渡。
  dout_media = dout * gammadgamma = np.sum(dout * (out_media + beta),axis = 0)dbeta = np.sum(dout * gamma,axis = 0)dx = dout_media / np.sqrt(var + eps)dmean = -np.sum(dout_media / np.sqrt(var+eps),axis = 0)dstd = np.sum(-dout_media * (x - mean) / (var + eps),axis = 0)dvar = 1./2./np.sqrt(var+eps) * dstddx_minus_mean_square = dvar / x.shape[0]dx_minus_mean = 2 * (x-mean) * dx_minus_mean_squaredx += dx_minus_meandmean += np.sum(-dx_minus_mean,axis = 0)dx += dmean / x.shape[0] 

3. 应用:带Batchnorm的多层神经网络

不带Batchnorm多层神经网络的实现参考CS231n作业笔记2.2:多层神经网络的实现。

3.1. 初始化代码

初始化参数,注意beta以及gamma都需要初始化,而且对于x的每一维都存在相应独立的参数。

    self.bn_params = []if self.use_batchnorm:self.bn_params = [{'mode': 'train'} for i in xrange(self.num_layers - 1)]for i in xrange(self.num_layers-1):self.params['beta'+str(i+1)] = np.zeros(hidden_dims[i])self.params['gamma'+str(i+1)] = np.ones(hidden_dims[i])

3.2. 前向传播代码

计算scores,注意对于最后一层全连接的输出不做BN;以及running_mean以及running_var是内部变量,每次只在自己内部更新,不同层的mean与var无关。

    cache = {}hidden_value = Nonehidden_value,cache['fc1'] = affine_forward(X,self.params['W1'],self.params['b1'])if self.use_batchnorm:hidden_value,cache['bn1'] = batchnorm_forward(hidden_value, self.params['gamma1'], self.params['beta1'], self.bn_params[0])hidden_value,cache['relu1'] = relu_forward(hidden_value)for index in range(2,self.num_layers):hidden_value,cache['fc'+str(index)] = affine_forward(hidden_value,self.params['W'+str(index)],self.params['b'+str(index)])if self.use_batchnorm:hidden_value,cache['bn'+str(index)] = batchnorm_forward(hidden_value,  self.params['gamma'+str(index)], self.params['beta'+str(index)], self.bn_params[index-1])hidden_value,cache['relu'+str(index)] = relu_forward(hidden_value)scores,cache['score'] = affine_forward(hidden_value,self.params['W'+str(self.num_layers)],self.params['b'+str(self.num_layers)])

3.3. 后向传播代码

计算gradient,注意本作业对于beta以及gamma不做正则化,但是keras等开源库提供了相应正则化的接口。

    loss, grads = 0.0, {}loss,dscores = softmax_loss(scores,y)for index in range(1,self.num_layers+1):loss += 0.5*self.reg*np.sum(self.params['W'+str(index)]**2)dhidden_value,grads['W'+str(self.num_layers)],grads['b'+str(self.num_layers)] = affine_backward(dscores,cache['score'])for index in range(self.num_layers-1,1,-1):dhidden_value = relu_backward(dhidden_value,cache['relu'+str(index)])if self.use_batchnorm:dhidden_value, grads['gamma'+str(index)], grads['beta'+str(index)] = batchnorm_backward(dhidden_value, cache['bn'+str(index)])dhidden_value,grads['W'+str(index)],grads['b'+str(index)] = affine_backward(dhidden_value,cache['fc'+str(index)])dhidden_value = relu_backward(dhidden_value,cache['relu1'])if self.use_batchnorm:dhidden_value, grads['gamma1'], grads['beta1'] = batchnorm_backward(dhidden_value, cache['bn1'])dhidden_value,grads['W1'],grads['b1'] = affine_backward(dhidden_value,cache['fc1'])for index in range(1,self.num_layers+1):grads['W'+str(index)] += self.reg * self.params['W'+str(index)] 

这篇关于CS231n作业笔记2.4:Batchnorm的实现与使用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/498323

相关文章

Spring Security+JWT如何实现前后端分离权限控制

《SpringSecurity+JWT如何实现前后端分离权限控制》本篇将手把手教你用SpringSecurity+JWT搭建一套完整的登录认证与权限控制体系,具有很好的参考价值,希望对大家... 目录Spring Security+JWT实现前后端分离权限控制实战一、为什么要用 JWT?二、JWT 基本结构

Java实现优雅日期处理的方案详解

《Java实现优雅日期处理的方案详解》在我们的日常工作中,需要经常处理各种格式,各种类似的的日期或者时间,下面我们就来看看如何使用java处理这样的日期问题吧,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言一、日期的坑1.1 日期格式化陷阱1.2 时区转换二、优雅方案的进阶之路2.1 线程安全重构2

Android实现两台手机屏幕共享和远程控制功能

《Android实现两台手机屏幕共享和远程控制功能》在远程协助、在线教学、技术支持等多种场景下,实时获得另一部移动设备的屏幕画面,并对其进行操作,具有极高的应用价值,本项目旨在实现两台Android手... 目录一、项目概述二、相关知识2.1 MediaProjection API2.2 Socket 网络

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB

Maven的使用和配置国内源的保姆级教程

《Maven的使用和配置国内源的保姆级教程》Maven是⼀个项目管理工具,基于POM(ProjectObjectModel,项目对象模型)的概念,Maven可以通过一小段描述信息来管理项目的构建,报告... 目录1. 什么是Maven?2.创建⼀个Maven项目3.Maven 核心功能4.使用Maven H

Redis消息队列实现异步秒杀功能

《Redis消息队列实现异步秒杀功能》在高并发场景下,为了提高秒杀业务的性能,可将部分工作交给Redis处理,并通过异步方式执行,Redis提供了多种数据结构来实现消息队列,总结三种,本文详细介绍Re... 目录1 Redis消息队列1.1 List 结构1.2 Pub/Sub 模式1.3 Stream 结

C# Where 泛型约束的实现

《C#Where泛型约束的实现》本文主要介绍了C#Where泛型约束的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录使用的对象约束分类where T : structwhere T : classwhere T : ne

Python中__init__方法使用的深度解析

《Python中__init__方法使用的深度解析》在Python的面向对象编程(OOP)体系中,__init__方法如同建造房屋时的奠基仪式——它定义了对象诞生时的初始状态,下面我们就来深入了解下_... 目录一、__init__的基因图谱二、初始化过程的魔法时刻继承链中的初始化顺序self参数的奥秘默认

将Java程序打包成EXE文件的实现方式

《将Java程序打包成EXE文件的实现方式》:本文主要介绍将Java程序打包成EXE文件的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录如何将Java程序编程打包成EXE文件1.准备Java程序2.生成JAR包3.选择并安装打包工具4.配置Launch4

SpringBoot使用GZIP压缩反回数据问题

《SpringBoot使用GZIP压缩反回数据问题》:本文主要介绍SpringBoot使用GZIP压缩反回数据问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录SpringBoot使用GZIP压缩反回数据1、初识gzip2、gzip是什么,可以干什么?3、Spr