hk.LayerNorm 模块介绍

2024-01-08 04:52
文章标签 模块 介绍 layernorm hk

本文主要是介绍hk.LayerNorm 模块介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

hk.LayerNorm 是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。

主要参数:

  • axis(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。

  • create_scale(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。

  • create_offset(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。

  • epsilon(默认为1e-5): 一个小的正数,用于防止除以零的情况。

import haiku as hk
import jax
import jax.numpy as jnp
import pickle### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):"""LayerNorm module.Equivalent to hk.LayerNorm but with different parameter shapes: they arealways vectors rather than possibly higher-rank tensors. This makes it easierto change the layout whilst keep the model weight-compatible."""def __init__(self,axis,create_scale: bool,create_offset: bool,eps: float = 1e-5,scale_init=None,offset_init=None,use_fast_variance: bool = False,name=None,param_axis=None):super().__init__(axis=axis,create_scale=False,create_offset=False,eps=eps,scale_init=None,offset_init=None,use_fast_variance=use_fast_variance,name=name,param_axis=param_axis)self._temp_create_scale = create_scaleself._temp_create_offset = create_offset#self.scale_init = hk.initializers.Constant(1)#self.offset_init = hk.initializers.Constant(0)def __call__(self, x: jnp.ndarray) -> jnp.ndarray:is_bf16 = (x.dtype == jnp.bfloat16)if is_bf16:x = x.astype(jnp.float32)param_axis = self.param_axis[0] if self.param_axis else -1param_shape = (x.shape[param_axis],)param_broadcast_shape = [1] * x.ndimparam_broadcast_shape[param_axis] = x.shape[param_axis]scale = Noneoffset = None# scale,offset张量的形状必须可扩展到输入数据的形状。# 没有显式指定 self.scale_init,self.offset_init参数,# 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定if self._temp_create_scale:scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)scale = scale.reshape(param_broadcast_shape)if self._temp_create_offset:offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)offset = offset.reshape(param_broadcast_shape)out = super().__call__(x, scale=scale, offset=offset)if is_bf16:out = out.astype(jnp.bfloat16)return outwith open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:Human_HBB_tensor_dict = pickle.load(f)input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model)## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params) 
print("params scale shape:") 
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model2)params2 = model2.init(rng, input_data)
print(params2) 
print("params2 scale shape:") 
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm

这篇关于hk.LayerNorm 模块介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582385

相关文章

Python使用date模块进行日期处理的终极指南

《Python使用date模块进行日期处理的终极指南》在处理与时间相关的数据时,Python的date模块是开发者最趁手的工具之一,本文将用通俗的语言,结合真实案例,带您掌握date模块的六大核心功能... 目录引言一、date模块的核心功能1.1 日期表示1.2 日期计算1.3 日期比较二、六大常用方法详

MySQL中慢SQL优化的不同方式介绍

《MySQL中慢SQL优化的不同方式介绍》慢SQL的优化,主要从两个方面考虑,SQL语句本身的优化,以及数据库设计的优化,下面小编就来给大家介绍一下有哪些方式可以优化慢SQL吧... 目录避免不必要的列分页优化索引优化JOIN 的优化排序优化UNION 优化慢 SQL 的优化,主要从两个方面考虑,SQL 语

C++中函数模板与类模板的简单使用及区别介绍

《C++中函数模板与类模板的简单使用及区别介绍》这篇文章介绍了C++中的模板机制,包括函数模板和类模板的概念、语法和实际应用,函数模板通过类型参数实现泛型操作,而类模板允许创建可处理多种数据类型的类,... 目录一、函数模板定义语法真实示例二、类模板三、关键区别四、注意事项 ‌在C++中,模板是实现泛型编程

Python实现html转png的完美方案介绍

《Python实现html转png的完美方案介绍》这篇文章主要为大家详细介绍了如何使用Python实现html转png功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 1.增强稳定性与错误处理建议使用三层异常捕获结构:try: with sync_playwright(

Java使用多线程处理未知任务数的方案介绍

《Java使用多线程处理未知任务数的方案介绍》这篇文章主要为大家详细介绍了Java如何使用多线程实现处理未知任务数,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 知道任务个数,你可以定义好线程数规则,生成线程数去跑代码说明:1.虚拟线程池:使用 Executors.newVir

python中time模块的常用方法及应用详解

《python中time模块的常用方法及应用详解》在Python开发中,时间处理是绕不开的刚需场景,从性能计时到定时任务,从日志记录到数据同步,时间模块始终是开发者最得力的工具之一,本文将通过真实案例... 目录一、时间基石:time.time()典型场景:程序性能分析进阶技巧:结合上下文管理器实现自动计时

JAVA SE包装类和泛型详细介绍及说明方法

《JAVASE包装类和泛型详细介绍及说明方法》:本文主要介绍JAVASE包装类和泛型的相关资料,包括基本数据类型与包装类的对应关系,以及装箱和拆箱的概念,并重点讲解了自动装箱和自动拆箱的机制,文... 目录1. 包装类1.1 基本数据类型和对应的包装类1.2 装箱和拆箱1.3 自动装箱和自动拆箱2. 泛型2

Node.js net模块的使用示例

《Node.jsnet模块的使用示例》本文主要介绍了Node.jsnet模块的使用示例,net模块支持TCP通信,处理TCP连接和数据传输,具有一定的参考价值,感兴趣的可以了解一下... 目录简介引入 net 模块核心概念TCP (传输控制协议)Socket服务器TCP 服务器创建基本服务器服务器配置选项服

Python利用自带模块实现屏幕像素高效操作

《Python利用自带模块实现屏幕像素高效操作》这篇文章主要为大家详细介绍了Python如何利用自带模块实现屏幕像素高效操作,文中的示例代码讲解详,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、获取屏幕放缩比例2、获取屏幕指定坐标处像素颜色3、一个简单的使用案例4、总结1、获取屏幕放缩比例from

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.