本文主要是介绍ASTER文字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
原文:ASTER文字识别 - 知乎 (zhihu.com)
论文:
https://ieeexplore.ieee.org/document/8395027ieeexplore.ieee.org/document/8395027
code:
https://github.com/ayumiymk/aster.pytorchgithub.com/ayumiymk/aster.pytorch
弯曲文字在自然场景中很常见,很难识别,对于弯曲文字的识别一般是先进行文字区域检测,然后再进行文本图片矫正、识别。ASTER提出显式图像矫正机制,可以在无额外标注的情况下显著地提升识别网络的识别效果。
一、ASTER的网络结构
ASTER的网络结构由一个矫正网络和一个识别网络组成。
class ModelBuilder(nn.Module):"""This is the integrated model."""def __init__(self, arch, rec_num_classes, sDim, attDim, max_len_labels, eos, STN_ON=False):super(ModelBuilder, self).__init__()self.arch = archself.rec_num_classes = rec_num_classesself.sDim = sDimself.attDim = attDimself.max_len_labels = max_len_labelsself.eos = eosself.STN_ON = STN_ONself.tps_inputsize = global_args.tps_inputsizeself.encoder = create(self.arch,with_lstm=global_args.with_lstm,n_group=global_args.n_group)encoder_out_planes = self.encoder.out_planesself.decoder = AttentionRecognitionHead(num_classes=rec_num_classes,in_planes=encoder_out_planes,sDim=sDim,attDim=attDim,max_len_labels=max_len_labels)self.rec_crit = SequenceCrossEntropyLoss()if self.STN_ON:self.tps = TPSSpatialTransformer(output_image_size=tuple(global_args.tps_outputsize),num_control_points=global_args.num_control_points,margins=tuple(global_args.tps_margins))self.stn_head = STNHead(in_planes=3,num_ctrlpoints=global_args.num_control_points,activation=global_args.stn_activation)def forward(self, input_dict):return_dict = {}return_dict['losses'] = {}return_dict['output'] = {}x, rec_targets, rec_lengths = input_dict['images'], \input_dict['rec_targets'], \input_dict['rec_lengths']# rectificationif self.STN_ON:# input images are downsampled before being fed into stn_head.stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True)stn_img_feat, ctrl_points = self.stn_head(stn_input)x, _ = self.tps(x, ctrl_points)if not self.training:# save for visualizationreturn_dict['output']['ctrl_points'] = ctrl_pointsreturn_dict['output']['rectified_images'] = xencoder_feats = self.encoder(x)encoder_feats = encoder_feats.contiguous()if self.training:rec_pred = self.decoder([encoder_feats, rec_targets, rec_lengths])loss_rec = self.rec_crit(rec_pred, rec_targets, rec_lengths)return_dict['losses']['loss_rec'] = loss_recelse:rec_pred, rec_pred_scores = self.decoder.beam_search(encoder_feats, global_args.beam_width, self.eos)# rec_pred, rec_pred_scores = self.decoder.sample([encoder_feats, rec_targets, rec_lengths])rec_pred_ = self.decoder([encoder_feats, rec_targets, rec_lengths])loss_rec = self.rec_crit(rec_pred_, rec_targets, rec_lengths)return_dict['losses']['loss_rec'] = loss_recreturn_dict['output']['pred_rec'] = rec_predreturn_dict['output']['pred_rec_score'] = rec_pred_scores# pytorch0.4 bug on gathering scalar(0-dim) tensorsfor k, v in return_dict['losses'].items():return_dict['losses'][k] = v.unsqueeze(0)return return_dict
二、矫正网络
矫正网络自适应地将输入图像中的文本进行矫正并转换成一个新的图像。
STN网络的核心是将图片空间矫正过程构建成可学习模型,流程如图所示:
先将输入图片downsample到Id,定位网络与格点生成器生成TPS变换的参数,再通过采样器(sampler)生成矫正后的图片Ir 。
TPS(Thin Plate Spline)可对图片进行柔性变换,对于透视和弯曲这两种典型的不规则文字的矫正效果很好。
2.1、定位网络
定位网络分别由2个含有K个基准点的坐标集合组成。K个基准点的坐标用C表示,C =[c1,… ,cK]∈R2xK。预测出的坐标用C'表示,C' =[c'1,… ,c'K]∈R2xK。
定位网络由一个CNN网络预测出控制点坐标,并且在训练过程中无需任何坐标标注,只依赖识别网络的文本gt,完全由反向传播的梯度监督。
from __future__ import absolute_importimport math
import numpy as np
import sysimport torch
from torch import nn
from torch.nn import functional as F
from torch.nn import initdef conv3x3_block(in_planes, out_planes, stride=1):"""3x3 convolution with padding"""conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1)block = nn.Sequential(conv_layer,nn.BatchNorm2d(out_planes),nn.ReLU(inplace=True),)return blockclass STNHead(nn.Module):def __init__(self, in_planes, num_ctrlpoints, activation='none'):super(STNHead, self).__init__()self.in
这篇关于ASTER文字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!