本文主要是介绍【CTC】CTC1D原理/代码/资料+2D CTC LOSS,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1 1D CTC
1.1 简介
就不写了
1.2 核心思想
和大多数有监督学习一样,CTC 使用最大似然标准进行训练。
给定输入 x,输出 l 的条件概率为:
其中,B-1(l)表示了长度为 T 且示经过 B 结果为 l 字符串的集合。
CTC 假设每一步输出的概率是(相对于输入)条件独立的,因此有:
p ( π ∣ x ) = ∏ y π t t , ∀ π ∈ L ′ T p(\pi|x) = \prod y^t_{\pi_t}, \forall \pi \in L^{\prime T} p(π∣x)=∏yπtt,∀π∈L′T
然而,直接按上式我们没有办理有效的计算似然值。下面用动态规划解决似然的计算及梯度计算, 涉及前向算法和后向算法。
1.3 图解原理
转载自[5]
如下图,为了更形象表示问题的搜索空间,用X轴表示时间序列, Y轴表示输出序列,并把输出序列做标准化处理,输出序列中间和头尾都加上blank,用l表示最终标签,l’表示扩展后的形式,则由2|l| + 1 = 2|l’|,比如:l=apple => l’=a_p_p_l_e
图中并不是所有的路径都是合法路径,所有的合法路径需要遵循一些约束,如下图:
所以,依据以上约束规则,遍历所有映射为“apple”的合法路径,最终时序T=8,标签labeling=“apple”的全部路径如下图:
接下来,如何计算这些路径的概率总和?暴力遍历?分而治之?作者借鉴HMM的Forward-Backward算法思路,利用动态规划算法求解,可以将路径集合分为前向和后向两部分,如下图所示:
通过动态规划求解出前向概率之后,可以用前向概率来计算CTC Loss函数,如下图:
说明:可将上面的α(t)理解成一个转移矩阵,走过的路径即为label,矩阵的值表示概率
根据 α 的定义,有如下递归关系:
α t ( s ) = { ( α t − 1 ( s ) + α t − 1 ( s − 1 ) ) y l s ′ t , i f l s ′ = b o r l s − 2 ′ = l s ′ ( α t − 1 ( s ) + α t − 1 ( s − 1 ) + α t − 1 ( s − 2 ) ) y l s ′ t o t h e r w i s e \alpha_t(s) = \{ \begin{array}{l} (\alpha_{t-1}(s)+\alpha_{t-1}(s-1)) y^t_{l^\prime_s},\ \ \ if\ l^\prime_s = b \ or\ l_{s-2}^\prime = l_s^{\prime} \\ (\alpha_{t-1}(s)+\alpha_{t-1}(s-1) + \alpha_{t-1}(s-2)) y^t_{l^\prime_s} \ \ otherwise \end{array} αt(s)={(αt−1(s)+αt−1(s−1))yls′t, if ls′=b or ls−2′=ls′(αt−1(s)+αt−1(s−1)+αt−1(s−2))yls′t otherwise
case 2
递归公式中 case 2 是一般的情形。如图所示,t 时刻字符为 s 为 blank 时,它可能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 a,2)字符发生转换,即上个字符是非 a 的字符。第二种情况又分为两种情形,2.1)上一字符是 blank;2.2)a 由非 blank 字符直接跳转而来() 操作中, blank 最终会被去掉,因此 blank 并不是必须的)。
case 1
递归公式 case 1 是特殊的情形。
如图所示,t 时刻字符为 s 为 blank 时,它只能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 blank,2)字符发生转换,即上个字符是非 blank 字符。t 时刻字符为 s 为非 blank 时,类似于 case 2,但是这时两个相同字符之间的 blank 不能省略(否则无法区分”aa”和”a”),因此,也只有两种跳转情况。
1.4 demo code
必须理解。有相应的注释。主要思路就是:
- 先求当前步的所有可能转移概率的和
- 转移概率和×label的预测概率
import numpy as npnp.random.seed(1111)T, V = 12, 5
m, n = 6, Vx = np.random.random([T, m]) # T x m
w = np.random.random([m, n]) # weights, m x ndef softmax(logits):max_value = np.max(logits, axis=1, keepdims=True)exp = np.exp(logits - max_value)exp_sum = np.sum(exp, axis=1, keepdims=True)dist = exp / exp_sumreturn distdef toy_nw(x):y = np.matmul(x, w) # T x n y = softmax(y)return yy = toy_nw(x)
print(y)
print(y.sum(1, keepdims=True))def forward(y, labels):T, V = y.shapeL = len(labels) # 步长alpha = np.zeros([T, L]) # init初始化第一步的概率alpha[0, 0] = y[0, labels[0]] # 第一步的标签为blank时,pred的概率 // alpha是转移概率?alpha[0, 1] = y[0, labels[1]] # 第一步的标签为第一个字符时,pred的概率for t in range(1, T): # step,第n步的标签为s时for i in range(L): # 标签长度s = labels[i]a = alpha[t - 1, i] if i - 1 >= 0: # case1,有两种方式可以转移到当前位置a += alpha[t - 1, i - 1]if i - 2 >= 0 and s != 0 and s != labels[i - 2]: # case 2,有三种方式可以转移到当前位置,转移概率×lable概率a += alpha[t - 1, i - 2]alpha[t, i] = a * y[t, s]return alphalabels = [0, 3, 0, 3, 0, 4, 0] # 0 for blank
alpha = forward(y, labels)
print(alpha)p = alpha[-1, -1] + alpha[-1, -2]
print(p)
1.5 pytorch code
详细请看:ctc_loss.py
从上面可以知道,涉及到大量的概率值计算,这些概率值往往是很小的浮点数。而且概率值相乘后会越变越小,计算起来会损失精度,为了保持准确度,统一将这些概率值进行log处理,再参与运算。也就是说,在代码中处理的概率是对数域的值。所以网络输出的pred,会先进行torch.log操作。具体的计算请参考[2]
1.6 1D ctc 的局限性
- 1d ctc在高度方向上必须压缩成一维,这样在处理弯曲文本的时候,会存在字符在宽度方向分割不好的情况。于是有了后续的2D CTC LOSS
2 2D CTC LOSS
-
论文
-
2d比1d多了个高度,还是采用转移矩阵的方式来理解。相对于1d,2d多了一个h方向,转移矩阵相当于一个三维矩阵。
下图其实不够具体,没有清晰的解释转移矩阵的效果
2.2 网络结构图
- 网络有两个输出分支,1为batchch*w形状的在c维度的softmax表示每个位置,预测字符的概率。2为batch×1×h×w形状在h维度的softmax,表示在h方向选择的概率
- 而h方向上下跳我们可以利用一个网络进行学习,上面第一个输出是wh各个位置的概率输出向量,而下面第二个输出是各个位置在h方向上跳动的概率,由于最后一列不用跳,因此输出是(w-1) h * h
每个位置在h方向跳动的概率和为1
- 同样在计算2-D CTC loss时依然可以用到动态规划,只是在多了个h方向(将原来的某一个点(一个概率值),换成某一条竖线,变成h个概率值乘以跳转概率的和)
2.3 局限性
- 2D ctc loss还是采用序列(从左到右)的动态规划,所以其相对于1d ctc只是增加了一个h方向。能做弯曲文本的识别,但是还是单行。无法进行多行文本的识别。例如下图:
Reference
- 高大上的动图
- 对数域的计算log_add
- 【Learning Notes】CTC 原理及实现
- 2-D CTC Loss
- CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
这篇关于【CTC】CTC1D原理/代码/资料+2D CTC LOSS的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!