[Keras+Computer Vision+Tensorflow] OCR Text Recognition Actual Combat (with source code and data set super detailed must-see)

If you need source code and data sets, please like and follow the collection and leave a private message in the comment area~~~

1. Introduction to OCR text recognition

The technology of using computer to automatically recognize characters is an important field of pattern recognition applications. In production and life, people have to deal with a large number of words, reports and texts. In order to reduce people's labor and improve processing efficiency, text recognition methods have been explored since the 1950s, and optical character recognizers have been developed.

OCR (Optical Character Recognition) image text recognition is an important branch of artificial intelligence, endowing the computer with the function of human eyes, so that it can read pictures and read characters. The process of image text recognition system is generally divided into four parts: image acquisition, text detection, text recognition and result output. part.

 2. OCR text recognition project actual combat

1: Introduction to the dataset

MSRA-TD500 This dataset contains a total of 500 natural scene images, with a resolution between 1296 ´ 864 and 920 ´ 1280, covering most fields such as indoor shopping malls, signs, outdoor streets, billboards, etc. The text contains Chinese and English, with different fonts, sizes, and tilt directions, some dataset images are shown in the figure below.

 The data set project structure is divided into training set and test set as follows

2: Project structure

The overall project structure is as follows The above is the definition of some algorithms and models such as CRAFT CRNN, and the following is the test code

 The CRAFT algorithm realizes the detection of text lines as shown in the figure below. First input the complete text area into the CRAFT text detection network to obtain the character-level text score heat map (Text Score) and character-level text connection score heat map (Link Score), and finally obtain the position of each text line according to the connected domain

3: Effect display 

start running the code

The output operation results can be put into different pictures for testing 

 

 

 

 

3. Code 

Part of the code is as follows. All codes and data sets are required. Please like and follow the collection and leave a private message in the comment area~~~
 

"""This script demonstrates how to train the model
on the SynthText90 using multiple GPUs."""
# pylint: disable=invalid-name
import datetime
import argparse
import math
import random
import string
import functools
import itertools
import os
import tarfile
import urllib.request

import numpy as np
import cv2
import imgaug
import tqdm
import tensorflow as tf

import keras_ocr


# pylint: disable=redefined-outer-name
def get_filepaths(data_path, split):
    """Get the list of filepaths for a given split (train, val, or test)."""
    with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),
              'r') as text_file:
        filepaths = [
            os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px',
                         line.split(' ')[0][2:]) for line in text_file.readlines()
        ]
    return filepaths


# pylint: disable=redefined-outer-name
def download_extract_and_process_dataset(data_path):
    """Download and extract the synthtext90 dataset."""
    archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')
    extraction_directory = os.path.join(data_path, 'mnt')
    if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):
        print('Downloading the dataset.')
        urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz",
                                   archive_filepath)
    if not os.path.isdir(extraction_directory):
        print('Extracting files.')
        with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:
            tfile.extractall(data_path)


