Pytorch导出FP16 ONNX模型

2024-04-10 09:04
文章标签 模型 导出 pytorch onnx fp16

本文主要是介绍Pytorch导出FP16 ONNX模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

import torch
from torchvision.models import resnet50def main():export_fp16 = Trueexport_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = resnet50()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16
在这里插入图片描述

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB)。
在这里插入图片描述
大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>main()File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in mainmodel(x)File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_implreturn forward_call(*args, **kwargs)File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forwardx = F.conv2d(x, weight=kernel, bias=None, stride=1)RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.autocast(device_type="cuda", dtype=torch.float16):with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。
在这里插入图片描述

这篇关于Pytorch导出FP16 ONNX模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/890626

相关文章

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Python实现将MySQL中所有表的数据都导出为CSV文件并压缩

《Python实现将MySQL中所有表的数据都导出为CSV文件并压缩》这篇文章主要为大家详细介绍了如何使用Python将MySQL数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到... python将mysql数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到另一个

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

Java导入、导出excel用法步骤保姆级教程(附封装好的工具类)

《Java导入、导出excel用法步骤保姆级教程(附封装好的工具类)》:本文主要介绍Java导入、导出excel的相关资料,讲解了使用Java和ApachePOI库将数据导出为Excel文件,包括... 目录前言一、引入Apache POI依赖二、用法&步骤2.1 创建Excel的元素2.3 样式和字体2.

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

java导出pdf文件的详细实现方法

《java导出pdf文件的详细实现方法》:本文主要介绍java导出pdf文件的详细实现方法,包括制作模板、获取中文字体文件、实现后端服务以及前端发起请求并生成下载链接,需要的朋友可以参考下... 目录使用注意点包含内容1、制作pdf模板2、获取pdf导出中文需要的文件3、实现4、前端发起请求并生成下载链接使

SpringBoot实现导出复杂对象到Excel文件

《SpringBoot实现导出复杂对象到Excel文件》这篇文章主要为大家详细介绍了如何使用Hutool和EasyExcel两种方式来实现在SpringBoot项目中导出复杂对象到Excel文件,需要... 在Spring Boot项目中导出复杂对象到Excel文件,可以利用Hutool或EasyExcel