CTPN源码解析3.1-model()函数解析

2024-03-03 18:32
文章标签 源码 函数 解析 model 3.1 ctpn

本文主要是介绍CTPN源码解析3.1-model()函数解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文本检测算法一:CTPN

CTPN源码解析1-数据预处理split_label.py

CTPN源码解析2-代码整体结构和框架

CTPN源码解析3.1-model()函数解析

CTPN源码解析3.2-loss()函数解析

CTPN源码解析4-损失函数

CTPN源码解析5-文本线构造算法构造文本行

CTPN训练自己的数据集

由于解析的这个CTPN代码是被banjin-xjyeragonruan大神重新封装过的,所以代码整体结构非常的清晰,简洁!不像上次解析FasterRCNN的代码那样跳来跳去,没跳几步脑子就被跳乱了[捂脸],向大神致敬!PS:里面肯定会有理解和注释错误的,欢迎批评指正!

解析源码地址:https://github.com/eragonruan/text-detection-ctpn

知乎:从代码实现的角度理解CTPN:https://zhuanlan.zhihu.com/p/49588885

知乎:理解文本检测网络CTPN:https://zhuanlan.zhihu.com/p/77883736

知乎:场景文字检测—CTPN原理与实现:https://zhuanlan.zhihu.com/p/34757009

 

model()函数流程

model()函数代码

'''
0)传入图像,图像每个通道数减去相应的值,再将3个通道合并成一个图像
1)通过vgg16获得特征图conv5_3,shape(?,?,?,512)
2)滑动窗口获得特征向量rpn_conv,shape(?,?,?,512)
3)将得到的特征向量rpn_conv输入Bilstm中,得到lstm_output,shape(?,?,?,512)的输出
4)将lstm_output分别送入全连接层,得到 bbox_pred(预测框坐标)shape(?,?,?,40),cls_pred(分类概率值) shape(?,?,?,20)。
5)shape转换,返回相应的值
'''
def model(image):image = mean_image_subtraction(image) #图像每个通道数减去相应的值,再将3个通道合并成一个图像with slim.arg_scope(vgg.vgg_arg_scope()):conv5_3 = vgg.vgg_16(image)  #nets/vgg.py,VGG16作为基础网络,提取特征图  shape(N,H,W,512)rpn_conv = slim.conv2d(conv5_3, 512, 3) #在conv5_3上做3x3滑窗,又卷积一次  shape(N,H,W,512)# B×H×W×C大小的feature map经过BLSTM得到[B*H,W,512]大小的lstm_outputlstm_output = Bilstm(rpn_conv, 512, 128, 512, scope_name='BiLSTM')  # shape(?,?,?,512)# 本代码做了调整:1.[B*H,W,512]大小的lstm_output没有接卷积层(FC代表卷积)# 2.[B*H,W,512]大小的lstm_output直接预测的四个回归量bbox_pred = lstm_fc(lstm_output, 512, 10 * 4, scope_name="bbox_pred") #网络预测回归输出  # shape(?,?,?,40)cls_pred = lstm_fc(lstm_output, 512, 10 * 2, scope_name="cls_pred")   #网络预测分类输出  # shape(?,?,?,20)# transpose: (1, H, W, A x d) -> (1, H, WxA, d)cls_pred_shape = tf.shape(cls_pred) # 将矩阵的维度输出为一个维度矩阵 shape(?,?,?,20)-> shape(4,?)cls_pred_reshape = tf.reshape(cls_pred, [cls_pred_shape[0], cls_pred_shape[1], -1, 2]) # shape(?,?,?,20)-># shape(?,?,?,2)cls_pred_reshape_shape = tf.shape(cls_pred_reshape) # 将矩阵的维度输出为一个维度矩阵 shape(?,?,?,2)-> shape(4,?)cls_prob = tf.reshape(tf.nn.softmax(tf.reshape(cls_pred_reshape, [-1, cls_pred_reshape_shape[3]])),[-1, cls_pred_reshape_shape[1], cls_pred_reshape_shape[2], cls_pred_reshape_shape[3]],name="cls_prob")  # shape(?,?,?,?)return bbox_pred, cls_pred, cls_prob

下面按model()函数的处理步骤分别解析源码

0)传入图像,图像每个通道数减去相应的值,再将3个通道合并成一个图像

这一步在model()函数中的执行语句是:

image = mean_image_subtraction(image) #图像每个通道数减去相应的值,再将3个通道合并成一个图像
'''
图像每个通道数减去相应的值,再将3个通道合并成一个图像
'''
def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]):num_channels = images.get_shape().as_list()[-1]  #获取图像通道数if len(means) != num_channels:raise ValueError('len(means) must match the number of channels')channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images)for i in range(num_channels):channels[i] -= means[i]  #图像每个通道数减去相应的值return tf.concat(axis=3, values=channels)  #再将3个通道合并成一个图像

1)通过vgg16获得特征图conv5_3,shape(?,?,?,512)

这一步在model()函数中的执行语句是:

rpn_conv = slim.conv2d(conv5_3, 512, 3) #在conv5_3上做3x3滑窗,又卷积一次  shape(N,H,W,512)

我就不贴vgg16卷积的代码了。

2)滑动窗口获得特征向量rpn_conv,shape(?,?,?,512)

这一步在model()函数中的执行语句是:

 rpn_conv = slim.conv2d(conv5_3, 512, 3) #在conv5_3上做3x3滑窗,又卷积一次  shape(N,H,W,512)

原意是结合该点周边9个点的信息,但在tensorflow中就用卷积代替了。

3)将得到的特征向量rpn_conv输入Bilstm中,得到lstm_output,shape(?,?,?,512)的输出

这一步在model()函数中的执行语句是:

