这篇论文已经有很多人写过解析了,方法就大致说一说就好了。看这篇论文的时候学习了他们代码的实现,感觉学到了很多东西,就在这里讲一讲和代码实现有关的东西。
论文提要
我看的代码,原作者的博客讲了原理方面的内容。首先用FPN在图片中生成“推荐区域”,将“推荐区域”利用广度优先搜索进行合并,得到最后的结果。
论文中主要的东西是,使用FPN提取出图片中不同粗细的可能是文字的部分,然后使用广度优先搜索将FPN得到的几个部分给区分开。
代码实现
复现的论文神经网络部分使用的是tensorflow,广度优先搜索部分使用C++实现。
先从train.py开始看,103行定义了损失函数,在tower_loss函数中,构建了模型。模型的输出seg_maps是一个6通道的tensor,对应了论文中segmentation result。在train.py中没有引用到pse,pse在训练的过程中没有用到。
预测的过程在eval.py中。
# eval.py,第76行
def detect(seg_maps, timer, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1)
其中,min_area_thresh是一个连通分量中至少有10个像素,seg_map_thresh是在返回的seg_map中,0.9以下的被变成0,以上的变为1,以此将seg_map变成二值图。在论文中出现的kernel就是图中变成1的部分。在detect函数中,调用了pse,这部分是使用C++实现的。python中调用C++使用了pybind11。
pybind11
pse文件夹中包含了pse的实现。其中include目录是pybind11的开源代码。广度优先的过程都在pse.cpp中。
在pse/init,py中队pse.cpp进行了编译。
pybind11中,规定PYBIND!!_MODULE作为一个接口,写在C++文件中,编译的时候会将函数与python中的函数绑定。
m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6);
第一个pse_cpp是python中绑定的函数名,第二个&pse::pse是在C++文件中待绑定的函数py::arg声明了参数以及默认值。在pse::pse中实现了一个广度优先搜索。
__init__.py中,对pse进行了一次封装。
label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4)
cv2.connectedComponents对模型求出来的最后一个kernel求了一次连通分量。label_num是图中连通分量的个数,label是带有标签的图,如果在连通分量里面,那个像素的值就是对应的连通分量编号,否则就是0。
接下来的for循环将小于10个像素的连通分量删除。
dataloader
在论文的readme中有一句话,只支持icdar2017的格式。
right now , only support icdar2017 data format input, like (116,1179,206,1179,206,1207,116,1207,"###"), but you can modify data_provider.py to support polygon format input
在train.py中,数据生成使用的是
data_generator = data_provider.get_batch(num_workers=FLAGS.num_readers,
input_size=FLAGS.input_size,
batch_size=FLAGS.batch_size_per_gpu * len(gpus))
get_batch中,类GeneratorEnqueuer使用的数据生成器是generator(**kargs),生成器的返回值有images, image_fns, seg_maps, training_masks,其中有用的是images, seg_maps, training_masks
读取标注的函数是data_provider.py中的load_annotation,在第280行调用了这个函数,这个函数就是读取标记用的,如果要兼容别的数据集需要修改这个函数。返回值text_polys是一个三维数组
其中,每一层保存了一个多边形,由于数据集中只支持矩形,所以就只有四个点。第三个维度表示一张图片中存在多个文字块。text_tags是一个布尔型的数组。
text_polys, text_tags = load_annoataion(txt_fn)
在load_anotation之后调用check_and_validate_polys对text_polys和text_tags进行矫正。在这个函数中pyclipper.Area计算多边形内的面积,如果面积小于1,则舍去。使用pyclipper.Orientation使点的方向变成顺时针。
然后对图片进行随机放缩
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
crop_area会随机选择一块区域,如果有文字则为样本
否则为背景。
然后,generate_seg会将文字的矩形区域缩放成6种不同的大小,供金字塔结构使用
seg_map_per_image, training_mask = generate_seg((new_h, new_w), text_polys, text_tags,
image_list[i], scale_ratio)
在generage_seg函数中,调用了函数shrink_poly这个函数用来将ground_truth进行不同比例的缩小(论文的3.3节label generation)
shrinked_polys = []
if poly_idx not in ignore_poly_mark:
shrinked_polys = shrink_poly(poly.copy(), scale_ratio[i])
模型实现:
在model.py中,首先建立金字塔特征:
feature_pyramid = build_feature_pyramid(end_points, weight_decay=weight_decay)
其中endpoints是resnet中几个特征图。
然后讲feature_pyramid进行concat,由于每一层的feature_pyramid的大小不一定一样,所以需要先进行缩放(unpool函数)
然后经过两个卷积层,得到seg_S_pred