深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析

本文主要是介绍深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析

  • 1、BatchNorm
  • 2、LayerNorm
  • 3、GroupNorm
    • 用法:

BatchNorm、LayerNorm 和 GroupNorm 都是深度学习中常用的归一化方式。
它们通过将输入归一化到均值为 0 和方差为 1 的分布中,来防止梯度消失和爆炸,并提高模型的泛化能力

1、BatchNorm

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import numpy as np
import torch.nn as nn
import torchdef bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):# [batch, channel, height, width]feature_t = feature[:, i, :, :] # 得到每一个channel的height和widthmean_t = feature_t.mean()# 总体标准差std_t1 = feature_t.std()# 样本标准差std_t2 = feature_t.std(ddof=1)# bn process# 这里记得加上eps和pytorch保持一致feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)# update calculating mean and varmean[i] = mean[i] * 0.9 + mean_t * 0.1var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1print(feature)# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

显示结果如下:
在这里插入图片描述

在这里插入图片描述

代码:

import torch
import torch.nn as nn
import numpy as npfeatuer_array=(np.random.rand(2,4,2,2)).astype(np.float32)
print(featuer_array.dtype)featuer_tensor=torch.tensor(featuer_array,dtype=torch.float32)
bn_out=nn.BatchNorm2d( num_features=featuer_array.shape[1],eps=1e-5)(featuer_tensor)
print(bn_out)print("-----")for i in range(featuer_array.shape[1]):channel=featuer_array[:,i,:,:]mean=channel.mean()var=channel.var()print(f"mean---{mean},var---{var}")featuer_array[:,i,:,:]=(channel-mean) / np.sqrt(var + 1e-5)
print(featuer_array)

打印结果:
在这里插入图片描述

2、LayerNorm

Transformer block 中会使用到 LayerNorm , 一般输入尺寸形为 :(batch_size, token_num, dim),会在最后一个维度做 归一化,其中dim维度为token的特征向量: nn.LayerNorm(dim)

在这里插入图片描述

import torch
import torch.nn as nn
import numpy as npfeature_array=(np.random.rand(2,3,2,2).astype(np.float32))# 需要将其转化为[batch,token_num,dim]的形式
feature_array=feature_array.reshape((2,3,-1)).transpose(0,2,1)
print(feature_array.shape)   # (2, 4, 3)feature_tensor=torch.tensor(feature_array.copy(),dtype=torch.float32)layer_norm=nn.LayerNorm(normalized_shape=feature_array.shape[2])(feature_tensor)
print(layer_norm)print("\n","*"*50,"\n")
batch,token_num,dim=feature_array.shapefeature_array=feature_array.reshape((-1,dim))
for i in range(batch * token_num):mean=feature_array[i,:].mean()var=feature_array[i,:].var()print(f"mean----{mean},var----{var}")feature_array[i,:]=(feature_array[i,:]-mean) / np.sqrt(var + 1e-5)
print(feature_array.reshape(batch,token_num,dim))

打印效果如下所示:
在这里插入图片描述

3、GroupNorm

在这里插入图片描述

用法:

torch.nn.GroupNorm:将channel切分成许多组进行归一化
torch.nn.GroupNorm(num_groups,num_channels)
num_groups:组数
num_channels:通道数量
在这里插入图片描述
代码:

import torch
import torch.nn as nn
import numpy as npfeature_array=(np.random.rand(2,4,2,2)).astype(np.float32)
print(feature_array.dtype)feature_tensor=torch.tensor(feature_array.copy(),dtype=torch.float32)
group_result=nn.GroupNorm(num_groups=2,num_channels=feature_array.shape[1])(feature_tensor)
print(group_result)feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((4, 2, 2, 2))for i in range(feature_array.shape[0]):channel = feature_array[i, :, :, :]mean = feature_array[i, :, :, :].mean()var = feature_array[i, :, :, :].var()print(mean)print(var)feature_array[i, :, :, :] = (feature_array[i, :, :, :] - mean) / np.sqrt(var + 1e-5)
feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((2, 4, 2, 2))
print(feature_array)

打印结果:

在这里插入图片描述

这篇关于深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/182497

相关文章

java解析jwt中的payload的用法

《java解析jwt中的payload的用法》:本文主要介绍java解析jwt中的payload的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java解析jwt中的payload1. 使用 jjwt 库步骤 1:添加依赖步骤 2:解析 JWT2. 使用 N

Python中__init__方法使用的深度解析

《Python中__init__方法使用的深度解析》在Python的面向对象编程(OOP)体系中,__init__方法如同建造房屋时的奠基仪式——它定义了对象诞生时的初始状态,下面我们就来深入了解下_... 目录一、__init__的基因图谱二、初始化过程的魔法时刻继承链中的初始化顺序self参数的奥秘默认

Linux命令之firewalld的用法

《Linux命令之firewalld的用法》:本文主要介绍Linux命令之firewalld的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux命令之firewalld1、程序包2、启动firewalld3、配置文件4、firewalld规则定义的九大

SQL BETWEEN 的常见用法小结

《SQLBETWEEN的常见用法小结》BETWEEN操作符是SQL中非常有用的工具,它允许你快速选取某个范围内的值,本文给大家介绍SQLBETWEEN的常见用法,感兴趣的朋友一起看看吧... 在SQL中,BETWEEN是一个操作符,用于选取介于两个值之间的数据。它包含这两个边界值。BETWEEN操作符常用

MySql match against工具详细用法

《MySqlmatchagainst工具详细用法》在MySQL中,MATCH……AGAINST是全文索引(Full-Textindex)的查询语法,它允许你对文本进行高效的全文搜素,支持自然语言搜... 目录一、全文索引的基本概念二、创建全文索引三、自然语言搜索四、布尔搜索五、相关性排序六、全文索引的限制七

Java 正则表达式URL 匹配与源码全解析

《Java正则表达式URL匹配与源码全解析》在Web应用开发中,我们经常需要对URL进行格式验证,今天我们结合Java的Pattern和Matcher类,深入理解正则表达式在实际应用中... 目录1.正则表达式分解:2. 添加域名匹配 (2)3. 添加路径和查询参数匹配 (3) 4. 最终优化版本5.设计思

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Java字符串处理全解析(String、StringBuilder与StringBuffer)

《Java字符串处理全解析(String、StringBuilder与StringBuffer)》:本文主要介绍Java字符串处理全解析(String、StringBuilder与StringBu... 目录Java字符串处理全解析:String、StringBuilder与StringBuffer一、St

Spring Boot循环依赖原理、解决方案与最佳实践(全解析)

《SpringBoot循环依赖原理、解决方案与最佳实践(全解析)》循环依赖指两个或多个Bean相互直接或间接引用,形成闭环依赖关系,:本文主要介绍SpringBoot循环依赖原理、解决方案与最... 目录一、循环依赖的本质与危害1.1 什么是循环依赖?1.2 核心危害二、Spring的三级缓存机制2.1 三

C#中async await异步关键字用法和异步的底层原理全解析

《C#中asyncawait异步关键字用法和异步的底层原理全解析》:本文主要介绍C#中asyncawait异步关键字用法和异步的底层原理全解析,本文给大家介绍的非常详细,对大家的学习或工作具有一... 目录C#异步编程一、异步编程基础二、异步方法的工作原理三、代码示例四、编译后的底层实现五、总结C#异步编程