yolact导出onnx

2024-08-24 22:04
文章标签 导出 onnx yolact

本文主要是介绍yolact导出onnx,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

github上有yolact-onnx仓库可以导出不带有nms和两个分支的矩阵相乘的部分,但是无法导出带有nms的部分。

一、导出的代码

注意opset版本最低要求14, torch.onnx.export的输入要求是真实图片,否则后续推理会报错。

import torch
import cv2from yolact import Yolactdef export_onnx_model(saved_onnx_model):"""将模型导出为onnx格式, opset版本最低要设置14, 11的话有个算子不能导出"""device = torch.device('cpu')net = Yolact()net.load_weights('weights/yolact_base_54_800000.pth')net.to(device)net.eval()img = cv2.imread('images/test/4.jpg')img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (550, 550))img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(device)# img = torch.randn(1, 3, 550, 550).to(device)torch.onnx.export(net, img, saved_onnx_model, verbose=True, opset_version=17)export_onnx_model('yolact.onnx')

二、Bug及解决

  1. FPN
RuntimeError: Tried to trace <__torch__.yolact.FPN object at 0x6db3f50> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced.

解决:在yolact.py 第25行将 use_jit 设置为False。

use_jit = torch.cuda.device_count() <= 1
use_jit = False   
if not use_jit:print('Multiple GPUs detected! Turning off JIT.')
  1. numpy
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

解决:detection.py第208行改为如下:

# preds = torch.cat([boxes[conf_mask], cls_scores[:, None]], dim=1).cpu().numpy()
preds = torch.cat([boxes[conf_mask], cls_scores[:, None]], dim=1).cpu().detach().numpy()

第30行将use_fast_nms改为True:

self.use_fast_nms = True
# self.use_fast_nms = False
  1. tupels
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: Yolact

解决:将detection.py 第76行改为如下。这里的net是后处理用来eval_mask的,但是那个if语句是False,相当于返回去也没用上,这里直接不返回也没关系。不然输出的tuple无法转为onnx。

# out.append({'detection': result, 'net': net})
out.append(result)

output_utils.py中注释掉net

dets = det_output[batch_idx]
# net = dets['net']
dets = dets['detection']
...
# if cfg.use_maskiou:
#     with timer.env('maskiou_net'):                
#         with torch.no_grad():
#             maskiou_p = net.maskiou_net(masks.unsqueeze(1))
#             maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
#             if cfg.rescore_mask:
#                 if cfg.rescore_bbox:
#                     scores = scores * maskiou_p
#                 else:
#                     scores = [scores, scores * maskiou_p]
  1. output_utils.py

第74行增加这一句,输出增加 ‘priors’,不然后续推理出错:

if result is not None and proto_data is not None:
result['proto'] = proto_data[batch_idx]
result['priors'] = prior_data[batch_idx]   # add, important

第62行注释掉,这里默认为False,不会执行。只是为了后面我的验证代码能够正常运行:

# Test flag, do not upvote
# if cfg.mask_proto_debug:
#     np.save('scripts/proto.npy', proto_data.cpu().numpy())# if visualize_lincomb:
#     display_lincomb(proto_data, masks)
  1. box_utils.py
    (很重要)这里要将@torch.jit.script注释掉,否则到处的结果是错误的:
# @torch.jit.script
def decode(loc, priors, use_yolo_regressors:bool=False):

至此可以导出。

三、验证

注意修改图片路径:

import torch
import onnx
import os
import cv2
import torch
import argparse
from data import COCODetection, get_label_map, MEANS, COLORS
# from eval import parse_args
from eval import args
from layers.output_utils import postprocess
import onnxruntime as rt
import thop
from torch.profiler import profile, record_function, ProfilerActivity
from yolact import Yolact
from torch.utils.data import Dataset
from utils.augmentations import BaseTransform, FastBaseTransform, Resize
from layers import Detect
from collections import defaultdict
from data import cfg
from utils import timerdef prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):"""Note: If undo_transform=False then im_h and im_w are allowed to be None."""# args =parse_args()img_gpu = img / 255.0h, w, _ = img.shape# 后处理 w, h = 612 612with timer.env('Post'):t = postprocess(dets_out, w, h, visualize_lincomb = args.display_lincomb,crop_masks        = args.crop,score_threshold   = args.score_threshold)idx = t[1].argsort(0, descending=True)[:args.top_k]  # 012345   args.top_k=5?# if cfg.eval_mask_branch:# Masks are drawn on the GPU, so don't copy# masks = t[3][idx]classes, scores, boxes, masks = [x[idx].cpu().numpy() for x in t[:4]]  # 5,4    最终后处理的结果masks = torch.tensor(masks)num_dets_to_consider = min(args.top_k, classes.shape[0])  # 指定要检测的最大目标数 vs 检测出来的目标个数,取最小值for j in range(num_dets_to_consider):if scores[j] < args.score_threshold:num_dets_to_consider = jbreak# Quick and dirty lambda for selecting the color for a particular index# Also keeps track of a per-gpu color cache for maximum speeddef get_color(j, on_gpu=None):global color_cachecolor_idx = (classes[j] * 5 if class_color else j * 5) % len(COLORS)if on_gpu is not None and color_idx in color_cache[on_gpu]:return color_cache[on_gpu][color_idx]else:color = COLORS[color_idx]if not undo_transform:# The image might come in as RGB or BRG, dependingcolor = (color[2], color[1], color[0])if on_gpu is not None:color = torch.Tensor(color).to(on_gpu).float() / 255.color_cache[on_gpu][color_idx] = colorreturn color# First, draw the masks on the GPU where we can do it really fast# Beware: very fast but possibly unintelligible mask-drawing code ahead# I wish I had access to OpenGL or Vulkan but alas, I guess Pytorch tensor operations will have to sufficeif args.display_masks and num_dets_to_consider > 0:# After this, mask is of size [num_dets, h, w, 1]masks = masks[:num_dets_to_consider, :, :, None]# Prepare the RGB images for each mask given their color (size [num_dets, h, w, 1])colors = torch.cat([torch.Tensor(get_color(j, on_gpu=img_gpu.device.index)).view(1, 1, 1, 3) for j in range(num_dets_to_consider)], dim=0)masks_color = masks.repeat(1, 1, 1, 3) * colors * mask_alpha  # 3,1,1,3 -->3,h,w,3# This is 1 everywhere except for 1-mask_alpha where the mask isinv_alph_masks = masks * (-mask_alpha) + 1# I did the math for this on pen and paper. This whole block should be equivalent to:#    for j in range(num_dets_to_consider):#        img_gpu = img_gpu * inv_alph_masks[j] + masks_color[j]masks_color_summand = masks_color[0]if num_dets_to_consider > 1:inv_alph_cumul = inv_alph_masks[:(num_dets_to_consider-1)].cumprod(dim=0)masks_color_cumul = masks_color[1:] * inv_alph_cumulmasks_color_summand += masks_color_cumul.sum(dim=0)img_gpu = img_gpu * inv_alph_masks.prod(dim=0) + masks_color_summand# Then draw the stuff that needs to be done on the cpu# Note, make sure this is a uint8 tensor or opencv will not anti alias text for whatever reasonimg_numpy = (img_gpu * 255).byte().cpu().numpy()if num_dets_to_consider == 0:return img_numpyif args.display_text or args.display_bboxes:for j in reversed(range(num_dets_to_consider)):x1, y1, x2, y2 = boxes[j, :]color = get_color(j)score = scores[j]if args.display_bboxes:cv2.rectangle(img_numpy, (x1, y1), (x2, y2), color, 1)if args.display_text:_class = cfg.dataset.class_names[classes[j]]text_str = '%s: %.2f' % (_class, score) if args.display_scores else _classfont_face = cv2.FONT_HERSHEY_DUPLEXfont_scale = 0.6font_thickness = 1text_w, text_h = cv2.getTextSize(text_str, font_face, font_scale, font_thickness)[0]text_pt = (x1, y1 - 3)text_color = [255, 255, 255]cv2.rectangle(img_numpy, (x1, y1), (x1 + text_w, y1 - text_h - 4), color, -1)cv2.putText(img_numpy, text_str, text_pt, font_face, font_scale, text_color, font_thickness, cv2.LINE_AA)return img_numpydef eval_onnx_with_nms(onnx_model_path):"""使用导出的onnx带有nms的模型进行推理"""print("\nRunning eval_onnx_with_nms\n")onnx_model = onnx.load(onnx_model_path)save_path = 'onnx.jpg' # 输出图片路径path = '4.jpg'        # 修改输入图片路径frame = torch.from_numpy(cv2.imread(path)).float()batch = FastBaseTransform()(frame.unsqueeze(0))# 检查模型try:onnx.checker.check_model(onnx_model)print("Model check passed.")except Exception as e:print(f"Model check failed: {e}")sess = rt.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])input_name = sess.get_inputs()[0].nameloc_name = sess.get_outputs()[0].nameconf_name = sess.get_outputs()[1].namemask_name = sess.get_outputs()[2].namepriors_name = sess.get_outputs()[3].nameproto_name = sess.get_outputs()[4].nameproto_name2 = sess.get_outputs()[5].namewith timer.env("ONNX Runtime"):preds = sess.run([loc_name, conf_name, mask_name, priors_name, proto_name, proto_name2], {input_name: batch.cpu().detach().numpy()})# preds是一个包含100*4 array的list# """# preds是一个列表, 包含以下元素:# boxes: (N, 4) --> 100, 4# mask: (N, 32) -->100, 32# class: (N,) --> 100# score: (N,) --> 100# proto: 138,138,32 --> 138,138,32 # priors: (4) --> 4# """preds_out = [{'box': torch.tensor(preds[0]), 'mask': torch.tensor(preds[1]), 'class': torch.tensor(preds[2]), 'score': torch.tensor(preds[3]),'proto': torch.tensor(preds[4]),'priors': torch.tensor(preds[5])}]# for k,v in preds_out[0].items():#     print(k, v.shape)img_numpy = prep_display(preds_out, frame, None, None, undo_transform=False)if save_path is None:img_numpy = img_numpy[:, :, (2, 1, 0)]    cv2.imwrite(save_path, img_numpy)  # 保存图片
eval_onnx_with_nms('yolact.onnx')