def get_image_generator(filepaths, augmenter, width, height):
    """Get an image generator for a list of SynthText90 filepaths."""
    filepaths = filepaths.copy()
    for filepath in itertools.cycle(filepaths):
        text = filepath.split(os.sep)[-1].split('_')[1].lower()
        image = cv2.imread(filepath)
        if image is None:
            print(f'An error occurred reading: {filepath}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = keras_ocr.tools.fit(image,
                                    width=width,
                                    height=height,
                                    cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))
        if augmenter is not None:
            image = augmenter.augment_image(image)
        if filepath == filepaths[-1]:
            random.shuffle(filepaths)
        yield image, text


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model_id',
                        default='recognizer',
                        help='The name to use for saving model checkpoints.')
    parser.add_argument(
        '--data_path',
        default='.',
        help='The path to the directory containing the dataset and where we will put our logs.')
    parser.add_argument(
        '--logs_path',
        default='./logs',
        help=(
            'The path to where logs and checkpoints should be stored. '
            'If a checkpoint matching "model_id" is found, training will resume from that point.'))
    parser.add_argument('--batch_size', default=16, help='The training batch size to use.')
    parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')
    parser.set_defaults(verify_files=True)
    args = parser.parse_args()
    weights_path = os.path.join(args.logs_path, args.model_id + '.h5')
    csv_path = os.path.join(args.logs_path, args.model_id + '.csv')
    download_extract_and_process_dataset(args.data_path)
    with tf.distribute.MirroredStrategy().scope():
        recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +
                                                      string.ascii_lowercase,
                                                      height=31,
                                                      width=200,
                                                      stn=False,
                                                      optimizer=tf.keras.optimizers.RMSprop(),
                                                      weights=None)
    if os.path.isfile(weights_path):
        print('Loading saved weights and creating new version.')
        dt_string = datetime.datetime.now().isoformat()
        weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')
        csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')
        recognizer.model.load_weights(weights_path)
    augmenter = imgaug.augmenters.Sequential([
        imgaug.augmenters.Multiply((0.9, 1.1)),
        imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),
        imgaug.augmenters.Invert(0.25, per_channel=0.5)
    ])
    os.makedirs(args.logs_path, exist_ok=True)

    training_filepaths, validation_filepaths = [
        get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']
    ]
    if args.verify_files:
        assert all(
            os.path.isfile(filepath) for
            filepath in tqdm.tqdm(training_filepaths + validation_filepaths,
                                  desc='Checking filepaths.')), 'Some files appear to be missing.'

    (training_image_generator, training_steps), (validation_image_generator, validation_steps) = [
        (get_image_generator(
            filepaths=filepaths,
            augmenter=augmenter,
            width=recognizer.model.input_shape[2],
            height=recognizer.model.input_shape[1],
        ), math.ceil(len(filepaths) / args.batch_size))
        for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]
    ]

    training_generator, validation_generator = [
        tf.data.Dataset.from_generator(
            functools.partial(recognizer.get_batch_generator,
                              image_generator=image_generator,
                              batch_size=args.batch_size),
            output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),
            output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]), 
                            tf.TensorShape([None,
                                            1]), tf.TensorShape([None,
                                                                 1])), tf.TensorShape([None, 1])))
        for image_generator in [training_image_generator, validation_image_generator]
    ]
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                         min_delta=0,
                                         patience=10,
                                         restore_best_weights=False),
        tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),
        tf.keras.callbacks.CSVLogger(csv_path)
    ]
    recognizer.training_model.fit(
        x=training_generator,
        steps_per_epoch=training_steps,
        validation_steps=validation_steps,
        validation_data=validation_generator,
        callbacks=callbacks,
        epochs=1000,
    )
"""This script is what was used to generate the
backgrounds.zip and fonts.zip files.
"""
# pylint: disable=invalid-name,redefined-outer-name
import json
import urllib.request
import urllib.parse
import concurrent
import shutil
import zipfile
import glob
import os

import numpy as np
import tqdm
import cv2

import keras_ocr

