PaddleOCR识别框架解读[04] 文本检测det模型构建

2024-03-07 11:12

本文主要是介绍PaddleOCR识别框架解读[04] 文本检测det模型构建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • det_mv3_db.yml
    • build_model函数
      • base_model类
    • build_backbone函数
      • MobileNetV3
    • build_neck函数
    • build_head函数

det_mv3_db.yml

Global:use_gpu: trueuse_xpu: falseepoch_num: 1200log_smooth_window: 20print_batch_step: 10save_model_dir: ./output/db_mv3/save_epoch_step: 1200# evaluation is run every 2000 iterationseval_batch_step: [0, 2000]cal_metric_during_train: Falsepretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrainedcheckpoints:save_inference_dir:use_visualdl: Falseinfer_img: doc/imgs_en/img_10.jpgsave_res_path: ./output/det_db/predicts_db.txtArchitecture:model_type: detalgorithm: DBTransform:Backbone:name: MobileNetV3scale: 0.5model_name: largeNeck:name: DBFPNout_channels: 256Head:name: DBHeadk: 50Loss:name: DBLossbalance_loss: truemain_loss_type: DiceLossalpha: 5beta: 10ohem_ratio: 3Optimizer:name: Adambeta1: 0.9beta2: 0.999lr:learning_rate: 0.001regularizer:name: 'L2'factor: 0PostProcess:name: DBPostProcessthresh: 0.3box_thresh: 0.6max_candidates: 1000unclip_ratio: 1.5Metric:name: DetMetricmain_indicator: hmeanTrain:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/train_icdar2015_label.txtratio_list: [1.0]transforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- IaaAugment:augmenter_args:- { 'type': Fliplr, 'args': { 'p': 0.5 } }- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }- { 'type': Resize, 'args': { 'size': [0.5, 3] } }- EastRandomCropData:size: [640, 640]max_tries: 50keep_ratio: true- MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7- MakeShrinkMap:shrink_ratio: 0.4min_text_size: 8- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader listloader:shuffle: Truedrop_last: Falsebatch_size_per_card: 16num_workers: 8use_shared_memory: TrueEval:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/test_icdar2015_label.txttransforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- DetResizeForTest:image_shape: [736, 1280]- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'shape', 'polys', 'ignore_tags']loader:shuffle: Falsedrop_last: Falsebatch_size_per_card: 1 # must be 1num_workers: 8use_shared_memory: True

build_model函数

def build_model(config):config = copy.deepcopy(config)if not "name" in config:arch = BaseModel(config)else:name = config.pop("name")mod = importlib.import_module(__name__)arch = getattr(mod, name)(config)return arch

base_model类

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionfrom paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head__all__ = ['BaseModel']class BaseModel(nn.Layer):def __init__(self, config):super(BaseModel, self).__init__()# 输入通道in_channels = config.get('in_channels', 3)# 网络类型, 目前支持det, rec, cls.model_type = config['model_type']# ==========构建transfrom==========# 识别rec任务, transfrom可以设置为TPS、None;# 检测det和分类cls任务, transform可以设置为None;# if you make model differently, you can use transfrom in det and cls.if 'Transform' not in config or config['Transform'] is None:self.use_transform = Falseelse:self.use_transform = Trueconfig['Transform']['in_channels'] = in_channelsself.transform = build_transform(config['Transform'])in_channels = self.transform.out_channels# ==========构建backbone==========if 'Backbone' not in config or config['Backbone'] is None:self.use_backbone = Falseelse:self.use_backbone = Trueconfig["Backbone"]['in_channels'] = in_channelsself.backbone = build_backbone(config["Backbone"], model_type)in_channels = self.backbone.out_channels# ==========构建neck==========# 识别rec任务, neck可以是cnn、rnn或者reshape(None);# 检测det任务, neck可以是FPN、BIFPN等;# 分类cls任务, neck是none.if 'Neck' not in config or config['Neck'] is None:self.use_neck = Falseelse:self.use_neck = Trueconfig['Neck']['in_channels'] = in_channelsself.neck = build_neck(config['Neck'])in_channels = self.neck.out_channels# ==========构建head==========if 'Head' not in config or config['Head'] is None:self.use_head = Falseelse:self.use_head = Trueconfig["Head"]['in_channels'] = in_channelsself.head = build_head(config["Head"])self.return_all_feats = config.get("return_all_feats", False)def forward(self, x, data=None):# 以rec任务为例,输入x, 即data['image']的shape为[bs,3,48,320], # data['label_ctc']的shape为[bs,30]# data['label_sar']的shape为[bs,30]# data['length']的shape为[bs]# data['valid_ratio']的shape为[bs]y = dict()if self.use_transform:x = self.transform(x)# 以det任务为例,骨干网络MobileNetv3_large输出为列表# 特征图大小分别为原图的1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160], [bs, 24, 80, 80], [bs, 56, 40, 40], [bs, 480, 20, 20]# 以rec任务为例,骨干网络MobileNetV1Enhance输出为[bs, 512, 1, 40]if self.use_backbone:x = self.backbone(x)if isinstance(x, dict):y.update(x)else:y["backbone_out"] = xfinal_name = "backbone_out"# 以det任务为例,Neck网络DBFPN输出为特征图为原图的1/4大小 # [bs, 256, 160, 160]if self.use_neck:x = self.neck(x)if isinstance(x, dict):y.update(x)else:y["neck_out"] = xfinal_name = "neck_out"# 以det任务为例,Head网络DBHead输出为字典# 特征图大小为原图的大小,{'maps': y}  [bs, 3, 160, 160]# 以rec任务为例,Head网络MultiHead(CTCHead + SARHead)输出为字典# 'ctc_neck': [bs, 40, 64]# 'ctc_head': [bs, 40, 35], 35个字符是因为character_dict_path + blank + " "# 'sar_head': [bs, 30, 36], 36个字符是因为character_dict_path + " " + "<UKN>" + "<BOS/EOS>" + "<PAD>"if self.use_head:x = self.head(x, targets=data)if isinstance(x, dict) and 'ctc_neck' in x.keys():y["neck_out"] = x["ctc_neck"]y["head_out"] = xelif isinstance(x, dict):y.update(x)else:y["head_out"] = xfinal_name = "head_out"if self.return_all_feats:if self.training:return yelif isinstance(x, dict):return xelse:return {final_name: x}else:return x

