【实战项目】LSTM实现车牌识别

项目地址:License-Recognition

关于RNN和LSTM的原理,这里就不多赘述了,网上有许多文章对他们进行详细的介绍,后面有时间我也会去整理下相关知识。

听说使用LSTM实现车牌识别很简单,这就让我们尝试下吧~


一、数据集的生成

要想做出车牌识别,当然得有相应的数据集啦,但是网上很难找到相关数据集,那么就让我们自己生成吧!!!在这里插入图片描述
下面就是生成的图片啦
在这里插入图片描述
没错,我们可以生成黑底、绿底、蓝底、黄底4种颜色的图片(生成车牌图片这块,我是采用的Github上的 fake_chs_lp 这个项目生成的,其实大家自己生成也很简单,但我比较懒,还有就是我没有实际的底色模板,下面是生成蓝底车牌图的代码)

import os
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw


class Draw:
    _font = [
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126),
        ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95)
    ]
    _bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/blue_bg.png")), (440, 140))

    def __call__(self, plate):
        if len(plate) != 7:
            print("ERROR: Invalid length")
            return None
        fg = self._draw_fg(plate)
        return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB)

    def _draw_char(self, ch):
        img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0))
        draw = ImageDraw.Draw(img)
        draw.text(
            (0, -11 if ch.isupper() or ch.isdigit() else 3), ch,
            fill=(255, 255, 255),
            font=self._font[0 if ch.isupper() or ch.isdigit() else 1]
        )
        if img.width > 45:
            img = img.resize((45, 140))
        return np.array(img)

    def _draw_fg(self, plate):
        img = np.array(Image.new("RGB", (440, 140), (0, 0, 0)))
        offset = 15
        img[0:140, offset:offset + 45] = self._draw_char(plate[0])
        offset = offset + 45 + 12
        img[0:140, offset:offset + 45] = self._draw_char(plate[1])
        offset = offset + 45 + 34
        for i in range(2, len(plate)):
            img[0:140, offset:offset + 45] = self._draw_char(plate[i])
            offset = offset + 45 + 12
        return img


if __name__ == "__main__":
    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser(description="Generate a blue plate.")
    parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345")
    args = parser.parse_args()

    draw = Draw()
    plate = draw(args.plate)
    plt.imshow(plate)
    plt.show()

二、数据处理

既然已经有了数据,那接下来就是进行数据的处理啦

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from utils import str_to_label

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


class Sampling(Dataset):
    def __init__(self, root):
        self.transform = data_transforms
        self.images = []
        self.labels = []

        for filename in os.listdir(root):
            x = os.path.join(root, filename)
            y = filename.split(".")[0]
            self.images.append(x)
            self.labels.append(y)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = self.transform(Image.open(image_path))
        label = self.labels[index]
        label = str_to_label(label)  # 将字母转成数字表示,方便做one-hot
        label = self.one_hot(label)
        # label = torch.Tensor(label)

        return image, label

    @staticmethod
    def one_hot(x):
        z = np.zeros((7, 65))
        for i in range(7):
            index = int(x[i])
            z[i][index] = 1
        return z


if __name__ == '__main__':
    sampling = Sampling("G:/DL_Data/Plate/train_plate")
    dataloader = DataLoader(sampling, 10, shuffle=True)
    for j, (img, labels) in enumerate(dataloader):
        # print(img.shape)
        print(labels)
        print(labels.shape)
        exit()

三、模型训练

接下来就是模型的训练了,车牌识别虽然用CNN也是可以做的,这次我们尝试使用循环网络的编解码结构(Encoder-Decoder),Seq2Seq模型对其进行识别
在这里插入图片描述
Pytorch中可直接调用LSTM,但是shape变换有点繁琐,我在代码中已进行注释
LSTM的输入格式为 ( N , S , V ) (N,S,V) (N,S,V),可以理解为将图片从左到右进行扫描,每次扫描得到的向量依次传入循环网络。 S S S为扫描多少步,就是图片宽度440,也就是上图中输入 x x x的个数; V V V为每步扫描得到的向量,就是上图中的每个 x x x,为图片高度×通道数=140×3

网络结构代码如下:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(420, 128),  # 420数据长度
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)

    def forward(self, x):
        # [N,3,140,440] --> [N,420,440] --> [N,440,420]
        x = x.reshape(-1, 420, 440)
        # [N,440,420] --> [N*440,420]
        x = x.reshape(-1, 420)
        # [N*440,420].[420,128]=[N*440,128]
        fc1 = self.fc1(x)
        # [N*440,128] --> [N,440,128]
        fc1 = fc1.reshape(-1, 440, 128)
        lstm, *_ = self.lstm(fc1)
        # [N,440,128] --> [N,128]
        out = lstm[:, -1, :]
        return out


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=128,
                            hidden_size=128,
                            num_layers=1,
                            batch_first=True)
        self.out = nn.Linear(128, 65)
        # self.out_province = nn.Linear(128, 31)
        # self.out_upper = nn.Linear(128, 24)
        # self.out_digits = nn.Linear(128, 10)

    def forward(self, x):
        # [N,128] --> [N,1,128]
        x = x.reshape(-1, 1, 128)
        # [N,1,128] --> [N,7,128]
        x = x.expand(-1, 7, 128)
        lstm, *_ = self.lstm(x)
        # [N,7,128] --> [N*7,128]
        y1 = lstm.reshape(-1, 128)
        # [N*7,128].[128,65]=[N*7,65]
        out = self.out(y1)
        # out_province = self.out_province(y1)
        # out_upper = self.out_upper(y1)
        # out_digits = self.out_digits(y1)

        # [N*7,65] --> [N,7,65]
        output = out.reshape(-1, 7, 65)
        return output
        # return out_province, out_upper, out_digits


class MainNet(nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)

        return decoder

在实际测试中,我和朋友发现Encoder-Decoder都为LSTM时,训练时间比较久且效果没达到预期,于是我们将Encoder换成了CNN,Decoder继续保持为LSTM

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder, self).__init__()
        self.cnn_layer = nn.Sequential(
            nn.Conv2d(3, 8, 3, 2, 1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, 2, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 5 * 14, 128),
        )

    def forward(self, x):
        out = self.cnn_layer(x)
        out = out.reshape(x.size(0), -1)
        out = self.fc(out)
        return out


class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(
            input_size=128, hidden_size=128, num_layers=1, batch_first=True
        )
        self.out = nn.Linear(128, 65)

    def forward(self, x):
        # [N, 128]-->[N, 1, 128]-->[N, 7, 128]
        x = x.reshape(-1, 1, 128).expand(-1, 7, 128)
        lstm, (_, _) = self.lstm(x)
        # [N, 7, 128]-->[N*7, 128]
        y = lstm.reshape(-1, 128)
        # [N*4, 128]-->[N*7, 10]
        out = self.out(y)
        # [N*7, 10]-->[N, 7, 10]
        out = out.reshape(-1, 7, 65)
        return out


class MainNet(nn.Module):

    def __init__(self):
        super(MainNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)
        return decoder

四、测试

训练好模型后,使用保存好的权重进行测试,如下图所示
在这里插入图片描述
总的来说,效果还是不错滴

五、提升

虽然我们做的这个整体来说,效果很不错,但在实际的车牌识别项目中,还是需要使用真实数据集的,我们这个这么良好的数据估计只能用在停车场固定收费的拍照处

而实际生活中,往往要先定位到车牌,首先检测车辆位置,再从车辆上检测车牌位置,最后识别车牌号

所以,有兴趣的同学,可以尝试使用真实数据去做实验,这儿放一份车牌数据集的项目地址:CCPD,这是一个用于车牌识别的大型国内的数据集,由中科大的科研人员构建出来的

猜你喜欢

转载自blog.csdn.net/weixin_42166222/article/details/118300944