本文主要是介绍DBnet源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 前言
- 1. YAML配置文件
- 2. 配置文件读取
- 3. 构造并初始化配置文件中的类
- 4. DBNet训练过程
- 4.1 从train.py开始。
- 4.2 Trainer类实现DBNet训练过程
- 4.2.1 训练与测试数据组织方式
- 4.2 DBNet训练过程详解
- 4.2.1 模型加载
- 4.2.2 数据集加载
- 4.2.2.1 训练数据增强
- 4.2.2.2 生成标签(label generation)
- 4.2.3 训练
- 4.2.3.1 训练模型forward
- 4.2.3.2 训练过程损失计算
- 4.3 DBNet推理过程详解
- 总结
前言
之前阅读DBNet论文时,发现很多地方讲的不是很清楚,比如为什么推理阶段速度比以前的方法快 - 没有使用启发式聚类方法根据二值图产生最终的文本框?希望通过阅读源码能够解惑,并且能够训练自己的文本检测数据集。
DBnet 源码github 地址: https://github.com/MhLiao/DB
这里分享下个人的论文阅读笔记:https://blog.csdn.net/DU_YULIN/article/details/118365298
1. YAML配置文件
DBNet的配置文件包括三个部分:
(1)import: 依赖的其它配置文件,如果没有则为空;
(2)package: 当前配置文件需要加载的python package, 如果没有则为空;
(3)define:当前配置文件中的配置项,如果没有则为空;
define中包括实现DBNet模型相应功能所用的类,类中包括的属性,以及属性的值,允许通过配置的方式来修改DBNet的模型架构,比如更换DBNet的backbone, 修改数据集加载方式,修改数据增强方式等。
DBnet的配置文件分为三类:
(1)“dataset_backbone_deform/large_thre.yaml”:
(1.1)import:依赖的数据集配置文件;
(1.2)define:DBNet模型相关配置项(structure), 训练相关配置项(train),验证相关配置项(validation),日志相关配置项(logger),评估相关配置项(evaluation),这些配置项都在类Experiment有定义,同时每一配置项都有相应的类的定义,具体如下列所示:
ic15_resnet18_deform_thre.yaml:
(2)“base_dataset.yaml”: 定义了DBNet中数据集相关的配置,包括依赖的其它配置,需要用到的python包,以及训练集与验证集相关配置:数据集地址,加载数据集的类定义,数据增强类,数据增强参数等;
base_ic15.yaml:
(3)“base.yaml”: 基本的配置文件,数据集配置依赖的配置文件,配置了DBNet需要加载的python包。
base.yaml:
这里详细介绍下配置文件中define与源码的关系(以ic15_resnet18_deform_thre.yaml为例):
上图展示了DBNet源码中UML类图(仅仅包含一部分类),与配置文件define中的"name Experiment"对应。配置文件中"class"对应源码中的类名,位于“class”下面且同一级缩进的其它定义一般对应类中的成员属性,比如位于“class:Experiment”下方且与之同属于一个缩进级别的定义有“structure”, “train”, “validation”, “logger”, “evaluation”,这几个定义在源码类“Experiment”中都有对应的成员与之对应(见上图类模块-Experiment)
配置文件中对应源码的类大部分继承自类-Configuable,类Configuable以元类方式继承StateMeta这一元类,元类StateMeta主要对继承Configuable的子类中State数据类型的成员进行过滤,即将子类中所有的State数据类型的成员映射到新定义的“states”成员属性中(dict类型),并修改原有State数据类型的成员类型:改变为State类型中default成员的类型。这里State类型类似c++ void 类型,经过配置文件的解析并构造对应类后将它转换为对应的实际数据类型。
关于元类的用法如果不清楚,大家可以在网上搜索下,这里主要是动态改变子类中成员的数据类型并添加新的成员states(dict类型)。所以如果大家看代码时发现类中用到了self.states成员却找不到定义和初始化的地方,这里可以解答你的疑惑。
2. 配置文件读取
源码中应用python 模块:anyconfig 来解读YAML配置文件,并按照dictory数据类型来存储YAML中的数据,应用munch模块将读取的dictory数据转换为Munch数据类型,这样在读取YAML配置文件中的数据时不用这样coding: conf[‘key’], 可以这样coding:conf.key。
源码中通过concern\config.py中定义的类Config来读取YAML配置文件,
这里简要介绍下Config类中每个函数的作用:
(1)load(): 读取配置文件内容存储到munch变量中并返回;
(2)compile(): 解析配置文件的内容,这里主要包括将配置文件中的import即依赖的配置文件进行加载并解析,通过python 加载配置文件中的package并保存,在加载的package中查找配置文件中的define中每一个“class”并替换为相对DBNet根目录的源码路径,比如class="Configuable"替换为class=“concern.Configuable”, concern为包含类Configuable的包名(目录名)。最后返回所有配置文件中package应用python加载后存储的结果list,以及所有配置文件define中“class”被修正后的存储结果dictory。
(3)compile_conf(): 配置文件中define是一个list,调用这个方法对list中每项进行解析,compile()调用这里函数来修正“class”的值。
(4)find_class_in_modules():compile_conf()调用这一函数确认define中“class”是否在python 加载的所有配置文件package中,如果在则用“模块名.类名”替换之前的类名,否则报错。
在train.py中调用Config类解析配置文件的代码如下:
conf = Config()experiment_args = conf.compile(conf.load(args['exp']))['Experiment']
3. 构造并初始化配置文件中的类
这里以Experiment类的构造为例讲解源码中是如何根据配置文件中的类的信息来构造相应的类。
类的构造主要通过concern\config.py中Configurable类实现:
源码中定义的类基本都继承Configurable类,这里主要介绍Configurable类如何根据配置文件中的类信息构造相应的类。
根据配置文件中类的信息构造类主要通过Configurable类中静态方法:
@staticmethoddef construct_class_from_config(args):cls = Configurable.extract_class_from_args(args)return cls(**args)
该静态方法的参数args为解析后的配置文件的数据(directory), 该方法调用Configurable类中另一个静态方法获取类并返回初始化后的类的对象。
@staticmethoddef extract_class_from_args(args):cls = args.copy().pop('class')package, cls = cls.rsplit('.', 1)module = importlib.import_module(package)cls = getattr(module, cls)return cls
上述静态方法很简单,先从配置文件参数中获取class的信息,即"模块名.类名”,然后字符串分割分别得到模块名和类名,接着根据模块名应用python加载模块,最后根据类名查找模块中的属性并返回模块中对应的类。
如果上述构造函数返回Experiment类,则返回语句retrun cls(**args)
会先执行Experiment类的对象初始化函数,然后才返回初始化后的Experiment对象。
class Experiment(Configurable):structure = State(autoload=False)train = State()validation = State(autoload=False)evaluation = State(autoload=False)logger = State(autoload=True)def __init__(self, **kwargs):self.load('structure', **kwargs)cmd = kwargs.get('cmd', {})if 'name' not in cmd:cmd['name'] = self.structure.model_nameself.load_all(**kwargs)self.distributed = cmd.get('distributed', False)self.local_rank = cmd.get('local_rank', 0)if cmd.get('validate', False):self.load('validation'
这篇关于DBnet源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!