本文主要是介绍PaddleOCR识别框架解读[04] 文本检测det模型构建,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- det_mv3_db.yml
- build_model函数
- base_model类
- build_backbone函数
- MobileNetV3
- build_neck函数
- build_head函数
det_mv3_db.yml
Global:use_gpu: trueuse_xpu: falseepoch_num: 1200log_smooth_window: 20print_batch_step: 10save_model_dir: ./output/db_mv3/save_epoch_step: 1200# evaluation is run every 2000 iterationseval_batch_step: [0, 2000]cal_metric_during_train: Falsepretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrainedcheckpoints:save_inference_dir:use_visualdl: Falseinfer_img: doc/imgs_en/img_10.jpgsave_res_path: ./output/det_db/predicts_db.txtArchitecture:model_type: detalgorithm: DBTransform:Backbone:name: MobileNetV3scale: 0.5model_name: largeNeck:name: DBFPNout_channels: 256Head:name: DBHeadk: 50Loss:name: DBLossbalance_loss: truemain_loss_type: DiceLossalpha: 5beta: 10ohem_ratio: 3Optimizer:name: Adambeta1: 0.9beta2: 0.999lr:learning_rate: 0.001regularizer:name: 'L2'factor: 0PostProcess:name: DBPostProcessthresh: 0.3box_thresh: 0.6max_candidates: 1000unclip_ratio: 1.5Metric:name: DetMetricmain_indicator: hmeanTrain:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/train_icdar2015_label.txtratio_list: [1.0]transforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- IaaAugment:augmenter_args:- { 'type': Fliplr, 'args': { 'p': 0.5 } }- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }- { 'type': Resize, 'args': { 'size': [0.5, 3] } }- EastRandomCropData:size: [640, 640]max_tries: 50keep_ratio: true- MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7- MakeShrinkMap:shrink_ratio: 0.4min_text_size: 8- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader listloader:shuffle: Truedrop_last: Falsebatch_size_per_card: 16num_workers: 8use_shared_memory: TrueEval:dataset:name: SimpleDataSetdata_dir: ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/icdar2015/text_localization/test_icdar2015_label.txttransforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- DetResizeForTest:image_shape: [736, 1280]- NormalizeImage:scale: 1./255.mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: ['image', 'shape', 'polys', 'ignore_tags']loader:shuffle: Falsedrop_last: Falsebatch_size_per_card: 1 # must be 1num_workers: 8use_shared_memory: True
build_model函数
def build_model(config):config = copy.deepcopy(config)if not "name" in config:arch = BaseModel(config)else:name = config.pop("name")mod = importlib.import_module(__name__)arch = getattr(mod, name)(config)return arch
base_model类
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionfrom paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head__all__ = ['BaseModel']class BaseModel(nn.Layer):def __init__(self, config):super(BaseModel, self).__init__()# 输入通道in_channels = config.get('in_channels', 3)# 网络类型, 目前支持det, rec, cls.model_type = config['model_type']# ==========构建transfrom==========# 识别rec任务, transfrom可以设置为TPS、None;# 检测det和分类cls任务, transform可以设置为None;# if you make model differently, you can use transfrom in det and cls.if 'Transform' not in config or config['Transform'] is None:self.use_transform = Falseelse:self.use_transform = Trueconfig['Transform']['in_channels'] = in_channelsself.transform = build_transform(config['Transform'])in_channels = self.transform.out_channels# ==========构建backbone==========if 'Backbone' not in config or config['Backbone'] is None:self.use_backbone = Falseelse:self.use_backbone = Trueconfig["Backbone"]['in_channels'] = in_channelsself.backbone = build_backbone(config["Backbone"], model_type)in_channels = self.backbone.out_channels# ==========构建neck==========# 识别rec任务, neck可以是cnn、rnn或者reshape(None);# 检测det任务, neck可以是FPN、BIFPN等;# 分类cls任务, neck是none.if 'Neck' not in config or config['Neck'] is None:self.use_neck = Falseelse:self.use_neck = Trueconfig['Neck']['in_channels'] = in_channelsself.neck = build_neck(config['Neck'])in_channels = self.neck.out_channels# ==========构建head==========if 'Head' not in config or config['Head'] is None:self.use_head = Falseelse:self.use_head = Trueconfig["Head"]['in_channels'] = in_channelsself.head = build_head(config["Head"])self.return_all_feats = config.get("return_all_feats", False)def forward(self, x, data=None):# 以rec任务为例,输入x, 即data['image']的shape为[bs,3,48,320], # data['label_ctc']的shape为[bs,30]# data['label_sar']的shape为[bs,30]# data['length']的shape为[bs]# data['valid_ratio']的shape为[bs]y = dict()if self.use_transform:x = self.transform(x)# 以det任务为例,骨干网络MobileNetv3_large输出为列表# 特征图大小分别为原图的1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160], [bs, 24, 80, 80], [bs, 56, 40, 40], [bs, 480, 20, 20]# 以rec任务为例,骨干网络MobileNetV1Enhance输出为[bs, 512, 1, 40]if self.use_backbone:x = self.backbone(x)if isinstance(x, dict):y.update(x)else:y["backbone_out"] = xfinal_name = "backbone_out"# 以det任务为例,Neck网络DBFPN输出为特征图为原图的1/4大小 # [bs, 256, 160, 160]if self.use_neck:x = self.neck(x)if isinstance(x, dict):y.update(x)else:y["neck_out"] = xfinal_name = "neck_out"# 以det任务为例,Head网络DBHead输出为字典# 特征图大小为原图的大小,{'maps': y} [bs, 3, 160, 160]# 以rec任务为例,Head网络MultiHead(CTCHead + SARHead)输出为字典# 'ctc_neck': [bs, 40, 64]# 'ctc_head': [bs, 40, 35], 35个字符是因为character_dict_path + blank + " "# 'sar_head': [bs, 30, 36], 36个字符是因为character_dict_path + " " + "<UKN>" + "<BOS/EOS>" + "<PAD>"if self.use_head:x = self.head(x, targets=data)if isinstance(x, dict) and 'ctc_neck' in x.keys():y["neck_out"] = x["ctc_neck"]y["head_out"] = xelif isinstance(x, dict):y.update(x)else:y["head_out"] = xfinal_name = "head_out"if self.return_all_feats:if self.training:return yelif isinstance(x, dict):return xelse:return {final_name: x}else:return x
build_backbone函数
__all__ = ["build_backbone"]def build_backbone(config, model_type):if model_type == "det" or model_type == "table":from .det_mobilenet_v3 import MobileNetV3from .det_resnet import ResNetfrom .det_resnet_vd import ResNet_vdfrom .det_resnet_vd_sast import ResNet_SASTfrom .det_pp_lcnet import PPLCNetsupport_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"]if model_type == "table":from .table_master_resnet import TableResNetExtrasupport_dict.append('TableResNetExtra')elif model_type == "rec" or model_type == "cls":from .rec_mobilenet_v3 import MobileNetV3from .rec_resnet_vd import ResNetfrom .rec_resnet_fpn import ResNetFPNfrom .rec_mv1_enhance import MobileNetV1Enhancefrom .rec_nrtr_mtb import MTBfrom .rec_resnet_31 import ResNet31from .rec_resnet_32 import ResNet32from .rec_resnet_45 import ResNet45from .rec_resnet_aster import ResNet_ASTERfrom .rec_micronet import MicroNetfrom .rec_efficientb3_pren import EfficientNetb3_PRENfrom .rec_svtrnet import SVTRNetfrom .rec_vitstr import ViTSTRfrom .rec_resnet_rfl import ResNetRFLfrom .rec_densenet import DenseNetsupport_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB','ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet','EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL','DenseNet']elif model_type == 'e2e':from .e2e_resnet_vd_pg import ResNetsupport_dict = ['ResNet']elif model_type == 'kie':from .kie_unet_sdmgr import Kie_backbonefrom .vqa_layoutlm import LayoutLMForSer, LayoutLMv2ForSer, LayoutLMv2ForRe, LayoutXLMForSer, LayoutXLMForResupport_dict = ['Kie_backbone', 'LayoutLMForSer', 'LayoutLMv2ForSer','LayoutLMv2ForRe', 'LayoutXLMForSer', 'LayoutXLMForRe']elif model_type == 'table':from .table_resnet_vd import ResNetfrom .table_mobilenet_v3 import MobileNetV3support_dict = ['ResNet', 'MobileNetV3']else:raise NotImplementedErrormodule_name = config.pop('name')assert module_name in support_dict, Exception("when model typs is {}, backbone only support {}".format(model_type, support_dict))module_class = eval(module_name)(**config)return module_class
MobileNetV3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr__all__ = ['MobileNetV3']def make_divisible(v, divisor=8, min_value=None):if min_value is None:min_value = divisornew_v = max(min_value, int(v+divisor/2)//divisor*divisor)if new_v < 0.9*v:new_v += divisorreturn new_vclass ConvBNLayer(nn.Layer):def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, if_act=True, act=None):super(ConvBNLayer, self).__init__()self.if_act = if_actself.act = actself.conv = nn.Conv2D(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias_attr=False)self.bn = nn.BatchNorm(num_channels=out_channels, act=None)def forward(self, x):x = self.conv(x)x = self.bn(x)if self.if_act:if self.act == "relu":x = F.relu(x)elif self.act == "hardswish":x = F.hardswish(x)else:print("The activation function({}) is selected incorrectly.".format(self.act))exit()return xclass ResidualUnit(nn.Layer):def __init__(self, in_channels, mid_channels, out_channels, kernel_size, stride, use_se, act=None):super(ResidualUnit, self).__init__()self.if_shortcut = stride == 1 and in_channels == out_channelsself.if_se = use_se# 1x1卷积self.expand_conv = ConvBNLayer(in_channels=in_channels,out_channels=mid_channels,kernel_size=1,stride=1,padding=0,if_act=True,act=act)# 膨胀卷积self.bottleneck_conv = ConvBNLayer(in_channels=mid_channels,out_channels=mid_channels,kernel_size=kernel_size,stride=stride,padding=int((kernel_size - 1) // 2),groups=mid_channels,if_act=True,act=act)# SE注意力机制if self.if_se:self.mid_se = SEModule(mid_channels)# 1x1卷积self.linear_conv = ConvBNLayer(in_channels=mid_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0,if_act=False,act=None)def forward(self, inputs):x = self.expand_conv(inputs)x = self.bottleneck_conv(x)if self.if_se:x = self.mid_se(x)x = self.linear_conv(x)if self.if_shortcut:x = paddle.add(inputs, x)return xclass SEModule(nn.Layer):def __init__(self, in_channels, reduction=4):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2D(1)self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=in_channels // reduction,kernel_size=1,stride=1,padding=0)self.conv2 = nn.Conv2D(in_channels=in_channels // reduction,out_channels=in_channels,kernel_size=1,stride=1,padding=0)def forward(self, inputs):outputs = self.avg_pool(inputs)outputs = self.conv1(outputs)outputs = F.relu(outputs)outputs = self.conv2(outputs)outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)return inputs * outputsclass MobileNetV3(nn.Layer):def __init__(self, in_channels=3, model_name='large', scale=0.5, disable_se=False, **kwargs):super(MobileNetV3, self).__init__()# 不启用注意力机制SEself.disable_se = disable_seif model_name == "large":cfg = [# k, exp, c, se, nl, s,[3, 16, 16, False, 'relu', 1],[3, 64, 24, False, 'relu', 2],[3, 72, 24, False, 'relu', 1],[5, 72, 40, True, 'relu', 2],[5, 120, 40, True, 'relu', 1],[5, 120, 40, True, 'relu', 1],[3, 240, 80, False, 'hardswish', 2],[3, 200, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 184, 80, False, 'hardswish', 1],[3, 480, 112, True, 'hardswish', 1],[3, 672, 112, True, 'hardswish', 1],[5, 672, 160, True, 'hardswish', 2],[5, 960, 160, True, 'hardswish', 1],[5, 960, 160, True, 'hardswish', 1],]cls_ch_squeeze = 960elif model_name == "small":cfg = [# k, exp, c, se, nl, s,[3, 16, 16, True, 'relu', 2],[3, 72, 24, False, 'relu', 2],[3, 88, 24, False, 'relu', 1],[5, 96, 40, True, 'hardswish', 2],[5, 240, 40, True, 'hardswish', 1],[5, 240, 40, True, 'hardswish', 1],[5, 120, 48, True, 'hardswish', 1],[5, 144, 48, True, 'hardswish', 1],[5, 288, 96, True, 'hardswish', 2],[5, 576, 96, True, 'hardswish', 1],[5, 576, 96, True, 'hardswish', 1],]cls_ch_squeeze = 576else:raise NotImplementedError("mode[" + model_name + "_model] is not implemented!")supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]assert scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale, scale)inplanes = 16# conv1self.conv = ConvBNLayer(in_channels=in_channels,out_channels=make_divisible(inplanes * scale),kernel_size=3,stride=2,padding=1,groups=1,if_act=True,act='hardswish')self.stages = []self.out_channels = []block_list = []i = 0inplanes = make_divisible(inplanes * scale)# k表示卷积核大小,kernal_size;# exp表示隐藏层通道数;# c表示输出通道数;# se表示是否使用SENet;# nl表示激活函数;# s表示stride;for (k, exp, c, se, nl, s) in cfg:se = se and not self.disable_sestart_idx = 2 if model_name == 'large' else 0if s == 2 and i > start_idx:self.out_channels.append(inplanes)self.stages.append(nn.Sequential(*block_list))block_list = []block_list.append(ResidualUnit(in_channels=inplanes,mid_channels=make_divisible(scale * exp),out_channels=make_divisible(scale * c),kernel_size=k,stride=s,use_se=se,act=nl))inplanes = make_divisible(scale * c)i += 1block_list.append(ConvBNLayer(in_channels=inplanes,out_channels=make_divisible(scale * cls_ch_squeeze),kernel_size=1,stride=1,padding=0,groups=1,if_act=True,act='hardswish'))self.stages.append(nn.Sequential(*block_list))self.out_channels.append(make_divisible(scale * cls_ch_squeeze))for i, stage in enumerate(self.stages):self.add_sublayer(sublayer=stage, name="stage{}".format(i))def forward(self, x):# 输入shape [16, 3, 640, 640]x = self.conv(x) out_list = []# 有四个stage, 1/4, 1/8, 1/16, 1/32# [bs, 16, 160, 160]# [bs, 24, 80, 80]# [bs, 56, 40, 40]# [bs, 480, 20, 20]for stage in self.stages:x = stage(x)out_list.append(x)return out_list
build_neck函数
__all__ = ['build_neck']def build_neck(config):from .db_fpn import DBFPN, RSEFPN, LKPANfrom .east_fpn import EASTFPNfrom .sast_fpn import SASTFPNfrom .rnn import SequenceEncoderfrom .pg_fpn import PGFPNfrom .table_fpn import TableFPNfrom .fpn import FPNfrom .fce_fpn import FCEFPNfrom .pren_fpn import PRENFPNfrom .csp_pan import CSPPANfrom .ct_fpn import CTFPNfrom .fpn_unet import FPN_UNetfrom .rf_adaptor import RFAdaptorsupport_dict = ['FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN','SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN','RFAdaptor', 'FPN_UNet']module_name = config.pop('name')assert module_name in support_dict, Exception('neck only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class
build_head函数
__all__ = ['build_head']def build_head(config):# det headfrom .det_db_head import DBHeadfrom .det_east_head import EASTHeadfrom .det_sast_head import SASTHeadfrom .det_pse_head import PSEHeadfrom .det_fce_head import FCEHeadfrom .e2e_pg_head import PGHeadfrom .det_ct_head import CT_Head# rec headfrom .rec_ctc_head import CTCHeadfrom .rec_att_head import AttentionHeadfrom .rec_srn_head import SRNHeadfrom .rec_nrtr_head import Transformerfrom .rec_sar_head import SARHeadfrom .rec_aster_head import AsterHeadfrom .rec_pren_head import PRENHeadfrom .rec_multi_head import MultiHeadfrom .rec_spin_att_head import SPINAttentionHeadfrom .rec_abinet_head import ABINetHeadfrom .rec_robustscanner_head import RobustScannerHeadfrom .rec_visionlan_head import VLHeadfrom .rec_rfl_head import RFLHeadfrom .rec_can_head import CANHead# cls headfrom .cls_head import ClsHead# kie headfrom .kie_sdmgr_head import SDMGRHead# table headfrom .table_att_head import TableAttentionHead, SLAHeadfrom .table_master_head import TableMasterHeadsupport_dict = ['DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead','ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer','TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead','MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead','VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead','DRRGHead', 'CANHead']if config['name'] == 'DRRGHead':from .det_drrg_head import DRRGHeadsupport_dict.append('DRRGHead')module_name = config.pop('name')assert module_name in support_dict, Exception('head only support {}'.format(support_dict))module_class = eval(module_name)(**config)return module_class
这篇关于PaddleOCR识别框架解读[04] 文本检测det模型构建的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!