yolact导出onnx

2024-08-24 22:04
文章标签 导出 onnx yolact

本文主要是介绍yolact导出onnx,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

github上有yolact-onnx仓库可以导出不带有nms和两个分支的矩阵相乘的部分,但是无法导出带有nms的部分。

一、导出的代码

注意opset版本最低要求14, torch.onnx.export的输入要求是真实图片,否则后续推理会报错。

import torch
import cv2from yolact import Yolactdef export_onnx_model(saved_onnx_model):"""将模型导出为onnx格式, opset版本最低要设置14, 11的话有个算子不能导出"""device = torch.device('cpu')net = Yolact()net.load_weights('weights/yolact_base_54_800000.pth')net.to(device)net.eval()img = cv2.imread('images/test/4.jpg')img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (550, 550))img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(device)# img = torch.randn(1, 3, 550, 550).to(device)torch.onnx.export(net, img, saved_onnx_model, verbose=True, opset_version=17)export_onnx_model('yolact.onnx')

二、Bug及解决

  1. FPN
RuntimeError: Tried to trace <__torch__.yolact.FPN object at 0x6db3f50> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced.

解决:在yolact.py 第25行将 use_jit 设置为False。

use_jit = torch.cuda.device_count() <= 1
use_jit = False   
if not use_jit:print('Multiple GPUs detected! Turning off JIT.')
  1. numpy
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

解决:detection.py第208行改为如下:

# preds = torch.cat([boxes[conf_mask], cls_scores[:, None]], dim=1).cpu().numpy()
preds = torch.cat([boxes[conf_mask], cls_scores[:, None]], dim=1).cpu().detach().numpy()

第30行将use_fast_nms改为True:

self.use_fast_nms = True
# self.use_fast_nms = False
  1. tupels
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: Yolact

解决:将detection.py 第76行改为如下。这里的net是后处理用来eval_mask的,但是那个if语句是False,相当于返回去也没用上,这里直接不返回也没关系。不然输出的tuple无法转为onnx。

# out.append({'detection': result, 'net': net})
out.append(result)

output_utils.py中注释掉net

dets = det_output[batch_idx]
# net = dets['net']
dets = dets['detection']
...
# if cfg.use_maskiou:
#     with timer.env('maskiou_net'):                
#         with torch.no_grad():
#             maskiou_p = net.maskiou_net(masks.unsqueeze(1))
#             maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
#             if cfg.rescore_mask:
#                 if cfg.rescore_bbox:
#                     scores = scores * maskiou_p
#                 else:
#                     scores = [scores, scores * maskiou_p]
  1. output_utils.py

第74行增加这一句,输出增加 ‘priors’,不然后续推理出错:

if result is not None and proto_data is not None:
result['proto'] = proto_data[batch_idx]
result['priors'] = prior_data[batch_idx]   # add, important

第62行注释掉,这里默认为False,不会执行。只是为了后面我的验证代码能够正常运行:

# Test flag, do not upvote
# if cfg.mask_proto_debug:
#     np.save('scripts/proto.npy', proto_data.cpu().numpy())# if visualize_lincomb:
#     display_lincomb(proto_data, masks)
  1. box_utils.py
    (很重要)这里要将@torch.jit.script注释掉,否则到处的结果是错误的:
# @torch.jit.script
def decode(loc, priors, use_yolo_regressors:bool=False):

至此可以导出。

三、验证

注意修改图片路径:

import torch
import onnx
import os
import cv2
import torch
import argparse
from data import COCODetection, get_label_map, MEANS, COLORS
# from eval import parse_args
from eval import args
from layers.output_utils import postprocess
import onnxruntime as rt
import thop
from torch.profiler import profile, record_function, ProfilerActivity
from yolact import Yolact
from torch.utils.data import Dataset
from utils.augmentations import BaseTransform, FastBaseTransform, Resize
from layers import Detect
from collections import defaultdict
from data import cfg
from utils import timerdef prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):"""Note: If undo_transform=False then im_h and im_w are allowed to be None."""# args =parse_args()img_gpu = img / 255.0h, w, _ = img.shape# 后处理 w, h = 612 612with timer.env('Post'):t = postprocess(dets_out, w, h, visualize_lincomb = args.display_lincomb,crop_masks        = args.crop,score_threshold   = args.score_threshold)idx = t[1].argsort(0, descending=True)[:args.top_k]  # 012345   args.top_k=5?# if cfg.eval_mask_branch:# Masks are drawn on the GPU, so don't copy# masks = t[3][idx]classes, scores, boxes, masks = [x[idx].cpu().numpy() for x in t[:4]]  # 5,4    最终后处理的结果masks = torch.tensor(masks)num_dets_to_consider = min(args.top_k, classes.shape[0])  # 指定要检测的最大目标数 vs 检测出来的目标个数,取最小值for j in range(num_dets_to_consider):if scores[j] < args.score_threshold:num_dets_to_consider = jbreak# Quick and dirty lambda for selecting the color for a particular index# Also keeps track of a per-gpu color cache for maximum speeddef get_color(j, on_gpu=None):global color_cachecolor_idx = (classes[j] * 5 if class_color else j * 5) % len(COLORS)if on_gpu is not None and color_idx in color_cache[on_gpu]:return color_cache[on_gpu][color_idx]else:color = COLORS[color_idx]if not undo_transform:# The image might come in as RGB or BRG, dependingcolor = (color[2], color[1], color[0])if on_gpu is not None:color = torch.Tensor(color).to(on_gpu).float() / 255.color_cache[on_gpu][color_idx] = colorreturn color# First, draw the masks on the GPU where we can do it really fast# Beware: very fast but possibly unintelligible mask-drawing code ahead# I wish I had access to OpenGL or Vulkan but alas, I guess Pytorch tensor operations will have to sufficeif args.display_masks and num_dets_to_consider > 0:# After this, mask is of size [num_dets, h, w, 1]masks = masks[:num_dets_to_consider, :, :, None]# Prepare the RGB images for each mask given their color (size [num_dets, h, w, 1])colors = torch.cat([torch.Tensor(get_color(j, on_gpu=img_gpu.device.index)).view(1, 1, 1, 3) for j in range(num_dets_to_consider)], dim=0)masks_color = masks.repeat(1, 1, 1, 3) * colors * mask_alpha  # 3,1,1,3 -->3,h,w,3# This is 1 everywhere except for 1-mask_alpha where the mask isinv_alph_masks = masks * (-mask_alpha) + 1# I did the math for this on pen and paper. This whole block should be equivalent to:#    for j in range(num_dets_to_consider):#        img_gpu = img_gpu * inv_alph_masks[j] + masks_color[j]masks_color_summand = masks_color[0]if num_dets_to_consider > 1:inv_alph_cumul = inv_alph_masks[:(num_dets_to_consider-1)].cumprod(dim=0)masks_color_cumul = masks_color[1:] * inv_alph_cumulmasks_color_summand += masks_color_cumul.sum(dim=0)img_gpu = img_gpu * inv_alph_masks.prod(dim=0) + masks_color_summand# Then draw the stuff that needs to be done on the cpu# Note, make sure this is a uint8 tensor or opencv will not anti alias text for whatever reasonimg_numpy = (img_gpu * 255).byte().cpu().numpy()if num_dets_to_consider == 0:return img_numpyif args.display_text or args.display_bboxes:for j in reversed(range(num_dets_to_consider)):x1, y1, x2, y2 = boxes[j, :]color = get_color(j)score = scores[j]if args.display_bboxes:cv2.rectangle(img_numpy, (x1, y1), (x2, y2), color, 1)if args.display_text:_class = cfg.dataset.class_names[classes[j]]text_str = '%s: %.2f' % (_class, score) if args.display_scores else _classfont_face = cv2.FONT_HERSHEY_DUPLEXfont_scale = 0.6font_thickness = 1text_w, text_h = cv2.getTextSize(text_str, font_face, font_scale, font_thickness)[0]text_pt = (x1, y1 - 3)text_color = [255, 255, 255]cv2.rectangle(img_numpy, (x1, y1), (x1 + text_w, y1 - text_h - 4), color, -1)cv2.putText(img_numpy, text_str, text_pt, font_face, font_scale, text_color, font_thickness, cv2.LINE_AA)return img_numpydef eval_onnx_with_nms(onnx_model_path):"""使用导出的onnx带有nms的模型进行推理"""print("\nRunning eval_onnx_with_nms\n")onnx_model = onnx.load(onnx_model_path)save_path = 'onnx.jpg' # 输出图片路径path = '4.jpg'        # 修改输入图片路径frame = torch.from_numpy(cv2.imread(path)).float()batch = FastBaseTransform()(frame.unsqueeze(0))# 检查模型try:onnx.checker.check_model(onnx_model)print("Model check passed.")except Exception as e:print(f"Model check failed: {e}")sess = rt.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])input_name = sess.get_inputs()[0].nameloc_name = sess.get_outputs()[0].nameconf_name = sess.get_outputs()[1].namemask_name = sess.get_outputs()[2].namepriors_name = sess.get_outputs()[3].nameproto_name = sess.get_outputs()[4].nameproto_name2 = sess.get_outputs()[5].namewith timer.env("ONNX Runtime"):preds = sess.run([loc_name, conf_name, mask_name, priors_name, proto_name, proto_name2], {input_name: batch.cpu().detach().numpy()})# preds是一个包含100*4 array的list# """# preds是一个列表, 包含以下元素:# boxes: (N, 4) --> 100, 4# mask: (N, 32) -->100, 32# class: (N,) --> 100# score: (N,) --> 100# proto: 138,138,32 --> 138,138,32 # priors: (4) --> 4# """preds_out = [{'box': torch.tensor(preds[0]), 'mask': torch.tensor(preds[1]), 'class': torch.tensor(preds[2]), 'score': torch.tensor(preds[3]),'proto': torch.tensor(preds[4]),'priors': torch.tensor(preds[5])}]# for k,v in preds_out[0].items():#     print(k, v.shape)img_numpy = prep_display(preds_out, frame, None, None, undo_transform=False)if save_path is None:img_numpy = img_numpy[:, :, (2, 1, 0)]    cv2.imwrite(save_path, img_numpy)  # 保存图片
eval_onnx_with_nms('yolact.onnx')

