【tensorrt】——双线性上采样插件(提供源码)

2024-06-13 08:48

本文主要是介绍【tensorrt】——双线性上采样插件(提供源码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述
在这里插入图片描述

简介:
如果用nvidia的gpu,在推理的时候,采用tensorrt进行加速是一个很好的选择,虽然tensorrt没有开源。

我一般选择的模型训练到部署的流程是:

  1. pytorch训练模型
  2. onnx模型导出
  3. onnx模型转ncnn,mnn,tensorrt等模型
  4. 嵌入式推理框架,推理脚本书写。

这里用tensorrt做语义分割网络pspnet的推理加速。技术路线采用:pytorch——onnx——tensorrt。

1. pytorch——onnx

pytorch是内嵌了onnx模型导出的。这里pytorch版本的选择由使用的tensorrt的版本确定。这里我们采用TensorRT-YOLOv4项目中onnx-tensorrt中的tensorrt版本5.1xx。

这个版本上采样onnx中还是upsample,对应到pytorch<=1.0。pytorch1.0是支持nearest,bilinear两种方式的导出的。

2. onnx——tensorrt

TensorRT-YOLOv4中有resizenearest插件是没有双线性插值的。

2.1 写插件

resizebilinear是没有网络权重参数的,所以没有序列化重构,可以需要对以下进行重构。
需要重构:

  • getPluginType:
  • getOutputDimensions:计算网络输出tensor的尺寸
  • initialize:
  • enqueue:前向推理的具体入口
#pragma once#include "plugin.hpp"
#include "serialize.hpp"
#include <cassert>class ResizeBilinearPlugin final : public onnx2trt::Plugin {int   _ndims;float _scale[nvinfer1::Dims::MAX_DIMS];nvinfer1::Dims _output_dims;protected:void deserialize(void const* serialData, size_t serialLength) {deserializeBase(serialData, serialLength);deserialize_value(&serialData, &serialLength, &_ndims);deserialize_value(&serialData, &serialLength, &_scale);}size_t getSerializationSize() override {return serialized_size(_ndims) + serialized_size(_scale) + getBaseSerializationSize();}void serialize(void *buffer) override {serializeBase(buffer);serialize_value(&buffer, _ndims);serialize_value(&buffer, _scale);}public:ResizeBilinearPlugin(std::vector<float> const& scale): _ndims(scale.size()) {assert(scale.size() <= nvinfer1::Dims::MAX_DIMS);std::copy(scale.begin(), scale.end(), _scale);}ResizeBilinearPlugin(void const* serialData, size_t serialLength) {this->deserialize(serialData, serialLength);}virtual const char* getPluginType() const override { return "ResizeBilinear"; }virtual int getNbOutputs() const override { return 1; }virtual nvinfer1::Dims getOutputDimensions(int index,const nvinfer1::Dims *inputs, int nbInputDims) override;virtual int initialize() override;int enqueue(int batchSize,const void *const *inputs, void **outputs,void *workspace, cudaStream_t stream) override;
};

重构之后,一般都会向,tensorrt怎么调了。具体可以参考:【onnx-tensorrt】——源码阅读记录

总结就是:你看不到调用的接口,你只能模仿着写。

2.2 注册插件

builtin_plugins.cpp 中注册插件

REGISTER_BUILTIN_PLUGIN("FancyActivation",       FancyActivationPlugin);        // 相当于入库
REGISTER_BUILTIN_PLUGIN("ResizeNearest",         ResizeNearestPlugin);
REGISTER_BUILTIN_PLUGIN("ResizeBilinear",        ResizeBilinearPlugin);
REGISTER_BUILTIN_PLUGIN("Split"        ,         SplitPlugin);
REGISTER_BUILTIN_PLUGIN("InstanceNormalization", InstanceNormalizationPlugin);
REGISTER_BUILTIN_NVPLUGIN("Concat", ConcatPlugin);
REGISTER_BUILTIN_PLUGIN("DCNv2", DCNv2Plugin);
REGISTER_BUILTIN_PLUGIN("Mish", MishPlugin);
REGISTER_BUILTIN_PLUGIN("YOLO", YOLOPlugin);
REGISTER_BUILTIN_PLUGIN("DarkNetAdd", ADDPlugin);

注意:
注册插件的字符串ResizeBilinear和 virtual const char* getPluginType() const override { return “ResizeBilinear”; }的字符串保持一致。

2.3 使用插件,修改builtin_op_importers.cpp

插件写好了,什么时候使用的呢?我怎么让tensorrt使用我的插件呢?

答案: 具体是在builtin_op_importers.cpp中进行控制的,这里以upsample为例子:

DEFINE_BUILTIN_OP_IMPORTER(Upsample) {ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE);nvinfer1::ITensor &tensor = inputs.at(0).tensor();ASSERT(tensor.getDimensions().nbDims == 3, ErrorCode::kUNSUPPORTED_NODE);OnnxAttrs attrs(node);float height_scale, width_scale;if (ctx->getOpsetVersion() >= 9) {ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE);auto scales_input = inputs.at(1);ASSERT(scales_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE);ShapedWeights scales_weights = scales_input.weights();ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales_weights.count() == 4, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT,ErrorCode::kINVALID_NODE);float const *scales_ptr = static_cast<float const *>(scales_weights.values);ASSERT(scales_ptr[0] == 1 && scales_ptr[1] == 1,ErrorCode::kUNSUPPORTED_NODE);height_scale = scales_ptr[2];width_scale = scales_ptr[3];} else {if (!attrs.count("scales")) {height_scale = attrs.get<float>("height_scale");width_scale = attrs.get<float>("width_scale");} else {auto scales = attrs.get<std::vector<float>>("scales");ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE);height_scale = scales[2];width_scale = scales[3];}}auto scale = {height_scale, width_scale};auto mode = attrs.get<std::string>("mode", "nearest");        // 默认采用 nearest 上采样方式ASSERT(mode == "nearest" || "linear", ErrorCode::kUNSUPPORTED_NODE);if (mode == "nearest")RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()}));        // 这里确定使用何种自定义的类别插件else if (mode == "linear")RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeBilinearPlugin(scale), {&inputs.at(0).tensor()}));}
  • DEFINE_BUILTIN_OP_IMPORTER(Upsample)是在onnx模型解析导入的时候调用的。
  • ctx->addPlugin(new ResizeBilinearPlugin(scale), {&inputs.at(0).tensor()})就是初始化一个类,后续tensorrt模型的序列化,推理就会使用新定义的(自己定义的)网络层

2 动态库编译

前面插件也写好了,onnx模型也能解析了,别人怎么用呢?

  • 参考TensorRT-YOLOv4直接给源码编译
  • 编译成动态库,给别人动态库,别人直接用动态库

这里还是使用TensorRT-YOLOv4编译成动态库。
其实动态库中已经包含了我们刚才修改的文件:

  • resizebilinear.cu
  • resizebilinear.h
  • builtin_op_importers.cpp

后续在使用的时候还是用tensorrt原有的头文件,链接时候,链接上前面编译好的动态库就好了。比如:

cmake_minimum_required(VERSION 2.8)
project(net)find_package(CUDA REQUIRED)include_directories(../include   
)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast ")
set(CUDA_NVCC_FLAGS  "-D_FORCE_INLINES -Xcompiler -fPIC -gencode arch=compute_${GPU_ARCHS},code=sm_${GPU_ARCHS} -gencode arch=compute_${GPU_ARCHS},code=compute_${GPU_ARCHS}")# packed so library
set(srcs net.cpp resize.cu)
cuda_add_library(megengine SHARED ${srcs})
target_link_libraries(megengine mynvonnxparser              # 不用包含tensorrt plugin层的头文件,采用原有的头文件就可以mynvonnxparser_runtime)# 1. 不需要包含opencv_libs, 因为没有使用opencv的操作

说明:

  • 这里的mynvonnxparser, mynvonnxparser_runtime就是前面编译的动态库
  • tensorrt的头文件我放到了/usr/include下,所有cmake中没有指定

other

  • 下载地址
  • 一定要主要tensorrt的版本,不同版本插件的书写是不一样的
    没有积分的留下邮箱吧

这篇关于【tensorrt】——双线性上采样插件(提供源码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056828

相关文章

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

IDEA常用插件之代码扫描SonarLint详解

《IDEA常用插件之代码扫描SonarLint详解》SonarLint是一款用于代码扫描的插件,可以帮助查找隐藏的bug,下载并安装插件后,右键点击项目并选择“Analyze”、“Analyzewit... 目录SonajavascriptrLint 查找隐藏的bug下载安装插件扫描代码查看结果总结Sona

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

kubelet组件的启动流程源码分析

概述 摘要: 本文将总结kubelet的作用以及原理,在有一定基础认识的前提下,通过阅读kubelet源码,对kubelet组件的启动流程进行分析。 正文 kubelet的作用 这里对kubelet的作用做一个简单总结。 节点管理 节点的注册 节点状态更新 容器管理(pod生命周期管理) 监听apiserver的容器事件 容器的创建、删除(CRI) 容器的网络的创建与删除

Maven(插件配置和生命周期的绑定)

1.这篇文章很好,介绍的maven插件的。 2.maven的source插件为例,可以把源代码打成包。 Goals Overview就可以查看该插件下面所有的目标。 这里我们要使用的是source:jar-no-fork。 3.查看source插件的example,然后配置到riil-collect.xml中。  <build>   <plugins>    <pl