if __name__ == '__main__':
    fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'
    fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'
    fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(
        url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',
        cache_dir='.',
        sha256=fonts_sha256)
    shutil.rmtree('fonts-raw', ignore_errors=True)
    with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:
        zfile.extractall(path='fonts-raw')

    retained_fonts = []
    sha256s = []
    basenames = []
    # The blacklist includes fonts that, at least for the English alphabet, were found
    # to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics
    # fonts).
    blacklist = [
        'AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf',
        'Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf',
        'Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf',
        'EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf',
        'FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf',
        'Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf',
        'IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf',
        'Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf',
        'LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf',
        'EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf',
        'DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf',
        'IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf',
        'FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf',
        'Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf',
        'jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf',
        'EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf',
        'IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf',
        'Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf',
        'JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf',
        'AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf',
        'Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf',
        'LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf',
        'BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf',
        'Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf',
        'LibreBarcode128-Regular.ttf'
    ]
    for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),
                              desc='Filtering fonts.'):
        sha256 = keras_ocr.tools.sha256sum(filepath)
        basename = os.path.basename(filepath)
        # We check the sha256 and filenames because some of the fonts
        # in the repository are duplicated (see TRIVIA.md).
        if sha256 in sha256s or basename in basenames or basename in blacklist:
            continue
        sha256s.append(sha256)
        basenames.append(basename)
        retained_fonts.append(filepath)
    retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])
    added = []
    with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:
        for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):
            # We want to keep all the metadata files plus
            # the retained font files. And we don't want
            # to add the same file twice.
            files = [
                input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')
                if input_filepath not in added and
                (input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')
            ]
            added.extend(files)
            for input_filepath in files:
                zfile.write(filename=input_filepath,
                            arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))
    print('Finished saving fonts file.')

    # pylint: disable=line-too-long
    url = (
        'https://commons.wikimedia.org/w/api.php?action=query&generator=categorymembers&gcmtype=file&format=json'
        '&gcmtitle=Category:Featured_pictures_on_Wikimedia_Commons&prop=imageinfo&gcmlimit=50&iiprop=url&iiurlwidth=1024'
    )
    gcmcontinue = None
    max_responses = 300
    responses = []
    for responseCount in tqdm.tqdm(range(max_responses)):
        current_url = url
        if gcmcontinue is not None:
            current_url += f'&continue=gcmcontinue||&gcmcontinue={gcmcontinue}'
        with urllib.request.urlopen(url=current_url) as response:
            current = json.loads(response.read())
            responses.append(current)
            gcmcontinue = None if 'continue' not in current else current['continue']['gcmcontinue']
        if gcmcontinue is None:
            break
    print('Finished getting list of images.')

    # We want to avoid animated images as well as icon files.
    image_urls = []
    for response in responses:
        image_urls.extend(
            [page['imageinfo'][0]['thumburl'] for page in response['query']['pages'].values()])
    image_urls = [url for url in image_urls if url.lower().endswith('.jpg')]
    shutil.rmtree('backgrounds', ignore_errors=True)
    os.makedirs('backgrounds')
    assert len(image_urls) == len(set(image_urls)), 'Duplicates found!'
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = [
            executor.submit(keras_ocr.tools.download_and_verify,
                            url=url,
                            cache_dir='./backgrounds',
                            verbose=False) for url in image_urls
        ]
        for _ in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            pass
    for filepath in glob.glob('backgrounds/*.JPG'):
        os.rename(filepath, filepath.lower())

    print('Filtering images by aspect ratio and maximum contiguous contour.')
    image_paths = np.array(sorted(glob.glob('backgrounds/*.jpg')))

    def compute_metrics(filepath):
        image = keras_ocr.tools.read(filepath)
        aspect_ratio = image.shape[0] / image.shape[1]
        contour, _ = keras_ocr.tools.get_maximum_uniform_contour(image, fontsize=40)
        area = cv2.contourArea(contour) if contour is not None else 0
        return aspect_ratio, area

    metrics = np.array([compute_metrics(filepath) for filepath in tqdm.tqdm(image_paths)])
    filtered_paths = image_paths[(metrics[:, 0] < 3 / 2) & (metrics[:, 0] > 2 / 3) &
                                 (metrics[:, 1] > 1e6)]
    detector = keras_ocr.detection.Detector()
    paths_with_text = [
        filepath for filepath in tqdm.tqdm(filtered_paths) if len(
            detector.detect(
                images=[keras_ocr.tools.read_and_fit(filepath, width=640, height=640)])[0]) > 0
    ]
    filtered_paths = np.array([path for path in filtered_paths if path not in paths_with_text])
    filtered_basenames = list(map(os.path.basename, filtered_paths))
    basename_to_url = {
        os.path.basename(urllib.parse.urlparse(url).path).lower(): url
        for url in image_urls
    }
    filtered_urls = [basename_to_url[basename.lower()] for basename in filtered_basenames]
    assert len(filtered_urls) == len(filtered_paths)
    removed_paths = [filepath for filepath in image_paths if filepath not in filtered_paths]
    for filepath in removed_paths:
        os.remove(filepath)
    with open('backgrounds/urls.txt', 'w') as f:
        f.write('\n'.join(filtered_urls))
    with zipfile.ZipFile(file='backgrounds.zip', mode='w') as zfile:
        for filepath in tqdm.tqdm(filtered_paths.tolist() + ['backgrounds/urls.txt'],
                                  desc='Saving ZIP file.'):
            zfile.write(filename=filepath, arcname=os.path.basename(filepath.lower()))

It's not easy to create and find it helpful, please like, follow and collect~~~

Guess you like

Origin blog.csdn.net/jiebaoshayebuhui/article/details/128258262