最终推理结果:
在这里插入图片描述

这篇关于yolact导出onnx的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103728

相关文章

MySQL使用mysqldump导出数据

mysql mysqldump只导出表结构或只导出数据的实现方法 备份数据库: #mysqldump 数据库名 >数据库备份名 #mysqldump -A -u用户名 -p密码 数据库名>数据库备份名 #mysqldump -d -A --add-drop-table -uroot -p >xxx.sql 1.导出结构不导出数据 mysqldump --opt -d 数据库名 -u

一步一步将PlantUML类图导出为自定义格式的XMI文件

一步一步将PlantUML类图导出为自定义格式的XMI文件 说明: 首次发表日期:2024-09-08PlantUML官网: https://plantuml.com/zh/PlantUML命令行文档: https://plantuml.com/zh/command-line#6a26f548831e6a8cPlantUML XMI文档: https://plantuml.com/zh/xmi

SpringBoot中利用EasyExcel+aop实现一个通用Excel导出功能

一、结果展示 主要功能:可以根据前端传递的参数,导出指定列、指定行 1.1 案例一 前端页面 传递参数 {"excelName": "导出用户信息1725738666946","sheetName": "导出用户信息","fieldList": [{"fieldName": "userId","fieldDesc": "用户id"},{"fieldName": "age","fieldDe

F12抓包06-4:导出metersphere脚本

metersphere是一站式的开源持续测试平台,我们可以将浏览器请求导出为HAR文件,导入到metersphere,生成接口测试。 metersphere有2种导入入口(方式),导入结果不同:         1.导入到“接口定义”:自动生成接口API和单接口case(接口自动去重;每个请求生成不同case,重复的请求生成重复的case,名称自动加数字后缀,自动与接口关联)。

mysql导出导入数据和修改登录密码

导出表结构: mysqldump -uroot -ppassword -d dbname tablename>db.sql; 导出表数据: mysqldump -t dbname -uroot -ppassword > db.sql 导出表结构和数据(不加-d): mysqldump -uroot -ppassword dbname tablename > db.sql;

.Net Mvc-导出PDF-思路方案

效果图: 导语:     在我们做项目的过程中,经常会遇到一些服务性的需求,感到特别困扰,明明实用的价值不高,但是还是得实现;     因此小客在这里整理一下自己导出PDF的一些思路,供大家参考。     网上有很多导出PDF运用到的插件,大家也可以看看其他插件的使用,学习学习; 提要:     这里我使用的是-iTextSharp,供大家参考参考,借鉴方案,完善思路,补充自己,一起学习

.net MVC 导出Word--思路详解

序言:          一般在项目的开发过程中,总会接收到一个个需求,其中将数据转换成Work来下载,是一个很常见的需求;          那么,我们改如何处理这种需求,并输出实现呢?          在做的过程中,去思考 1、第一步:首先确认,Work的存在位置,并创建字符输出路:             //在的项目中创建一个存储work的文件夹             string

yolov8 pt转onnx

第一步: 安装onnx pip install --upgrade onnx 第二步: 将以下代码创建、拷贝到yolov8根目录下。具体代码test.py: from ultralytics import YOLO# Load a modelmodel = YOLO('yolov8n.pt') # load an official model# Export the model

如何将Product依赖的LibraryModule导出成jar

在Android Studio新建Module时可以选择创建的module是工程module还是Android Library。 或者可以在工程module中的build.gradle文件中将 apply plugin: 'com.android.application'改为apply plugin: 'com.android.library' 同时将applicati

Java 导出数据到Excel中(详细代码)

前言 平时开发中,经常会用到导入导出,绝大部分是excel表格,所以开发对office的处理需要熟悉的。office的处理上我认为还是C#最好,功能最全,基本什么 功能都能实现。毕竟一家的东西,其它像java,c++,都有解决方案,下面说java如何处理的excel的。使用的是Apache POI,感觉是java处理excel中最好的。 先看结果: Java实现代码 1.pom 引包