【ncnn android】算法移植(七)——pytorch2onnx代码粗看

2024-06-13 09:08

本文主要是介绍【ncnn android】算法移植(七)——pytorch2onnx代码粗看,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目的:

  • 了解torch2onnx的流程
  • 了解其中的一些技术细节

1. 程序细节

  1. get_graph
    将pytorch的模型转成onnx需要的graph
  • graph, torch_out = _trace_and_get_graph_from_model(model, args, training)

  • trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True) warn_on_static_input_change(inputs_states)

  1. graph_export_onnx
proto, export_map = graph._export_onnx(params_dict, opset_version, dynamic_axes, defer_weight_export,operator_export_type, strip_doc_string, val_keep_init_as_ip)

2. 其他

  1. batchnorm
    在保存成onnx的时候,设置verbose=True,可以看有哪些属性。
%554 : Float(1, 16, 8, 8) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%550, %model.detect.context.inconv.conv.weight), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/Conv2d[conv] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0%555 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%554, %model.detect.context.inconv.bn.weight, %model.detect.context.inconv.bn.bias, %model.detect.context.inconv.bn.running_mean, %model.detect.context.inconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0%556 : Float(1, 16, 8, 8) = onnx::Relu(%555), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/ReLU[act] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:912:0%557 : Float(1, 16, 8, 8) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%556, %model.detect.context.upconv.conv.weight), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/Conv2d[conv] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0%558 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%557, %model.detect.context.upconv.bn.weight, %model.detect.context.upconv.bn.bias, %model.detect.context.upconv.bn.running_mean, %model.detect.context.upconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0%559 : Float(1, 16, 8, 8) = onnx::Relu(%558), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/ReLU[act] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:912:0

这里以batchnorm为例,说明一下:

  • 首先是pytorch中的:
    %558 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%557, %model.detect.context.upconv.bn.weight, %model.detect.context.upconv.bn.bias, %model.detect.context.upconv.bn.running_mean, %model.detect.context.upconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
    其中小括号中就是要保存的参数的属性有:bn.weight bn.bias bn.running_mean bn.running_var

  • ncnn中onnx2ncnn中如何读取预训练权重。

const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
const onnx::TensorProto& mean = weights[node.input(3)];
const onnx::TensorProto& var = weights[node.input(4)];
  • node.input(1):bn.weight
  • node.input(2):bn.bias
  • node.input(3):bn.running_mean
  • node.input(4):bn.running_var
    顺序和pytorch2onnx写入的顺序一致
  1. maxpool
  • pytorch的打印信息
%pool_hm : Float(1, 1, 8, 8) = onnx::MaxPool[ceil_mode=0, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%hm), scope: OnnxModel # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:488:0
  • ncnn中如何读取结构参数
    因为maxpool层是没有预训练权重的,只有一些结构参数
std::string auto_pad = get_node_attr_s(node, "auto_pad");//TODO
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
std::vector<int> strides = get_node_attr_ai(node, "strides");
std::vector<int> pads = get_node_attr_ai(node, "pads");
  • 注意:这里“auto_pad”字段和pytorch中的“ceil_model”字段是不一样的。这是因为pytorch2onnx版本和ncnn版本不对应造成的。可能ncnn20180704版时,maxpool的onnx表达中有“auto_pad”属性。

这篇关于【ncnn android】算法移植(七)——pytorch2onnx代码粗看的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056868

相关文章

openCV中KNN算法的实现

《openCV中KNN算法的实现》KNN算法是一种简单且常用的分类算法,本文主要介绍了openCV中KNN算法的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录KNN算法流程使用OpenCV实现KNNOpenCV 是一个开源的跨平台计算机视觉库,它提供了各

Android实现打开本地pdf文件的两种方式

《Android实现打开本地pdf文件的两种方式》在现代应用中,PDF格式因其跨平台、稳定性好、展示内容一致等特点,在Android平台上,如何高效地打开本地PDF文件,不仅关系到用户体验,也直接影响... 目录一、项目概述二、相关知识2.1 PDF文件基本概述2.2 android 文件访问与存储权限2.

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Android Studio 配置国内镜像源的实现步骤

《AndroidStudio配置国内镜像源的实现步骤》本文主要介绍了AndroidStudio配置国内镜像源的实现步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、修改 hosts,解决 SDK 下载失败的问题二、修改 gradle 地址,解决 gradle

Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码

《Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码》:本文主要介绍Java中日期时间转换的多种方法,包括将Date转换为LocalD... 目录一、Date转LocalDateTime二、Date转LocalDate三、LocalDateTim

jupyter代码块没有运行图标的解决方案

《jupyter代码块没有运行图标的解决方案》:本文主要介绍jupyter代码块没有运行图标的解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录jupyter代码块没有运行图标的解决1.找到Jupyter notebook的系统配置文件2.这时候一般会搜索到

在Android平台上实现消息推送功能

《在Android平台上实现消息推送功能》随着移动互联网应用的飞速发展,消息推送已成为移动应用中不可或缺的功能,在Android平台上,实现消息推送涉及到服务端的消息发送、客户端的消息接收、通知渠道(... 目录一、项目概述二、相关知识介绍2.1 消息推送的基本原理2.2 Firebase Cloud Me

Python通过模块化开发优化代码的技巧分享

《Python通过模块化开发优化代码的技巧分享》模块化开发就是把代码拆成一个个“零件”,该封装封装,该拆分拆分,下面小编就来和大家简单聊聊python如何用模块化开发进行代码优化吧... 目录什么是模块化开发如何拆分代码改进版:拆分成模块让模块更强大:使用 __init__.py你一定会遇到的问题模www.

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

springboot循环依赖问题案例代码及解决办法

《springboot循环依赖问题案例代码及解决办法》在SpringBoot中,如果两个或多个Bean之间存在循环依赖(即BeanA依赖BeanB,而BeanB又依赖BeanA),会导致Spring的... 目录1. 什么是循环依赖?2. 循环依赖的场景案例3. 解决循环依赖的常见方法方法 1:使用 @La