本文主要是介绍PyTorch神经网络打印存储所有权重+激活值(运行时中间值),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
很多时候嵌入式或者新硬件需要纯净的权重模型和激活值(运行时中间值),本文提供一种最简洁的方法。
假设已经有模型model和pt文件了,在当前目录下新建weights文件夹,运行这段代码,就可以得到模型的权重(文本形式和二进制形式)
model.load_state_dict(state_dict)global_index = 0
for name, param in model.named_parameters():print(name, param.size())print(param.data.numpy(),file=open(f"weights/{global_index}-{name}.txt", "w"))param.data.numpy().tofile(f"weights/{global_index}-{name}.bin")global_index += 1
对于二进制形式的文件,可以通过od -t f4 <binary file name>
查看其对应的浮点数值。f4
表示fp32.
打印forward的中间值:(这么复杂是必要的)
global_index = 0
def hook_fn(module, input, output):global global_indexmodule_name = str(module)module_name=module_name.replace(" ", "")module_name=module_name.replace("\n", "")# print(name)intermediate_outputs = {}# input is a tuple, output is a tensorfor i, inp in enumerate(input):intermediate_outputs[f"{global_index}-{module_name}-input-{i}"] = inpintermediate_outputs[f"{global_index}-{module_name}-output"] = outputmodule_name = module_name[0:200] # make sure full path <= 255print(intermediate_outputs)print(f"Size input:",end=" ")if(type(input) == tuple):for i, inp in enumerate(input):if type(inp) == torch.Tensor:print(f"{i}-th Size: {inp.size()}", end=", ")inp.numpy().tofile(f"activations/{global_index}-{module_name}-input-{i}.bin")else:print(f"{i}-th : {inp}", end=", ")elif type(input) == torch.Tensor:print(f"Size: {input.size()}")input.numpy().tofile(f"activations/{global_index}-{module_name}-input.bin")print(f"Size output: {output.size()}")global_index += 1output.numpy().tofile(f"activations/{global_index}-{module_name}-output.bin")def register_hooks(model):for name, layer in model.named_children():# print(name, layer) # dump all layers, > layers.txt# Register the hook to the current layerlayer.register_forward_hook(hook_fn)# Recursively apply the same to all submodulesregister_hooks(layer)register_hooks(model)
其中regster_hooks
和以下等价(不需要recursive了)
def register_hooks(model):for name, layer in model.named_modules():# print(name, layer) # dump all layerslayer.register_forward_hook(hook_fn)
其中nn.sequential
作为一个整体,目前没办法拆开来看其内部的中间值。
这篇关于PyTorch神经网络打印存储所有权重+激活值(运行时中间值)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!