YOLOv8 classify介绍

2024-09-05 16:36
文章标签 介绍 yolov8 classify

本文主要是介绍YOLOv8 classify介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

      图像分类器(image classifier)的输出是单个类别标签和置信度分数。当你只需要知道图像属于哪个类别,而不需要知道该类别的目标位于何处或它们的确切形状时,图像分类非常有用。

      YOLOv8支持的预训练分类模型包括:YOLOv8n-cls、YOLOv8s-cls、YOLOv8m-cls、YOLOv8l-cls、YOLOv8x-cls。分类模型在ImageNet数据集上进行预训练,而检测、分割和Pose模型是在COCO数据集上进行预训练。

      数据集格式

      (1).对于Ultralytics YOLO分类任务,数据集必须在根目录下以特定的拆分目录(split-directory)结构进行组织,以便进行正确的训练、测试和验证(可选的)过程。此结构包括用于训练(train)和测试(test)阶段的单独目录,以及用于验证(val)的可选目录。

      (2).每个目录都应包含数据集中每个类别的一个子目录。子目录以相应的类别命名,包含该类别的所有图像。确保每个图像文件都具有唯一的名称,并以JPEG或PNG等通用格式存储。

      (3).例如CIFAR-10数据集的目录结构如下所示:

      这里使用 https://blog.csdn.net/fengbingchun/article/details/141635132 中的数据集,通过YOLOv8 classify进行train和predict:

      train代码如下:

import argparse
import colorama
from ultralytics import YOLO
import torchdef parse_args():parser = argparse.ArgumentParser(description="YOLOv8 train")parser.add_argument("--yaml", required=True, type=str, help="yaml file or datasets path(classify)")parser.add_argument("--epochs", required=True, type=int, help="number of training")parser.add_argument("--task", required=True, type=str, choices=["detect", "segment", "classify"], help="specify what kind of task")parser.add_argument("--imgsz", type=int, default=640, help="input net image size")args = parser.parse_args()return argsdef train(task, yaml, epochs, imgsz):if task == "detect":model = YOLO("yolov8n.pt") # load a pretrained model, should be a *.pt PyTorch model to run this methodelif task == "segment":model = YOLO("yolov8n-seg.pt") # load a pretrained model, should be a *.pt PyTorch model to run this methodelif task == "classify":model = YOLO("yolov8n-cls.pt") # n/s/m/l/xelse:raise ValueError(colorama.Fore.RED + f"Error: unsupported task: {task}")# petience: Training stopped early as no improvement observed in last patience epochs, use patience=0 to disable EarlyStoppingresults = model.train(data=yaml, epochs=epochs, imgsz=imgsz, patience=150, augment=True) # train the model, supported parameter reference, for example: runs/segment(detect)/train3/args.yamlmetrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings rememberedif task == "classify":print("Top-1 Accuracy:", metrics.top1) # top1 accuracyprint("Top-5 Accuracy:", metrics.top5) # top5 accuracymodel.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=imgsz) # onnx, export the model, cannot specify dynamic=True, opencv does not support# model.export(format="torchscript", imgsz=imgsz) # libtorch# model.export(format="engine", imgsz=imgsz, dynamic=False, verbose=False, batch=1, workspace=2) # tensorrt fp32# model.export(format="engine", imgsz=imgsz, dynamic=False, verbose=False, batch=1, workspace=2, half=True) # tensorrt fp16# model.export(format="engine", imgsz=imgsz, dynamic=False, verbose=False, batch=1, workspace=2, int8=True, data=yaml) # tensorrt int8# model.export(format="openvino", imgsz=imgsz) # openvino fp32# model.export(format="openvino", imgsz=imgsz, half=True) # openvino fp16# model.export(format="openvino", imgsz=imgsz, int8=True, data=yaml) # openvino int8, INT8 export requires 'data' arg for calibrationif __name__ == "__main__":# python test_yolov8_train.py --yaml datasets/melon_new_detect/melon_new_detect.yaml --epochs 1000 --task detect --imgsz 640colorama.init(autoreset=True)args = parse_args()print("Runging on GPU") if torch.cuda.is_available() else print("Runting on CPU")train(args.task, args.yaml, args.epochs, args.imgsz)print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下图所示:因数据集很小,训练速度很快

      predict代码如下:

import colorama
import argparse
from ultralytics import YOLO
import os
import torchimport numpy as np
np.bool = np.bool_ # Fix Error: AttributeError: module 'numpy' has no attribute 'bool'. OR: downgrade numpy: pip unistall numpy; pip install numpy==1.23.1def parse_args():parser = argparse.ArgumentParser(description="YOLOv8 predict")parser.add_argument("--model", required=True, type=str, help="model file")parser.add_argument("--task", required=True, type=str, choices=["detect", "segment", "classify"], help="specify what kind of task")parser.add_argument("--dir_images", required=True, type=str, help="directory of test images")parser.add_argument("--dir_result", type=str, default="", help="directory where the image results are saved")args = parser.parse_args()return argsdef get_images(dir):# supported image formatsimg_formats = (".bmp", ".jpeg", ".jpg", ".png", ".webp")images = []for file in os.listdir(dir):if os.path.isfile(os.path.join(dir, file)):# print(file)_, extension = os.path.splitext(file)for format in img_formats:if format == extension.lower():images.append(file)breakreturn imagesdef predict(model, task, dir_images, dir_result):model = YOLO(model) # load an model, support format: *.pt, *.onnx, *.torchscript, *.engine, openvino_model# model.info() # display model information # only *.pt format supportimages = get_images(dir_images)# print("images:", images)if task == "detect" or task =="segment":os.makedirs(dir_result) #, exist_ok=True)for image in images:device = "cuda" if torch.cuda.is_available() else "cpu"results = model.predict(dir_images+"/"+image, verbose=True, device=device)# print("results:", results)if task == "detect" or task =="segment":for result in results:result.save(dir_result+"/"+image)else:print(f"class names:{results[0].names}: top5: {results[0].probs.top5}; conf:{results[0].probs.top5conf}")if __name__ == "__main__":# python test_yolov8_predict.py --model runs/detect/train10/weights/best_int8.engine --dir_images datasets/melon_new_detect/images/test --dir_result result_detect_engine_int8 --task classifycolorama.init(autoreset=True)args = parse_args()print("Runging on GPU") if torch.cuda.is_available() else print("Runting on CPU")predict(args.model, args.task, args.dir_images, args.dir_result)print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下图所示:top1识别率100%

      GitHub:https://github.com/fengbingchun/NN_Test

这篇关于YOLOv8 classify介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1139489

相关文章

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

C++——stack、queue的实现及deque的介绍

目录 1.stack与queue的实现 1.1stack的实现  1.2 queue的实现 2.重温vector、list、stack、queue的介绍 2.1 STL标准库中stack和queue的底层结构  3.deque的简单介绍 3.1为什么选择deque作为stack和queue的底层默认容器  3.2 STL中对stack与queue的模拟实现 ①stack模拟实现

Mysql BLOB类型介绍

BLOB类型的字段用于存储二进制数据 在MySQL中,BLOB类型,包括:TinyBlob、Blob、MediumBlob、LongBlob,这几个类型之间的唯一区别是在存储的大小不同。 TinyBlob 最大 255 Blob 最大 65K MediumBlob 最大 16M LongBlob 最大 4G

FreeRTOS-基本介绍和移植STM32

FreeRTOS-基本介绍和STM32移植 一、裸机开发和操作系统开发介绍二、任务调度和任务状态介绍2.1 任务调度2.1.1 抢占式调度2.1.2 时间片调度 2.2 任务状态 三、FreeRTOS源码和移植STM323.1 FreeRTOS源码3.2 FreeRTOS移植STM323.2.1 代码移植3.2.2 时钟中断配置 一、裸机开发和操作系统开发介绍 裸机:前后台系

nginx介绍及常用功能

什么是nginx nginx跟Apache一样,是一个web服务器(网站服务器),通过HTTP协议提供各种网络服务。 Apache:重量级的,不支持高并发的服务器。在Apache上运行数以万计的并发访问,会导致服务器消耗大量内存。操作系统对其进行进程或线程间的切换也消耗了大量的CPU资源,导致HTTP请求的平均响应速度降低。这些都决定了Apache不可能成为高性能WEB服务器  nginx:

多路转接之select(fd_set介绍,参数详细介绍),实现非阻塞式网络通信

目录 多路转接之select 引入 介绍 fd_set 函数原型 nfds readfds / writefds / exceptfds readfds  总结  fd_set操作接口  timeout timevalue 结构体 传入值 返回值 代码 注意点 -- 调用函数 select的参数填充  获取新连接 注意点 -- 通信时的调用函数 添加新fd到

火语言RPA流程组件介绍--浏览网页

🚩【组件功能】:浏览器打开指定网址或本地html文件 配置预览 配置说明 网址URL 支持T或# 默认FLOW输入项 输入需要打开的网址URL 超时时间 支持T或# 打开网页超时时间 执行后后等待时间(ms) 支持T或# 当前组件执行完成后继续等待的时间 UserAgent 支持T或# User Agent中文名为用户代理,简称 UA,它是一个特殊字符串头,使得服务器