最终推理结果:
在这里插入图片描述

这篇关于yolact导出onnx的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103728

相关文章

Java导出Excel动态表头的示例详解

《Java导出Excel动态表头的示例详解》这篇文章主要为大家详细介绍了Java导出Excel动态表头的相关知识,文中的示例代码简洁易懂,具有一定的借鉴价值,有需要的小伙伴可以了解下... 目录前言一、效果展示二、代码实现1.固定头实体类2.动态头实现3.导出动态头前言本文只记录大致思路以及做法,代码不进

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Python实现将实体类列表数据导出到Excel文件

《Python实现将实体类列表数据导出到Excel文件》在数据处理和报告生成中,将实体类的列表数据导出到Excel文件是一项常见任务,Python提供了多种库来实现这一目标,下面就来跟随小编一起学习一... 目录一、环境准备二、定义实体类三、创建实体类列表四、将实体类列表转换为DataFrame五、导出Da

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档

Oracle Expdp按条件导出指定表数据的方法实例

《OracleExpdp按条件导出指定表数据的方法实例》:本文主要介绍Oracle的expdp数据泵方式导出特定机构和时间范围的数据,并通过parfile文件进行条件限制和配置,文中通过代码介绍... 目录1.场景描述 2.方案分析3.实验验证 3.1 parfile文件3.2 expdp命令导出4.总结

java poi实现Excel多级表头导出方式(多级表头,复杂表头)

《javapoi实现Excel多级表头导出方式(多级表头,复杂表头)》文章介绍了使用javapoi库实现Excel多级表头导出的方法,通过主代码、合并单元格、设置表头单元格宽度、填充数据、web下载... 目录Java poi实现Excel多级表头导出(多级表头,复杂表头)上代码1.主代码2.合并单元格3.

MySQL使用mysqldump导出数据

mysql mysqldump只导出表结构或只导出数据的实现方法 备份数据库: #mysqldump 数据库名 >数据库备份名 #mysqldump -A -u用户名 -p密码 数据库名>数据库备份名 #mysqldump -d -A --add-drop-table -uroot -p >xxx.sql 1.导出结构不导出数据 mysqldump --opt -d 数据库名 -u

一步一步将PlantUML类图导出为自定义格式的XMI文件

一步一步将PlantUML类图导出为自定义格式的XMI文件 说明: 首次发表日期:2024-09-08PlantUML官网: https://plantuml.com/zh/PlantUML命令行文档: https://plantuml.com/zh/command-line#6a26f548831e6a8cPlantUML XMI文档: https://plantuml.com/zh/xmi

SpringBoot中利用EasyExcel+aop实现一个通用Excel导出功能

一、结果展示 主要功能:可以根据前端传递的参数,导出指定列、指定行 1.1 案例一 前端页面 传递参数 {"excelName": "导出用户信息1725738666946","sheetName": "导出用户信息","fieldList": [{"fieldName": "userId","fieldDesc": "用户id"},{"fieldName": "age","fieldDe

F12抓包06-4:导出metersphere脚本

metersphere是一站式的开源持续测试平台,我们可以将浏览器请求导出为HAR文件,导入到metersphere,生成接口测试。 metersphere有2种导入入口(方式),导入结果不同:         1.导入到“接口定义”:自动生成接口API和单接口case(接口自动去重;每个请求生成不同case,重复的请求生成重复的case,名称自动加数字后缀,自动与接口关联)。