DBNet详解及训练ICDAR2015数据集

2024-02-05 19:28

本文主要是介绍DBNet详解及训练ICDAR2015数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文地址:https://arxiv.org/pdf/1911.08947.pdf

开源代码pytorch版本:GitHub - WenmuZhou/DBNet.pytorch: A pytorch re-implementation of Real-time Scene Text Detection with Differentiable Binarization

前言

在这篇论文之前,文字检测算法主要分为两类:基于回归的方法和基于分割的方法。基于分割的方法通常涉及以下流程,如下图蓝色箭头所示:首先,通过网络输出图像的文本分割结果,即概率图,其中每个像素表示是否属于正样本的概率。然后,通过使用预设的阈值将分割结果图转换为二值图。最后,通过一些聚合操作,例如连通域分析,将像素级的结果转换为最终的文本检测结果。然而,由于涉及使用阈值来判定前景和背景的不可微分操作,因此这一部分流程无法被直接放入网络中进行训练。所以本文引入了一种新的方法。具体而言,通过学习阈值映射(threshmap)并采用可微分的操作,将阈值的转换过程嵌入到网络中进行训练。这一创新的流程如下图中红色箭头所示,通过可微分的操作来处理阈值的学习,使得整个流程可以在神经网络的训练中进行端到端的优化。通过这种方式,文本检测模型能够自适应地学习阈值,更有效地捕捉文本的分割信息,提高了检测性能。这一方法有助于简化原有基于分割方法的后处理流程,同时使整个模型更具可训练性。

网络结构

其实从下图的网络结构中不难看出,相比较于PSENet,多了一条threshold map分支罢了,该分支的主要目的是和分割图联合得到更接近二值化的二值图,属于辅助分支。

整个网络结构流程:

图像输入特征提取主干: 使用图像输入,经过一个特征提取的主干网络,该网络负责从输入图像中提取高层次的语义特征。这可以是一个卷积神经网络(CNN)的主要部分,如ResNet或其他先进的架构。

特征金字塔上采样和级联: 从特征提取主干获得的特征被送入特征金字塔。在特征金字塔中,通过上采样将不同尺寸的特征图调整到相同的尺寸,并将它们级联在一起,形成一个具有丰富多尺度信息的特征F。这有助于模型对不同大小和尺度的目标进行有效的检测和分割。

预测概率图和阈值图: 利用级联的特征F,进行概率图(probability map P)和阈值图(threshold map T)的预测。概率图通常表示每个像素属于某个类别(在这里可能是目标文本与非文本的概率),而阈值图则用于指导后续的二值化操作。这一步的目的是产生用于后续计算的中间结果。

计算近似二值图: 利用概率图P和阈值图T,通过一定的计算过程(可能是使用阈值或其他运算),得到一个近似的二值图B。这个近似二值图用于最终的文本检测,其中文本区域被二值化为前景,而非文本区域为背景。

在训练过程中,该模型通过使用相同的监督信号对概率图 P 和近似二值图 B 进行监督训练,其中概率图表示文本区域的概率,而近似二值图是文本二值化结果。在推理阶段,只需使用概率图 P 或者近似二值图B 中的任一即可获取文本检测结果,无需依赖额外的阈值图。这种设计简化了推理流程,提高了模型的实际应用效率。

模型的输出

Probability Map(概率图): 这是一个大小为w×h×1 的张量,其中 w 和 ℎ分别表示图像的宽度和高度。概率图的每个像素表示相应位置是否为文本的概率。对于二进制文本检测任务,概率图的值通常在 0 到 1 之间,表示每个像素点属于文本的概率,1 表示高置信度是文本,0 表示低置信度是文本。

Threshold Map(阈值图): 阈值图也是一个大小为 w×h×1 的张量,其中每个像素点包含一个阈值。这些阈值用于二值化概率图,将其转换为最终的二值图。阈值图的每个值表示相应位置的二值化操作的阈值。

Binary Map(二值图): 由概率图和阈值图计算得到,也是一个大小为 w×h×1 的张量。它表示最终的文本检测结果,其中每个像素点被二值化为前景(文本)或背景(非文本)。这里提到使用了 "DB 公式" 来计算二值图,而 DB(Differentiable Binarization)通常是一个近似二值化的函数,通过可微分的操作来实现对阈值的学习和调整。

DB公式

标准二值化

一般使用分割网络(segmentation network)产生的概率图(probability map P),将P转化为一个二值图P,当像素为1的时候,认定其为有效的文本区域。i和j代表了坐标点的坐标,t是预定义的阈值