# B×H×W×C大小的feature map经过BLSTM得到[B*H,W,512]大小的lstm_outputlstm_output = Bilstm(rpn_conv, 512, 128, 512, scope_name='BiLSTM')  # shape(?,?,?,512)

双向lstm获取横向(宽度方向)序列特征

'''
#BLSTM 双向LSTM
net,  特征图
input_channel,  输入的通道数 
hidden_unit_num, 隐藏层单元数目
output_channel,  输出的通道数
scope_name       #名称
'''
def Bilstm(net, input_channel, hidden_unit_num, output_channel, scope_name):# width--->time step  width方向作为序列方向with tf.variable_scope(scope_name) as scope:shape = tf.shape(net) #获取特征图的维度信息N, H, W, C = shape[0], shape[1], shape[2], shape[3]net = tf.reshape(net, [N * H, W, C])   # 改变数据格式  # shape(N * H, W, C)net.set_shape([None, None, input_channel])    # shape(?,?,input_channel)lstm_fw_cell = tf.contrib.rnn.LSTMCell(hidden_unit_num, state_is_tuple=True) #前向lstmlstm_bw_cell = tf.contrib.rnn.LSTMCell(hidden_unit_num, state_is_tuple=True) #反向lstmlstm_out, last_state = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, net, dtype=tf.float32)lstm_out = tf.concat(lstm_out, axis=-1) # axis=1 代表在第1个维度拼接lstm_out = tf.reshape(lstm_out, [N * H * W, 2 * hidden_unit_num])# 这种初始化方法比常规高斯分布初始化、截断高斯分布初始化及 Xavier 初始化的泛化/缩放性能更好init_weights = tf.contrib.layers.variance_scaling_initializer(factor=0.01, mode='FAN_AVG', uniform=False)init_biases = tf.constant_initializer(0.0)weights = make_var('weights', [2 * hidden_unit_num, output_channel], init_weights)  # 初始化权重biases = make_var('biases', [output_channel], init_biases)  # 初始化偏移outputs = tf.matmul(lstm_out, weights) + biasesoutputs = tf.reshape(outputs, [N, H, W, output_channel]) #还原成原来的形状return outputs

4)将lstm_output分别送入全连接层,得到 bbox_pred(预测框坐标)shape(?,?,?,40),cls_pred(分类概率值) shape(?,?,?,20)。

这一步在model()函数中的执行语句是:

    # 本代码做了调整:1.[B*H,W,512]大小的lstm_output没有接卷积层(FC代表卷积)# 2.[B*H,W,512]大小的lstm_output直接预测的四个回归量bbox_pred = lstm_fc(lstm_output, 512, 10 * 4, scope_name="bbox_pred") #网络预测回归输出  # shape(?,?,?,40)cls_pred = lstm_fc(lstm_output, 512, 10 * 2, scope_name="cls_pred")   #网络预测分类输出  # shape(?,?,?,20)
'''
全连接层,改变输出通道数
'''
def lstm_fc(net, input_channel, output_channel, scope_name):with tf.variable_scope(scope_name) as scope:shape = tf.shape(net)N, H, W, C = shape[0], shape[1], shape[2], shape[3]net = tf.reshape(net, [N * H * W, C])init_weights = tf.contrib.layers.variance_scaling_initializer(factor=0.01, mode='FAN_AVG', uniform=False)init_biases = tf.constant_initializer(0.0)weights = make_var('weights', [input_channel, output_channel], init_weights) #全连接层512-》output_channelbiases = make_var('biases', [output_channel], init_biases)output = tf.matmul(net, weights) + biasesoutput = tf.reshape(output, [N, H, W, output_channel])return output

5)shape转换,返回相应的值

这一步在model()函数中的执行语句是:

    # transpose: (1, H, W, A x d) -> (1, H, WxA, d)cls_pred_shape = tf.shape(cls_pred) # 将矩阵的维度输出为一个维度矩阵 shape(?,?,?,20)-> shape(4,?)cls_pred_reshape = tf.reshape(cls_pred, [cls_pred_shape[0], cls_pred_shape[1], -1, 2]) # shape(?,?,?,20)-># shape(?,?,?,2)cls_pred_reshape_shape = tf.shape(cls_pred_reshape) # 将矩阵的维度输出为一个维度矩阵 shape(?,?,?,2)-> shape(4,?)cls_prob = tf.reshape(tf.nn.softmax(tf.reshape(cls_pred_reshape, [-1, cls_pred_reshape_shape[3]])),[-1, cls_pred_reshape_shape[1], cls_pred_reshape_shape[2], cls_pred_reshape_shape[3]],name="cls_prob")  # shape(?,?,?,?)return bbox_pred, cls_pred, cls_prob

然后整个model()操作就结束了。

这篇关于CTPN源码解析3.1-model()函数解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/770516

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

OWASP十大安全漏洞解析

OWASP(开放式Web应用程序安全项目)发布的“十大安全漏洞”列表是Web应用程序安全领域的权威指南,它总结了Web应用程序中最常见、最危险的安全隐患。以下是对OWASP十大安全漏洞的详细解析: 1. 注入漏洞(Injection) 描述:攻击者通过在应用程序的输入数据中插入恶意代码,从而控制应用程序的行为。常见的注入类型包括SQL注入、OS命令注入、LDAP注入等。 影响:可能导致数据泄

从状态管理到性能优化:全面解析 Android Compose

文章目录 引言一、Android Compose基本概念1.1 什么是Android Compose?1.2 Compose的优势1.3 如何在项目中使用Compose 二、Compose中的状态管理2.1 状态管理的重要性2.2 Compose中的状态和数据流2.3 使用State和MutableState处理状态2.4 通过ViewModel进行状态管理 三、Compose中的列表和滚动