hk.LayerNorm 模块介绍

2024-01-08 04:52
文章标签 模块 介绍 layernorm hk

本文主要是介绍hk.LayerNorm 模块介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

hk.LayerNorm 是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。

主要参数:

  • axis(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。

  • create_scale(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。

  • create_offset(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。

  • epsilon(默认为1e-5): 一个小的正数,用于防止除以零的情况。

import haiku as hk
import jax
import jax.numpy as jnp
import pickle### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):"""LayerNorm module.Equivalent to hk.LayerNorm but with different parameter shapes: they arealways vectors rather than possibly higher-rank tensors. This makes it easierto change the layout whilst keep the model weight-compatible."""def __init__(self,axis,create_scale: bool,create_offset: bool,eps: float = 1e-5,scale_init=None,offset_init=None,use_fast_variance: bool = False,name=None,param_axis=None):super().__init__(axis=axis,create_scale=False,create_offset=False,eps=eps,scale_init=None,offset_init=None,use_fast_variance=use_fast_variance,name=name,param_axis=param_axis)self._temp_create_scale = create_scaleself._temp_create_offset = create_offset#self.scale_init = hk.initializers.Constant(1)#self.offset_init = hk.initializers.Constant(0)def __call__(self, x: jnp.ndarray) -> jnp.ndarray:is_bf16 = (x.dtype == jnp.bfloat16)if is_bf16:x = x.astype(jnp.float32)param_axis = self.param_axis[0] if self.param_axis else -1param_shape = (x.shape[param_axis],)param_broadcast_shape = [1] * x.ndimparam_broadcast_shape[param_axis] = x.shape[param_axis]scale = Noneoffset = None# scale,offset张量的形状必须可扩展到输入数据的形状。# 没有显式指定 self.scale_init,self.offset_init参数,# 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定if self._temp_create_scale:scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)scale = scale.reshape(param_broadcast_shape)if self._temp_create_offset:offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)offset = offset.reshape(param_broadcast_shape)out = super().__call__(x, scale=scale, offset=offset)if is_bf16:out = out.astype(jnp.bfloat16)return outwith open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:Human_HBB_tensor_dict = pickle.load(f)input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model)## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params) 
print("params scale shape:") 
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model2)params2 = model2.init(rng, input_data)
print(params2) 
print("params2 scale shape:") 
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm

这篇关于hk.LayerNorm 模块介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582385

相关文章

四种Flutter子页面向父组件传递数据的方法介绍

《四种Flutter子页面向父组件传递数据的方法介绍》在Flutter中,如果父组件需要调用子组件的方法,可以通过常用的四种方式实现,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录方法 1:使用 GlobalKey 和 State 调用子组件方法方法 2:通过回调函数(Callb

Python进阶之Excel基本操作介绍

《Python进阶之Excel基本操作介绍》在现实中,很多工作都需要与数据打交道,Excel作为常用的数据处理工具,一直备受人们的青睐,本文主要为大家介绍了一些Python中Excel的基本操作,希望... 目录概述写入使用 xlwt使用 XlsxWriter读取修改概述在现实中,很多工作都需要与数据打交

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

多模块的springboot项目发布指定模块的脚本方式

《多模块的springboot项目发布指定模块的脚本方式》该文章主要介绍了如何在多模块的SpringBoot项目中发布指定模块的脚本,作者原先的脚本会清理并编译所有模块,导致发布时间过长,通过简化脚本... 目录多模块的springboot项目发布指定模块的脚本1、不计成本地全部发布2、指定模块发布总结多模

Python中构建终端应用界面利器Blessed模块的使用

《Python中构建终端应用界面利器Blessed模块的使用》Blessed库作为一个轻量级且功能强大的解决方案,开始在开发者中赢得口碑,今天,我们就一起来探索一下它是如何让终端UI开发变得轻松而高... 目录一、安装与配置:简单、快速、无障碍二、基本功能:从彩色文本到动态交互1. 显示基本内容2. 创建链

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

python中的与时间相关的模块应用场景分析

《python中的与时间相关的模块应用场景分析》本文介绍了Python中与时间相关的几个重要模块:`time`、`datetime`、`calendar`、`timeit`、`pytz`和`dateu... 目录1. time 模块2. datetime 模块3. calendar 模块4. timeit

Python模块导入的几种方法实现

《Python模块导入的几种方法实现》本文主要介绍了Python模块导入的几种方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录一、什么是模块?二、模块导入的基本方法1. 使用import整个模块2.使用from ... i

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题