PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出

2024-02-29 03:20

本文主要是介绍PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

获取TensorRT(TRT)模型输入和输出,用于创建TRT的模型服务使用,具体参考脚本check_trt_script.py,如下:

  • 脚本输入:TRT的模型路径和输入图像尺寸
  • 脚本输出:模型的输入和输出结点信息,同时验证TRT模型是否可用
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 16.9.21
"""import argparseimport numpy as npdef check_trt(model_path, image_size):"""检查TRT模型"""import pycuda.driver as cudaimport tensorrt as trt# 必须导入包,import pycuda.autoinit,否则报错import pycuda.autoinitprint('[Info] model_path: {}'.format(model_path))img_shape = (1, 3, image_size, image_size)print('[Info] img_shape: {}'.format(img_shape))trt_logger = trt.Logger(trt.Logger.WARNING)trt_path = model_path  # TRT模型路径with open(trt_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())for binding in engine:binding_idx = engine.get_binding_index(binding)size = engine.get_binding_shape(binding_idx)dtype = trt.nptype(engine.get_binding_dtype(binding))print("[Info] binding: {}, binding_idx: {}, size: {}, dtype: {}".format(binding, binding_idx, size, dtype))input_image = np.random.randn(*img_shape).astype(np.float32)  # 图像尺寸input_image = np.ascontiguousarray(input_image)print('[Info] input_image: {}'.format(input_image.shape))with engine.create_execution_context() as context:stream = cuda.Stream()bindings = [0] * len(engine)for binding in engine:idx = engine.get_binding_index(binding)if engine.binding_is_input(idx):input_memory = cuda.mem_alloc(input_image.nbytes)bindings[idx] = int(input_memory)cuda.memcpy_htod_async(input_memory, input_image, stream)else:dtype = trt.nptype(engine.get_binding_dtype(binding))shape = context.get_binding_shape(idx)output_buffer = np.empty(shape, dtype=dtype)output_buffer = np.ascontiguousarray(output_buffer)output_memory = cuda.mem_alloc(output_buffer.nbytes)bindings[idx] = int(output_memory)context.execute_async_v2(bindings, stream.handle)stream.synchronize()cuda.memcpy_dtoh(output_buffer, output_memory)print("[Info] output_buffer: {}".format(output_buffer))def parse_args():"""处理脚本参数"""parser = argparse.ArgumentParser(description='检查TRT模型')parser.add_argument('-m', dest='model_path', required=True, help='TRT模型路径', type=str)parser.add_argument('-s', dest='image_size', required=False, help='图像尺寸,如336', type=int, default=336)args = parser.parse_args()arg_model_path = args.model_pathprint("[Info] 模型路径: {}".format(arg_model_path))arg_image_size = args.image_sizeprint("[Info] image_size: {}".format(arg_image_size))return arg_model_path, arg_image_sizedef main():arg_model_path, arg_image_size = parse_args()check_trt(arg_model_path, arg_image_size)  # 检查TRT模型if __name__ == '__main__':main()

注意:必须导入包,import pycuda.autoinit,否则cuda.Stream()报错,如下:
image-20210916162952425

输出信息如下:

[Info] 模型路径: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] image_size: 336
[Info] model_path: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] img_shape: (1, 3, 336, 336)
[Info] binding: input_0, binding_idx: 0, size: (1, 3, 336, 336), dtype: <class 'numpy.float32'>
[Info] binding: output_0, binding_idx: 1, size: (1, 2), dtype: <class 'numpy.float32'>
[Info] input_image: (1, 3, 336, 336)
[Info] output_buffer: [[ 0.23275298 -0.2184143 ]]

有效信息为:

  • 输入结点binding: input_0,输入尺寸size: (1, 3, 336, 336),输入类型dtype: <class 'numpy.float32'>
  • 输出结果binding: output_0,输出尺寸size: (1, 2),输出类型dtype: <class 'numpy.float32'>

相应的json文件如下:

{"model_path": "model_best_c2_20210915_cuda.trt","model_format": "trt","quant_type": "FP32","gpu_index": 0,"inputs": {"input_0": {"shapes": [1,3,336,336],"type": "FP32"}},"outputs": {"output_0": {"shapes": [1,2],"type": "FP32"}}
}

这篇关于PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/757429

相关文章

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

如何利用Java获取当天的开始和结束时间

《如何利用Java获取当天的开始和结束时间》:本文主要介绍如何使用Java8的LocalDate和LocalDateTime类获取指定日期的开始和结束时间,展示了如何通过这些类进行日期和时间的处... 目录前言1. Java日期时间API概述2. 获取当天的开始和结束时间代码解析运行结果3. 总结前言在J

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

java获取图片的大小、宽度、高度方式

《java获取图片的大小、宽度、高度方式》文章介绍了如何将File对象转换为MultipartFile对象的过程,并分享了个人经验,希望能为读者提供参考... 目China编程录Java获取图片的大小、宽度、高度File对象(该对象里面是图片)MultipartFile对象(该对象里面是图片)总结java获取图片

Java通过反射获取方法参数名的方式小结

《Java通过反射获取方法参数名的方式小结》这篇文章主要为大家详细介绍了Java如何通过反射获取方法参数名的方式,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、前言2、解决方式方式2.1: 添加编译参数配置 -parameters方式2.2: 使用Spring的内部工具类 -

Java如何获取视频文件的视频时长

《Java如何获取视频文件的视频时长》文章介绍了如何使用Java获取视频文件的视频时长,包括导入maven依赖和代码案例,同时,也讨论了在运行过程中遇到的SLF4J加载问题,并给出了解决方案... 目录Java获取视频文件的视频时长1、导入maven依赖2、代码案例3、SLF4J: Failed to lo

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

使用Java实现获取客户端IP地址

《使用Java实现获取客户端IP地址》这篇文章主要为大家详细介绍了如何使用Java实现获取客户端IP地址,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 首先是获取 IP,直接上代码import org.springframework.web.context.request.Requ

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee