爆改YOLOv8|利用yolov10的PSA注意力机制改进yolov8-高效涨点

2024-09-04 23:44

本文主要是介绍爆改YOLOv8|利用yolov10的PSA注意力机制改进yolov8-高效涨点,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1,本文介绍

PSA是一种改进的自注意力机制,旨在提升模型的效率和准确性。传统的自注意力机制需要计算所有位置对之间的注意力,这会导致计算复杂度高和训练时间长。PSA通过引入极化因子来减少需要计算的注意力对的数量,从而降低计算负担。极化因子是一个向量,通过与每个位置的向量点积,确定哪些位置需要计算注意力。这种方法可以在保持模型准确度的前提下,显著减少计算量,从而提升自注意力机制的效率。

关于PSA 的详细介绍可以看论文:https://arxiv.org/pdf/2405.14458

本文将讲解如何将PSA 融合进yolov8

话不多说,上代码!

2, 将PSA融合进yolov8

2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个PSA.py文件,文件名字可以根据你自己的习惯起,然后将PSA的核心代码复制进去


import torch
import torch.nn as nn__all__ = ['PSA']def autopad(k, p=None, d=1):  # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class Attention(nn.Module):def __init__(self, dim, num_heads=8,attn_ratio=0.5):super().__init__()self.num_heads = num_headsself.head_dim = dim // num_headsself.key_dim = int(self.head_dim * attn_ratio)self.scale = self.key_dim ** -0.5nh_kd = nh_kd = self.key_dim * num_headsh = dim + nh_kd * 2self.qkv = Conv(dim, h, 1, act=False)self.proj = Conv(dim, dim, 1, act=False)self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)def forward(self, x):B, _, H, W = x.shapeN = H * Wqkv = self.qkv(x)q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)attn = ((q.transpose(-2, -1) @ k) * self.scale)attn = attn.softmax(dim=-1)x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W))x = self.proj(x)return xclass PSA(nn.Module):def __init__(self, c1, c2, e=0.5):super().__init__()assert (c1 == c2)self.c = int(c1 * e)self.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv(2 * self.c, c1, 1)self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1),Conv(self.c * 2, self.c, 1, act=False))def forward(self, x):a, b = self.cv1(x).split((self.c, self.c), dim=1)b = b + self.attn(b)b = b + self.ffn(b)return self.cv2(torch.cat((a, b), 1))

2.2 步骤二

在task.py导入我们的模块

from .modules.PSA import PSA

2.3 步骤三

在task.py的parse_model方法里面注册我们的模块

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9- [-1, 1, PSA, [1024]]  # 10# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 13- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 22 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

# 关于PSA添加的位置可以自行调试,针对不同数据集位置不同,效果不同

不知不觉已经看完了哦,动动小手留个点赞吧--_--

这篇关于爆改YOLOv8|利用yolov10的PSA注意力机制改进yolov8-高效涨点的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1137370

相关文章

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.

高效管理你的Linux系统: Debian操作系统常用命令指南

《高效管理你的Linux系统:Debian操作系统常用命令指南》在Debian操作系统中,了解和掌握常用命令对于提高工作效率和系统管理至关重要,本文将详细介绍Debian的常用命令,帮助读者更好地使... Debian是一个流行的linux发行版,它以其稳定性、强大的软件包管理和丰富的社区资源而闻名。在使用

Redis主从/哨兵机制原理分析

《Redis主从/哨兵机制原理分析》本文介绍了Redis的主从复制和哨兵机制,主从复制实现了数据的热备份和负载均衡,而哨兵机制可以监控Redis集群,实现自动故障转移,哨兵机制通过监控、下线、选举和故... 目录一、主从复制1.1 什么是主从复制1.2 主从复制的作用1.3 主从复制原理1.3.1 全量复制

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1

Java如何通过反射机制获取数据类对象的属性及方法

《Java如何通过反射机制获取数据类对象的属性及方法》文章介绍了如何使用Java反射机制获取类对象的所有属性及其对应的get、set方法,以及如何通过反射机制实现类对象的实例化,感兴趣的朋友跟随小编一... 目录一、通过反射机制获取类对象的所有属性以及相应的get、set方法1.遍历类对象的所有属性2.获取

MySQL中的锁和MVCC机制解读

《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本

Spring使用@Retryable实现自动重试机制

《Spring使用@Retryable实现自动重试机制》在微服务架构中,服务之间的调用可能会因为一些暂时性的错误而失败,例如网络波动、数据库连接超时或第三方服务不可用等,在本文中,我们将介绍如何在Sp... 目录引言1. 什么是 @Retryable?2. 如何在 Spring 中使用 @Retryable

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

高效+灵活,万博智云全球发布AWS无代理跨云容灾方案!

摘要 近日,万博智云推出了基于AWS的无代理跨云容灾解决方案,并与拉丁美洲,中东,亚洲的合作伙伴面向全球开展了联合发布。这一方案以AWS应用环境为基础,将HyperBDR平台的高效、灵活和成本效益优势与无代理功能相结合,为全球企业带来实现了更便捷、经济的数据保护。 一、全球联合发布 9月2日,万博智云CEO Michael Wong在线上平台发布AWS无代理跨云容灾解决方案的阐述视频,介绍了