Use the scipy package to calculate the peak value of the table line, restore the table to get the table structure

1. Use the scipy package to calculate the peak value of the table line

import cv2
import numpy as np
from scipy.signal import find_peaks, peak_widths


def get_lines_from_image(img_bin, axis, kernel_len_div = 20, kernel_len = None, iters = 3):
    """
    :param img_bin: opencv img
    :param axis: 0 对应竖直, 1对应水平线
    :param kernel_len_div: 相对于边长的几分之几
    :param kernel_len: 直接给定和长度,如果这个长度不为0, 上述例子失效
    :return:
    """
    DEBUG = True
    # Defining a kernel length
    if kernel_len is not None:
        assert kernel_len > 0
        kernel_length = kernel_len
    else:
        kernel_length = max(np.array(img_bin).shape[axis] // kernel_len_div, 1)

    if axis == 0:
        # A verticle kernel of (1 X kernel_length), which will detect all the verticle lines from the image.
        verticle_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_length))

        # Morphological operation to detect verticle lines from an image
        img_temp1 = cv2.erode(img_bin, verticle_kernel, iterations=iters)
        verticle_lines_img = cv2.dilate(img_temp1, verticle_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("verticle_lines.jpg", verticle_lines_img)
        return verticle_lines_img

    else:
        # A horizontal kernel of (kernel_length X 1), which will help to detect all the horizontal line from the image.
        hori_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_length, 1))

        # Morphological operation to detect horizontal lines from an image
        img_temp2 = cv2.erode(img_bin, hori_kernel, iterations=iters)
        horizontal_lines_img = cv2.dilate(img_temp2, hori_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("horizontal_lines.jpg", horizontal_lines_img)
        return horizontal_lines_img

def line_img_add(verticle_lines_img, horizontal_lines_img):
    # 把检测出来的横线和竖线相加
    alpha = 0.5
    beta = 1.0 - alpha
    img_final_bin = cv2.addWeighted(verticle_lines_img, alpha, horizontal_lines_img, beta, 0.0)
    return img_final_bin


def project(np_arr, axis):
    # 水平或垂直投影, 0竖直,1水平
    return np.count_nonzero(np_arr == 0, axis=axis)

def get_grid_coordinate(img_bin, prominence_ratio = 0.3, height_ratio=None, distance=None, DEBUG=0):
    """
    计算格点水平(x)和竖直(y)坐标和线宽
    :param img_bin: 白底黑线
    :return:
    """
    #参数
    # prominence_ratio 峰值的突出程度, 相对于表格长宽
    h, w = img_bin.shape
    # print("size",h,w)
    x_prj = project(img_bin, 0)
    y_prj = project(img_bin, 1)
    # 检测峰值
    # high_ratio = 0.1 # todo 这也是一个参数
    height_x = height_y = None
    if height_ratio is not None:
        height_x = height_ratio * h
        height_y = height_ratio * w
    # x_peaks, _ = find_peaks(x_prj, height=high_ratio*h, distance = max(1,w/20), prominence=(h*prominence_ratio, None))
    # y_peaks, _ = find_peaks(y_prj, height=high_ratio*w, distance = max(1,w/50), prominence=(w*prominence_ratio, None))
    print('height_x,height_y:', height_x, height_y)
    x_peaks, _ = find_peaks(x_prj, height=height_x, distance=distance,  prominence=(h * prominence_ratio, None))
    y_peaks, _ = find_peaks(y_prj, height=height_y, distance=distance, prominence=(w * prominence_ratio, None))

    x_peaks = list(x_peaks)
    y_peaks = list(y_peaks)

    DEBUG =True
    if DEBUG:
        #plot
        import matplotlib.pyplot as plt
        img = img_bin
        plt.subplot(211)
        plt.title("x")
        print('range(x_prj.shape[0]):',range(x_prj.shape[0]))
        plt.plot(range(x_prj.shape[0]), x_prj)
        plt.plot(x_peaks, x_prj[x_peaks], "x")
        plt.subplot(212)
        plt.title("y")
        plt.plot(range(y_prj.shape[0]), y_prj)
        plt.plot(y_peaks, y_prj[y_peaks], "x")
        plt.show()

    if len(x_peaks) == 0: # 如果没检测到峰值, 把检测框边界峰值
        x_peaks = [0, w]
        print("x_peaks is None !!!!!!!")
    if len(y_peaks) == 0:
        y_peaks = [0, h]
        print("y_peaks is None !!!!!!!")

    # 计算线宽, 假设线宽一定, 横有m根线, 竖有n根线, 表格高为h, 宽为w, 线宽为x
    # n_nonzero = m*w*x + n*h*x - m*n*x^2
    # n_nonzero 约等于 m*w*x + n*h*x
    h,w = img_bin.shape
    m,n = len(y_peaks), len(x_peaks)
    line_width = np.count_nonzero(img_bin == 0) / (m*w + n*h)
    line_width = round(line_width) + 1
    return x_peaks, y_peaks, line_width

if __name__ == '__main__':
    path= './test_page_debug_out_debug/table_crop_fix_rm_char.jpg'
    img = cv2.imread(path)
    img_bin = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    verticle_lines_img = get_lines_from_image(img_bin, 0, kernel_len_div=40)
    horizontal_lines_img = get_lines_from_image(img_bin, 1, kernel_len_div=40)
    # 表格线提取
    img_final_bin_lines = line_img_add(verticle_lines_img, horizontal_lines_img)
    cv2.imwrite('./img_final_bin_lines.jpg',img_final_bin_lines)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    # 膨胀并二值化
    img_final_bin_lines = cv2.erode(~img_final_bin_lines, kernel, iterations=2)
    (thresh, img_final_bin_lines) = cv2.threshold(img_final_bin_lines, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    cv2.imwrite('./img_final_bin_lines_fix.jpg', img_final_bin_lines)
    # 根据表格线计算格点坐标 -----------------------------------
    x_grids, y_grids, line_w = get_grid_coordinate(img_final_bin_lines)
    

enter:

Extract vertical lines:

Extract the horizontal line:

Peak search for horizontal and vertical lines:

2. Restore the table structure

import cv2
from PIL import Image
import numpy as np
import os
import os.path as osp
from scipy.signal import find_peaks, peak_widths

debug = True

def get_lines_from_image(img_bin, axis, kernel_len_div=20, kernel_len=None, iters=3):
    """

    :param img_bin: opencv img
    :param axis: 0 对应竖直, 1对应水平线
    :param kernel_len_div: 相对于边长的几分之几
    :param kernel_len: 直接给定和长度,如果这个长度不为0, 上述例子失效
    :return:
    """
    DEBUG = 0
    # Defining a kernel length
    if kernel_len is not None:
        assert kernel_len > 0
        kernel_length = kernel_len
    else:
        kernel_length = max(np.array(img_bin).shape[axis] // kernel_len_div, 1)

    if axis == 0:
        # A verticle kernel of (1 X kernel_length), which will detect all the verticle lines from the image.
        verticle_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_length))

        # Morphological operation to detect verticle lines from an image
        img_temp1 = cv2.erode(img_bin, verticle_kernel, iterations=iters)
        verticle_lines_img = cv2.dilate(img_temp1, verticle_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("verticle_lines.jpg", verticle_lines_img)
        return verticle_lines_img

    else:
        # A horizontal kernel of (kernel_length X 1), which will help to detect all the horizontal line from the image.
        hori_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_length, 1))

        # Morphological operation to detect horizontal lines from an image
        img_temp2 = cv2.erode(img_bin, hori_kernel, iterations=iters)
        horizontal_lines_img = cv2.dilate(img_temp2, hori_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("horizontal_lines.jpg", horizontal_lines_img)
        return horizontal_lines_img

def line_img_add(verticle_lines_img, horizontal_lines_img):
    # Weighting parameters, this will decide the quantity of an image to be added to make a new image.
    alpha = 0.5
    beta = 1.0 - alpha
    # This function helps to add two image with specific weight parameter to get a third image as summation of two image.
    img_final_bin = cv2.addWeighted(verticle_lines_img, alpha, horizontal_lines_img, beta, 0.0)
    return img_final_bin

def project(np_arr, axis):
    # 水平或垂直投影, 0竖直,1水平
    return np.count_nonzero(np_arr == 0, axis=axis)

def get_grid_coordinate(img_bin, prominence_ratio=0.3, height_ratio=None, distance=None):
    """
    计算格点水平(x)和竖直(y)坐标和线宽
    :param img_bin: 白底黑线
    :return:
    """
    # 参数
    # prominence_ratio 峰值的突出程度, 相对于表格长宽
    h, w = img_bin.shape
    DEBUG = False
    if DEBUG:
        cv2.imwrite('table_crop.jpg', img_bin)
    # print("size",h,w)
    x_prj = project(img_bin, 0)
    y_prj = project(img_bin, 1)
    # 检测峰值
    # high_ratio = 0.1 # todo 这也是一个参数
    height_x = height_y = None
    if height_ratio is not None:
        height_x = height_ratio * h
        height_y = height_ratio * w
    # x_peaks, _ = find_peaks(x_prj, height=high_ratio*h, distance = max(1,w/20), prominence=(h*prominence_ratio, None))
    # y_peaks, _ = find_peaks(y_prj, height=high_ratio*w, distance = max(1,w/50), prominence=(w*prominence_ratio, None))
    x_peaks, _ = find_peaks(x_prj, height=height_x, distance=distance, prominence=(h * prominence_ratio, None))
    y_peaks, _ = find_peaks(y_prj, height=height_y, distance=distance, prominence=(w * prominence_ratio, None))

    if DEBUG:
        # plot
        import matplotlib.pyplot as plt
        img = img_bin
        plt.subplot(211)
        plt.title("x")
        plt.plot(range(x_prj.shape[0]), x_prj)
        plt.plot(x_peaks, x_prj[x_peaks], "x")
        plt.subplot(212)
        plt.title("y")
        plt.plot(range(y_prj.shape[0]), y_prj)
        plt.plot(y_peaks, y_prj[y_peaks], "x")
        plt.show()
        # cv2.waitKey(0)

    if len(x_peaks) == 0:  # 如果没检测到峰值, 把检测框边界峰值
        x_peaks = [0, w]
        # print("x_peaks is None !!!!!!!")
    if len(y_peaks) == 0:
        y_peaks = [0, h]
        # print("y_peaks is None !!!!!!!")

    # 计算线宽, 假设线宽一定, 横有m根线, 竖有n根线, 表格高为h, 宽为w, 线宽为x
    # n_nonzero = m*w*x + n*h*x - m*n*x^2
    #  n_nonzero 约等于 m*w*x + n*h*x
    h, w = img_bin.shape
    m, n = len(y_peaks), len(x_peaks)
    line_width = np.count_nonzero(img_bin == 0) / (m * w + n * h)
    line_width = max(round(line_width), 1)
    return list(x_peaks), list(y_peaks), line_width
def check_line_exist(img_bin, pt1, pt2, width, threshold=0.5, DEBUG=0):
    # 剪切图片以加速
    x1 = min(pt1[0], pt2[0])
    x2 = max(pt1[0], pt2[0])
    y1 = min(pt1[1], pt2[1])
    y2 = max(pt1[1], pt2[1])
    h, w = img_bin.shape
    d = width * 2
    x1 = max(0, x1 - d)
    y1 = max(0, y1 - d)
    x2 = min(w-1, x2 + d)
    y2 = min(h-1, y2 + d)
    img_bin = img_bin[y1: y2, x1: x2].copy()
    pt1 = (pt1[0] - x1, pt1[1] - y1)
    pt2 = (pt2[0] - x1, pt2[1] - y1)
    if DEBUG:
        cv2.imwrite('./img_bin_after_crop.jpg', img_bin)

    # print("now check", pt1, pt2)
    line_mask = np.zeros_like(img_bin)
    cv2.line(line_mask, pt1, pt2, color=(255, 255, 255), thickness=width)
    mask_cnt = np.count_nonzero(line_mask)
    img_bin_tmp = ~img_bin.copy()
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    img_bin_tmp = cv2.dilate(img_bin_tmp, kernel, iterations=1)
    img_after_mask = cv2.bitwise_and(line_mask, img_bin_tmp)
    and_cnt = np.count_nonzero(img_after_mask)
    if DEBUG:
        cv2.imwrite("./line_mask.jpg", line_mask)
        cv2.imwrite('./img_after_mask.jpg', img_after_mask)
        cv2.imwrite("./img_bin_tmp.jpg", img_bin_tmp)
        # print('check_line_exist', (and_cnt / mask_cnt))
    return (and_cnt / mask_cnt) > threshold

def get_table_structure(img_final_bin_lines, x_grids, y_grids, line_w):
    # 推断表格结构
    # 判断每条边是否存在, 不存在在改边两边区域相连
    DEBUG = 0
    n_x = len(x_grids)
    n_y = len(y_grids)
    if DEBUG:
        print("n_x, n_y", n_x, n_y)
    cell_id_mark = np.full((n_y - 1, n_x - 1), -1, dtype=int)  # 给每个cell一个id,id相同代表联通
    cell_id_sets = [set() for _ in range(n_x * n_y)]  # 记录每个id包含哪些cell
    id = 0

    # def f(row, col):
    #     # 单元格坐标转序号
    #     return n_x*row + col
    #
    # def f_revers(id):
    #     # 序号转单元格坐标
    #     return (id//n_x, id%n_x)

    # 检查竖直线
    if len(x_grids) > 2:
        for x_id, x in enumerate(x_grids[1:-1]):
            x_id += 1  # 因为是从1开始
            for y_id, y in enumerate(y_grids[:-1]):
                if not check_line_exist(img_final_bin_lines, (x, y), (x, y_grids[y_id + 1]), width=line_w,
                                        threshold=0.5, DEBUG=False):
                    # if DEBUG:
                    print("没有发现竖直线:x_id,y_id", x_id, y_id)

                    left_id = cell_id_mark[y_id, x_id - 1]
                    # print('==left_id:', left_id)

                    if left_id == -1:
                        cell_id_mark[y_id, x_id - 1] = id
                        cell_id_mark[y_id, x_id] = id
                        cell_id_sets[id].add((y_id, x_id - 1))
                        cell_id_sets[id].add((y_id, x_id))
                        id += 1
                    else:
                        cell_id_mark[y_id, x_id] = left_id
                        cell_id_sets[left_id].add((y_id, x_id))
                    print('==cell_id_mark:', cell_id_mark)
        # assert 1 == 0

    # print('cell_id_sets', cell_id_sets)
    # 检查水平线
    if len(y_grids) > 2:
        for y_id, y in enumerate(y_grids[1:-1]):
            y_id += 1
            for x_id, x in enumerate(x_grids[:-1]):
                # print(cell_id_mark)
                if not check_line_exist(img_final_bin_lines, (x_grids[x_id + 1], y), (x, y), width=line_w,
                                        threshold=0.5, DEBUG=False):
                    # if DEBUG:
                    print("======没有发现水平线,x_id, y_id", x_id, y_id)
                    up_id = cell_id_mark[y_id - 1, x_id]
                    down_id = cell_id_mark[y_id, x_id]
                    # print('===up_id:', up_id)
                    # print('===down_id:', down_id)
                    if up_id != -1:
                        if down_id != -1:
                            if up_id != down_id:  # 合并同一区域的id
                                # print('cell_id_sets[up_id]',cell_id_sets[up_id])
                                # print('cell_id_sets[down_id]', cell_id_sets[down_id])
                                cell_id_mark[y_id, x_id] = up_id
                                cell_id_sets[up_id] |= cell_id_sets[down_id]
                                cell_id_sets[down_id].clear()
                                # print('cell_id_sets[up_id]',cell_id_sets[up_id])
                        else:
                            cell_id_mark[y_id, x_id] = up_id
                            cell_id_sets[up_id].add((y_id, x_id))

                    else:
                        cell_id_mark[y_id - 1, x_id] = id
                        cell_id_mark[y_id, x_id] = id
                        cell_id_sets[id].add((y_id - 1, x_id))
                        cell_id_sets[id].add((y_id, x_id))
                        id += 1
                    print('==cell_id_mark:', cell_id_mark)
    # assert 1 == 0
    print('==x_grids:', x_grids)
    print('==y_grids:', y_grids)
    print('==cell_id_mark:', cell_id_mark)
    print('==cell_id_sets:', cell_id_sets)
    # 填补其他没id的单元格依次加1
    for x_id, x in enumerate(x_grids[:-1]):
        for y_id, y in enumerate(y_grids[:-1]):
            if cell_id_mark[y_id, x_id] == -1:
                cell_id_mark[y_id, x_id] = id
                cell_id_sets[id].add((y_id, x_id))
                id += 1
    print('==cell_id_mark:', cell_id_mark)
    print('==cell_id_sets:', cell_id_sets)
    print('==id:', id)
    # assert 1 == 0
    # print('after check ver',cell_id_mark)
    # print(cell_id_sets)
    # 输出
    rst = []
    for id in range(id):
        if len(cell_id_sets[id]) == 0:
            continue
        if len(cell_id_sets[id]) == 1:
            cell = {}
            cell_row, cell_col = list(cell_id_sets[id])[0]
            cell["id"] = id
            cell["row_start"] = cell_row  # 结构坐标
            cell["col_start"] = cell_col
            cell["row_end"] = cell_row + 1
            cell["col_end"] = cell_col + 1
            cell["x1"] = x_grids[cell_col]  # 绝对坐标
            cell["y1"] = y_grids[cell_row]
            cell["x2"] = x_grids[cell_col + 1]
            cell["y2"] = y_grids[cell_row + 1]
            cell["crnn"] = []  # 后续使用
            cell["text"] = ""  # 后续使用
            rst.append(cell)
        else:
            id_min = sorted(cell_id_sets[id])[0]
            id_max = sorted(cell_id_sets[id])[-1]
            cell = {}
            cell_row_min, cell_col_min = id_min
            cell_row_max, cell_col_max = id_max
            cell["id"] = id
            cell["row_start"] = cell_row_min  # 结构坐标
            cell["col_start"] = cell_col_min
            cell["row_end"] = cell_row_max + 1
            cell["col_end"] = cell_col_max + 1
            cell["x1"] = x_grids[cell_col_min]  # 绝对坐标
            cell["y1"] = y_grids[cell_row_min]
            cell["x2"] = x_grids[cell_col_max + 1]
            cell["y2"] = y_grids[cell_row_max + 1]
            cell["crnn"] = []  # 后续使用
            cell["text"] = ""  # 后续使用
            rst.append(cell)
    return cell_id_mark, rst



def box_extraction(cv_img):
    """
    提取有框线表格结构, 返回list [[row_start,col_start,row_end,col_end],[...]]
    :param img_path:
    :param result_path:
    :return:
    """
    if len(cv_img.shape) == 3:
        cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)

    # 二值化
    # (thresh, img_bin) = cv2.threshold(cv_img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)  # Thresholding the image
    img_bin = cv2.adaptiveThreshold(cv_img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, \
                                    cv2.THRESH_BINARY, 11, 2)

    img_bin = 255 - img_bin  # Invert the image
    # 二次消除小轮廓
    image, contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    mask = np.ones(image.shape[:2], dtype="uint8") * 255
    th_w = img_bin.shape[1] / 30
    th_h = img_bin.shape[0] / 30
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)  # 第一遍根据长宽删选
        if w < th_w and h < th_h:
            cv2.drawContours(mask, [c], -1, 0, -1)
    img_bin = cv2.bitwise_and(img_bin, img_bin, mask=mask)

    if debug:
        cv2.imwrite('./img_bin_no_noise.jpg', img_bin)

    kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
    img_bin = cv2.dilate(img_bin, kernel, iterations=1)
    image, contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    mask = np.ones(image.shape[:2], dtype="uint8") * 255
    th_w = img_bin.shape[1] / 5
    th_h = img_bin.shape[0] / 5
    for c in contours:
        if cv2.contourArea(c) < th_w * th_h:
            cv2.drawContours(mask, [c], -1, 0, -1)
    img_bin = cv2.bitwise_and(img_bin, img_bin, mask=mask)

    if debug:
        cv2.imwrite("img_remove_noise2.jpg", img_bin)

    verticle_lines_img = get_lines_from_image(img_bin, 0, kernel_len_div=40)
    horizontal_lines_img = get_lines_from_image(img_bin, 1, kernel_len_div=40)

    # 表格线提取
    img_final_bin_lines = line_img_add(verticle_lines_img, horizontal_lines_img)
    # 膨胀并二值化
    # A kernel of (3 X 3) ones.
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    img_final_bin_lines = cv2.erode(~img_final_bin_lines, kernel, iterations=2)

    (thresh, img_final_bin_lines) = cv2.threshold(img_final_bin_lines, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    if debug:
        cv2.imwrite("img_final_bin_lines.jpg", img_final_bin_lines)

    # 根据表格线计算格点坐标
    x_grids, y_grids, line_w = get_grid_coordinate(img_final_bin_lines)
    print('===x_grids, y_grids, line_w:', x_grids, y_grids, line_w)

    cell_id_mark, rst = get_table_structure(img_final_bin_lines, x_grids, y_grids, line_w)

    return x_grids, y_grids, cell_id_mark, rst

def debug_single_img():
    # img_path = './table_crop.jpg'
    img_path = './table_crop2.png'
    img = cv2.imread(img_path, 0)  # Read the image
    x_grids, y_grids, cell_id_mark, rst = box_extraction(img)
    print('==x_grids:', x_grids)
    print('==y_grids:', y_grids)
    print('==cell_id_mark:', cell_id_mark)
    print('==rst:', rst)

if __name__ == '__main__':
    debug_single_img()

After connecting rst to this blog , the corresponding excel is restored.

 

Guess you like

Origin blog.csdn.net/fanzonghao/article/details/104135755