[Proyecto práctico] LSTM realiza el reconocimiento de matrículas

Dirección del proyecto: Licencia-Reconocimiento

Con respecto a los principios de RNN y LSTM, no entraré en detalles aquí. Hay muchos artículos en Internet que los presentan en detalle. Clasificaré el conocimiento relevante más adelante cuando tenga tiempo.

Escuché que usar LSTM para realizar el reconocimiento de matrículas es muy simple, así que intentémoslo ~


1. Generación de conjuntos de datos

Si desea realizar el reconocimiento de matrículas, por supuesto debe tener el conjunto de datos correspondiente, pero es difícil encontrar conjuntos de datos relevantes en Internet, así que ¡permítanos generarlo nosotros mismos! ! ! inserte la descripción de la imagen aquí
La siguiente es la imagen generada,
inserte la descripción de la imagen aquí
así es, podemos generar imágenes en cuatro colores: fondo negro, fondo verde, fondo azul y fondo amarillo (para generar la imagen de la matrícula, utilicé este proyecto en Github para generarla. De hecho , Todos pueden generarlo por sí mismos fake_chs_lp. La generación también es muy simple, pero soy vago y no tengo la plantilla de color de fondo real, el siguiente es el código para generar la imagen de la placa de fondo azul)

import os
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw


class Draw:
    _font = [
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126),
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95)
    ]
    _bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/blue_bg.png")), (440, 140))

    def __call__(self, plate):
        if len(plate) != 7:
            print("ERROR: Invalid length")
            return None
        fg = self._draw_fg(plate)
        return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB)

    def _draw_char(self, ch):
        img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0))
        draw = ImageDraw.Draw(img)
        draw.text(
            (0, -11 if ch.isupper() or ch.isdigit() else 3), ch,
            fill=(255, 255, 255),
            font=self._font[0 if ch.isupper() or ch.isdigit() else 1]
        )
        if img.width > 45:
            img = img.resize((45, 140))
        return np.array(img)

    def _draw_fg(self, plate):
        img = np.array(Image.new("RGB", (440, 140), (0, 0, 0)))
        offset = 15
        img[0:140, offset:offset + 45] = self._draw_char(plate[0])
        offset = offset + 45 + 12
        img[0:140, offset:offset + 45] = self._draw_char(plate[1])
        offset = offset + 45 + 34
        for i in range(2, len(plate)):
            img[0:140, offset:offset + 45] = self._draw_char(plate[i])
            offset = offset + 45 + 12
        return img


if __name__ == "__main__":
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser(description="Generate a blue plate.")
    parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345")
    args = parser.parse_args()

    draw = Draw()
    plate = draw(args.plate)
    plt.imshow(plate)
    plt.show()

2. Procesamiento de datos

Ahora que tenemos los datos, el siguiente paso es procesarlos.

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from utils import str_to_label

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


class Sampling(Dataset):
    def __init__(self, root):
        self.transform = data_transforms
        self.images = []
        self.labels = []

        for filename in os.listdir(root):
            x = os.path.join(root, filename)
            y = filename.split(".")[0]
            self.images.append(x)
            self.labels.append(y)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = self.transform(Image.open(image_path))
        label = self.labels[index]
        label = str_to_label(label)  # 将字母转成数字表示,方便做one-hot
        label = self.one_hot(label)
        # label = torch.Tensor(label)

        return image, label

    @staticmethod
    def one_hot(x):
        z = np.zeros((7, 65))
        for i in range(7):
            index = int(x[i])
            z[i][index] = 1
        return z


if __name__ == '__main__':
    sampling = Sampling("G:/DL_Data/Plate/train_plate")
    dataloader = DataLoader(sampling, 10, shuffle=True)
    for j, (img, labels) in enumerate(dataloader):
        # print(img.shape)
        print(labels)
        print(labels.shape)
        exit()

3. Entrenamiento modelo

El siguiente paso es el entrenamiento del modelo, aunque el reconocimiento de matrículas se puede realizar con CNN, esta vez intentamos utilizar la estructura codificador-decodificador (Encoder-Decoder) de la red recurrente, y el modelo Seq2Seq puede reconocerlo. En Pytorch, se puede llamar directamente a LSTM, pero la
inserte la descripción de la imagen aquí
forma La transformación es un poco engorrosa. He comentado en el código que el
formato de entrada de LSTM es ( N , S , V ) (N,S,V)( norte ,S ,V ) puede entenderse como escanear la imagen de izquierda a derecha, y el vector obtenido por cada escaneo se pasa secuencialmente a la red de bucle. SSS es cuántos pasos escanear, es decir, el ancho de la imagen es 440, es decir, ingresexxEl número de x ;VVV es el vector obtenido al escanear cada paso, que es cadaxxx , es la altura de la imagen × número de canales = 140×3

El código de estructura de red es el siguiente:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(420, 128),  # 420数据长度
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)

    def forward(self, x):
        # [N,3,140,440] --> [N,420,440] --> [N,440,420]
        x = x.reshape(-1, 420, 440)
        # [N,440,420] --> [N*440,420]
        x = x.reshape(-1, 420)
        # [N*440,420].[420,128]=[N*440,128]
        fc1 = self.fc1(x)
        # [N*440,128] --> [N,440,128]
        fc1 = fc1.reshape(-1, 440, 128)
        lstm, *_ = self.lstm(fc1)
        # [N,440,128] --> [N,128]
        out = lstm[:, -1, :]
        return out


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
        self.out = nn.Linear(128, 65)
        # self.out_province = nn.Linear(128, 31)
        # self.out_upper = nn.Linear(128, 24)
        # self.out_digits = nn.Linear(128, 10)

    def forward(self, x):
        # [N,128] --> [N,1,128]
        x = x.reshape(-1, 1, 128)
        # [N,1,128] --> [N,7,128]
        x = x.expand(-1, 7, 128)
        lstm, *_ = self.lstm(x)
        # [N,7,128] --> [N*7,128]
        y1 = lstm.reshape(-1, 128)
        # [N*7,128].[128,65]=[N*7,65]
        out = self.out(y1)
        # out_province = self.out_province(y1)
        # out_upper = self.out_upper(y1)
        # out_digits = self.out_digits(y1)

        # [N*7,65] --> [N,7,65]
        output = out.reshape(-1, 7, 65)
        return output
        # return out_province, out_upper, out_digits


class MainNet(nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)

        return decoder

En la prueba real, mi amigo y yo descubrimos que cuando el codificador-decodificador es LSTM, el tiempo de entrenamiento es relativamente largo y el efecto no es el esperado, por lo que reemplazamos el codificador con CNN y el decodificador continuó siendo LSTM.

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()
        self.cnn_layer = nn.Sequential(
            nn.Conv2d(3, 8, 3, 2, 1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, 2, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 5 * 14, 128),
        )

    def forward(self, x):
        out = self.cnn_layer(x)
        out = out.reshape(x.size(0), -1)
        out = self.fc(out)
        return out


class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(
            input_size=128, hidden_size=128, num_layers=1, batch_first=True
        )
        self.out = nn.Linear(128, 65)

    def forward(self, x):
        # [N, 128]-->[N, 1, 128]-->[N, 7, 128]
        x = x.reshape(-1, 1, 128).expand(-1, 7, 128)
        lstm, (_, _) = self.lstm(x)
        # [N, 7, 128]-->[N*7, 128]
        y = lstm.reshape(-1, 128)
        # [N*4, 128]-->[N*7, 10]
        out = self.out(y)
        # [N*7, 10]-->[N, 7, 10]
        out = out.reshape(-1, 7, 65)
        return out


class MainNet(nn.Module):

    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)
        return decoder

4. Prueba

Después de entrenar el modelo, utilice los pesos guardados para realizar pruebas, como se muestra en la figura siguiente.
inserte la descripción de la imagen aquí
En términos generales, el efecto es bueno.

Cinco, mejorar

Aunque el efecto general de lo que hemos hecho es muy bueno, en el proyecto real de reconocimiento de matrículas todavía necesitamos utilizar conjuntos de datos reales. Nuestros buenos datos solo se pueden utilizar en el estacionamiento de tarifa fija.

En la vida real, a menudo es necesario localizar primero la matrícula, primero detectar la posición del vehículo, luego detectar la posición de la matrícula del vehículo y finalmente identificar el número de matrícula.

Por lo tanto, los estudiantes interesados ​​pueden intentar utilizar datos reales para realizar experimentos. Aquí hay una dirección de proyecto de un conjunto de datos de matrículas: CCPD , que es un conjunto de datos doméstico a gran escala para el reconocimiento de matrículas, construido por investigadores de la Universidad de Ciencias. y Tecnología de China de

Supongo que te gusta

Origin blog.csdn.net/weixin_42166222/article/details/118300944
Recomendado
Clasificación