build_backbone函数

__all__ = ["build_backbone"]def build_backbone(config, model_type):if model_type == "det" or model_type == "table":from .det_mobilenet_v3 import MobileNetV3from .det_resnet import ResNetfrom .det_resnet_vd import ResNet_vdfrom .det_resnet_vd_sast import ResNet_SASTfrom .det_pp_lcnet import PPLCNetsupport_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"]if model_type == "table":from .table_master_resnet import TableResNetExtrasupport_dict.append('TableResNetExtra')elif model_type == "rec" or model_type == "cls":from .rec_mobilenet_v3 import MobileNetV3from .rec_resnet_vd import ResNetfrom .rec_resnet_fpn import ResNetFPNfrom .rec_mv1_enhance import MobileNetV1Enhancefrom .rec_nrtr_mtb import MTBfrom .rec_resnet_31 import ResNet31from .rec_resnet_32 import ResNet32from .rec_resnet_45 import ResNet45from .rec_resnet_aster import ResNet_ASTERfrom .rec_micronet import MicroNetfrom .rec_efficientb3_pren import EfficientNetb3_PRENfrom .rec_svtrnet import SVTRNetfrom .rec_vitstr import ViTSTRfrom .rec_resnet_rfl import ResNetRFLfrom .rec_densenet import DenseNetsupport_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB','ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet','EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL','DenseNet']elif model_type == 'e2e':from .e2e_resnet_vd_pg import ResNetsupport_dict = ['ResNet']elif model_type == 'kie':from .kie_unet_sdmgr import Kie_backbonefrom .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForResupport_dict = ['Kie_backbone', 'LayoutLMForSer', 'LayoutLMv2ForSer','LayoutLMv2ForRe', 'LayoutXLMForSer', 'LayoutXLMForRe']elif model_type == 'table':from .table_resnet_vd import ResNetfrom .table_mobilenet_v3 import MobileNetV3support_dict = ['ResNet', 'MobileNetV3']else:raise NotImplementedErrormodule_name = config.pop('name')assert module_name in support_dict, Exception("when model typs is {}, backbone only support {}".format(model_type, support_dict))module_class = eval(module_name)(**config)return module_class

