YOLOv7输出层之间的热力图

2024-08-31 20:36
文章标签 输出 之间 力图 yolov7

本文主要是介绍YOLOv7输出层之间的热力图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

我们经常看到一些论文里绘制了不同的热力图,来直观的感受其模型的有效性。特别是使用了注意力模块的网络,热力图就可以验证注意力机制是否真正聚焦到了预期的重要特征上,以便对模型的有效性和合理性进行评估。

例如Centralized Feature Pyramid for Object Detection这篇文章中展示的,就很能够表达作者改进后的模型相比之前模型的一个优越性。

在这里插入图片描述
本文就来记录一下如何使用python脚本来输出YOLOv7层之间的热力图。

添加步骤

1️⃣ 在本地的YOLOv7项目的根目录下新建heatmap.py,将以下代码复制到其中

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import torch.nn as nn
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from models.yolo import Model
from utils.torch_utils import intersect_dicts
from utils.datasets import letterbox
from utils.general import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsclass yolov7_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)self.__dict__.update(locals())def post_process(self, result):boxes_ = result[0][..., :4]logits_ = []for data in result[1]:bs, n, w, h, _ = data.size()logits_.append(data.reshape((bs, n * w * h, _)))logits_ = torch.cat(logits_, dim=1)[..., 4:]sorted, indices = torch.sort(logits_[..., 0], descending=True)logits_ = logits_[0][indices[0]]logits_[:, 0] = torch.sigmoid(logits_[:, 0])return logits_, xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, post_boxes = self.post_process(result)for i in trange(int(post_result.size(0) * self.ratio)):if post_result[i][0] < self.conf_threshold:breakself.model.zero_grad()if self.backward_type == 'conf':post_result[i, 0].backward(retain_graph=True)else:# get max probability for this predictionscore = post_result[i, 1:].max()score.backward(retain_graph=True)# process heatmapgradients = grads.gradients[0]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)#cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': 'runs/train/exp/weights/best.pt',  'cfg': 'cfg/training/yolov7_test.yaml','device': 'cuda:0','method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[-2]',  'backward_type': 'class', # class or conf'conf_threshold': 0.6, # 0.6'ratio': 0.02 # 0.02-0.1}return paramsif __name__ == '__main__':model = yolov7_heatmap(**get_params())model('inference/heat_image/001.jpg', 'heat_result')

2️⃣ 修改配置参数

文件中的主要参数配置如下:

在这里插入图片描述

参数解释
weight权重路径,训练完成后的权重文件
cfg模型文件路径,与权重所训练出来的模型文件一致
device运行的设备,和模型训练时的device参数设置一致
method可选择GradCAM,GradCAMPlusPlus和XGradCAM ,可以都试试,效果不同
layer想要输出第几层的热力图就写几,我这里写的的-2,即倒数第二层,可以多换换,看看效果
backward_type反向传播的计算类型,class表示按照类别最大概率进行计算 或 conf 通过置信度计算梯度
conf_threshold置信度阈值,设置成0.6
ratio取前多少数据,设置成0.02

在这里插入图片描述

箭头指向的数据就是行号。

3️⃣ 数据源

在这里插入图片描述
model('inference/heat_image/001.jpg', 'heat_result')中:

第一个参数inference/heat_image/001.jpg表示想要进行热力图绘制的原图像路径。

第二个参数'heat_result'表示绘制完成后输出的文件夹路径。

4️⃣ 调试

在这里插入图片描述
此时就已经绘制完成了,在指定的文件夹下就已经输出了热力图了。进度条还没有满就停止,是因为后面的目标已经不满足置信度conf_threshold的设定值。

这个进度条的长度151是之前设定的参数ratio的结果,其只会选择前0.02的目标进行热力图可视化。

博客参考链接
代码参考链接

这篇关于YOLOv7输出层之间的热力图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1124930

相关文章

Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码

《Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码》:本文主要介绍Java中日期时间转换的多种方法,包括将Date转换为LocalD... 目录一、Date转LocalDateTime二、Date转LocalDate三、LocalDateTim

golang获取当前时间、时间戳和时间字符串及它们之间的相互转换方法

《golang获取当前时间、时间戳和时间字符串及它们之间的相互转换方法》:本文主要介绍golang获取当前时间、时间戳和时间字符串及它们之间的相互转换,本文通过实例代码给大家介绍的非常详细,感兴趣... 目录1、获取当前时间2、获取当前时间戳3、获取当前时间的字符串格式4、它们之间的相互转化上篇文章给大家介

Vue中组件之间传值的六种方式(完整版)

《Vue中组件之间传值的六种方式(完整版)》组件是vue.js最强大的功能之一,而组件实例的作用域是相互独立的,这就意味着不同组件之间的数据无法相互引用,针对不同的使用场景,如何选择行之有效的通信方式... 目录前言方法一、props/$emit1.父组件向子组件传值2.子组件向父组件传值(通过事件形式)方

Python实现PDF与多种图片格式之间互转(PNG, JPG, BMP, EMF, SVG)

《Python实现PDF与多种图片格式之间互转(PNG,JPG,BMP,EMF,SVG)》PDF和图片是我们日常生活和工作中常用的文件格式,有时候,我们可能需要将PDF和图片进行格式互转来满足... 目录一、介绍二、安装python库三、Python实现多种图片格式转PDF1、单张图片转换为PDF2、多张图

python多种数据类型输出为Excel文件

《python多种数据类型输出为Excel文件》本文主要介绍了将Python中的列表、元组、字典和集合等数据类型输出到Excel文件中,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一.列表List二.字典dict三.集合set四.元组tuplepython中的列表、元组、字典

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Java对象和JSON字符串之间的转换方法(全网最清晰)

《Java对象和JSON字符串之间的转换方法(全网最清晰)》:本文主要介绍如何在Java中使用Jackson库将对象转换为JSON字符串,并提供了一个简单的工具类示例,该工具类支持基本的转换功能,... 目录前言1. 引入 Jackson 依赖2. 创建 jsON 工具类3. 使用示例转换 Java 对象为

Rust格式化输出方式总结

《Rust格式化输出方式总结》Rust提供了强大的格式化输出功能,通过std::fmt模块和相关的宏来实现,主要的输出宏包括println!和format!,它们支持多种格式化占位符,如{}、{:?}... 目录Rust格式化输出方式基本的格式化输出格式化占位符Format 特性总结Rust格式化输出方式

java父子线程之间实现共享传递数据

《java父子线程之间实现共享传递数据》本文介绍了Java中父子线程间共享传递数据的几种方法,包括ThreadLocal变量、并发集合和内存队列或消息队列,并提醒注意并发安全问题... 目录通过 ThreadLocal 变量共享数据通过并发集合共享数据通过内存队列或消息队列共享数据注意并发安全问题总结在 J

Java文件与Base64之间的转化方式

《Java文件与Base64之间的转化方式》这篇文章介绍了如何使用Java将文件(如图片、视频)转换为Base64编码,以及如何将Base64编码转换回文件,通过提供具体的工具类实现,作者希望帮助读者... 目录Java文件与Base64之间的转化1、文件转Base64工具类2、Base64转文件工具类3、