【tensorrt】——双线性上采样插件(提供源码)

2024-06-13 08:48

本文主要是介绍【tensorrt】——双线性上采样插件(提供源码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述
在这里插入图片描述

简介:
如果用nvidia的gpu,在推理的时候,采用tensorrt进行加速是一个很好的选择,虽然tensorrt没有开源。

我一般选择的模型训练到部署的流程是:

  1. pytorch训练模型
  2. onnx模型导出
  3. onnx模型转ncnn,mnn,tensorrt等模型
  4. 嵌入式推理框架,推理脚本书写。

这里用tensorrt做语义分割网络pspnet的推理加速。技术路线采用:pytorch——onnx——tensorrt。

1. pytorch——onnx

pytorch是内嵌了onnx模型导出的。这里pytorch版本的选择由使用的tensorrt的版本确定。这里我们采用TensorRT-YOLOv4项目中onnx-tensorrt中的tensorrt版本5.1xx。

这个版本上采样onnx中还是upsample,对应到pytorch<=1.0。pytorch1.0是支持nearest,bilinear两种方式的导出的。

2. onnx——tensorrt

TensorRT-YOLOv4中有resizenearest插件是没有双线性插值的。

2.1 写插件

resizebilinear是没有网络权重参数的,所以没有序列化重构,可以需要对以下进行重构。
需要重构:

  • getPluginType:
  • getOutputDimensions:计算网络输出tensor的尺寸
  • initialize:
  • enqueue:前向推理的具体入口
#pragma once#include "plugin.hpp"
#include "serialize.hpp"
#include <cassert>class ResizeBilinearPlugin final : public onnx2trt::Plugin {int   _ndims;float _scale[nvinfer1::Dims::MAX_DIMS];nvinfer1::Dims _output_dims;protected:void deserialize(void const* serialData, size_t serialLength) {deserializeBase(serialData, serialLength);deserialize_value(&serialData, &serialLength, &_ndims);deserialize_value(&serialData, &serialLength, &_scale);}size_t getSerializationSize() override {return serialized_size(_ndims) + serialized_size(_scale) + getBaseSerializationSize();}void serialize(void *buffer) override {serializeBase(buffer);serialize_value(&buffer, _ndims);serialize_value(&buffer, _scale);}public:ResizeBilinearPlugin(std::vector<float> const& scale): _ndims(scale.size()) {assert(scale.size() <= nvinfer1::Dims::MAX_DIMS);std::copy(scale.begin(), scale.end(), _scale);}ResizeBilinearPlugin(void const* serialData, size_t serialLength) {this->deserialize(serialData, serialLength);}virtual const char* getPluginType() const override { return "ResizeBilinear"; }virtual int getNbOutputs() const override { return 1; }virtual nvinfer1::Dims getOutputDimensions(int index,const nvinfer1::Dims *inputs, int nbInputDims) override;virtual int initialize() override;int enqueue(int batchSize,const void *const *inputs, void **outputs,void *workspace, cudaStream_t stream) override;
};

重构之后,一般都会向,tensorrt怎么调了。具体可以参考:【onnx-tensorrt】——源码阅读记录

总结就是:你看不到调用的接口,你只能模仿着写。

2.2 注册插件

builtin_plugins.cpp 中注册插件

REGISTER_BUILTIN_PLUGIN("FancyActivation",       FancyActivationPlugin);        // 相当于入库
REGISTER_BUILTIN_PLUGIN("ResizeNearest",         ResizeNearestPlugin);
REGISTER_BUILTIN_PLUGIN("ResizeBilinear",        ResizeBilinearPlugin);
REGISTER_BUILTIN_PLUGIN("Split"        ,         SplitPlugin);
REGISTER_BUILTIN_PLUGIN("InstanceNormalization", InstanceNormalizationPlugin);
REGISTER_BUILTIN_NVPLUGIN("Concat", ConcatPlugin);
REGISTER_BUILTIN_PLUGIN("DCNv2", DCNv2Plugin);
REGISTER_BUILTIN_PLUGIN("Mish", MishPlugin);
REGISTER_BUILTIN_PLUGIN("YOLO", YOLOPlugin);
REGISTER_BUILTIN_PLUGIN("DarkNetAdd", ADDPlugin);

注意:
注册插件的字符串ResizeBilinear和 virtual const char* getPluginType() const override { return “ResizeBilinear”; }的字符串保持一致。

2.3 使用插件,修改builtin_op_importers.cpp

插件写好了,什么时候使用的呢?我怎么让tensorrt使用我的插件呢?

答案: 具体是在builtin_op_importers.cpp中进行控制的,这里以upsample为例子:

DEFINE_BUILTIN_OP_IMPORTER(Upsample) {ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE);nvinfer1::ITensor &tensor = inputs.at(0).tensor();ASSERT(tensor.getDimensions().nbDims == 3, ErrorCode::kUNSUPPORTED_NODE);OnnxAttrs attrs(node);float height_scale, width_scale;if (ctx->getOpsetVersion() >= 9) {ASSERT(inputs.size() == 2, ErrorCode::kINVALID_NODE);auto scales_input = inputs.at(1);ASSERT(scales_input.is_weights(), ErrorCode::kUNSUPPORTED_NODE);ShapedWeights scales_weights = scales_input.weights();ASSERT(scales_weights.shape.nbDims == 1, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales_weights.count() == 4, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales_weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT,ErrorCode::kINVALID_NODE);float const *scales_ptr = static_cast<float const *>(scales_weights.values);ASSERT(scales_ptr[0] == 1 && scales_ptr[1] == 1,ErrorCode::kUNSUPPORTED_NODE);height_scale = scales_ptr[2];width_scale = scales_ptr[3];} else {if (!attrs.count("scales")) {height_scale = attrs.get<float>("height_scale");width_scale = attrs.get<float>("width_scale");} else {auto scales = attrs.get<std::vector<float>>("scales");ASSERT(scales.size() == 4, ErrorCode::kUNSUPPORTED_NODE);ASSERT(scales[0] == 1 && scales[1] == 1, ErrorCode::kUNSUPPORTED_NODE);height_scale = scales[2];width_scale = scales[3];}}auto scale = {height_scale, width_scale};auto mode = attrs.get<std::string>("mode", "nearest");        // 默认采用 nearest 上采样方式ASSERT(mode == "nearest" || "linear", ErrorCode::kUNSUPPORTED_NODE);if (mode == "nearest")RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()}));        // 这里确定使用何种自定义的类别插件else if (mode == "linear")RETURN_FIRST_OUTPUT(ctx->addPlugin(new ResizeBilinearPlugin(scale), {&inputs.at(0).tensor()}));}
  • DEFINE_BUILTIN_OP_IMPORTER(Upsample)是在onnx模型解析导入的时候调用的。
  • ctx->addPlugin(new ResizeBilinearPlugin(scale), {&inputs.at(0).tensor()})就是初始化一个类,后续tensorrt模型的序列化,推理就会使用新定义的(自己定义的)网络层

2 动态库编译

前面插件也写好了,onnx模型也能解析了,别人怎么用呢?

  • 参考TensorRT-YOLOv4直接给源码编译
  • 编译成动态库,给别人动态库,别人直接用动态库

这里还是使用TensorRT-YOLOv4编译成动态库。
其实动态库中已经包含了我们刚才修改的文件:

  • resizebilinear.cu
  • resizebilinear.h
  • builtin_op_importers.cpp

后续在使用的时候还是用tensorrt原有的头文件,链接时候,链接上前面编译好的动态库就好了。比如:

cmake_minimum_required(VERSION 2.8)
project(net)find_package(CUDA REQUIRED)include_directories(../include   
)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast ")
set(CUDA_NVCC_FLAGS  "-D_FORCE_INLINES -Xcompiler -fPIC -gencode arch=compute_${GPU_ARCHS},code=sm_${GPU_ARCHS} -gencode arch=compute_${GPU_ARCHS},code=compute_${GPU_ARCHS}")# packed so library
set(srcs net.cpp resize.cu)
cuda_add_library(megengine SHARED ${srcs})
target_link_libraries(megengine mynvonnxparser              # 不用包含tensorrt plugin层的头文件,采用原有的头文件就可以mynvonnxparser_runtime)# 1. 不需要包含opencv_libs, 因为没有使用opencv的操作

说明:

  • 这里的mynvonnxparser, mynvonnxparser_runtime就是前面编译的动态库
  • tensorrt的头文件我放到了/usr/include下,所有cmake中没有指定

other

  • 下载地址
  • 一定要主要tensorrt的版本,不同版本插件的书写是不一样的
    没有积分的留下邮箱吧

这篇关于【tensorrt】——双线性上采样插件(提供源码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056828

相关文章

springboot家政服务管理平台 LW +PPT+源码+讲解

3系统的可行性研究及需求分析 3.1可行性研究 3.1.1技术可行性分析 经过大学四年的学习,已经掌握了JAVA、Mysql数据库等方面的编程技巧和方法,对于这些技术该有的软硬件配置也是齐全的,能够满足开发的需要。 本家政服务管理平台采用的是Mysql作为数据库,可以绝对地保证用户数据的安全;可以与Mysql数据库进行无缝连接。 所以,家政服务管理平台在技术上是可以实施的。 3.1

高仿精仿愤怒的小鸟android版游戏源码

这是一款很完美的高仿精仿愤怒的小鸟android版游戏源码,大家可以研究一下吧、 为了报复偷走鸟蛋的肥猪们,鸟儿以自己的身体为武器,仿佛炮弹一样去攻击肥猪们的堡垒。游戏是十分卡通的2D画面,看着愤怒的红色小鸟,奋不顾身的往绿色的肥猪的堡垒砸去,那种奇妙的感觉还真是令人感到很欢乐。而游戏的配乐同样充满了欢乐的感觉,轻松的节奏,欢快的风格。 源码下载

上采样(upsample)的方法

上采样(upsample)的方法   在神经网络中,扩大特征图的方法,即upsample/上采样的方法   1)unpooling:恢复max的位置,其余部分补零   2)deconvolution(反卷积):先对input补零,再conv   3)插值方法,双线性插值等;   4)扩张卷积,dilated conv;

