pytorch转onnx转mnn并验证

2023-12-18 02:58
文章标签 验证 pytorch onnx mnn

本文主要是介绍pytorch转onnx转mnn并验证,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch训练的模型在实际使用时往往需要转换成onnx或mnn部署,训练好的模型需先转成onnx:

import sys
import argparse
import torch
import torchvision
import torch.onnxfrom  mobilenetv2  import MobileNetV2if __name__ == '__main__':model=MobileNetV2(2)model_path='./model/mobilenetv2.mdl'model.eval()model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))dummy_input = torch.randn([1,3,32,32])   #batch,channel,height,widthtorch.onnx.export(model, dummy_input, model_path.replace('mdl', 'onnx'), verbose=True, input_names=['input'], output_names=['output'],opset_version=11)print('Done!')

转换成功后,再转mnn,通过MNN转换工具:

.MNNConvert -f ONNX --modelFile XXX.onnx --MNNModel XXX.mnn --bizCode biz

测试pytorch的结果:

import argparse
import os
from glob import glob
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()# create modelmodel = MobileNetV2(2).to(device)model.load_state_dict(torch.load('models/best-mobilenetv2.mdl',map_location=torch.device('cpu')))model.eval()img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)x = img.to(device)outputs = model(x)# 输出概率最大的类别_, indices = torch.max(outputs, 1)percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100perc = percentage[int(indices)].item()print('predicted:', perc)print('id:', int(indices))end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

测试ONNX的结果,与pytorch结果一致:

import argparse
import os
from glob import glob
import onnxruntime
import onnx
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef to_numpy(tensor):return tensor.detach().cpu.numpy() if tensor.requires_grad else tensor.cpu().numpy()def main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()model = 'models/best-mobilenetv2.onnx'onet_seeion=onnxruntime.InferenceSession(model)img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)x = img.to(device)inputs={onet_seeion.get_inputs()[0].name:to_numpy(img)}outputs=onet_seeion.run(None,inputs)print(outputs)end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

测试mnn的结果,与前面的结果一致,但是速度快了近20倍:

import argparse
import os
from glob import glob
import MNN
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef to_numpy(tensor):return tensor.detach().cpu.numpy() if tensor.requires_grad else tensor.cpu().numpy()def main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()model = 'models/best-mobilenetv2.mnn'interpreter = MNN.Interpreter(model)mnn_session = interpreter.createSession()input_tensor = interpreter.getSessionInput(mnn_session)img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)tmp_input = MNN.Tensor((1, 3, 32, 32), MNN.Halide_Type_Float, \to_numpy(img[0]), MNN.Tensor_DimensionType_Caffe)print(tmp_input.getShape())# print(tmp_input.getData())print(input_tensor.copyFrom(tmp_input))input_tensor.printTensorData()interpreter.runSession(mnn_session)output_tensor = interpreter.getSessionOutput(mnn_session, 'output')output_tensor.printTensorData()output_data = np.array(output_tensor.getData())print('mnn result is:', output_data)print("output belong to class: {}".format(np.argmax(output_tensor.getData())))end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

用c++进行mnn重写测试,结果一致,这样就可以编库了:

