[Practical project] LSTM realizes license plate recognition

Project address: License-Recognition

Regarding the principles of RNN and LSTM, I won’t go into details here. There are many articles on the Internet that introduce them in detail. I will sort out the relevant knowledge later when I have time.

I heard that using LSTM to realize license plate recognition is very simple, so let us try it~


1. Dataset generation

If you want to do license plate recognition, of course you must have a corresponding data set, but it is difficult to find relevant data sets on the Internet, so let us generate it ourselves! ! ! insert image description here
The following is the generated picture
insert image description here
. That’s right, we can generate pictures in four colors: black background, green background, blue background, and yellow background (the license plate image is generated by this project on Github, in fact, fake_chs_lpeveryone The generation is also very simple, but I'm lazy, and I don't have the actual background color template, the following is the code to generate the blue background license plate image)

import os
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw


class Draw:
    _font = [
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126),
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95)
    ]
    _bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/blue_bg.png")), (440, 140))

    def __call__(self, plate):
        if len(plate) != 7:
            print("ERROR: Invalid length")
            return None
        fg = self._draw_fg(plate)
        return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB)

    def _draw_char(self, ch):
        img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0))
        draw = ImageDraw.Draw(img)
        draw.text(
            (0, -11 if ch.isupper() or ch.isdigit() else 3), ch,
            fill=(255, 255, 255),
            font=self._font[0 if ch.isupper() or ch.isdigit() else 1]
        )
        if img.width > 45:
            img = img.resize((45, 140))
        return np.array(img)

    def _draw_fg(self, plate):
        img = np.array(Image.new("RGB", (440, 140), (0, 0, 0)))
        offset = 15
        img[0:140, offset:offset + 45] = self._draw_char(plate[0])
        offset = offset + 45 + 12
        img[0:140, offset:offset + 45] = self._draw_char(plate[1])
        offset = offset + 45 + 34
        for i in range(2, len(plate)):
            img[0:140, offset:offset + 45] = self._draw_char(plate[i])
            offset = offset + 45 + 12
        return img


if __name__ == "__main__":
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser(description="Generate a blue plate.")
    parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345")
    args = parser.parse_args()

    draw = Draw()
    plate = draw(args.plate)
    plt.imshow(plate)
    plt.show()

2. Data processing

Now that we have the data, the next step is to process the data.

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from utils import str_to_label

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


class Sampling(Dataset):
    def __init__(self, root):
        self.transform = data_transforms
        self.images = []
        self.labels = []

        for filename in os.listdir(root):
            x = os.path.join(root, filename)
            y = filename.split(".")[0]
            self.images.append(x)
            self.labels.append(y)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = self.transform(Image.open(image_path))
        label = self.labels[index]
        label = str_to_label(label)  # 将字母转成数字表示,方便做one-hot
        label = self.one_hot(label)
        # label = torch.Tensor(label)

        return image, label

    @staticmethod
    def one_hot(x):
        z = np.zeros((7, 65))
        for i in range(7):
            index = int(x[i])
            z[i][index] = 1
        return z


if __name__ == '__main__':
    sampling = Sampling("G:/DL_Data/Plate/train_plate")
    dataloader = DataLoader(sampling, 10, shuffle=True)
    for j, (img, labels) in enumerate(dataloader):
        # print(img.shape)
        print(labels)
        print(labels.shape)
        exit()

3. Model training

The next step is the training of the model. Although the license plate recognition can be done with CNN, this time we try to use the encoding and decoding structure (Encoder-Decoder) of the recurrent network, and the Seq2Seq model can recognize it. In Pytorch, LSTM can be directly called, but the
insert image description here
shape The transformation is a bit cumbersome, I have commented in the code
LSTM input format is ( N , S , V ) (N,S,V)(N,S,V ) can be understood as scanning the picture from left to right, and the vector obtained by each scan is sequentially passed into the loop network. SSS is how many steps to scan, that is, the width of the picture is 440, that is, enterxxThe number of x ;VVV is the vector obtained by scanning each step, which is eachxxx , is image height × number of channels = 140×3

The network structure code is as follows:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(420, 128),  # 420数据长度
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)

    def forward(self, x):
        # [N,3,140,440] --> [N,420,440] --> [N,440,420]
        x = x.reshape(-1, 420, 440)
        # [N,440,420] --> [N*440,420]
        x = x.reshape(-1, 420)
        # [N*440,420].[420,128]=[N*440,128]
        fc1 = self.fc1(x)
        # [N*440,128] --> [N,440,128]
        fc1 = fc1.reshape(-1, 440, 128)
        lstm, *_ = self.lstm(fc1)
        # [N,440,128] --> [N,128]
        out = lstm[:, -1, :]
        return out


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
        self.out = nn.Linear(128, 65)
        # self.out_province = nn.Linear(128, 31)
        # self.out_upper = nn.Linear(128, 24)
        # self.out_digits = nn.Linear(128, 10)

    def forward(self, x):
        # [N,128] --> [N,1,128]
        x = x.reshape(-1, 1, 128)
        # [N,1,128] --> [N,7,128]
        x = x.expand(-1, 7, 128)
        lstm, *_ = self.lstm(x)
        # [N,7,128] --> [N*7,128]
        y1 = lstm.reshape(-1, 128)
        # [N*7,128].[128,65]=[N*7,65]
        out = self.out(y1)
        # out_province = self.out_province(y1)
        # out_upper = self.out_upper(y1)
        # out_digits = self.out_digits(y1)

        # [N*7,65] --> [N,7,65]
        output = out.reshape(-1, 7, 65)
        return output
        # return out_province, out_upper, out_digits


class MainNet(nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)

        return decoder

In the actual test, my friend and I found that when the Encoder-Decoder is both LSTM, the training time is relatively long and the effect is not as expected, so we replaced the Encoder with CNN, and the Decoder continued to be LSTM

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()
        self.cnn_layer = nn.Sequential(
            nn.Conv2d(3, 8, 3, 2, 1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, 2, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 5 * 14, 128),
        )

    def forward(self, x):
        out = self.cnn_layer(x)
        out = out.reshape(x.size(0), -1)
        out = self.fc(out)
        return out


class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(
            input_size=128, hidden_size=128, num_layers=1, batch_first=True
        )
        self.out = nn.Linear(128, 65)

    def forward(self, x):
        # [N, 128]-->[N, 1, 128]-->[N, 7, 128]
        x = x.reshape(-1, 1, 128).expand(-1, 7, 128)
        lstm, (_, _) = self.lstm(x)
        # [N, 7, 128]-->[N*7, 128]
        y = lstm.reshape(-1, 128)
        # [N*4, 128]-->[N*7, 10]
        out = self.out(y)
        # [N*7, 10]-->[N, 7, 10]
        out = out.reshape(-1, 7, 65)
        return out


class MainNet(nn.Module):

    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)
        return decoder

4. Test

After training the model, use the saved weights for testing, as shown in the figure below.
insert image description here
Generally speaking, the effect is good.

Five, improve

Although the overall effect of what we have done is very good, in the actual license plate recognition project, we still need to use real data sets. Our good data can only be used in the fixed-fee parking lot.

In real life, it is often necessary to locate the license plate first, first detect the position of the vehicle, then detect the position of the license plate from the vehicle, and finally identify the license plate number

Therefore, interested students can try to use real data to do experiments. Here is a project address of a license plate dataset: CCPD , which is a large-scale domestic dataset for license plate recognition, constructed by researchers from the University of Science and Technology of China of

Guess you like

Origin blog.csdn.net/weixin_42166222/article/details/118300944