YOLOv7输出层之间的热力图

2024-08-31 20:36
文章标签 输出 之间 力图 yolov7

本文主要是介绍YOLOv7输出层之间的热力图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

我们经常看到一些论文里绘制了不同的热力图,来直观的感受其模型的有效性。特别是使用了注意力模块的网络,热力图就可以验证注意力机制是否真正聚焦到了预期的重要特征上,以便对模型的有效性和合理性进行评估。

例如Centralized Feature Pyramid for Object Detection这篇文章中展示的,就很能够表达作者改进后的模型相比之前模型的一个优越性。

在这里插入图片描述
本文就来记录一下如何使用python脚本来输出YOLOv7层之间的热力图。

添加步骤

1️⃣ 在本地的YOLOv7项目的根目录下新建heatmap.py,将以下代码复制到其中

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import torch.nn as nn
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from models.yolo import Model
from utils.torch_utils import intersect_dicts
from utils.datasets import letterbox
from utils.general import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsclass yolov7_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)self.__dict__.update(locals())def post_process(self, result):boxes_ = result[0][..., :4]logits_ = []for data in result[1]:bs, n, w, h, _ = data.size()logits_.append(data.reshape((bs, n * w * h, _)))logits_ = torch.cat(logits_, dim=1)[..., 4:]sorted, indices = torch.sort(logits_[..., 0], descending=True)logits_ = logits_[0][indices[0]]logits_[:, 0] = torch.sigmoid(logits_[:, 0])return logits_, xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, post_boxes = self.post_process(result)for i in trange(int(post_result.size(0) * self.ratio)):if post_result[i][0] < self.conf_threshold:breakself.model.zero_grad()if self.backward_type == 'conf':post_result[i, 0].backward(retain_graph=True)else:# get max probability for this predictionscore = post_result[i, 1:].max()score.backward(retain_graph=True)# process heatmapgradients = grads.gradients[0]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)#cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': 'runs/train/exp/weights/best.pt',  'cfg': 'cfg/training/yolov7_test.yaml','device': 'cuda:0','method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[-2]',  'backward_type': 'class', # class or conf'conf_threshold': 0.6, # 0.6'ratio': 0.02 # 0.02-0.1}return paramsif __name__ == '__main__':model = yolov7_heatmap(**get_params())model('inference/heat_image/001.jpg', 'heat_result')

2️⃣ 修改配置参数

文件中的主要参数配置如下:

在这里插入图片描述

参数解释
weight权重路径,训练完成后的权重文件
cfg模型文件路径,与权重所训练出来的模型文件一致
device运行的设备,和模型训练时的device参数设置一致
method可选择GradCAM,GradCAMPlusPlus和XGradCAM ,可以都试试,效果不同
layer想要输出第几层的热力图就写几,我这里写的的-2,即倒数第二层,可以多换换,看看效果
backward_type反向传播的计算类型,class表示按照类别最大概率进行计算 或 conf 通过置信度计算梯度
conf_threshold置信度阈值,设置成0.6
ratio取前多少数据,设置成0.02

在这里插入图片描述

箭头指向的数据就是行号。

3️⃣ 数据源

在这里插入图片描述
model('inference/heat_image/001.jpg', 'heat_result')中:

第一个参数inference/heat_image/001.jpg表示想要进行热力图绘制的原图像路径。

第二个参数'heat_result'表示绘制完成后输出的文件夹路径。

4️⃣ 调试

在这里插入图片描述
此时就已经绘制完成了,在指定的文件夹下就已经输出了热力图了。进度条还没有满就停止,是因为后面的目标已经不满足置信度conf_threshold的设定值。

这个进度条的长度151是之前设定的参数ratio的结果,其只会选择前0.02的目标进行热力图可视化。

博客参考链接
代码参考链接

这篇关于YOLOv7输出层之间的热力图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1124930

相关文章

python多种数据类型输出为Excel文件

《python多种数据类型输出为Excel文件》本文主要介绍了将Python中的列表、元组、字典和集合等数据类型输出到Excel文件中,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一.列表List二.字典dict三.集合set四.元组tuplepython中的列表、元组、字典

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Java对象和JSON字符串之间的转换方法(全网最清晰)

《Java对象和JSON字符串之间的转换方法(全网最清晰)》:本文主要介绍如何在Java中使用Jackson库将对象转换为JSON字符串,并提供了一个简单的工具类示例,该工具类支持基本的转换功能,... 目录前言1. 引入 Jackson 依赖2. 创建 jsON 工具类3. 使用示例转换 Java 对象为

Rust格式化输出方式总结

《Rust格式化输出方式总结》Rust提供了强大的格式化输出功能,通过std::fmt模块和相关的宏来实现,主要的输出宏包括println!和format!,它们支持多种格式化占位符,如{}、{:?}... 目录Rust格式化输出方式基本的格式化输出格式化占位符Format 特性总结Rust格式化输出方式

java父子线程之间实现共享传递数据

《java父子线程之间实现共享传递数据》本文介绍了Java中父子线程间共享传递数据的几种方法,包括ThreadLocal变量、并发集合和内存队列或消息队列,并提醒注意并发安全问题... 目录通过 ThreadLocal 变量共享数据通过并发集合共享数据通过内存队列或消息队列共享数据注意并发安全问题总结在 J

Java文件与Base64之间的转化方式

《Java文件与Base64之间的转化方式》这篇文章介绍了如何使用Java将文件(如图片、视频)转换为Base64编码,以及如何将Base64编码转换回文件,通过提供具体的工具类实现,作者希望帮助读者... 目录Java文件与Base64之间的转化1、文件转Base64工具类2、Base64转文件工具类3、

使用TomCat,service输出台出现乱码的解决

《使用TomCat,service输出台出现乱码的解决》本文介绍了解决Tomcat服务输出台中文乱码问题的两种方法,第一种方法是修改`logging.properties`文件中的`prefix`和`... 目录使用TomCat,service输出台出现乱码问题1解决方案问题2解决方案总结使用TomCat,

C++中实现调试日志输出

《C++中实现调试日志输出》在C++编程中,调试日志对于定位问题和优化代码至关重要,本文将介绍几种常用的调试日志输出方法,并教你如何在日志中添加时间戳,希望对大家有所帮助... 目录1. 使用 #ifdef _DEBUG 宏2. 加入时间戳:精确到毫秒3.Windows 和 MFC 中的调试日志方法MFC

Python使用Colorama库美化终端输出的操作示例

《Python使用Colorama库美化终端输出的操作示例》在开发命令行工具或调试程序时,我们可能会希望通过颜色来区分重要信息,比如警告、错误、提示等,而Colorama是一个简单易用的Python库... 目录python Colorama 库详解:终端输出美化的神器1. Colorama 是什么?2.

day-51 合并零之间的节点

思路 直接遍历链表即可,遇到val=0跳过,val非零则加在一起,最后返回即可 解题过程 返回链表可以有头结点,方便插入,返回head.next Code /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*