PSENet源码阅读笔记

论文在这里

这篇论文已经有很多人写过解析了,方法就大致说一说就好了。看这篇论文的时候学习了他们代码的实现,感觉学到了很多东西,就在这里讲一讲和代码实现有关的东西。

论文提要

我看的代码,原作者的博客讲了原理方面的内容。首先用FPN在图片中生成“推荐区域”,将“推荐区域”利用广度优先搜索进行合并,得到最后的结果。
论文中主要的东西是,使用FPN提取出图片中不同粗细的可能是文字的部分,然后使用广度优先搜索将FPN得到的几个部分给区分开。

代码实现

复现的论文神经网络部分使用的是tensorflow,广度优先搜索部分使用C++实现。
先从train.py开始看,103行定义了损失函数,在tower_loss函数中,构建了模型。模型的输出seg_maps是一个6通道的tensor,对应了论文中segmentation result。在train.py中没有引用到pse,pse在训练的过程中没有用到。
预测的过程在eval.py中。

# eval.py,第76行
def detect(seg_maps, timer, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1)

其中,min_area_thresh是一个连通分量中至少有10个像素,seg_map_thresh是在返回的seg_map中,0.9以下的被变成0,以上的变为1,以此将seg_map变成二值图。在论文中出现的kernel就是图中变成1的部分。在detect函数中,调用了pse,这部分是使用C++实现的。python中调用C++使用了pybind11。

pybind11

pse文件夹中包含了pse的实现。其中include目录是pybind11的开源代码。广度优先的过程都在pse.cpp中。
在pse/init,py中队pse.cpp进行了编译。
pybind11中,规定PYBIND!!_MODULE作为一个接口,写在C++文件中,编译的时候会将函数与python中的函数绑定。

m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6);

第一个pse_cpp是python中绑定的函数名,第二个&pse::pse是在C++文件中待绑定的函数py::arg声明了参数以及默认值。在pse::pse中实现了一个广度优先搜索。

__init__.py中,对pse进行了一次封装。

label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4)

cv2.connectedComponents对模型求出来的最后一个kernel求了一次连通分量。label_num是图中连通分量的个数,label是带有标签的图,如果在连通分量里面,那个像素的值就是对应的连通分量编号,否则就是0。
接下来的for循环将小于10个像素的连通分量删除。

dataloader

在论文的readme中有一句话,只支持icdar2017的格式。

right now , only support icdar2017 data format input, like (116,1179,206,1179,206,1207,116,1207,"###"), but you can modify data_provider.py to support polygon format input

在train.py中,数据生成使用的是

data_generator = data_provider.get_batch(num_workers=FLAGS.num_readers,
                                                 input_size=FLAGS.input_size,
                                                 batch_size=FLAGS.batch_size_per_gpu * len(gpus))

get_batch中,类GeneratorEnqueuer使用的数据生成器是generator(**kargs),生成器的返回值有images, image_fns, seg_maps, training_masks,其中有用的是images, seg_maps, training_masks

读取标注的函数是data_provider.py中的load_annotation,在第280行调用了这个函数,这个函数就是读取标记用的,如果要兼容别的数据集需要修改这个函数。返回值text_polys是一个三维数组
text_polys
其中,每一层保存了一个多边形,由于数据集中只支持矩形,所以就只有四个点。第三个维度表示一张图片中存在多个文字块。text_tags是一个布尔型的数组。

text_polys, text_tags = load_annoataion(txt_fn)

在load_anotation之后调用check_and_validate_polys对text_polys和text_tags进行矫正。在这个函数中pyclipper.Area计算多边形内的面积,如果面积小于1,则舍去。使用pyclipper.Orientation使点的方向变成顺时针。
然后对图片进行随机放缩

im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)

crop_area会随机选择一块区域,如果有文字则为样本
否则为背景。
然后,generate_seg会将文字的矩形区域缩放成6种不同的大小,供金字塔结构使用

seg_map_per_image, training_mask = generate_seg((new_h, new_w), text_polys, text_tags,
                                                                  image_list[i], scale_ratio)

在generage_seg函数中,调用了函数shrink_poly这个函数用来将ground_truth进行不同比例的缩小(论文的3.3节label generation)

shrinked_polys = []
if poly_idx not in ignore_poly_mark:
    shrinked_polys = shrink_poly(poly.copy(), scale_ratio[i])

模型实现:

在model.py中,首先建立金字塔特征:

feature_pyramid = build_feature_pyramid(end_points, weight_decay=weight_decay)

其中endpoints是resnet中几个特征图。
然后讲feature_pyramid进行concat,由于每一层的feature_pyramid的大小不一定一样,所以需要先进行缩放(unpool函数)
然后经过两个卷积层,得到seg_S_pred

发布了267 篇原创文章 · 获赞 12 · 访问量 15万+

猜你喜欢

转载自blog.csdn.net/u010734277/article/details/90813756