Segment the table in the form of pictures, identify the content through deep learning, and output it to Excel files

It took less than two days to write the group assignment for the artificial intelligence course
. Finally, under the guidance of my predecessors, I finally completed the result that is barely visible.
Both DNN and CNN networks were used, but the results were not very good. It may be that the training set and the input images are not similar.
In the second half of the code, due to the haste, many redundant parts have not been deleted. I will polish it when I have time.
Functions that can now be implemented:

  1. Split the table, binarize it and save it locally
  2. Read pictures and identify numbers (limited to a single character).
    The following is my crude code. Please laugh.
from numpy.lib.function_base import average
from openpyxl import Workbook
import cv2
import numpy as np
import io
import tensorflow as tf
import csv
import openpyxl

img=cv2.imread('./4.png',1)
def showing(img):
    cv2.imshow('showing.jpg',img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

def horizon(): #取出水平线的坐标
    height=img.shape[0]
    horizontal_lines=[]
    i=0
    while True:
        if abs(np.average(img[i])-np.average(img[i+1]))>0.6*255:
            horizontal_lines.append(i)
            while True:
                i+=1
                if np.average(img[i])>0.6*255:
                    break
        else:
            i+=1
        if i == height-2:
            break
    return horizontal_lines

def vertical(): #取出竖直线的坐标
    width=img.shape[1]
    vertical_lines=[]
    i=0
    while True:
        if abs(np.average(img[:,i])-np.average(img[:,i+1]))>0.6*255:
            vertical_lines.append(i)
            while True:
                i+=1
                if np.average(img[:,i])>0.6*255:
                    break
        else:
            i+=1
        if i == width-2:
            break
    return vertical_lines

def sections():#确定各个表格边框的坐标
    global label_y
    global label_x
    label_y=len(horizon())-1
    label_x=len(vertical())-1
    global intersection
    intersection=[]  
    for i in vertical():
        for j in horizon():
            intersection.append((i,j))
    sections=[]
    for i in range(len(intersection)-1):
        try:
            if (((i+1)/(label_y+1))-int((i+1)/(label_y+1))!=0)and(i+label_y+2<=len(horizon())*len(vertical())):
                sections.append((intersection[i],intersection[i+label_y+2]))
        except:
            pass
    return sections

def segment1():#将图片按照表格分割并保存在list里面
    global list_segmented_img
    list_segmented_img=[]
    for i in range(len(sections())):
        coordinates=sections()[i]
        segmented_img=img[int(coordinates[0][1]):int(coordinates[1][1]),int(coordinates[0][0]):int(coordinates[1][0])]
        list_segmented_img.append(segmented_img)
    return list_segmented_img

def segment2():#将上述图片裁剪为正方形
    global list_cropped_img
    list_cropped_img=[]
    for segmented_img in list_segmented_img:
        height=segmented_img.shape[0]
        width=segmented_img.shape[1]
        if width>height:#考虑两边长的大小情况
            col_start=abs((height-width)//2)
            col_end=col_start+height 
            cropped_img=segmented_img[20:,col_start:col_end]
            cropped_img=cv2.resize(cropped_img,(28,28))
            list_cropped_img.append(cropped_img)
        else:
            row_start=abs((height-width)//2)
            row_end=row_start+width
            cropped_img=segmented_img[20:,row_start:row_end]
            cropped_img=cv2.resize(cropped_img,(28,28))
            list_cropped_img.append(cropped_img)
    return list_cropped_img

def blacknwhite(imge):
    gray_img = cv2.cvtColor(imge,cv2.COLOR_BGR2GRAY)
    (thresh, blacknwhite) = cv2.threshold(gray_img,255,255,cv2.THRESH_BINARY|cv2.THRESH_OTSU)
    blacknwhite = cv2.bitwise_not(blacknwhite)
    return blacknwhite

def blacknwhite_all():
    global list_cropped_refined_img
    list_cropped_refined_img=[]
    for i in list_cropped_img:
        list_cropped_refined_img.append(blacknwhite(i))
    return list_cropped_refined_img

def save():
    for i in range(len(list_cropped_refined_img)):
        cv2.imwrite(f"{
      
      i+1}.jpg",list_cropped_refined_img[i])

sections()
print(f'The size of the label is {
      
      label_x} * {
      
      label_y}={
      
      label_x*label_y}')
segment1()
segment2()
blacknwhite_all()
save()


#########################上述过程为对输入图像的裁剪处理,接下来将使用DNN方法识别各个图像中心的数字##############################

def train():#训练深度学习模型
    global model
    mnist=tf.keras.datasets.mnist
    (x_train,y_train),(x_test,y_test)=mnist.load_data()
    #normalize图片处理
    x_train=tf.keras.utils.normalize(x_train,axis=1)
    x_test=tf.keras.utils.normalize(x_test,axis=1)
    #创建DNN模型 128神经元*128神经元*10个神经元
    model=tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
    model.fit(x_train,y_train,epochs=15)
    (val_loss,val_acc)=model.evaluate(x_test,y_test)
    print('loss=',val_loss,'acc=',val_acc)

def read_img(img_path):#读取图片并Flatten
    img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
    img=img/255
    img=img.reshape(-1,28,28,1)
    return img

def predicting():
    global result1 #识别结果
    result1=[]
    count=int(label_x*label_y)
    # success=0
    # fail=0
    for i in range(count):
        prediction = np.argmax(model.predict(read_img(f'./{
      
      i+1}.jpg')))
        result1.append(prediction)
        # print(prediction)
        # print(i%10)
    #     if prediction == (i+1)%10:
    #         success+=1
    #         print("success")
    #     else:
    #         fail+=1
    #         print("fail")
    # print(success/(success+fail))

train()
predicting()
##################为改善识别效果 使用CNN方法进行识别############################
import re
import tensorflow as tf
# from tensorflow.python.keras.backend import reshape
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, Dropout
import numpy as np

#构建DNN模型 32神经元*32神经元*10神经元
def predict_cnn():
    global result2
    model=Sequential()
    model.add(Conv2D(10,(5,5),activation='relu', input_shape=(28,28,1)))
    model.add(MaxPooling2D(pool_size=(2,2)))

    model.add(Conv2D(20,(5,5),activation='relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(100,activation='relu'))
    model.add(Dense(10,activation='softmax'))
    model.compile(optimizer='rmsprop',loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])


    #normalize图片处理
    (x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
    normalized_x_train=tf.keras.utils.normalize(x_train)
    normalized_x_train=normalized_x_train.reshape(-1,28,28,1)
    normalized_x_test=tf.keras.utils.normalize(x_test)
    normalized_x_test=normalized_x_test.reshape(-1,28,28,1)
    #one hot 标签处理
    one_hot_y_train=tf.one_hot(y_train,10)
    one_hot_y_test=tf.one_hot(y_test,10)

    train_result=model.fit(normalized_x_train,one_hot_y_train,epochs=15,validation_data=(normalized_x_test,one_hot_y_test))

    import cv2
    def read_img(path):
        img=cv2.imread(path,cv2.IMREAD_GRAYSCALE)
        img=img/255
        img=img.reshape(-1,28,28,1)
        return img

    result2=[]
    for i in range(10):
        prediction = np.argmax(model.predict(read_img(f'./{
      
      i+1}.jpg')))
        result2.append(prediction)

predict_cnn()

###################以上是将各图片进行识别的过程,下面将所得结果输出为excel文件#########
def save_result(result,i):
    result_array=np.array(result,dtype=str)
    result_array=result_array.reshape(label_x,label_y).transpose()#还原到原来的形状
    print(result_array)

    workbook = Workbook()
    # 默认sheet
    sheet = workbook.active
    sheet.title = "默认sheet"
    for data in result_array:
        sheet.append(data.tolist())#写入
    workbook.save(f'./{
      
      i}.xlsx')

save_result(result1,1)
save_result(result2,2)

Guess you like

Origin blog.csdn.net/weixin_55010198/article/details/122017835