基于Java医院药品交易系统详细设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W+,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码+数据库🌟 感兴趣的可以先收藏起来,还有大家在毕设选题,项目以及论文编写等相关问题都可以给我留言咨询,希望帮助更多的人  Java精品实战案例《600套》 2023-2025年最值得选择的Java毕业设计选题大全:1000个热

WordPress网创自动采集并发布插件

网创教程:WordPress插件网创自动采集并发布 阅读更新:随机添加文章的阅读数量,购买数量,喜欢数量。 使用插件注意事项 如果遇到404错误,请先检查并调整网站的伪静态设置,这是最常见的问题。需要定制化服务,请随时联系我。 本次更新内容 我们进行了多项更新和优化,主要包括: 界面设置:用户现在可以更便捷地设置文章分类和发布金额。代码优化:改进了采集和发布代码,提高了插件的稳定

vscode-创建vue3项目-修改暗黑主题-常见错误-element插件标签-用法涉及问题

文章目录 1.vscode创建运行编译vue3项目2.添加项目资源3.添加element-plus元素4.修改为暗黑主题4.1.在main.js主文件中引入暗黑样式4.2.添加自定义样式文件4.3.html页面html标签添加样式 5.常见错误5.1.未使用变量5.2.关闭typescript检查5.3.调试器支持5.4.允许未到达代码和未定义代码 6.element常用标签6.1.下拉列表

美容美发店营销版微信小程序源码

打造线上生意新篇章 一、引言:微信小程序,开启美容美发行业新纪元 在数字化时代,微信小程序以其便捷、高效的特点,成为了美容美发行业营销的新宠。本文将带您深入了解美容美发营销微信小程序,探讨其独特优势及如何助力商家实现业务增长。 二、微信小程序:美容美发行业的得力助手 拓宽客源渠道:微信小程序基于微信社交平台,轻松实现线上线下融合,帮助商家快速吸引潜在客户,拓宽客源渠道。 提升用户体验:

风水研究会官网源码系统-可展示自己的领域内容-商品售卖等

一款用于展示风水行业,周易测算行业,玄学行业的系统,并支持售卖自己的商品。 整洁大气,非常漂亮,前端内容均可通过后台修改。 大致功能: 支持前端内容通过后端自定义支持开启关闭会员功能,会员等级设置支持对接官方支付支持添加商品类支持添加虚拟下载类支持自定义其他类型字段支持生成虚拟激活卡支持采集其他站点文章支持对接收益广告支持文章评论支持积分功能支持推广功能更多功能,搭建完成自行体验吧! 原文

HTML5文旅文化旅游网站模板源码

文章目录 1.设计来源文旅宣传1.1 登录界面演示1.2 注册界面演示1.3 首页界面演示1.4 文旅之行界面演示1.5 文旅之行文章内容界面演示1.6 关于我们界面演示1.7 文旅博客界面演示1.8 文旅博客文章内容界面演示1.9 联系我们界面演示 2.效果和源码2.1 动态效果2.2 源代码2.3 源码目录 源码下载万套模板,程序开发,在线开发,在线沟通 作者:xcLeigh

ROS2从入门到精通4-4:局部控制插件开发案例(以PID算法为例)

目录 0 专栏介绍1 控制插件编写模板1.1 构造控制插件类1.2 注册并导出插件1.3 编译与使用插件 2 基于PID的路径跟踪原理3 控制插件开发案例(PID算法)常见问题 0 专栏介绍 本专栏旨在通过对ROS2的系统学习,掌握ROS2底层基本分布式原理,并具有机器人建模和应用ROS2进行实际项目的开发和调试的工程能力。 🚀详情:《ROS2从入门到精通》 1 控制插