30分钟吃掉CRNN-CTC验证码识别

本范例我们使用经典的 CRNN+ CTC Loss 的OCR模型来识别验证码。

我们通过导入一个叫 captcha 的库来生成验证码。

我们生成验证码的字符由数字和大写字母组成。

项目参考:https://github.com/ypwhs/captcha_break

#!pip install captcha torchkeras
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor, to_pil_image

from tqdm import tqdm
import random
import numpy as np

import torchkeras 
from pathlib import Path
from collections import OrderedDict 


characters = '-' + '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'  # 注:’-‘ 为[blank] 特殊字符
width, height  = 192, 64
n_classes = len(characters)

txt_length = 4 #识别的验证码长度
seq_length = 12 #CRNN输出序列长度,一般要求 seq_length>=2*txt_length+1

一,准备数据

from captcha.image import ImageCaptcha
generator = ImageCaptcha(width=320, height=64, 
        fonts=[str(Path(torchkeras.__file__).parent/'assets'/'SimHei.ttf') ],
        font_sizes=[40,45])
generator.generate_image('中国人民很行')

06f8b29f508d971c08590aad8661f82c.png

class CaptchaDataset(Dataset):
    def __init__(self, characters, length, 
                 width, height, input_length, label_length):
        super(CaptchaDataset, self).__init__()
        self.characters = characters
        self.length = length
        self.width = width
        self.height = height
        self.input_length = input_length
        self.label_length = label_length
        self.n_class = len(characters)
        self.generator = ImageCaptcha(width=width, height=height)

    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        random_str = ''.join([random.choice(self.characters[1:]) for j in range(self.label_length)])
        image = to_tensor(self.generator.generate_image(random_str))
        target = torch.tensor([self.characters.find(x) for x in random_str], dtype=torch.long)
        input_length = torch.full(size=(1, ), fill_value=self.input_length, dtype=torch.long)
        target_length = torch.full(size=(1, ), fill_value=self.label_length, dtype=torch.long)
        return image, target, input_length, target_length
batch_size = 128
ds_train= CaptchaDataset(characters, 100 * batch_size, 
                         width, height, seq_length, txt_length)

ds_val = CaptchaDataset(characters, 20 * batch_size, 
                        width, height, seq_length, txt_length)

dl_train = DataLoader(ds_train, batch_size=batch_size, num_workers=4)
dl_val = DataLoader(ds_val, batch_size=batch_size, num_workers=4)


ds_test = CaptchaDataset(characters, 1, width, height, seq_length, txt_length)

image, target, input_length, label_length = ds_test[0]
print(''.join([characters[x] for x in target]), input_length, label_length)
to_pil_image(image)

二,定义模型

class CRNN(nn.Module):
    def __init__(self, n_classes, input_shape=(3, 64, 128)):
        super().__init__()
        self.input_shape = input_shape
        channels = [32, 64, 128, 256, 256]
        layers = [2, 2, 2, 2, 2]
        kernels = [3, 3, 3, 3, 3]
        pools = [2, 2, 2, 2, (2, 1)]
        modules = OrderedDict()
        
        def cba(name, in_channels, out_channels, kernel_size):
            modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
                                               padding=(1, 1) if kernel_size == 3 else 0)
            modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
            modules[f'relu{name}'] = nn.ReLU(inplace=True)
        
        last_channel = 3
        for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
            for layer in range(1, n_layer + 1):
                cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
                last_channel = n_channel
            modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
        modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
        
        self.cnn = nn.Sequential(modules)
        self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
        self.fc = nn.Linear(in_features=256, out_features=n_classes)
    
    def infer_features(self):
        x = torch.zeros((1,)+self.input_shape)
        x = self.cnn(x)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        return x.shape[1]
    

    def forward(self, x):
        x = self.cnn(x)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        x = x.permute(2, 0, 1)
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x
net = CRNN(n_classes, input_shape=(3, height, width))
inputs = torch.zeros((32, 3, height, width))
outputs = net(inputs)
print(outputs.shape) # LSTM默认输出的形状是 Length在前
net.cuda();
torch.Size([12, 32, 37])

三, 训练模型

# 解码函数和计算准确率函数

def decode_target(sequence):
    return ''.join([characters[x] for x in sequence]).replace(' ', '')

def decode(sequence):
    a = ''.join([characters[x] for x in sequence])
    s = ''.join([x for j, x in enumerate(a[:-1]) if x != characters[0] and x != a[j+1]])
    if len(s) == 0:
        return ''
    if a[-1] != characters[0] and s[-1] != a[-1]:
        s += a[-1]
    return s


def eval_acc(targets, preds):
    preds_argmax = preds.detach().permute(1, 0, 2).argmax(dim=-1)
    targets = targets.cpu().numpy()
    preds_argmax = preds_argmax.cpu().numpy()
    a = np.array([decode_target(gt) == decode(pred) for gt,
                  pred in zip(targets, preds_argmax)])
    return a.mean()
import torch.nn.functional as F 
from torchkeras import KerasModel

#我们覆盖KerasModel的StepRunner以实现自定义训练逻辑。
#注意这里把acc指标的结果写在了step_loss中以便和loss一样在Epoch上求平均,这是一个非常灵活而且有用的写法。

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        images, targets, input_lengths, target_lengths = batch
        
        #loss
        preds = self.net(images)
        preds_log_softmax = F.log_softmax(preds, dim=-1)
        loss = F.ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)
        acc = eval_acc(targets,preds)
            

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":
                       all_loss.item(),
                       self.stage+'_acc':acc}
        
        #metrics
        step_metrics = {}
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
    
KerasModel.StepRunner = StepRunner
model = KerasModel(net,
                   loss_fn=None,
                   optimizer = torch.optim.AdamW(net.parameters(),lr = 2e-6)
                   )

model.load_ckpt('ctc_crnn.pt')
model.fit(
    train_data = dl_train,
    val_data= dl_val,
    ckpt_path='ctc_crnn.pt',
    epochs=30,
    patience=10,
    monitor="val_acc", 
    mode="max",
    plot = True,
    wandb = False,
    callbacks=[visdis],
    quiet = lambda epoch: epoch>5
)

b3660ec7195f1b7848bcd21a94d4f218.png

四,评估模型

def display_fn(model):
    model.eval()
    right = True
    while right:
        image, target, input_length, label_length = ds_test[0]
        output = model(image.unsqueeze(0).cuda())
        output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
        right = (decode_target(target) == decode(output_argmax[0]))
        print('gt:', decode_target(target),' ','pred:', decode(output_argmax[0]))
    display(to_pil_image(image))
    
    
from torchkeras.kerascallbacks import VisDisplay
visdis = VisDisplay(display_fn,model)

207e86800485cbaa61de4000aab6a625.png

五,使用模型

def predict(model,image):
    model.cuda()
    tensor = to_tensor(image)
    output = model(tensor.unsqueeze(0).cuda())
    output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
    preds = decode(output_argmax[0])
    return preds
model.load_ckpt('ctc_crnn.pt')
generator = ImageCaptcha(width=width, height=height)
image = generator.generate_image('ABBD')
image
predict(model,image)

11911d09501d490f47ee7896503488d1.png

六,保存模型

torch.save(model.net.state_dict(),'best.pt')

公众号算法美食屋后台回复关键词:CRNN,获取本文notebook源代码和B站视频讲解。

猜你喜欢

转载自blog.csdn.net/Python_Ai_Road/article/details/131016458