本文主要是介绍基于 YOLO V8 Cls Fine-Tuning 训练花卉图像分类模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一、YOLO V8 Cls
在本专栏的前面文章中,我们基于 YOLO V8 Fine-Tuning
训练了自定义的目标检测模型,以及 15
点人脸关键点检测模型,从结果中可以看出,在模型如此轻量的同时还拥有者如此好的效果。本文基于 yolov8n-cls
模型实验 Fine-Tuning
训练花卉图像分类模型。
YOLO V8
的细节可以参考下面官方的介绍:
https://docs.ultralytics.com/zh/models/yolov8/#citations-and-acknowledgements
本文依旧使用 ultralytics
框架进行训练和测试,其中 ultralytics
和 pytorch
的版本如下:
torch==1.13.1+cu116
ultralytics==8.1.37
YOLO V8 Cls
调用示例如下:
测试图像:
这里使用 yolov8n-pose
模型,如果模型不存在会自动下载:
from ultralytics import YOLO
from matplotlib import pyplot as plt
import requests
plt.rcParams['font.sans-serif'] = ['SimHei']def load_labels():url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'response = requests.get(url)if response.status_code == 200:return response.json()return []def main():# 加载模型model = YOLO('yolov8n-cls.pt')# 加载 ImageNet 类别标签labels = load_labels()# 图像预测image = plt.imread('./img/cat.jpg')results = model.predict(image, device='cpu')probs = results[0].probs# top1的分类值classify = probs.top1# top1的置信度conf = probs.top1confplt.imshow(image)plt.title('预测结果:' + labels[str(classify)][1] + ',概率:' + str("%.2f" % (conf) + '%'))plt.show()if __name__ == '__main__':main()
二、数据集准备和拆分
本篇文章使用数千张花卉照片作为数据集,共分为5个分类:雏菊(daisy)、蒲公英(dandelion)、玫瑰(roses)、向日葵(sunflowers)、郁金香(tulips)
。
数据集下载地址:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
每个分类的图片放在单独的子目录下,下载完毕后解压可以看到如下所示:
图像示例如下:
如果想要训练自己的图片,也可以像这样的方式,将每个类别的图片放在相应的子目录下。
ultralytics
对数据格式的解释如下:
https://docs.ultralytics.com/zh/datasets/classify/
下面拆分 80%
的图像为训练集,20%
为验证集:
import os
from tqdm import tqdm
import shutil# 数据集位置
image_path = "./data/flower_photos"
# 训练集的比例
training_ratio = 0.8
# 拆分后数据的位置
train_dir = "train_data"def split_data():for classify in os.listdir(image_path):if not os.path.isdir(os.path.join(image_path, classify)):continue# 创建目录os.makedirs(os.path.join(train_dir, "train", classify), exist_ok=True)os.makedirs(os.path.join(train_dir, "val", classify), exist_ok=True)images = os.listdir(os.path.join(image_path, classify))# 数据拆分train_size = int(len(images) * training_ratio)train_images = images[:train_size]val_images = images[train_size:]# 训练数据for image in tqdm(train_images, desc="Dispose Train Classify: " + classify):shutil.copy(os.path.join(image_path, classify, image), os.path.join(train_dir, "train", classify, image))# 验证数据for image in tqdm(val_images, desc="Dispose Val Classify: " + classify):shutil.copy(os.path.join(image_path, classify, image), os.path.join(train_dir, "val", classify, image))if __name__ == '__main__':split_data()
处理后的结构:
三、训练
使用 ultralytics
框架训练非常简单,仅需三行代码即可完成训练:
from ultralytics import YOLO# Load a model
model = YOLO('yolov8n-cls.pt') # load a pretrained model (recommended for training)# Train the model
model.train(data='./train_data', # 数据集地址epochs=30, # 训练的周期imgsz=640, # 图像的大小device=[0], # 设备,如果是 cpu 则是 device='cpu'workers=0,lr0=1e-4, # 学习率batch=25, # 批次大小amp=False # 是否启用混合精度训练
)
运行后可以看到打印的网络结构:
训练中:
训练结束后可以在 runs
目录下面看到训练的结果:
训练时 loss
的变化图:
四、模型预测
from ultralytics import YOLO
import os
import random
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']def get_data(val_path):images = []for classify in os.listdir(val_path):for image in os.listdir(os.path.join(val_path, classify)):image_path = os.path.join(val_path, classify, image)images.append({"path": image_path,"label": classify})# 打乱数据random.shuffle(images)# 9个拆分为一组return [images[i:i + 9] for i in range(0, len(images), 9)]def main():val_path = "./train_data/val"# 分类标签class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']# 加载模型model = YOLO('runs/classify/train/weights/best.pt')# 获取数据集datas = get_data(val_path)for item in datas:plt.figure(figsize=(10, 10))for i, image in enumerate(item):plt.subplot(3, 3, i + 1)image_path = image["path"]label = image["label"]img = plt.imread(image_path)# 预测results = model.predict(img, device='cpu')probs = results[0].probsclassify = probs.top1conf = probs.top1conf# 展示plt.imshow(img)plt.title('预测结果:' + class_names[classify] + ',概率:' + str("%.2f" % conf) + ',真实结果:' + label)plt.axis('off')plt.show()if __name__ == '__main__':main()
预测效果:
这篇关于基于 YOLO V8 Cls Fine-Tuning 训练花卉图像分类模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!