MobileNetV3

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr__all__ = ['MobileNetV3']def make_divisible(v, divisor=8, min_value=None):if min_value is None:min_value = divisornew_v = max(min_value, int(v+divisor/2)//divisor*divisor)if new_v < 0.9*v:new_v += divisorreturn new_vclass ConvBNLayer(nn.Layer):def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, if_act=True, act=None):super(ConvBNLayer, self).__init__()self.if_act = if_actself.act = actself.conv = nn.Conv2D(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias_attr=False)self.bn = nn.BatchNorm(num_channels=out_channels, act=None)def forward(self, x):x = self.conv(x)x = self.bn(x)if self.if_act:if self.act == "relu":x = F.relu(x)elif self.act == "hardswish":x = F.hardswish(x)else:print("The activation function({}) is selected incorrectly.".format(self.act))exit()return xclass ResidualUnit(nn.Layer):def __init__(self, in_channels, mid_channels, out_channels, kernel_size, stride, use_se, act=None):super(ResidualUnit, self).__init__()self.if_shortcut = stride == 1 and in_channels == out_channelsself.if_se = use_se# 1x1卷积self.expand_conv = ConvBNLayer(in_channels=in_channels,out_channels=mid_channels,kernel_size=1,stride=1,padding=0,if_act=True,act=act)# 膨胀卷积self.bottleneck_conv = ConvBNLayer(in_channels=mid_channels,out_channels=mid_channels,kernel_size=kernel_size,stride=stride,padding=int((kernel_size - 1) // 2),groups=mid_channels,if_act=True,act=act)# SE注意力机制if self.if_se:self.mid_se = SEModule(mid_channels)# 1x1卷积self.linear_conv = ConvBNLayer(in_channels=mid_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,if_act=False,act=None)def forward(self, inputs):x = self.expand_conv(inputs)x = self.bottleneck_conv(x)if self.if_se:x = self.mid_se(x)x = self.linear_conv(x)if self.if_shortcut:x = paddle.add(inputs, x)return xclass SEModule(nn.Layer):def __init__(self, in_channels, reduction=4):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2D(1)self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=in_channels // reduction,kernel_size=1,stride=1,padding=0)self.conv2 = nn.Conv2D(in_channels=in_channels // reduction,out_channels=in_channels,kernel_size=1,stride=1,padding=0)def forward(self, inputs):outputs = self.avg_pool(inputs)outputs = self.conv1(outputs)outputs = F.relu(outputs)outputs = self.conv2(outputs)outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)return inputs * outputsclass MobileNetV3(nn.Layer):def __init__(self, in_channels=3, model_name='large', scale=0.5, disable_se=False, **kwargs):super(MobileNetV3, self).__init__()# 不启用注意力机制SEself.disable_se = disable_seif model_name == "large":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, False, 'relu', 1],[3, 64, 24, False, 'relu', 2],[3, 72, 24, False, 'relu', 1],[5, 72, 40, True, 'relu', 2],[5, 120, 40, True, 'relu', 1],[5, 120, 40, True, 'relu', 1],[3, 240, 80, False, 'hardswish', 2],[3, 200, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 480, 112, True, 'hardswish', 1],[3, 672, 112, True, 'hardswish', 1],[5, 672, 160, True, 'hardswish', 2],[5, 960, 160, True, 'hardswish', 1],[5, 960, 160, True, 'hardswish', 1],]cls_ch_squeeze = 960elif model_name == "small":cfg = [# k, exp, c,  se,     nl,  s,[3, 16, 16, True, 'relu', 2],[3, 72, 24, False, 'relu', 2],[3, 88, 24, False, 'relu', 1],[5, 96, 40, True, 'hardswish', 2],[5, 240, 40, True, 'hardswish', 1],[5, 240, 40, True, 'hardswish', 1],[5, 120, 48, True, 'hardswish', 1],[5, 144, 48, True, 'hardswish', 1],[5, 288, 96, True, 'hardswish', 2],[5, 576, 96, True, 'hardswish', 1],[5, 576, 96, True, 'hardswish', 1],]cls_ch_squeeze = 576else:raise NotImplementedError("mode[" + model_name + "_model] is not implemented!")supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]assert scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale, scale)inplanes = 16# conv1self.conv = ConvBNLayer(in_channels=in_channels,out_channels=make_divisible(inplanes * scale),kernel_size=3,stride=2,padding=1,groups=1,if_act=True,act='hardswish')self.stages = []self.out_channels = []block_list = []i = 0inplanes = make_divisible(inplanes * scale)# k表示卷积核大小,kernal_size;# exp表示隐藏层通道数;# c表示输出通道数;# se表示是否使用SENet;# nl表示激活函数;# s表示stride;for (k, exp, c, se, nl, s) in cfg:se = se and not self.disable_sestart_idx = 2 if model_name == 'large' else 0if s == 2 and i > start_idx:self.out_channels.append(inplanes)self.stages.append(nn.Sequential(*block_list))block_list = []block_list.append(ResidualUnit(in_channels=inplanes,mid_channels=make_divisible(scale * exp),out_channels=make_divisible(scale * c),kernel_size=k,stride=s,use_se=se,act=nl))inplanes = make_divisible(scale * c)i += 1block_list.append(ConvBNLayer(in_channels=inplanes,out_channels=make_divisible(scale * cls_ch_squeeze),kernel_size=1,stride=1,padding=0,groups=1,if_act=True,act='hardswish'))self.stages.append(nn.Sequential(*block_list))self.out_channels.append(make_divisible(scale * cls_ch_squeeze))for i, stage in enumerate(self.stages):self.add_sublayer(sublayer=stage, name="stage{}".format(i))def forward(self, x):# 输入shape [16, 3, 640, 640]x = self.conv(x)    out_list = []# 有四个stage, 1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160]# [bs, 24, 80, 80]# [bs, 56, 40, 40]# [bs, 480, 20, 20]for stage in self.stages:x = stage(x)out_list.append(x)return out_list

build_neck函数

__all__ = ['build_neck']def build_neck(config):from .db_fpn import DBFPN, RSEFPN, LKPANfrom .east_fpn import EASTFPNfrom .sast_fpn import SASTFPNfrom .rnn import SequenceEncoderfrom .pg_fpn import PGFPNfrom .table_fpn import TableFPNfrom .fpn import FPNfrom .fce_fpn import FCEFPNfrom .pren_fpn import PRENFPNfrom .csp_pan import CSPPANfrom .ct_fpn import CTFPNfrom .fpn_unet import FPN_UNetfrom .rf_adaptor import RFAdaptorsupport_dict = ['FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN','SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN','RFAdaptor', 'FPN_UNet']module_name = config.pop('name')assert module_name in support_dict, Exception('neck only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class

build_head函数

__all__ = ['build_head']def build_head(config):# det headfrom .det_db_head import DBHeadfrom .det_east_head import EASTHeadfrom .det_sast_head import SASTHeadfrom .det_pse_head import PSEHeadfrom .det_fce_head import FCEHeadfrom .e2e_pg_head import PGHeadfrom .det_ct_head import CT_Head# rec headfrom .rec_ctc_head import CTCHeadfrom .rec_att_head import AttentionHeadfrom .rec_srn_head import SRNHeadfrom .rec_nrtr_head import Transformerfrom .rec_sar_head import SARHeadfrom .rec_aster_head import AsterHeadfrom .rec_pren_head import PRENHeadfrom .rec_multi_head import MultiHeadfrom .rec_spin_att_head import SPINAttentionHeadfrom .rec_abinet_head import ABINetHeadfrom .rec_robustscanner_head import RobustScannerHeadfrom .rec_visionlan_head import VLHeadfrom .rec_rfl_head import RFLHeadfrom .rec_can_head import CANHead# cls headfrom .cls_head import ClsHead# kie headfrom .kie_sdmgr_head import SDMGRHead# table headfrom .table_att_head import TableAttentionHead, SLAHeadfrom .table_master_head import TableMasterHeadsupport_dict = ['DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead','ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer','TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead','MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead','VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead','DRRGHead', 'CANHead']if config['name'] == 'DRRGHead':from .det_drrg_head import DRRGHeadsupport_dict.append('DRRGHead')module_name = config.pop('name')assert module_name in support_dict, Exception('head only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class

这篇关于PaddleOCR识别框架解读[04] 文本检测det模型构建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/783310

相关文章

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Python GUI框架中的PyQt详解

《PythonGUI框架中的PyQt详解》PyQt是Python语言中最强大且广泛应用的GUI框架之一,基于Qt库的Python绑定实现,本文将深入解析PyQt的核心模块,并通过代码示例展示其应用场... 目录一、PyQt核心模块概览二、核心模块详解与示例1. QtCore - 核心基础模块2. QtWid

使用Python实现文本转语音(TTS)并播放音频

《使用Python实现文本转语音(TTS)并播放音频》在开发涉及语音交互或需要语音提示的应用时,文本转语音(TTS)技术是一个非常实用的工具,下面我们来看看如何使用gTTS和playsound库将文本... 目录什么是 gTTS 和 playsound安装依赖库实现步骤 1. 导入库2. 定义文本和语言 3

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

SpringCloud负载均衡spring-cloud-starter-loadbalancer解读

《SpringCloud负载均衡spring-cloud-starter-loadbalancer解读》:本文主要介绍SpringCloud负载均衡spring-cloud-starter-loa... 目录简述主要特点使用负载均衡算法1. 轮询负载均衡策略(Round Robin)2. 随机负载均衡策略(