可微二值化(differentiable Binarization)

可微二值化的公式如下,其实就是带一个系数的sigmoid,其中其中T是阈值图,k取50

从图像上不难看出,二值化和标准二值化很相似,且可微分,因此可以和分割网络一起联合优化

从(b)(c)图我们不难看出通过增加参数 K,可以在模型的训练过程中加速对正确预测区域和错误预测区域的学习,以更快地收敛到最优解。这样的调整可以在某些情况下提高模型的训练效率和性能。原图,gt图,threshold map图如下所示

模型训练

自动下载的预训练模型下载地址:/home/xuzhen/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth(我看了代码,他是判断有没有预训练模型没有的话才下载)

这个源代码在配置文件中加载的是train.txt和test.txt,所以我写了一个脚本,根据img文件夹和gt文件夹自动生成这两个文件的脚本

import osdef create_gt_file(img_dir, gt_dir, output_file_path):# 检查文件夹是否存在if not os.path.exists(img_dir) or not os.path.exists(gt_dir):print("Error: One or both folders do not exist.")returnimg_paths = []  gt_paths = []   # 循环读取文件夹1中的文件名for filename in os.listdir(img_dir):img_path = os.path.join(img_dir,filename)img_paths.append(img_path)# 去掉后缀并在前面加上 "gt_"gt_path = os.path.join(gt_dir, "gt_" + os.path.splitext(filename)[0] + ".txt")gt_paths.append(gt_path)# 写入文件with open(output_file_path, 'w') as output_file:# 将 img_paths 和 gt_paths 写入文件for img_path, gt_path in zip(img_paths, gt_paths):output_file.write(f"{img_path}\t{gt_path}\n")print(f"{img_path}\t{gt_path}Strings written to {output_file_path}")# 主函数
def main():img_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/test_images"gt_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/testing_localization_transcription_gt"output_file_path = "/data2/xuzhen8/yzh/projects/DBNet.pytorch/datasets/test.txt"create_gt_file(img_dir, gt_dir, output_file_path)if __name__ == "__main__":main()

每一轮训练都会打印信息,我想对这个打印信息说明一下,以便后面复习

FPS(Frames Per Second): 99.37

表示每秒处理的图像帧数。在这个上下文中,表示模型在测试阶段的推断速度。这是通过测量模型在测试集上处理图像的速度来得到的,其单位是帧数/秒。

test: recall: 0.031477, precision: 0.596330, f1: 0.059798

提供了模型在测试集上的性能指标。在这里,包括了召回率(recall)、精确度(precision)和 F1 分数(f1-score)。这些指标用于衡量模型在检测任务中的性能,其中:

召回率表示正确检测到的正类别样本占所有实际正类别样本的比例。

精确度表示模型正确检测的正类别样本占所有模型检测为正类别的样本的比例。

F1 分数是召回率和精确度的调和平均数,综合考虑了这两个指标。

current best, recall: 0.101695, precision: 0.726644, hmean: 0.178420, train_loss: 1.706732, best_model_epoch: 5.000000

提供了模型在测试集上的当前最佳性能以及训练期间的一些指标。其中:

recall、precision、hmean 是测试集上的召回率、精确度和 F1 分数。

train_loss 表示模型在训练集上的损失值,用于衡量训练过程中模型的拟合情况。

best_model_epoch 表示在训练过程中取得最佳性能的模型所对应的训练轮次。

Saving checkpoint:DBNet.pytorch/output/DBNet_resnet18_FPN_DBHead/checkpoint/

model_latest.pth

表示当前训练轮次的模型参数被保存到了指定路径下的 model_latest.pth 文件中。这通常发生在模型在测试集上取得了更好性能后,保存了当前状态的模型参数,以备将来使用或继续训练。

 

小辉问:能不能举个例子说明一下召回率、精确度、F1 分数。以便更好的理解

小G答:假设有一个二分类任务,目标是检测患有某种疾病的患者。我们的模型对每个样本都进行预测,可以分为以下四种情况:

True Positive (TP): 模型正确地预测了患有疾病的患者。

True Negative (TN): 模型正确地预测了没有患疾病的健康人。

False Positive (FP): 模型错误地预测了没有患疾病的健康人为患病。

False Negative (FN): 模型错误地预测了患有疾病的患者为健康人。

现在,我们可以使用这些概念来解释这些指标:

召回率(Recall):

