本文主要是介绍hk.LayerNorm 模块介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
hk.LayerNorm
是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。
主要参数:
-
axis
(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。 -
create_scale
(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。 -
create_offset
(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。 -
epsilon
(默认为1e-5): 一个小的正数,用于防止除以零的情况。
import haiku as hk
import jax
import jax.numpy as jnp
import pickle### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):"""LayerNorm module.Equivalent to hk.LayerNorm but with different parameter shapes: they arealways vectors rather than possibly higher-rank tensors. This makes it easierto change the layout whilst keep the model weight-compatible."""def __init__(self,axis,create_scale: bool,create_offset: bool,eps: float = 1e-5,scale_init=None,offset_init=None,use_fast_variance: bool = False,name=None,param_axis=None):super().__init__(axis=axis,create_scale=False,create_offset=False,eps=eps,scale_init=None,offset_init=None,use_fast_variance=use_fast_variance,name=name,param_axis=param_axis)self._temp_create_scale = create_scaleself._temp_create_offset = create_offset#self.scale_init = hk.initializers.Constant(1)#self.offset_init = hk.initializers.Constant(0)def __call__(self, x: jnp.ndarray) -> jnp.ndarray:is_bf16 = (x.dtype == jnp.bfloat16)if is_bf16:x = x.astype(jnp.float32)param_axis = self.param_axis[0] if self.param_axis else -1param_shape = (x.shape[param_axis],)param_broadcast_shape = [1] * x.ndimparam_broadcast_shape[param_axis] = x.shape[param_axis]scale = Noneoffset = None# scale,offset张量的形状必须可扩展到输入数据的形状。# 没有显式指定 self.scale_init,self.offset_init参数,# 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定if self._temp_create_scale:scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)scale = scale.reshape(param_broadcast_shape)if self._temp_create_offset:offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)offset = offset.reshape(param_broadcast_shape)out = super().__call__(x, scale=scale, offset=offset)if is_bf16:out = out.astype(jnp.bfloat16)return outwith open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:Human_HBB_tensor_dict = pickle.load(f)input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model)## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params)
print("params scale shape:")
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape)
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model2)params2 = model2.init(rng, input_data)
print(params2)
print("params2 scale shape:")
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape)
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)
参考:
https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm
这篇关于hk.LayerNorm 模块介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!