【YOLOv10改进[Backbone]】图像修复网络AirNet助力YOLOv10目标检测效果 + 含全部代码和详细修改方式 + 手撕结构图 + 全网首发

本文主要是介绍【YOLOv10改进[Backbone]】图像修复网络AirNet助力YOLOv10目标检测效果 + 含全部代码和详细修改方式 + 手撕结构图 + 全网首发,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文带来的是图像复原网络AirNet,它由基于对比度的退化编码器( CBDE )退化引导的恢复网络( DGRN )两个模块组成。可以在一个网络中恢复各种退化图像。AirNet不受损坏类型和级别的先验限制,仅使用观察到的损坏图像进行推理本文中将使用图像修复网络AirNet助力YOLOv10的目标检测效果,文中含全部代码、详细修改方式以及结构图

目录

一 AirNet

二 图像修复网络AirNet助力YOLOv10目标检测效果

1 整体修改

① 添加AirNet.py文件

② 修改ultralytics/nn/tasks.py文件

2 配置文件

3 训练

其他


一 AirNet

官方论文地址:https://openaccess.thecvf.com/content/CVPR2022/papers/Li_All-in-One_Image_Restoration_for_Unknown_Corruption_CVPR_2022_paper.pdf

在本文中,我们研究了图像复原中一个具有挑战性的问题,即如何开发一种可以从多种未知类型和级别的损坏中恢复图像的一体式方法。为此,出了一个由两个神经模块组成的全集成图像恢复网络( AirNet ),分别命名基于对比度的退化编码器( CBDE )退化引导的恢复网络( DGRN )。AirNet的主要优势有两方面。首先,它是一种一体式的解决方案可以在一个网络中恢复各种退化图像。其次,AirNet不受损坏类型和级别的先验限制,仅使用观察到的损坏图像进行推理。这两个优点使得AirNet在现实世界场景中具有更好的灵活性和更高的经济性,其中损坏的先验信息是很难知道的,并且退化会随着空间和时间而变化。大量的实验结果表明,提出的方法在四个具有挑战性的数据集上优于17个图像复原方法。

如下图图1。我们基本思想的阐述。如图所示,大多数现有的多重退化方法通过将输入发送到专门设计的头部并使用相应尾部的输出来处理每个损失。因此,他们需要事先的损失信息来指定修正的头部和尾部。不同的是,我们的一体化图像复原网络( Air Net )不受损失(损坏)类型和级别的先验限制,因此在现实场景中具有更好的灵活性和更高的经济性

如下图图2为提出的AirNet的架构。

( a ) All - In - One图像复原网络( Airnet );

( b )基于对比度的退化编码器(CBDE )

( c )降解导向组( DGG);

( D )退化引导模块(DGM )。

如下图图3。在BSD68数据库上与SOTA去噪方法进行比较。为了更好地可视化和比较,部分区域在彩矩形中突出显示。

如下图图4。在Rain100L数据库上对SOTA方法进行了比较。为了更好地可视化和比较,部分区域在彩色矩形中突出显示。

如下图图5。SOTA去雾方法在SOTS数据库上的比较。为了更好地可视化和比较,部分区域在彩色矩形中突出显示。

本文提出了一种不依赖于损坏类型和程度先验的一体化图像复原网络( Air Net )。同时,该方法是一种ALL-IN-ONE的解决方案,可以从不同的损坏中恢复图像,这对于先验难以预知或退化可能随时间和空间变化的各种实场景具有竞争力。大量的实验结果表明了AirNet在定性和定量比较上的优越性

二 图像修复网络AirNet助力YOLOv10目标检测效果

YOLOv10的结构图如下:

1 整体修改

整个改进的结构的示意图如下所示:

① 添加AirNet.py文件

ultralytics/nn/modules目录下新建AirNet.py文件,文件的内容如下:

说明:适配CPU版的代码,训练速度巨慢,甚至崩掉。

如果有GPU,则使用labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 

② 修改ultralytics/nn/tasks.py文件

具体的修改内容如下图所示:

parse_model函数的修改部分如下:

至此,修改完毕。

2 配置文件

