爆改YOLOv8|利用全新的聚焦式线性注意力模块Focused Linear Attention 改进yolov8(v1)

本文主要是介绍爆改YOLOv8|利用全新的聚焦式线性注意力模块Focused Linear Attention 改进yolov8(v1),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1,本文介绍

全新的聚焦线性注意力模块(Focused Linear Attention)是一种旨在提高计算效率和准确性的注意力机制。传统的自注意力机制在处理长序列数据时通常计算复杂度较高,限制了其在大规模数据上的应用。聚焦线性注意力模块则通过优化注意力计算的方式,显著降低了计算复杂度。

核心特点:

  1. 线性时间复杂度:与传统的自注意力机制不同,聚焦线性注意力模块采用了线性时间复杂度的计算方法,这使得处理长序列数据时更加高效。

  2. 聚焦机制:该模块专注于关键的上下文信息,通过聚焦策略来提高注意力计算的准确性,同时减少不必要的计算开销。

  3. 改进的性能:通过优化注意力计算,聚焦线性注意力模块能够在保持高性能的同时,显著提升模型的计算效率,尤其适用于大规模数据处理。

应用场景:

  • 自然语言处理:在长文本或大规模语料库的处理上,聚焦线性注意力模块能够提供更高的效率和更低的延迟。
  • 计算机视觉:在处理高分辨率图像或视频数据时,能够加速计算过程,提升模型的实时性。

总体而言,聚焦线性注意力模块为处理大规模和长序列数据提供了一种高效且精确的解决方案,适用于各种需要高效注意力机制的应用场景。

关于Focused Linear Attention的详细介绍可以看论文:https://arxiv.org/pdf/2308.00442

本文将讲解如何将Focused Linear Attention融合进yolov8

话不多说,上代码!

2, 将Focused Linear Attention融合进yolov8


2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个FLA.py文件,文件名字可以根据你自己的习惯起,然后将Focused Linear Attention的核心代码复制进去

import torch
import torch.nn as nn
from einops import rearrangeclass FocusedLinearAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size=[20, 20], num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,focusing_factor=3, kernel_size=5):super().__init__()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.focusing_factor = focusing_factorself.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.window_size = window_sizeself.positional_encoding = nn.Parameter(torch.zeros(size=(1, window_size[0] * window_size[1], dim)))self.softmax = nn.Softmax(dim=-1)self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,groups=head_dim, padding=kernel_size // 2)self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))def forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)q, k, v = qkv.unbind(0)k = k + self.positional_encoding[:, :k.shape[1], :]focusing_factor = self.focusing_factorkernel_function = nn.ReLU()q = kernel_function(q) + 1e-6k = kernel_function(k) + 1e-6scale = nn.Softplus()(self.scale)q = q / scalek = k / scaleq_norm = q.norm(dim=-1, keepdim=True)k_norm = k.norm(dim=-1, keepdim=True)if float(focusing_factor) <= 6:q = q ** focusing_factork = k ** focusing_factorelse:q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factork = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factorq = (q / q.norm(dim=-1, keepdim=True)) * q_normk = (k / k.norm(dim=-1, keepdim=True)) * k_normq, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)if i * j * (c + d) > c * d * (i + j):kv = torch.einsum("b j c, b j d -> b c d", k, v)x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)else:qk = torch.einsum("b i c, b j c -> b i j", q, k)x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)num = int(v.shape[1] ** 0.5)feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")x = x + feature_mapx = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)x = self.proj(x)x = self.proj_drop(x)x = rearrange(x, "b (w h) c -> b c w h", b=B, c=self.dim, w=num, h=num)return x

2.2 步骤二

在task.py导入我们的模块

2.3 步骤三

在task.py的parse_model方法里面注册我们的模块

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9- [-1, 1, FocusedLinearAttention, [256]]  # 10# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 13- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 22 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

可能会出现如下所示报错

 raise EinopsError(message + "\n {}".format(e))
