pytorch中对nn.BatchNorm2d()函数的理解

2023-11-10 23:21

本文主要是介绍pytorch中对nn.BatchNorm2d()函数的理解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch中对BatchNorm2d函数的理解

  • 简介
  • 计算
  • 3. Pytorch的nn.BatchNorm2d()函数
  • 4 代码示例

简介

机器学习中,进行模型训练之前,需对数据做归一化处理,使其分布一致。在深度神经网络训练过程中,通常一次训练是一个batch,而非全体数据。每个batch具有不同的分布产生了internal covarivate shift问题——在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。Batch Normalization强行将数据拉回到均值为0,方差为1的正太分布上,一方面使得数据分布一致,另一方面避免梯度消失。

计算

如图所示:
在这里插入图片描述
在这里插入图片描述

3. Pytorch的nn.BatchNorm2d()函数

其主要需要输入4个参数:
(1)num_features:输入数据的shape一般为[batch_size, channel, height, width], num_features为其中的channel;
(2)eps: 分母中添加的一个值,目的是为了计算的稳定性,默认:1e-5;
(3)momentum: 一个用于运行过程中均值和方差的一个估计参数,默认值为0.1.
在这里插入图片描述
(4)affine:当设为true时,给定可以学习的系数矩阵 γ \gamma γ β \beta β

4 代码示例

import torchdata = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)print("\n")print("=========================使用封装的BatchNorm2d()计算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)print("\n")print("=========================自行计算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.将同一通道进行拼接(即把同一通道当作一个整体)
x_mean = torch.Tensor.mean(x)                       # 2.计算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.计算同一通道所有制的方差(即拼接后的方差)# 4.使用第一个数按照公式来求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)

在这里插入图片描述
在这里插入图片描述

这篇关于pytorch中对nn.BatchNorm2d()函数的理解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/385967

相关文章

Java function函数式接口的使用方法与实例

《Javafunction函数式接口的使用方法与实例》:本文主要介绍Javafunction函数式接口的使用方法与实例,函数式接口如一支未完成的诗篇,用Lambda表达式作韵脚,将代码的机械美感... 目录引言-当代码遇见诗性一、函数式接口的生物学解构1.1 函数式接口的基因密码1.2 六大核心接口的形态学

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.

深入理解C语言的void*

《深入理解C语言的void*》本文主要介绍了C语言的void*,包括它的任意性、编译器对void*的类型检查以及需要显式类型转换的规则,具有一定的参考价值,感兴趣的可以了解一下... 目录一、void* 的类型任意性二、编译器对 void* 的类型检查三、需要显式类型转换占用的字节四、总结一、void* 的

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

深入理解Redis大key的危害及解决方案

《深入理解Redis大key的危害及解决方案》本文主要介绍了深入理解Redis大key的危害及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、背景二、什么是大key三、大key评价标准四、大key 产生的原因与场景五、大key影响与危

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

深入理解C++ 空类大小

《深入理解C++空类大小》本文主要介绍了C++空类大小,规定空类大小为1字节,主要是为了保证对象的唯一性和可区分性,满足数组元素地址连续的要求,下面就来了解一下... 目录1. 保证对象的唯一性和可区分性2. 满足数组元素地址连续的要求3. 与C++的对象模型和内存管理机制相适配查看类对象内存在C++中,规

C++11的函数包装器std::function使用示例

《C++11的函数包装器std::function使用示例》C++11引入的std::function是最常用的函数包装器,它可以存储任何可调用对象并提供统一的调用接口,以下是关于函数包装器的详细讲解... 目录一、std::function 的基本用法1. 基本语法二、如何使用 std::function

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>