yolov10n_AirNet.yaml文件的内容如下:

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, AirNet, []] # 0-P1/2- [-1, 1, Conv, [64, 3, 2]] # 1-P1/2- [-1, 1, Conv, [128, 3, 2]] # 2-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 4-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, SCDown, [512, 3, 2]] # 6-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, SCDown, [1024, 3, 2]] # 8-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 10- [-1, 1, PSA, [1024]] # 11# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 7], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 14- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 5], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 17 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 14], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 20 (P4/16-medium)- [-1, 1, SCDown, [512, 3, 2]]- [[-1, 11], 1, Concat, [1]] # cat head P5- [-1, 3, C2fCIB, [1024, True, True]] # 23 (P5/32-large)- [[17, 20, 23], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

3 训练

上述修改完毕后,开始训练吧!🐳🐳🐳🐳🐳🐳🐳🐳🐳

训练示例如下:

yolo detect train data=coco128.yaml model=cfg/models/v10/yolov10n_AirNet.yaml epochs=100 batch=8 imgsz=640 device=cpu project=yolov10

改进后适配的CPU版,训练的速度可慢了。。。甚至崩掉。。。

其他

【报错】ModuleNotFoundError: No module named 'mmcv'

【解决方法】安装mmcv 

pip --default-timeout=100 install mmcv -i https://pypi.tuna.tsinghua.edu.cn/simple