einops.EinopsError:  Error while processing rearrange-reduction pattern "b (w h) c -> b c w h".Input tensor shape: torch.Size([128, 294, 32]). Additional info: {'w': 17, 'h': 17}.Shape mismatch, 294 != 289

则按以下方法进行修改

找到ultralytics/models/yolo/detect/train.py的如下所示代码

 def build_dataset(self, img_path, mode='train', batch=None):"""Build YOLO Dataset.Args:img_path (str): Path to the folder containing images.mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.batch (int, optional): Size of batches, this is for `rect`. Defaults to None."""gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)

将该方法的最后一句代码修改为下面的代码

return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=(False if mode == 'val' else False), stride=gs)

不知不觉已经看完了哦,动动小手留个点赞收藏吧--_--

这篇关于爆改YOLOv8|利用全新的聚焦式线性注意力模块Focused Linear Attention 改进yolov8(v1)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1116077

相关文章

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

线性因子模型 - 独立分量分析(ICA)篇

序言 线性因子模型是数据分析与机器学习中的一类重要模型,它们通过引入潜变量( latent variables \text{latent variables} latent variables)来更好地表征数据。其中,独立分量分析( ICA \text{ICA} ICA)作为线性因子模型的一种,以其独特的视角和广泛的应用领域而备受关注。 ICA \text{ICA} ICA旨在将观察到的复杂信号

【Tools】大模型中的自注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 自注意力机制(Self-Attention)是一种在Transformer等大模型中经常使用的注意力机制。该机制通过对输入序列中的每个元素计算与其他元素之间的相似性,

什么是 Flash Attention

Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的, 论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。 下面我

如何通俗理解注意力机制?

1、注意力机制(Attention Mechanism)是机器学习和深度学习中一种模拟人类注意力的方法,用于提高模型在处理大量信息时的效率和效果。通俗地理解,它就像是在一堆信息中找到最重要的部分,把注意力集中在这些关键点上,从而更好地完成任务。以下是几个简单的比喻来帮助理解注意力机制: 2、寻找重点:想象一下,你在阅读一篇文章的时候,有些段落特别重要,你会特别注意这些段落,反复阅读,而对其他部分

【Tools】大模型中的注意力机制

摇来摇去摇碎点点的金黄 伸手牵来一片梦的霞光 南方的小巷推开多情的门窗 年轻和我们歌唱 摇来摇去摇着温柔的阳光 轻轻托起一件梦的衣裳 古老的都市每天都改变模样                      🎵 方芳《摇太阳》 在大模型中,注意力机制是一种重要的技术,它被广泛应用于自然语言处理领域,特别是在机器翻译和语言模型中。 注意力机制的基本思想是通过计算输入序列中各个位置的权重,以确

Jenkins构建Maven聚合工程,指定构建子模块

一、设置单独编译构建子模块 配置: 1、Root POM指向父pom.xml 2、Goals and options指定构建模块的参数: mvn -pl project1/project1-son -am clean package 单独构建project1-son项目以及它所依赖的其它项目。 说明: mvn clean package -pl 父级模块名/子模块名 -am参数

寻迹模块TCRT5000的应用原理和功能实现(基于STM32)

目录 概述 1 认识TCRT5000 1.1 模块介绍 1.2 电气特性 2 系统应用 2.1 系统架构 2.2 STM32Cube创建工程 3 功能实现 3.1 代码实现 3.2 源代码文件 4 功能测试 4.1 检测黑线状态 4.2 未检测黑线状态 概述 本文主要介绍TCRT5000模块的使用原理,包括该模块的硬件实现方式,电路实现原理,还使用STM32类

理解分类器(linear)为什么可以做语义方向的指导?(解纠缠)

Attribute Manipulation(属性编辑)、disentanglement(解纠缠)常用的两种做法:线性探针和PCA_disentanglement和alignment-CSDN博客 在解纠缠的过程中,有一种非常简单的方法来引导G向某个方向进行生成,然后我们通过向不同的方向进行行走,那么就会得到这个属性上的图像。那么你利用多个方向进行生成,便得到了各种方向的图像,每个方向对应了很多