// mnn_test.cpp : 定义控制台应用程序的入口点。#include "stdafx.h"
#include <iostream>
#include <opencv2/opencv.hpp>
#include <MNN/Interpreter.hpp>
#include <MNN/MNNDefine.h>
#include <MNN/Tensor.hpp>
#include <MNN/ImageProcess.hpp>
#include <memory>#define IMAGE_VERIFY_SIZE 32
#define CLASSES_SIZE 2
#define INPUT_NAME "input"
#define OUTPUT_NAME "output"cv::Mat BGRToRGB(cv::Mat img)
{cv::Mat image(img.rows, img.cols, CV_8UC3);for (int i = 0; i<img.rows; ++i) {cv::Vec3b *p1 = img.ptr<cv::Vec3b>(i);cv::Vec3b *p2 = image.ptr<cv::Vec3b>(i);for (int j = 0; j<img.cols; ++j) {p2[j][2] = p1[j][0];p2[j][1] = p1[j][1];p2[j][0] = p1[j][2];}}return image;
}int main(int argc, char* argv[]) {if (argc < 2) {printf("Usage:\n\t%s mnn_model_path image_path\n", argv[0]);return -1;}// create net and sessionconst char *mnn_model_path = argv[1];const char *image_path = argv[2];auto mnnNet = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(mnn_model_path));MNN::ScheduleConfig netConfig;netConfig.type = MNN_FORWARD_CPU;netConfig.numThread = 4;auto session = mnnNet->createSession(netConfig);auto input = mnnNet->getSessionInput(session, INPUT_NAME);if (input->elementSize() <= 4) {mnnNet->resizeTensor(input, { 1, 3, IMAGE_VERIFY_SIZE, IMAGE_VERIFY_SIZE });mnnNet->resizeSession(session);}std::cout << "input shape: " << input->shape()[0] << " " << input->shape()[1] << " " << input->shape()[2] << " " << input->shape()[3] << std::endl;// preprocess imageMNN::Tensor givenTensor(input, MNN::Tensor::CAFFE);// const int inputSize = givenTensor.elementSize();// std::cout << inputSize << std::endl;auto inputData = givenTensor.host<float>();cv::Mat bgr_image = cv::imread(image_path);bgr_image = BGRToRGB(bgr_image);cv::Mat norm_image;cv::resize(bgr_image, norm_image, cv::Size(IMAGE_VERIFY_SIZE, IMAGE_VERIFY_SIZE));for (int k = 0; k < 3; k++) {for (int i = 0; i < norm_image.rows; i++) {for (int j = 0; j < norm_image.cols; j++) {const auto src = norm_image.at<cv::Vec3b>(i, j)[k];auto dst = 0.0;if (k == 0) dst = (float(src) / 255.0f - 0.485) / 0.229;if (k == 1) dst = (float(src) / 255.0f - 0.456) / 0.224;if (k == 2) dst = (float(src) / 255.0f - 0.406) / 0.225;inputData[k * IMAGE_VERIFY_SIZE * IMAGE_VERIFY_SIZE + i * IMAGE_VERIFY_SIZE + j] = dst;}}}input->copyFromHostTensor(&givenTensor);double st = cvGetTickCount();// run sessionmnnNet->runSession(session);double et = cvGetTickCount() - st;et = et / ((double)cvGetTickFrequency() * 1000);std::cout << " speed: " << et << " ms" << std::endl;// get output dataauto output = mnnNet->getSessionOutput(session, OUTPUT_NAME);// std::cout << "output shape: " << output->shape()[0] << " " << output->shape()[1] << std::endl;auto output_host = std::make_shared<MNN::Tensor>(output, MNN::Tensor::CAFFE);output->copyToHostTensor(output_host.get());auto values = output_host->host<float>();// post processstd::vector<float> output_values;auto exp_sum = 0.0;auto max_index = 0;for (int i = 0; i < CLASSES_SIZE; i++) {if (values[i] > values[max_index]) max_index = i;output_values.push_back(values[i]);exp_sum += std::exp(values[i]);}std::cout << "output: " << output_values[0]<<","<< output_values[1] << std::endl;std::cout << "id: " << max_index << std::endl;std::cout << "prob: " << std::exp(output_values[max_index]) / exp_sum << std::endl;system("pause");return 0;
}

 

这篇关于pytorch转onnx转mnn并验证的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/506796

相关文章

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

opencv图像处理之指纹验证的实现

《opencv图像处理之指纹验证的实现》本文主要介绍了opencv图像处理之指纹验证的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、简介二、具体案例实现1. 图像显示函数2. 指纹验证函数3. 主函数4、运行结果三、总结一、

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu