项目地址:License-Recognition
关于RNN和LSTM的原理,这里就不多赘述了,网上有许多文章对他们进行详细的介绍,后面有时间我也会去整理下相关知识。
听说使用LSTM实现车牌识别很简单,这就让我们尝试下吧~
一、数据集的生成
要想做出车牌识别,当然得有相应的数据集啦,但是网上很难找到相关数据集,那么就让我们自己生成吧!!!
下面就是生成的图片啦
没错,我们可以生成黑底、绿底、蓝底、黄底4种颜色的图片(生成车牌图片这块,我是采用的Github上的 fake_chs_lp
这个项目生成的,其实大家自己生成也很简单,但我比较懒,还有就是我没有实际的底色模板,下面是生成蓝底车牌图的代码)
import os
import cv2
import numpy as np
from PIL import Image, ImageFont, ImageDraw
class Draw:
_font = [
ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126),
ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95)
]
_bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/blue_bg.png")), (440, 140))
def __call__(self, plate):
if len(plate) != 7:
print("ERROR: Invalid length")
return None
fg = self._draw_fg(plate)
return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB)
def _draw_char(self, ch):
img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0))
draw = ImageDraw.Draw(img)
draw.text(
(0, -11 if ch.isupper() or ch.isdigit() else 3), ch,
fill=(255, 255, 255),
font=self._font[0 if ch.isupper() or ch.isdigit() else 1]
)
if img.width > 45:
img = img.resize((45, 140))
return np.array(img)
def _draw_fg(self, plate):
img = np.array(Image.new("RGB", (440, 140), (0, 0, 0)))
offset = 15
img[0:140, offset:offset + 45] = self._draw_char(plate[0])
offset = offset + 45 + 12
img[0:140, offset:offset + 45] = self._draw_char(plate[1])
offset = offset + 45 + 34
for i in range(2, len(plate)):
img[0:140, offset:offset + 45] = self._draw_char(plate[i])
offset = offset + 45 + 12
return img
if __name__ == "__main__":
import argparse
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(description="Generate a blue plate.")
parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345")
args = parser.parse_args()
draw = Draw()
plate = draw(args.plate)
plt.imshow(plate)
plt.show()
二、数据处理
既然已经有了数据,那接下来就是进行数据的处理啦
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from utils import str_to_label
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
class Sampling(Dataset):
def __init__(self, root):
self.transform = data_transforms
self.images = []
self.labels = []
for filename in os.listdir(root):
x = os.path.join(root, filename)
y = filename.split(".")[0]
self.images.append(x)
self.labels.append(y)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image_path = self.images[index]
image = self.transform(Image.open(image_path))
label = self.labels[index]
label = str_to_label(label) # 将字母转成数字表示,方便做one-hot
label = self.one_hot(label)
# label = torch.Tensor(label)
return image, label
@staticmethod
def one_hot(x):
z = np.zeros((7, 65))
for i in range(7):
index = int(x[i])
z[i][index] = 1
return z
if __name__ == '__main__':
sampling = Sampling("G:/DL_Data/Plate/train_plate")
dataloader = DataLoader(sampling, 10, shuffle=True)
for j, (img, labels) in enumerate(dataloader):
# print(img.shape)
print(labels)
print(labels.shape)
exit()
三、模型训练
接下来就是模型的训练了,车牌识别虽然用CNN也是可以做的,这次我们尝试使用循环网络的编解码结构(Encoder-Decoder),Seq2Seq模型对其进行识别
Pytorch中可直接调用LSTM,但是shape变换有点繁琐,我在代码中已进行注释
LSTM的输入格式为 ( N , S , V ) (N,S,V) (N,S,V),可以理解为将图片从左到右进行扫描,每次扫描得到的向量依次传入循环网络。 S S S为扫描多少步,就是图片宽度440,也就是上图中输入 x x x的个数; V V V为每步扫描得到的向量,就是上图中的每个 x x x,为图片高度×通道数=140×3
网络结构代码如下:
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(420, 128), # 420数据长度
nn.BatchNorm1d(128),
nn.ReLU()
)
self.lstm = nn.LSTM(input_size=128,
hidden_size=128,
num_layers=1,
batch_first=True)
def forward(self, x):
# [N,3,140,440] --> [N,420,440] --> [N,440,420]
x = x.reshape(-1, 420, 440)
# [N,440,420] --> [N*440,420]
x = x.reshape(-1, 420)
# [N*440,420].[420,128]=[N*440,128]
fc1 = self.fc1(x)
# [N*440,128] --> [N,440,128]
fc1 = fc1.reshape(-1, 440, 128)
lstm, *_ = self.lstm(fc1)
# [N,440,128] --> [N,128]
out = lstm[:, -1, :]
return out
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.lstm = nn.LSTM(input_size=128,
hidden_size=128,
num_layers=1,
batch_first=True)
self.out = nn.Linear(128, 65)
# self.out_province = nn.Linear(128, 31)
# self.out_upper = nn.Linear(128, 24)
# self.out_digits = nn.Linear(128, 10)
def forward(self, x):
# [N,128] --> [N,1,128]
x = x.reshape(-1, 1, 128)
# [N,1,128] --> [N,7,128]
x = x.expand(-1, 7, 128)
lstm, *_ = self.lstm(x)
# [N,7,128] --> [N*7,128]
y1 = lstm.reshape(-1, 128)
# [N*7,128].[128,65]=[N*7,65]
out = self.out(y1)
# out_province = self.out_province(y1)
# out_upper = self.out_upper(y1)
# out_digits = self.out_digits(y1)
# [N*7,65] --> [N,7,65]
output = out.reshape(-1, 7, 65)
return output
# return out_province, out_upper, out_digits
class MainNet(nn.Module):
def __init__(self):
super(MainNet, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
encoder = self.encoder(x)
decoder = self.decoder(encoder)
return decoder
在实际测试中,我和朋友发现Encoder-Decoder都为LSTM时,训练时间比较久且效果没达到预期,于是我们将Encoder换成了CNN,Decoder继续保持为LSTM
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.cnn_layer = nn.Sequential(
nn.Conv2d(3, 8, 3, 2, 1),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(8, 16, 3, 2, 1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, 2, 1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Linear(32 * 5 * 14, 128),
)
def forward(self, x):
out = self.cnn_layer(x)
out = out.reshape(x.size(0), -1)
out = self.fc(out)
return out
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.lstm = nn.LSTM(
input_size=128, hidden_size=128, num_layers=1, batch_first=True
)
self.out = nn.Linear(128, 65)
def forward(self, x):
# [N, 128]-->[N, 1, 128]-->[N, 7, 128]
x = x.reshape(-1, 1, 128).expand(-1, 7, 128)
lstm, (_, _) = self.lstm(x)
# [N, 7, 128]-->[N*7, 128]
y = lstm.reshape(-1, 128)
# [N*4, 128]-->[N*7, 10]
out = self.out(y)
# [N*7, 10]-->[N, 7, 10]
out = out.reshape(-1, 7, 65)
return out
class MainNet(nn.Module):
def __init__(self):
super(MainNet, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
encoder = self.encoder(x)
decoder = self.decoder(encoder)
return decoder
四、测试
训练好模型后,使用保存好的权重进行测试,如下图所示
总的来说,效果还是不错滴
五、提升
虽然我们做的这个整体来说,效果很不错,但在实际的车牌识别项目中,还是需要使用真实数据集的,我们这个这么良好的数据估计只能用在停车场固定收费的拍照处
而实际生活中,往往要先定位到车牌,首先检测车辆位置,再从车辆上检测车牌位置,最后识别车牌号
所以,有兴趣的同学,可以尝试使用真实数据去做实验,这儿放一份车牌数据集的项目地址:CCPD,这是一个用于车牌识别的大型国内的数据集,由中科大的科研人员构建出来的