task.py的内容包括累加的机制,如果使用,需要自行删除无关部分。文件内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport contextlib
from copy import deepcopy
from pathlib import Pathimport torch
import torch.nn as nnfrom ultralytics.nn.modules import (AIFI,C1,C2,C3,C3TR,OBB,SPP,SPPF,Bottleneck,BottleneckCSP,C2f,C2fAttn,ImagePoolingAttn,C3Ghost,C3x,Classify,Concat,Conv,Conv2,ConvTranspose,Detect,DWConv,DWConvTranspose2d,Focus,GhostBottleneck,GhostConv,HGBlock,HGStem,Pose,RepC3,RepConv,ResNetLayer,RTDETRDecoder,Segment,WorldDetect,RepNCSPELAN4,ADown,SPPELAN,CBFuse,CBLinear,Silence,C2fCIB,PSA,SCDown,RepVGGDW,v10Detect
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (fuse_conv_and_bn,fuse_deconv_and_bn,initialize_weights,intersect_dicts,make_divisible,model_info,scale_img,time_sync,
)
from .modules.OREPA import OREPA #导入OREPA
from .modules.haar_HWD import Down_wt
from .modules.SCINet import EnhanceNetwork
from .modules.ADNet import ADNet
from .modules.SEAM import SEAM, MultiSEAM
from .modules.AirNet import AirNettry:import thop
except ImportError:thop = Noneclass BaseModel(nn.Module):"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""def forward(self, x, *args, **kwargs):"""Forward pass of the model on a single scale. Wrapper for `_forward_once` method.Args:x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.Returns:(torch.Tensor): The output of the network."""if isinstance(x, dict):  # for cases of training and validating while training.return self.loss(x, *args, **kwargs)return self.predict(x, *args, **kwargs)def predict(self, x, profile=False, visualize=False, augment=False, embed=None):"""Perform a forward pass through the network.Args:x (torch.Tensor): The input tensor to the model.profile (bool):  Print the computation time of each layer if True, defaults to False.visualize (bool): Save the feature maps of the model if True, defaults to False.augment (bool): Augment image during prediction, defaults to False.embed (list, optional): A list of feature vectors/embeddings to return.Returns:(torch.Tensor): The last output of the model."""if augment:return self._predict_augment(x)return self._predict_once(x, profile, visualize, embed)def _predict_once(self, x, profile=False, visualize=False, embed=None):"""Perform a forward pass through the network.Args:x (torch.Tensor): The input tensor to the model.profile (bool):  Print the computation time of each layer if True, defaults to False.visualize (bool): Save the feature maps of the model if True, defaults to False.embed (list, optional): A list of feature vectors/embeddings to return.Returns:(torch.Tensor): The last output of the model."""y, dt, embeddings = [], [], []  # outputsfor m in self.model:if m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:self._profile_one_layer(m, x, dt)x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif visualize:feature_visualization(x, m.type, m.i, save_dir=visualize)if embed and m.i in embed:embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flattenif m.i == max(embed):return torch.unbind(torch.cat(embeddings, 1), dim=0)return xdef _predict_augment(self, x):"""Perform augmentations on input image x and return augmented inference."""LOGGER.warning(f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "f"Reverting to single-scale inference instead.")return self._predict_once(x)def _profile_one_layer(self, m, x, dt):"""Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results tothe provided list.Args:m (nn.Module): The layer to be profiled.x (torch.Tensor): The input data to the layer.dt (list): A list to store the computation time of the layer.Returns:None"""c = m == self.model[-1] and isinstance(x, list)  # is final layer list, copy input as inplace fixflops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0  # FLOPst = time_sync()for _ in range(10):m(x.copy() if c else x)dt.append((time_sync() - t) * 100)if m == self.model[0]:LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module")LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}")if c:LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")def fuse(self, verbose=True):"""Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve thecomputation efficiency.Returns:(nn.Module): The fused model is returned."""if not self.is_fused():for m in self.model.modules():if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):if isinstance(m, Conv2):m.fuse_convs()m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update convdelattr(m, "bn")  # remove batchnormm.forward = m.forward_fuse  # update forwardif isinstance(m, ConvTranspose) and hasattr(m, "bn"):m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)delattr(m, "bn")  # remove batchnormm.forward = m.forward_fuse  # update forwardif isinstance(m, RepConv):m.fuse_convs()m.forward = m.forward_fuse  # update forwardif isinstance(m, RepVGGDW):m.fuse()m.forward = m.forward_fuseself.info(verbose=verbose)return selfdef is_fused(self, thresh=10):"""Check if the model has less than a certain threshold of BatchNorm layers.Args:thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.Returns:(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise."""bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)  # normalization layers, i.e. BatchNorm2d()return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in modeldef info(self, detailed=False, verbose=True, imgsz=640):"""Prints model information.Args:detailed (bool): if True, prints out detailed information about the model. Defaults to Falseverbose (bool): if True, prints out the model information. Defaults to Falseimgsz (int): the size of the image that the model will be trained on. Defaults to 640"""return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)def _apply(self, fn):"""Applies a function to all the tensors in the model that are not parameters or registered buffers.Args:fn (function): the function to apply to the modelReturns:(BaseModel): An updated BaseModel object."""self = super()._apply(fn)m = self.model[-1]  # Detect()if isinstance(m, Detect):  # includes all Detect subclasses like Segment, Pose, OBB, WorldDetectm.stride = fn(m.stride)m.anchors = fn(m.anchors)m.strides = fn(m.strides)return selfdef load(self, weights, verbose=True):"""Load the weights into the model.Args:weights (dict | torch.nn.Module): The pre-trained weights to be loaded.verbose (bool, optional): Whether to log the transfer progress. Defaults to True."""model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dictscsd = model.float().state_dict()  # checkpoint state_dict as FP32csd = intersect_dicts(csd, self.state_dict())  # intersectself.load_state_dict(csd, strict=False)  # loadif verbose:LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")def loss(self, batch, preds=None):"""Compute loss.Args:batch (dict): Batch to compute loss onpreds (torch.Tensor | List[torch.Tensor]): Predictions."""if not hasattr(self, "criterion"):self.criterion = self.init_criterion()preds = self.forward(batch["img"]) if preds is None else predsreturn self.criterion(preds, batch)def init_criterion(self):"""Initialize the loss criterion for the BaseModel."""raise NotImplementedError("compute_loss() needs to be implemented by task heads")class DetectionModel(BaseModel):"""YOLOv8 detection model."""def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True):  # model, input channels, number of classes"""Initialize the YOLOv8 detection model with the given config and parameters."""super().__init__()self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict# Define modelch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channelsif nc and nc != self.yaml["nc"]:LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")self.yaml["nc"] = nc  # override YAML valueself.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelistself.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dictself.inplace = self.yaml.get("inplace", True)# Build stridesm = self.model[-1]  # Detect()if isinstance(m, Detect):  # includes all Detect subclasses like Segment, Pose, OBB, WorldDetects = 256  # 2x min stridem.inplace = self.inplaceforward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)if isinstance(m, v10Detect):forward = lambda x: self.forward(x)["one2many"]m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forwardself.stride = m.stridem.bias_init()  # only run onceelse:self.stride = torch.Tensor([32])  # default stride for i.e. RTDETR# Init weights, biasesinitialize_weights(self)if verbose:self.info()LOGGER.info("")def _predict_augment(self, x):"""Perform augmentations on input image x and return augmented inference and train outputs."""img_size = x.shape[-2:]  # height, widths = [1, 0.83, 0.67]  # scalesf = [None, 3, None]  # flips (2-ud, 3-lr)y = []  # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))yi = super().predict(xi)[0]  # forwardyi = self._descale_pred(yi, fi, si, img_size)y.append(yi)y = self._clip_augmented(y)  # clip augmented tailsreturn torch.cat(y, -1), None  # augmented inference, train@staticmethoddef _descale_pred(p, flips, scale, img_size, dim=1):"""De-scale predictions following augmented inference (inverse operation)."""p[:, :4] /= scale  # de-scalex, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)if flips == 2:y = img_size[0] - y  # de-flip udelif flips == 3:x = img_size[1] - x  # de-flip lrreturn torch.cat((x, y, wh, cls), dim)def _clip_augmented(self, y):"""Clip YOLO augmented inference tails."""nl = self.model[-1].nl  # number of detection layers (P3-P5)g = sum(4**x for x in range(nl))  # grid pointse = 1  # exclude layer counti = (y[0].shape[-1] // g) * sum(4**x for x in range(e))  # indicesy[0] = y[0][..., :-i]  # largei = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indicesy[-1] = y[-1][..., i:]  # smallreturn ydef init_criterion(self):"""Initialize the loss criterion for the DetectionModel."""return v8DetectionLoss(self)class OBBModel(DetectionModel):"""YOLOv8 Oriented Bounding Box (OBB) model."""def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):"""Initialize YOLOv8 OBB model with given config and parameters."""super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)def init_criterion(self):"""Initialize the loss criterion for the model."""return v8OBBLoss(self)class SegmentationModel(DetectionModel):"""YOLOv8 segmentation model."""def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):"""Initialize YOLOv8 segmentation model with given config and parameters."""super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)def init_criterion(self):"""Initialize the loss criterion for the SegmentationModel."""return v8SegmentationLoss(self)class PoseModel(DetectionModel):"""YOLOv8 pose model."""def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):"""Initialize YOLOv8 Pose model."""if not isinstance(cfg, dict):cfg = yaml_model_load(cfg)  # load model YAMLif any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")cfg["kpt_shape"] = data_kpt_shapesuper().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)def init_criterion(self):"""Initialize the loss criterion for the PoseModel."""return v8PoseLoss(self)class ClassificationModel(BaseModel):"""YOLOv8 classification model."""def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""super().__init__()self._from_yaml(cfg, ch, nc, verbose)def _from_yaml(self, cfg, ch, nc, verbose):"""Set YOLOv8 model configurations and define the model architecture."""self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict# Define modelch = self.yaml["ch"] = self.yaml.get("ch", ch)  # input channelsif nc and nc != self.yaml["nc"]:LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")self.yaml["nc"] = nc  # override YAML valueelif not nc and not self.yaml.get("nc", None):raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelistself.stride = torch.Tensor([1])  # no stride constraintsself.names = {i: f"{i}" for i in range(self.yaml["nc"])}  # default names dictself.info()@staticmethoddef reshape_outputs(model, nc):"""Update a TorchVision classification model to class count 'n' if required."""name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1]  # last moduleif isinstance(m, Classify):  # YOLO Classify() headif m.linear.out_features != nc:m.linear = nn.Linear(m.linear.in_features, nc)elif isinstance(m, nn.Linear):  # ResNet, EfficientNetif m.out_features != nc:setattr(model, name, nn.Linear(m.in_features, nc))elif isinstance(m, nn.Sequential):types = [type(x) for x in m]if nn.Linear in types:i = types.index(nn.Linear)  # nn.Linear indexif m[i].out_features != nc:m[i] = nn.Linear(m[i].in_features, nc)elif nn.Conv2d in types:i = types.index(nn.Conv2d)  # nn.Conv2d indexif m[i].out_channels != nc:m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)def init_criterion(self):"""Initialize the loss criterion for the ClassificationModel."""return v8ClassificationLoss()class RTDETRDetectionModel(DetectionModel):"""RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating boththe training and inference processes. RTDETR is an object detection and tracking model that extends from theDetectionModel base class.Attributes:cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'.ch (int): Number of input channels. Default is 3 (RGB).nc (int, optional): Number of classes for object detection. Default is None.verbose (bool): Specifies if summary statistics are shown during initialization. Default is True.Methods:init_criterion: Initializes the criterion used for loss calculation.loss: Computes and returns the loss during training.predict: Performs a forward pass through the network and returns the output."""def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):"""Initialize the RTDETRDetectionModel.Args:cfg (str): Configuration file name or path.ch (int): Number of input channels.nc (int, optional): Number of classes. Defaults to None.verbose (bool, optional): Print additional information during initialization. Defaults to True."""super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)def init_criterion(self):"""Initialize the loss criterion for the RTDETRDetectionModel."""from ultralytics.models.utils.loss import RTDETRDetectionLossreturn RTDETRDetectionLoss(nc=self.nc, use_vfl=True)def loss(self, batch, preds=None):"""Compute the loss for the given batch of data.Args:batch (dict): Dictionary containing image and label data.preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None.Returns:(tuple): A tuple containing the total loss and main three losses in a tensor."""if not hasattr(self, "criterion"):self.criterion = self.init_criterion()img = batch["img"]# NOTE: preprocess gt_bbox and gt_labels to list.bs = len(img)batch_idx = batch["batch_idx"]gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]targets = {"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),"bboxes": batch["bboxes"].to(device=img.device),"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),"gt_groups": gt_groups,}preds = self.predict(img, batch=targets) if preds is None else predsdec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]if dn_meta is None:dn_bboxes, dn_scores = None, Noneelse:dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes])  # (7, bs, 300, 4)dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])loss = self.criterion((dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta)# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device)def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):"""Perform a forward pass through the model.Args:x (torch.Tensor): The input tensor.profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.batch (dict, optional): Ground truth data for evaluation. Defaults to None.augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.embed (list, optional): A list of feature vectors/embeddings to return.Returns:(torch.Tensor): Model's output tensor."""y, dt, embeddings = [], [], []  # outputsfor m in self.model[:-1]:  # except the head partif m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:self._profile_one_layer(m, x, dt)x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif visualize:feature_visualization(x, m.type, m.i, save_dir=visualize)if embed and m.i in embed:embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flattenif m.i == max(embed):return torch.unbind(torch.cat(embeddings, 1), dim=0)head = self.model[-1]x = head([y[j] for j in head.f], batch)  # head inferencereturn xclass WorldModel(DetectionModel):"""YOLOv8 World Model."""def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):"""Initialize YOLOv8 world model with given config and parameters."""self.txt_feats = torch.randn(1, nc or 80, 512)  # features placeholderself.clip_model = None  # CLIP model placeholdersuper().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)def set_classes(self, text):"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""try:import clipexcept ImportError:check_requirements("git+https://github.com/openai/CLIP.git")import clipif not getattr(self, "clip_model", None):  # for backwards compatibility of models lacking clip_model attributeself.clip_model = clip.load("ViT-B/32")[0]device = next(self.clip_model.parameters()).devicetext_token = clip.tokenize(text).to(device)txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()self.model[-1].nc = len(text)def init_criterion(self):"""Initialize the loss criterion for the model."""raise NotImplementedErrordef predict(self, x, profile=False, visualize=False, augment=False, embed=None):"""Perform a forward pass through the model.Args:x (torch.Tensor): The input tensor.profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.embed (list, optional): A list of feature vectors/embeddings to return.Returns:(torch.Tensor): Model's output tensor."""txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)if len(txt_feats) != len(x):txt_feats = txt_feats.repeat(len(x), 1, 1)ori_txt_feats = txt_feats.clone()y, dt, embeddings = [], [], []  # outputsfor m in self.model:  # except the head partif m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:self._profile_one_layer(m, x, dt)if isinstance(m, C2fAttn):x = m(x, txt_feats)elif isinstance(m, WorldDetect):x = m(x, ori_txt_feats)elif isinstance(m, ImagePoolingAttn):txt_feats = m(x, txt_feats)else:x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif visualize:feature_visualization(x, m.type, m.i, save_dir=visualize)if embed and m.i in embed:embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flattenif m.i == max(embed):return torch.unbind(torch.cat(embeddings, 1), dim=0)return xclass YOLOv10DetectionModel(DetectionModel):def init_criterion(self):return v10DetectLoss(self)class Ensemble(nn.ModuleList):"""Ensemble of models."""def __init__(self):"""Initialize an ensemble of models."""super().__init__()def forward(self, x, augment=False, profile=False, visualize=False):"""Function generates the YOLO network's final layer."""y = [module(x, augment, profile, visualize)[0] for module in self]# y = torch.stack(y).max(0)[0]  # max ensemble# y = torch.stack(y).mean(0)  # mean ensembley = torch.cat(y, 2)  # nms ensemble, y shape(B, HW, C)return y, None  # inference, train output# Functions ------------------------------------------------------------------------------------------------------------@contextlib.contextmanager
def temporary_modules(modules=None):"""Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).This function can be used to change the module paths during runtime. It's useful when refactoring code,where you've moved a module from one location to another, but you still want to support the old importpaths for backwards compatibility.Args:modules (dict, optional): A dictionary mapping old module paths to new module paths.Example:```pythonwith temporary_modules({'old.module.path': 'new.module.path'}):import old.module.path  # this will now import new.module.path```Note:The changes are only in effect inside the context manager and are undone once the context manager exits.Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in largerapplications or libraries. Use this function with caution."""if not modules:modules = {}import importlibimport systry:# Set modules in sys.modules under their old namefor old, new in modules.items():sys.modules[old] = importlib.import_module(new)yieldfinally:# Remove the temporary module pathsfor old in modules:if old in sys.modules:del sys.modules[old]def torch_safe_load(weight):"""This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,it catches the error, logs a warning message, and attempts to install the missing module via thecheck_requirements() function. After installation, the function again attempts to load the model using torch.load().Args:weight (str): The file path of the PyTorch model.Returns:(dict): The loaded PyTorch model."""from ultralytics.utils.downloads import attempt_download_assetcheck_suffix(file=weight, suffix=".pt")file = attempt_download_asset(weight)  # search online if missing locallytry:with temporary_modules({"ultralytics.yolo.utils": "ultralytics.utils","ultralytics.yolo.v8": "ultralytics.models.yolo","ultralytics.yolo.data": "ultralytics.data",}):  # for legacy 8.0 Classify and Pose modelsckpt = torch.load(file, map_location="cpu")except ModuleNotFoundError as e:  # e.name is missing module nameif e.name == "models":raise TypeError(emojis(f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "f"YOLOv8 at https://github.com/ultralytics/ultralytics."f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from eLOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")check_requirements(e.name)  # install missing moduleckpt = torch.load(file, map_location="cpu")if not isinstance(ckpt, dict):# File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")LOGGER.warning(f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. "f"For optimal results, use model.save('filename.pt') to correctly save YOLO models.")ckpt = {"model": ckpt.model}return ckpt, file  # loaddef attempt_load_weights(weights, device=None, inplace=True, fuse=False):"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""ensemble = Ensemble()for w in weights if isinstance(weights, list) else [weights]:ckpt, w = torch_safe_load(w)  # load ckptargs = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None  # combined argsmodel = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model# Model compatibility updatesmodel.args = args  # attach args to modelmodel.pt_path = w  # attach *.pt file path to modelmodel.task = guess_model_task(model)if not hasattr(model, "stride"):model.stride = torch.tensor([32.0])# Appendensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval())  # model in eval mode# Module updatesfor m in ensemble.modules():if hasattr(m, "inplace"):m.inplace = inplaceelif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):m.recompute_scale_factor = None  # torch 1.11.0 compatibility# Return modelif len(ensemble) == 1:return ensemble[-1]# Return ensembleLOGGER.info(f"Ensemble created with {weights}\n")for k in "names", "nc", "yaml":setattr(ensemble, k, getattr(ensemble[0], k))ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].strideassert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"return ensembledef attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):"""Loads a single model weights."""ckpt, weight = torch_safe_load(weight)  # load ckptargs = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))}  # combine model and default args, preferring model argsmodel = (ckpt.get("ema") or ckpt["model"]).to(device).float()  # FP32 model# Model compatibility updatesmodel.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # attach args to modelmodel.pt_path = weight  # attach *.pt file path to modelmodel.task = guess_model_task(model)if not hasattr(model, "stride"):model.stride = torch.tensor([32.0])model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()  # model in eval mode# Module updatesfor m in model.modules():if hasattr(m, "inplace"):m.inplace = inplaceelif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):m.recompute_scale_factor = None  # torch 1.11.0 compatibility# Return model and ckptreturn model, ckptdef parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)"""Parse a YOLO model.yaml dictionary into a PyTorch model."""import ast# Argsmax_channels = float("inf")nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))if scales:scale = d.get("scale")if not scale:scale = tuple(scales.keys())[0]LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")depth, width, max_channels = scales[scale]if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()if verbose:LOGGER.info(f"{colorstr('activation:')} {act}")  # printif verbose:LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")ch = [ch]layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, argsm = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get modulefor j, a in enumerate(args):if isinstance(a, str):with contextlib.suppress(ValueError):args[j] = locals()[a] if a in locals() else ast.literal_eval(a)n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gainif m in {Classify,Conv,ConvTranspose,GhostConv,Bottleneck,GhostBottleneck,SPP,SPPF,DWConv,Focus,BottleneckCSP,C1,C2,C2f,RepNCSPELAN4,ADown,SPPELAN,C2fAttn,C3,C3TR,C3Ghost,nn.ConvTranspose2d,DWConvTranspose2d,C3x,RepC3,PSA,SCDown,C2fCIB,Down_wt}:c1, c2 = ch[f], args[0]if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)if m is C2fAttn:args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)  # embed channelsargs[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])  # num headsargs = [c1, c2, *args[1:]]if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB):args.insert(2, n)  # number of repeatsn = 1elif m is AIFI:args = [ch[f], *args]elif m in {OREPA, EnhanceNetwork, ADNet, SEAM, MultiSEAM, AirNet}:args = [ch[f], *args]elif m in {HGStem, HGBlock}:c1, cm, c2 = ch[f], args[0], args[1]args = [c1, cm, c2, *args[2:]]if m is HGBlock:args.insert(4, n)  # number of repeatsn = 1elif m is ResNetLayer:c2 = args[1] if args[3] else args[1] * 4elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:args.append([ch[x] for x in f])if m is Segment:args[2] = make_divisible(min(args[2], max_channels) * width, 8)elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1args.insert(1, [ch[x] for x in f])elif m is CBLinear:c2 = args[0]c1 = ch[f]args = [c1, c2, *args[1:]]elif m is CBFuse:c2 = ch[f[-1]]else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace("__main__.", "")  # module typem.np = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, typeif verbose:LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)def yaml_model_load(path):"""Load a YOLOv8 model from a YAML file."""import repath = Path(path)if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")path = path.with_name(new_stem + path.suffix)if "v10" not in str(path):unified_path = re.sub(r"(\d+)([nsblmx])(.+)?$", r"\1\3", str(path))  # i.e. yolov8x.yaml -> yolov8.yamlelse:unified_path = pathyaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)d = yaml_load(yaml_file)  # model dictd["scale"] = guess_model_scale(path)d["yaml_file"] = str(path)return ddef guess_model_scale(model_path):"""Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The functionuses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted byn, s, m, l, or x. The function returns the size character of the model scale as a string.Args:model_path (str | Path): The path to the YOLO model's YAML file.Returns:(str): The size character of the model's scale, which can be n, s, m, l, or x."""with contextlib.suppress(AttributeError):import rereturn re.search(r"yolov\d+([nsblmx])", Path(model_path).stem).group(1)  # n, s, m, l, or xreturn ""def guess_model_task(model):"""Guess the task of a PyTorch model from its architecture or configuration.Args:model (nn.Module | dict): PyTorch model or model configuration in YAML format.Returns:(str): Task of the model ('detect', 'segment', 'classify', 'pose').Raises:SyntaxError: If the task of the model could not be determined."""def cfg2task(cfg):"""Guess from YAML dictionary."""m = cfg["head"][-1][-2].lower()  # output module nameif m in {"classify", "classifier", "cls", "fc"}:return "classify"if m == "detect" or m == "v10detect":return "detect"if m == "segment":return "segment"if m == "pose":return "pose"if m == "obb":return "obb"# Guess from model cfgif isinstance(model, dict):with contextlib.suppress(Exception):return cfg2task(model)# Guess from PyTorch modelif isinstance(model, nn.Module):  # PyTorch modelfor x in "model.args", "model.model.args", "model.model.model.args":with contextlib.suppress(Exception):return eval(x)["task"]for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":with contextlib.suppress(Exception):return cfg2task(eval(x))for m in model.modules():if isinstance(m, Segment):return "segment"elif isinstance(m, Classify):return "classify"elif isinstance(m, Pose):return "pose"elif isinstance(m, OBB):return "obb"elif isinstance(m, (Detect, WorldDetect, v10Detect)):return "detect"# Guess from model filenameif isinstance(model, (str, Path)):model = Path(model)if "-seg" in model.stem or "segment" in model.parts:return "segment"elif "-cls" in model.stem or "classify" in model.parts:return "classify"elif "-pose" in model.stem or "pose" in model.parts:return "pose"elif "-obb" in model.stem or "obb" in model.parts:return "obb"elif "detect" in model.parts:return "detect"# Unable to determine task from modelLOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. ""Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")return "detect"  # assume detect

如果本文帮到你了,欢迎点赞 + 评论哈。🐳🐳🐳

如果需要代码点赞 + 评论 + 收藏,然后私信我要代码吧!🎵🎵🎵

到此,本文分享的内容就结束啦!遇见便是缘,感恩遇见!!!💛 💙 💜 ❤️ 💚 💛 💙 💜 ❤️ 💚

这篇关于【YOLOv10改进[Backbone]】图像修复网络AirNet助力YOLOv10目标检测效果 + 含全部代码和详细修改方式 + 手撕结构图 + 全网首发的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1029676

相关文章

Jsoncpp的安装与使用方式

《Jsoncpp的安装与使用方式》JsonCpp是一个用于解析和生成JSON数据的C++库,它支持解析JSON文件或字符串到C++对象,以及将C++对象序列化回JSON格式,安装JsonCpp可以通过... 目录安装jsoncppJsoncpp的使用Value类构造函数检测保存的数据类型提取数据对json数

Redis事务与数据持久化方式

《Redis事务与数据持久化方式》该文档主要介绍了Redis事务和持久化机制,事务通过将多个命令打包执行,而持久化则通过快照(RDB)和追加式文件(AOF)两种方式将内存数据保存到磁盘,以防止数据丢失... 目录一、Redis 事务1.1 事务本质1.2 数据库事务与redis事务1.2.1 数据库事务1.

Linux磁盘分区、格式化和挂载方式

《Linux磁盘分区、格式化和挂载方式》本文详细介绍了Linux系统中磁盘分区、格式化和挂载的基本操作步骤和命令,包括MBR和GPT分区表的区别、fdisk和gdisk命令的使用、常见的文件系统格式以... 目录一、磁盘分区表分类二、fdisk命令创建分区1、交互式的命令2、分区主分区3、创建扩展分区,然后

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Linux中chmod权限设置方式

《Linux中chmod权限设置方式》本文介绍了Linux系统中文件和目录权限的设置方法,包括chmod、chown和chgrp命令的使用,以及权限模式和符号模式的详细说明,通过这些命令,用户可以灵活... 目录设置基本权限命令:chmod1、权限介绍2、chmod命令常见用法和示例3、文件权限详解4、ch

Java调用Python代码的几种方法小结

《Java调用Python代码的几种方法小结》Python语言有丰富的系统管理、数据处理、统计类软件包,因此从java应用中调用Python代码的需求很常见、实用,本文介绍几种方法从java调用Pyt... 目录引言Java core使用ProcessBuilder使用Java脚本引擎总结引言python

最新版IDEA配置 Tomcat的详细过程

《最新版IDEA配置Tomcat的详细过程》本文介绍如何在IDEA中配置Tomcat服务器,并创建Web项目,首先检查Tomcat是否安装完成,然后在IDEA中创建Web项目并添加Web结构,接着,... 目录配置tomcat第一步,先给项目添加Web结构查看端口号配置tomcat    先检查自己的to

Java中的密码加密方式

《Java中的密码加密方式》文章介绍了Java中使用MD5算法对密码进行加密的方法,以及如何通过加盐和多重加密来提高密码的安全性,MD5是一种不可逆的哈希算法,适合用于存储密码,因为其输出的摘要长度固... 目录Java的密码加密方式密码加密一般的应用方式是总结Java的密码加密方式密码加密【这里采用的

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

使用Nginx来共享文件的详细教程

《使用Nginx来共享文件的详细教程》有时我们想共享电脑上的某些文件,一个比较方便的做法是,开一个HTTP服务,指向文件所在的目录,这次我们用nginx来实现这个需求,本文将通过代码示例一步步教你使用... 在本教程中,我们将向您展示如何使用开源 Web 服务器 Nginx 设置文件共享服务器步骤 0 —