Pytorch导出FP16 ONNX模型

2024-04-10 09:04
文章标签 模型 导出 pytorch onnx fp16

本文主要是介绍Pytorch导出FP16 ONNX模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

import torch
from torchvision.models import resnet50def main():export_fp16 = Trueexport_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = resnet50()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16
在这里插入图片描述

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB)。
在这里插入图片描述
大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>main()File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in mainmodel(x)File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_implreturn forward_call(*args, **kwargs)File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forwardx = F.conv2d(x, weight=kernel, bias=None, stride=1)RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.autocast(device_type="cuda", dtype=torch.float16):with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。
在这里插入图片描述

这篇关于Pytorch导出FP16 ONNX模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/890626

相关文章

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Python实现将实体类列表数据导出到Excel文件

《Python实现将实体类列表数据导出到Excel文件》在数据处理和报告生成中,将实体类的列表数据导出到Excel文件是一项常见任务,Python提供了多种库来实现这一目标,下面就来跟随小编一起学习一... 目录一、环境准备二、定义实体类三、创建实体类列表四、将实体类列表转换为DataFrame五、导出Da

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档

Oracle Expdp按条件导出指定表数据的方法实例

《OracleExpdp按条件导出指定表数据的方法实例》:本文主要介绍Oracle的expdp数据泵方式导出特定机构和时间范围的数据,并通过parfile文件进行条件限制和配置,文中通过代码介绍... 目录1.场景描述 2.方案分析3.实验验证 3.1 parfile文件3.2 expdp命令导出4.总结

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

java poi实现Excel多级表头导出方式(多级表头,复杂表头)

《javapoi实现Excel多级表头导出方式(多级表头,复杂表头)》文章介绍了使用javapoi库实现Excel多级表头导出的方法,通过主代码、合并单元格、设置表头单元格宽度、填充数据、web下载... 目录Java poi实现Excel多级表头导出(多级表头,复杂表头)上代码1.主代码2.合并单元格3.

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G