CTC Loss及OCR经典算法CRNN实现

这里我们将基于深度神经网络CNN+CTC loss进行OCR(图像文本识别),可以使用经典CRNN网络,但是我这里使用的是全卷积网络,因为文本长度并不长,所以并不想考虑文本的序列信息,没有使用LSTM部分。

参考文献:

  1. (CTC loss)Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks
  2. (CRNN)An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition
  3. YC Note - OCR:CRNN+CTC開源加詳細解析
  4. ZhiHu - jax:CTC loss 实作与优化

OCR检测

在OCR检测问题中,设数据集格式为 D={(xi,li)}i=1MD=\{(x_i, l_i)\}_{i=1}^M,其中 xix_i 表示输入的图像数据,lil_i 表示该图像对应的文本串,假设我们的图像长宽一定均为 W×HW\times H 的灰度图像。

注:此处的 xix_i 表示的是已经通过目标检测出来的带有一行文字的图片框(如下所示),也就是说必须确保该图像中具有文字,目标检测技术可以使用YOLO,R-CNN,SDD等等。

1_pontifically
2_Senoritas
3_JERKIER
4_underbellies
5_minion
6_embracing

设数据集中全体字符集大小为 CC(包含空字符,用于占位),考虑一个样本 (x,l)(x,l),设文本串 ll 长度为 NN,设深度神经网络为 f(x,θ)f(x,\theta),于是我们可以利用卷积的平移等变性,将 ff 的输出压缩成一个高度为 11 宽度为 TTCC 维特征,于是输出中第 ii 列特征,就正好代表原始图像 W×HW\times H 中,按照宽度平均划分为 WW' 份中的第 ii 块的信息(这块信息大小也称为卷积的感受野范围)。

f:RW×HRT×1×Cf:\R^{W\times H}\to \R^{T\times 1\times C}

也就是说,如果我们图像中的文本正好就是横向排布的,而且只有一行,那么如果我们的 TT 将原图划分的足够细,那么对于原图中第 ii 个字符,ti{1,,T}\exists t_i\in \{1,\cdots,T\},使得 f(x;θ)t,1f(x;\theta)_{t,1} 表示的特征正好是该字符对应的,下图中的Feature Sequence就是 f(x;θ)f(x;\theta),其中每个特征向量对应的感受野大小就是虚线所指向在图片中的框大小:

receptive field

并且我们网络的输出 f(x;θ)tRCf(x;\theta)_{t}\in\R^{C},所以可以直接作用 softmaxsoftmax 后就是该处对应的字符概率分布,但是这里我们会很容易发现一个问题,我们的预测的特征数目 TT 一定是大于字符长度 NN 的,所以我们将会有非常多的空字符和重复字符(相近的感受野可能识别成同样的字符),举一个例子:

我们将空字符记为 ε\varepsilon,假设标签为 l=CATl = \text{CAT},则 N=3N=3,一个训练好的模型的识别结果可能是

f(x;θ)=εεCCAAεT 或者 εεCεAAεTεεf(x;\theta) = \varepsilon\varepsilon\text{CCAA}\varepsilon\text{T}\text{ 或者 }\varepsilon\varepsilon\text{C}\varepsilon\text{AA}\varepsilon\text{T}\varepsilon\varepsilon\cdots

如果两个字符之间的间距过大则模型可能识别出多个空字符 ε\varepsilon,如果字体过大或者文本特征非常明显,可能识别出重复的字符例如 AA\text{AA},那么我们就需要分别对上述两个问题进行解决,首先将连续出现的相同字符是保留一个,再将剩余的 ε\varepsilon 全部去掉即可,我们将由大小为 CC 的字符集构成的全体字符串集合记为 Π\Pi,其中的字符串记为 π\pi,将上述化简算子记为 B:ΠΠB:\Pi\to\Pi

容易发现一个小问题,如何识别 CAAT\text{CAAT} 呢?这样重复出现的字符之间就必须预测出一个空字符,例如 CAεAT\text{CA}\varepsilon\text{AT}

模型预测方法

不基于字典的理论预测方法非常直接,枚举所有可能字符串 ll 和所有可能的预测串 π\pi,求出通过简化运算得到该串 ll 的概率,取概率最大的一个即可:

lpred=argmaxlΠπ:B(π)=lP(πf(x;θ))l_{pred} = \arg\max_{l\in\Pi}\sum_{\pi:B(\pi) = l}P(\pi|f(x;\theta))

