DBNet的简单复现

       DBNet即Real-time Scene Text Detection with Differentiable Binarization,用于OCR文本检测。网上有较多对该论文解读的文章,为了更熟悉pytorch的使用以及对DBNet更深入的理解,我利用业余时间以及参考了大佬的代码,对DBNet进行了简单版本的复现。

        简单回顾一下DBNet网络。

        网络结构:

图1  网络结构图

        Backbone采用resnet,后接FPN,再对不同尺寸的特征图进行concat,最终由两个不同的输出头给出结果。

        创新点:

        加入自适应二值化,二值化阈值由网络学习得到。

        label的生成:

        论文中对label的生成讲解得较为简洁,label分为二值化label和阈值label,这里以我自己采用的仿真数据为例子说一下。

图2  仿真数据原图

        二值化label的生成相对比较简单。

        将文本框的轮廓按照公式D=A(1-r2)/L的偏移量进行缩小,其中L是标注框的周长,A是标注框的面积,r为预设的缩放因子,论文中为0.4。

图3  仿真数据二值化label(为了显示,填充了蓝色)

        阈值label的生成

        1). 将文本框的轮廓按照上述偏移量D进行扩大

        2). 取得扩大后的轮廓的外接矩形(方便采用numpy的广播进行快速计算)

        3). 计算外接矩形中的每个点到文本框原始轮廓每条线段的距离,并取其中的最小距离

        4). 将所求的最小距离除以偏移量D进行归一化

        5). 1-归一化结果,小于0的值变为0,大于1的值变为1

        6). 因为是阈值label,需要进行缩放,论文中将1缩放到0.7,0缩放到0.3

图3  仿真数据阈值label(为了显示,该图阈值未进行缩放)

图4  仿真数据阈值label与二值化label叠加

        Loss函数:

        L = Ls + α×Lb + β×Lt,其中,Ls为概率图的loss, Lb为二值图的loss, Lt 为阈值图的loss。本文中α和β取值分别为1.0和10。Ls和Lb采用二值交叉熵(BCE)求解,并使用了hard negative mining,Lt使用的是L1 loss。

        我的复现:

        完全从零搭建了Resnet50、FPN以及DB_Head,搭建的代码不够简洁,但逻辑清晰,便于理解和修改。大佬的代码中loss函数修改为了BalanceCrossEntropyLoss、DiceLoss以及MaskL1Loss,而我的只用了pytorch自带的BCE以及L1loss。数据的读取,标签的生成,网络结构,以及模型的训练和推测都放在DBnet_pytorch.py文件中。

https://github.com/yts2020/DBnet_pytorch

在仿真数据集上进行了训练,模型的推测结果如下:

图5  模型预测结果

猜你喜欢

转载自blog.csdn.net/ytsaiztt/article/details/118090611