召回率衡量了模型在所有实际患有疾病的样本中,有多少被成功地检测到。计算公式:

例如,如果总共有 100 名患有疾病的患者,而模型成功地检测到其中的 80 人,则召回率为 80/80+20 =0.8 或 80%。

精确度(Precision):

精确度衡量了模型在所有预测为患有疾病的样本中,有多少实际上是真正患有疾病的人。计算公式:

例如,如果模型预测了 90 个人患有疾病,而其中有 80 人确实是患有疾病的,则精确度为 80/80+10=0.888 或 88.8%。

F1 分数:

F1 分数是召回率和精确度的调和平均数,它综合考虑了两者的性能。计算公式:

F1 分数的取值范围在 [0,1],越接近 1 表示模型在召回率和精确度之间取得了更好的平衡。

这篇关于DBNet详解及训练ICDAR2015数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/681916

相关文章

Debezium 与 Apache Kafka 的集成方式步骤详解

《Debezium与ApacheKafka的集成方式步骤详解》本文详细介绍了如何将Debezium与ApacheKafka集成,包括集成概述、步骤、注意事项等,通过KafkaConnect,D... 目录一、集成概述二、集成步骤1. 准备 Kafka 环境2. 配置 Kafka Connect3. 安装 D

Java中ArrayList和LinkedList有什么区别举例详解

《Java中ArrayList和LinkedList有什么区别举例详解》:本文主要介绍Java中ArrayList和LinkedList区别的相关资料,包括数据结构特性、核心操作性能、内存与GC影... 目录一、底层数据结构二、核心操作性能对比三、内存与 GC 影响四、扩容机制五、线程安全与并发方案六、工程

Spring Cloud LoadBalancer 负载均衡详解

《SpringCloudLoadBalancer负载均衡详解》本文介绍了如何在SpringCloud中使用SpringCloudLoadBalancer实现客户端负载均衡,并详细讲解了轮询策略和... 目录1. 在 idea 上运行多个服务2. 问题引入3. 负载均衡4. Spring Cloud Load

Springboot中分析SQL性能的两种方式详解

《Springboot中分析SQL性能的两种方式详解》文章介绍了SQL性能分析的两种方式:MyBatis-Plus性能分析插件和p6spy框架,MyBatis-Plus插件配置简单,适用于开发和测试环... 目录SQL性能分析的两种方式:功能介绍实现方式:实现步骤:SQL性能分析的两种方式:功能介绍记录

在 Spring Boot 中使用 @Autowired和 @Bean注解的示例详解

《在SpringBoot中使用@Autowired和@Bean注解的示例详解》本文通过一个示例演示了如何在SpringBoot中使用@Autowired和@Bean注解进行依赖注入和Bean... 目录在 Spring Boot 中使用 @Autowired 和 @Bean 注解示例背景1. 定义 Stud

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

SQL 中多表查询的常见连接方式详解

《SQL中多表查询的常见连接方式详解》本文介绍SQL中多表查询的常见连接方式,包括内连接(INNERJOIN)、左连接(LEFTJOIN)、右连接(RIGHTJOIN)、全外连接(FULLOUTER... 目录一、连接类型图表(ASCII 形式)二、前置代码(创建示例表)三、连接方式代码示例1. 内连接(I

Go路由注册方法详解

《Go路由注册方法详解》Go语言中,http.NewServeMux()和http.HandleFunc()是两种不同的路由注册方式,前者创建独立的ServeMux实例,适合模块化和分层路由,灵活性高... 目录Go路由注册方法1. 路由注册的方式2. 路由器的独立性3. 灵活性4. 启动服务器的方式5.

Java中八大包装类举例详解(通俗易懂)

《Java中八大包装类举例详解(通俗易懂)》:本文主要介绍Java中的包装类,包括它们的作用、特点、用途以及如何进行装箱和拆箱,包装类还提供了许多实用方法,如转换、获取基本类型值、比较和类型检测,... 目录一、包装类(Wrapper Class)1、简要介绍2、包装类特点3、包装类用途二、装箱和拆箱1、装

Go语言中三种容器类型的数据结构详解

《Go语言中三种容器类型的数据结构详解》在Go语言中,有三种主要的容器类型用于存储和操作集合数据:本文主要介绍三者的使用与区别,感兴趣的小伙伴可以跟随小编一起学习一下... 目录基本概念1. 数组(Array)2. 切片(Slice)3. 映射(Map)对比总结注意事项基本概念在 Go 语言中,有三种主要