PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出

2024-02-29 03:20

本文主要是介绍PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

获取TensorRT(TRT)模型输入和输出,用于创建TRT的模型服务使用,具体参考脚本check_trt_script.py,如下:

  • 脚本输入:TRT的模型路径和输入图像尺寸
  • 脚本输出:模型的输入和输出结点信息,同时验证TRT模型是否可用
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 16.9.21
"""import argparseimport numpy as npdef check_trt(model_path, image_size):"""检查TRT模型"""import pycuda.driver as cudaimport tensorrt as trt# 必须导入包,import pycuda.autoinit,否则报错import pycuda.autoinitprint('[Info] model_path: {}'.format(model_path))img_shape = (1, 3, image_size, image_size)print('[Info] img_shape: {}'.format(img_shape))trt_logger = trt.Logger(trt.Logger.WARNING)trt_path = model_path  # TRT模型路径with open(trt_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())for binding in engine:binding_idx = engine.get_binding_index(binding)size = engine.get_binding_shape(binding_idx)dtype = trt.nptype(engine.get_binding_dtype(binding))print("[Info] binding: {}, binding_idx: {}, size: {}, dtype: {}".format(binding, binding_idx, size, dtype))input_image = np.random.randn(*img_shape).astype(np.float32)  # 图像尺寸input_image = np.ascontiguousarray(input_image)print('[Info] input_image: {}'.format(input_image.shape))with engine.create_execution_context() as context:stream = cuda.Stream()bindings = [0] * len(engine)for binding in engine:idx = engine.get_binding_index(binding)if engine.binding_is_input(idx):input_memory = cuda.mem_alloc(input_image.nbytes)bindings[idx] = int(input_memory)cuda.memcpy_htod_async(input_memory, input_image, stream)else:dtype = trt.nptype(engine.get_binding_dtype(binding))shape = context.get_binding_shape(idx)output_buffer = np.empty(shape, dtype=dtype)output_buffer = np.ascontiguousarray(output_buffer)output_memory = cuda.mem_alloc(output_buffer.nbytes)bindings[idx] = int(output_memory)context.execute_async_v2(bindings, stream.handle)stream.synchronize()cuda.memcpy_dtoh(output_buffer, output_memory)print("[Info] output_buffer: {}".format(output_buffer))def parse_args():"""处理脚本参数"""parser = argparse.ArgumentParser(description='检查TRT模型')parser.add_argument('-m', dest='model_path', required=True, help='TRT模型路径', type=str)parser.add_argument('-s', dest='image_size', required=False, help='图像尺寸,如336', type=int, default=336)args = parser.parse_args()arg_model_path = args.model_pathprint("[Info] 模型路径: {}".format(arg_model_path))arg_image_size = args.image_sizeprint("[Info] image_size: {}".format(arg_image_size))return arg_model_path, arg_image_sizedef main():arg_model_path, arg_image_size = parse_args()check_trt(arg_model_path, arg_image_size)  # 检查TRT模型if __name__ == '__main__':main()

注意:必须导入包,import pycuda.autoinit,否则cuda.Stream()报错,如下:
image-20210916162952425

输出信息如下:

[Info] 模型路径: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] image_size: 336
[Info] model_path: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] img_shape: (1, 3, 336, 336)
[Info] binding: input_0, binding_idx: 0, size: (1, 3, 336, 336), dtype: <class 'numpy.float32'>
[Info] binding: output_0, binding_idx: 1, size: (1, 2), dtype: <class 'numpy.float32'>
[Info] input_image: (1, 3, 336, 336)
[Info] output_buffer: [[ 0.23275298 -0.2184143 ]]

有效信息为:

  • 输入结点binding: input_0,输入尺寸size: (1, 3, 336, 336),输入类型dtype: <class 'numpy.float32'>
  • 输出结果binding: output_0,输出尺寸size: (1, 2),输出类型dtype: <class 'numpy.float32'>

相应的json文件如下:

{"model_path": "model_best_c2_20210915_cuda.trt","model_format": "trt","quant_type": "FP32","gpu_index": 0,"inputs": {"input_0": {"shapes": [1,3,336,336],"type": "FP32"}},"outputs": {"output_0": {"shapes": [1,2],"type": "FP32"}}
}

这篇关于PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/757429

相关文章

Python使用OpenCV实现获取视频时长的小工具

《Python使用OpenCV实现获取视频时长的小工具》在处理视频数据时,获取视频的时长是一项常见且基础的需求,本文将详细介绍如何使用Python和OpenCV获取视频时长,并对每一行代码进行深入解析... 目录一、代码实现二、代码解析1. 导入 OpenCV 库2. 定义获取视频时长的函数3. 打开视频文

MySQL 获取字符串长度及注意事项

《MySQL获取字符串长度及注意事项》本文通过实例代码给大家介绍MySQL获取字符串长度及注意事项,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录mysql 获取字符串长度详解 核心长度函数对比⚠️ 六大关键注意事项1. 字符编码决定字节长度2

python3如何找到字典的下标index、获取list中指定元素的位置索引

《python3如何找到字典的下标index、获取list中指定元素的位置索引》:本文主要介绍python3如何找到字典的下标index、获取list中指定元素的位置索引问题,具有很好的参考价值,... 目录enumerate()找到字典的下标 index获取list中指定元素的位置索引总结enumerat

SpringMVC高效获取JavaBean对象指南

《SpringMVC高效获取JavaBean对象指南》SpringMVC通过数据绑定自动将请求参数映射到JavaBean,支持表单、URL及JSON数据,需用@ModelAttribute、@Requ... 目录Spring MVC 获取 JavaBean 对象指南核心机制:数据绑定实现步骤1. 定义 Ja

C++中RAII资源获取即初始化

《C++中RAII资源获取即初始化》RAII通过构造/析构自动管理资源生命周期,确保安全释放,本文就来介绍一下C++中的RAII技术及其应用,具有一定的参考价值,感兴趣的可以了解一下... 目录一、核心原理与机制二、标准库中的RAII实现三、自定义RAII类设计原则四、常见应用场景1. 内存管理2. 文件操

SpringBoot服务获取Pod当前IP的两种方案

《SpringBoot服务获取Pod当前IP的两种方案》在Kubernetes集群中,SpringBoot服务获取Pod当前IP的方案主要有两种,通过环境变量注入或通过Java代码动态获取网络接口IP... 目录方案一:通过 Kubernetes Downward API 注入环境变量原理步骤方案二:通过

使用Python实现获取屏幕像素颜色值

《使用Python实现获取屏幕像素颜色值》这篇文章主要为大家详细介绍了如何使用Python实现获取屏幕像素颜色值,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、一个小工具,按住F10键,颜色值会跟着显示。完整代码import tkinter as tkimport pyau

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p