但是实际上我们当然不可能枚举出所有的 llπ\pi,所以简单来说可以通过贪心或者贪心树搜索来寻找,效果其实差不多,贪心就是直接取每个位置 tt 出现概率最大的字符,然后通过简化运算即可得到预测结果。

训练方法

CTC Loss

该问题的难点在于如何训练,简单来说,我们只需要通过极大似然就可以对概率进行优化:

maxθL(w)=(x,l)DlogP(lf(x;θ))=(x,l)Dlogπ:B(π)=lP(πf(x;θ))\max_{\theta}L(w) = -\sum_{(x,l)\in D}\log P(l|f(x;\theta)) = -\sum_{(x,l)\in D}\log\sum_{\pi:B(\pi)=l}P(\pi|f(x;\theta))

难点在于我们如何求出 π:B(π)=lP(πf(x;θ))\sum_{\pi:B(\pi)=l}P(\pi|f(x;\theta)),这就是一个动态规划(DP)问题了:

首先我们将 ll 中间隔插入 ε\varepsilon 变为 ll',例如 CAT 就变为 -C-A-T-,其中 - 表示 ε\varepsilon,假设我们的预测的序列长度为 TT,那么简化运算 BB 生成的所有序列就是下面表格中从红色点开始,对于位置 (i,j)(i,j),只能向右移动到 (i,j+1)(i,j+1)(i+1,j+1)(i+1,j+1) 或者当 lili+2l'_i\neq l'_{i+2} 时可以移动到 (i+2,j+1)(i+2,j+1),最终移动到蓝色点上,全体路径通过 BB 运算均可以得到 CAT,又由于每条路径对应了一个概率,我们只需要考虑所有路径的概率和,就是我们要求的 π:B(π)=lP(πf(x;θ))\sum_{\pi:B(\pi)=l}P(\pi|f(x;\theta)) 了。

图片来源:https://ycc.idv.tw/crnn-ctc.html

假设表格中 (i,j)(i,j) 处的概率大小为 aija_{ij},设状态 f(i,j)f(i,j) 表示从任意一个起点到达 f(i,j)f(i,j) 全部路径的概率之和,即

f(i,j)=π:B(π)=l1iπ=jk=1jaπk,kf(i,j) = \sum_{\substack{\pi:B(\pi)=l'_{1\sim i}\\|\pi|=j}}\prod_{k=1}^ja_{\pi_k,k}

考虑点 f(i,j)f(i,j) 的状态只能来自于 f(i,j1),f(i1,j1)f(i,j-1), f(i-1,j-1),或者当前 liεl'_i\neq \varepsilonlili2l'_i\neq l'_{i-2} 时可以从 f(i2,j1)f(i-2,j-1) 转移得到,则状态转移方程为

f(i,j)= {(f(i,j1)+f(i1,j1))aij,li=ε 或 li=li2,(f(i,j1)+f(i1,j1)+f(i2,j1))aij,否则.= (f(i,j1)+f(i1,j1)+f(i2,j1)[lklk1,i=2k])aij\begin{aligned} f(i,j) =&\ \begin{cases} \big(f(i,j-1)+f(i-1,j-1)\big)a_{ij},&\quad l'_i = \varepsilon\text{ 或 }l'_i = l'_{i-2},\\ \big(f(i,j-1)+f(i-1,j-1)+f(i-2,j-1)\big)a_{ij},&\quad \text{否则}. \end{cases}\\ =&\ \big(f(i,j-1)+f(i-1,j-1)+f(i-2,j-1)[l_k\neq l_{k-1},i=2k]\big)a_{ij} \end{aligned}

其中 [条件判断][\text{条件判断}] 只有当其中的条件判断为真时为 11 否则为 00

初始化:f(1,1)=a11,f(2,1)=a21,f(i,1)=0,(i=3,4,,2N+1)f(1,1) = a_{11}, f(2,1) = a_{21}, f(i,1) = 0, (i =3,4,\cdots,2N+1)

时间复杂度:O(NT)\mathcal{O}(NT),最终CTC损失函数

LCTC=log(f(2N+1,T)+f(2N,T))\mathcal{L}_{CTC} = -\log \big(f(2N+1,T) + f(2N, T)\big)

JAX 实现

但是,这里不是C++,不仅要正向求出损失函数,还需反向计算梯度,在论文2^2YC Note3^3中介绍了如何计算反向梯度,于正向方法类似,也是利用DP。

但是现在自动微分技术这么强,并且能用GPU加速,不想这么麻烦。所以下面将利用JAX实现GPU加速的CTC损失计算:

首先要有JAX的基础,知道JAX中不能直接使用 if, for 这样会使得代码变得非常慢,常用的技巧是,用 mask 或者 jax.lax.cond 代替 if,用 jax.lax.scan 代替 for,参考 optax.ctc_loss 的代码设计方法(实际运用中推荐直接调用 optax.ctc_loss),我也自己重新实现了一遍,并且在小数据上速度略高于 optax.ctc_loss,这里具体对其原理进行解释:

我们先要将概率转化为 log\log 域下的运算,即做变换 f(i,j)logf(i,j),aijlogaijf(i,j)\gets \log f(i,j),a_{ij}\gets \log a_{ij},这样可以避免浮点数的精度问题,在数值计算中我们还需要一个 ε=0+\varepsilon = 0^+,用于表示概率为 00(代码中取为 logε=105\log\varepsilon = -10^5),则原状态转移方程变化为:

f(i,j)=log[ef(i,j1)+ef(i1,j1)+ef(i2,j1)+log([lklk1,i=2k]+ε)]+aijf(i,j) = \log\left[e^{f(i,j-1)}+e^{f(i-1,j-1)}+e^{f(i-2,j-1) + \log([l_k'\neq l'_{k-1},i=2k]+\varepsilon)}\right] + a_{ij}

ff 的奇偶行分别提取出来:令g(i,j)=gij=(g1,g2,,gN),h(i,j)=hij=(h1,h2,,hN,hN+1)g(i,j) = g_{ij} = (g_1,g_2,\cdots,g_N), h(i,j) = h_{ij} = (h_1,h_2,\cdots,h_N,h_{N+1}),其中

giT= (f(2i,1),f(2i,2),,f(2i,T)),hiT= (f(2i1,1),f(2i1,2),,f(2i1,T))\begin{aligned} g_i^T =&\ \big(f(2i,1), f(2i,2),\cdots,f(2i,T)\big),\\ h_i^T =&\ \big(f(2i-1,1), f(2i-1,2),\cdots,f(2i-1,T)\big) \end{aligned}

那么 gg 就表示到达原字符串 ll 中点的概率大小,hh 表示到达空字符中点的概率大小,于是我们可以分别对他们俩进行求解:

g(i,j)= log[eg(i,j1)+eh(i,j1)+eg(i1,j1)+log([lili1]+ε)]+aijchar,h(i,j)= log[eh(i,j1)+eg(i1,j1)]+aijblank\begin{aligned} g(i,j) =&\ \log\left[{\color{red}e^{g(i,j-1)}}+e^{h(i,j-1)} + {\color{red}e^{g(i-1,j-1) + \log([l_i\neq l_{i-1}]+\varepsilon)}}\right]+a_{ij}^{char},\\ h(i,j) =&\ \log\left[{\color{red}e^{h(i,j-1)} + e^{g(i-1,j-1)}}\right] + a_{ij}^{blank} \end{aligned}

其中 aijchar=a2i,j,aijblank=a2i1,ja_{ij}^{char} = a_{2i,j}, a_{ij}^{blank} = a_{2i-1,j}

并且由红色表出的部分为二者共同的一项,只不过 g(i,j)g(i,j) 计算中需要取去除掉重复的部分,所以我们可以先计算 g(i,j)g(i,j) 中红色的项,然后补全重复的项就可以得到 h(i,j)h(i,j) 了,部分代码如下:

def update_h(h, delta):  # 错位加和
    return jnp.logaddexp(
        h,
        jnp.pad(delta, ((0,0),(1,0)), constant_values=log_eps)
    )

tmp = update_h(pre_log_h, pre_log_g + repeat * log_eps)  # 上式中红色部分
log_g = jnp.logaddexp(pre_log_g, tmp[:,:-1]) + logprob_char
log_h = update_h(
    tmp,
    pre_log_g + (1.0 - repeat) * log_eps  # 补全去除掉的部分
) + logprob_blank

最后利用 jax.lax.scanTT 对应的维度进行遍历即可,初始化 h(1,0)=log1=0,h(i,0)=g(j,0)=logε,(i2,j1)h(1,0) = \log 1 = 0, h(i,0) = g(j,0) = \log \varepsilon,\quad(i\geqslant 2,j\geqslant 1)

还有一些小技巧,例如计算 aijchara_{ij}^{char} 用到了爱因斯坦求和约定 jnp.einsum

logprobs = jax.nn.log_softmax(logits)
B, T, C = logits.shape
B, N = labels.shape
one_hot = jax.nn.one_hot(labels, C)  # (B,N,C)
logprobs_char = jnp.einsum('btc,bnc->tbn', logprobs, one_hot)  # (T,B,N)

这份代码和 optax.ctc_loss 的唯一区别就在于更新 hh 的方法,我这里使用的是 jnp.pad 然后做加法,而 optax.ctc_loss 中用的是 jnp.concatenate[h[:,:1], jnp.logaddexp(h[:,1:], delta)] 进行拼接,使用 jnp.pad 在CPU上速度能快接近一倍,在GPU上速度仍有微小优化。

我在代码 ctc_loss.py 中对 optax.ctc_loss ,我写的 ctc_loss ,知乎上 jax:CTC loss 实作与优化 中两个版本的代码,还有PyTorch官方代码 torch.ctc_loss ,总共四个代码进行了速度比较,如果在CPU上跑JAX并无优势,速度反而慢了接近一倍,但是在GPU上,JAX速度比PyTorch能更快一倍,在实际训练中训练速度应该能够更快:

GPU测试结果

网络架构

CRNN顾名思义,就是将网络拆分成CNN和RNN两部分:

  • CNN为Backbone部分,用于提取图像特征,由于我们的输入图像大小仅有 100×32100\times 32,所以这部分使用的是VGG模型(3×33\times 3卷积+2×22\times 2最大池化),同样能保持较小的参数量。
  • RNN为BiLSTM,假设我们的输入维度为 (B,T,N)(B,T,N),分别表示Batch大小、时间序列长度 TT 和特征维度 NN,两个LSTM分别按照维度 TT 的正向和反向对 (B,T,N)(B,T,N) 分别求出两个输出结果 (B,T,Nforward)(B,T,N_{forward})(B,T,Nbackward)(B,T,N_{backward}),最后对每个 TT 按照最后一个维度进行合并就得到了BiLSTM的输出 (B,T,Nforward+Nbackward)(B,T,N_{forward}+N_{backward})。简单来说就是创建了两个LSTM,分别对序列的特征进行了正向和反向的提取。

网络结构与论文2^2所给出的基本一致,如下图所示:
Network Struct

这里右侧Shape是我后续加的,现实了左侧层输出的宽度和高度或者特征的维数,输入图像为灰度,宽度必须为 44 的倍数,常用宽度为 100100,则输入尺度为 (B,32,100,1)(B,32,100,1) 输出尺度为 (B,24,C)(B,24,C),其中 CC 为不同字符的类别数目。做的一点改进在于将所有的激活函数换成了Mish,所有的卷积后都会跟上BatchNormalization。

代码实现

使用MJSynth数据集,该数据集共包含8919273个样本(但是我解压后损坏了29个图像),都是通过打印体字体进行数据增强来模仿真是环境中的各种字体,总共包含 6262 种字符,英文大小写共 5252 个,数字共 1010 个。代码包括两部分:

cd KataCV
python katacv/utils/ocr/translate_tfrecord.py --path-dataset "your/mjsynth_path"  # 创建tfreocrd文件到数据集同级目录下
python katacv/ocr/ocr_ctc.py --train --path-dataset-tfrecord "your/mjsynth_path/tfrecord"  # 开始训练

可以在 ocr_ctc.py 中加入 --wandb-track 参数使用 wandb 在线查看训练情况。模型训练参数可以在 katacv/ocr/constant.py 中进行修改,也可以通过 ocr_ctc.py

训练结果

训练上使用了cosine学习率调整,初始学习率为 5×1045\times 10^{-4},训练20个epochs,batch size大小为 128128,没有使用 2\ell^2 正则项:训练结果 wandb-OCR CRNN

通过贪心方法对字符进行预测,则整个字符串完全对应的准确率分别为:

  • 验证集准确率:91.23%91.23\%
  • 训练集准确率:96.9%96.9\%

CTC Loss及OCR经典算法CRNN实现
https://wty-yy.github.io/posts/62694/
作者
wty
发布于
2023年10月13日
许可协议