YOLO的Anchor聚类代码

2024-05-01 15:32
文章标签 代码 yolo anchor 聚类

本文主要是介绍YOLO的Anchor聚类代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码来源于GiantPandaCV ,作者BBuf

 


#coding=utf-8import xml.etree.ElementTree as ET
import numpy as npdef iou(box, clusters):"""计算一个ground truth边界盒和k个先验框(Anchor)的交并比(IOU)值。参数box: 元组或者数据,代表ground truth的长宽。参数clusters: 形如(k,2)的numpy数组,其中k是聚类Anchor框的个数返回:ground truth和每个Anchor框的交并比。"""x = np.minimum(clusters[:, 0], box[0])y = np.minimum(clusters[:, 1], box[1])if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:raise ValueError("Box has no area")intersection = x * ybox_area = box[0] * box[1]cluster_area = clusters[:, 0] * clusters[:, 1]iou_ = intersection / (box_area + cluster_area - intersection)return iou_def avg_iou(boxes, clusters):"""计算一个ground truth和k个Anchor的交并比的均值。"""return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])def kmeans(boxes, k, dist=np.median):"""利用IOU值进行K-means聚类参数boxes: 形状为(r, 2)的ground truth框,其中r是ground truth的个数参数k: Anchor的个数参数dist: 距离函数返回值:形状为(k, 2)的k个Anchor框"""# 即是上面提到的rrows = boxes.shape[0]# 距离数组,计算每个ground truth和k个Anchor的距离distances = np.empty((rows, k))# 上一次每个ground truth"距离"最近的Anchor索引last_clusters = np.zeros((rows,))# 设置随机数种子np.random.seed()# 初始化聚类中心,k个簇,从r个ground truth随机选k个clusters = boxes[np.random.choice(rows, k, replace=False)]# 开始聚类while True:# 计算每个ground truth和k个Anchor的距离,用1-IOU(box,anchor)来计算for row in range(rows):distances[row] = 1 - iou(boxes[row], clusters)# 对每个ground truth,选取距离最小的那个Anchor,并存下索引nearest_clusters = np.argmin(distances, axis=1)# 如果当前每个ground truth"距离"最近的Anchor索引和上一次一样,聚类结束if (last_clusters == nearest_clusters).all():break# 更新簇中心为簇里面所有的ground truth框的均值for cluster in range(k):clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)# 更新每个ground truth"距离"最近的Anchor索引last_clusters = nearest_clustersreturn clusters# 加载自己的数据集,只需要所有labelimg标注出来的xml文件即可
def load_dataset(path):dataset = []for xml_file in glob.glob("{}/*xml".format(path)):tree = ET.parse(xml_file)# 图片高度height = int(tree.findtext("./size/height"))# 图片宽度width = int(tree.findtext("./size/width"))for obj in tree.iter("object"):# 偏移量xmin = int(obj.findtext("bndbox/xmin")) / widthymin = int(obj.findtext("bndbox/ymin")) / heightxmax = int(obj.findtext("bndbox/xmax")) / widthymax = int(obj.findtext("bndbox/ymax")) / heightxmin = np.float64(xmin)ymin = np.float64(ymin)xmax = np.float64(xmax)ymax = np.float64(ymax)if xmax == xmin or ymax == ymin:print(xml_file)# 将Anchor的长宽放入dateset,运行kmeans获得Anchordataset.append([xmax - xmin, ymax - ymin])return np.array(dataset)if __name__ == '__main__':ANNOTATIONS_PATH = "F:\Annotations" #xml文件所在文件夹CLUSTERS = 9 #聚类数量,anchor数量INPUTDIM = 416 #输入网络大小data = load_dataset(ANNOTATIONS_PATH)out = kmeans(data, k=CLUSTERS)print('Boxes:')print(np.array(out)*INPUTDIM)print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))final_anchors = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()print("Before Sort Ratios:\n {}".format(final_anchors))print("After Sort Ratios:\n {}".format(sorted(final_anchors)))

 

这篇关于YOLO的Anchor聚类代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/952028

相关文章

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java调用Python代码的几种方法小结

《Java调用Python代码的几种方法小结》Python语言有丰富的系统管理、数据处理、统计类软件包,因此从java应用中调用Python代码的需求很常见、实用,本文介绍几种方法从java调用Pyt... 目录引言Java core使用ProcessBuilder使用Java脚本引擎总结引言python

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

SpringBoot使用注解集成Redis缓存的示例代码

《SpringBoot使用注解集成Redis缓存的示例代码》:本文主要介绍在SpringBoot中使用注解集成Redis缓存的步骤,包括添加依赖、创建相关配置类、需要缓存数据的类(Tes... 目录一、创建 Caching 配置类二、创建需要缓存数据的类三、测试方法Spring Boot 熟悉后,集成一个外

轻松掌握python的dataclass让你的代码更简洁优雅

《轻松掌握python的dataclass让你的代码更简洁优雅》本文总结了几个我在使用Python的dataclass时常用的技巧,dataclass装饰器可以帮助我们简化数据类的定义过程,包括设置默... 目录1. 传统的类定义方式2. dataclass装饰器定义类2.1. 默认值2.2. 隐藏敏感信息

opencv实现像素统计的示例代码

《opencv实现像素统计的示例代码》本文介绍了OpenCV中统计图像像素信息的常用方法和函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 统计像素值的基本信息2. 统计像素值的直方图3. 统计像素值的总和4. 统计非零像素的数量

IDEA常用插件之代码扫描SonarLint详解

《IDEA常用插件之代码扫描SonarLint详解》SonarLint是一款用于代码扫描的插件,可以帮助查找隐藏的bug,下载并安装插件后,右键点击项目并选择“Analyze”、“Analyzewit... 目录SonajavascriptrLint 查找隐藏的bug下载安装插件扫描代码查看结果总结Sona

Python开发围棋游戏的实例代码(实现全部功能)

《Python开发围棋游戏的实例代码(实现全部功能)》围棋是一种古老而复杂的策略棋类游戏,起源于中国,已有超过2500年的历史,本文介绍了如何用Python开发一个简单的围棋游戏,实例代码涵盖了游戏的... 目录1. 围棋游戏概述1.1 游戏规则1.2 游戏设计思路2. 环境准备3. 创建棋盘3.1 棋盘类

Java实现批量化操作Excel文件的示例代码

《Java实现批量化操作Excel文件的示例代码》在操作Excel的场景中,通常会有一些针对Excel的批量操作,这篇文章主要为大家详细介绍了如何使用GcExcel实现批量化操作Excel,感兴趣的可... 目录前言 | 问题背景什么是GcExcel场景1 批量导入Excel文件,并读取特定区域的数据场景2