长短期记忆网络LSTM
关于LSTM的介绍和认识,可以参考这篇文章
长短期记忆网络LSTM:https://blog.csdn.net/eagleuniversityeye/article/details/91345671
说明:
…entry原图 ———————— reshape展开 —————— permute换轴 ———————— 输入LSTM
一、LSTM识别验证码——一个模型
使用LSTM结合Seq2Seq结构实现验证码识别
验证码样式如下图:
代码生成42000张验证码(train:40000, test:2000),验证码有清晰的,有低度模糊的,也有中度模糊的,位置也随机。
验证码和标签采用DataLoader加载,标签采用4*10的one-hot编码,网络输出每个图片也是4*10,训练20轮即达到了正确率100%,效果不错。
下面是模型部分代码,其他部分的代码就不贴了,损失函数MSELoss,优化器Adam。
import torch
from torch import nn
class Lstm(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Sequential(
nn.Linear(180, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
)
self.lstm1 = nn.LSTM(128, 256, 2, batch_first=True)
self.lstm2 = nn.LSTM(256, 128, 2, batch_first=True)
self.fc2 = nn.Sequential(
nn.Linear(128, 10),
)
def forward(self, entry): # N C H W N * 3 * 60 * 120
entry = entry.reshape(-1, 3*60, 120) # N V S N * 180 * 120
entry = entry.permute(0, 2, 1) # N S V N * 120 * 180
entry = entry.reshape(-1, 180) # N V 120N * 180
fc1_out = self.fc1(entry) # N V 120N * 128
fc1_out = fc1_out.reshape(-1, 120, 128) # N S V N * 120 * 128
lstm1_out, _ = self.lstm1(fc1_out) # N S V N * 120 * 256网络会输出S次
lstm1_out = lstm1_out[:, -1, :] # N V N * 256只保留最后一次输出
lstm1_out = lstm1_out.reshape(-1, 1, 256) # N 1 V N * 1 * 256
# 下行代码:N 4 V 广播为N * 4 * 256,后面对每个256提取特征输出做损失,后面的优化使得每个V保留一个字符的特征
lstm1_out = lstm1_out.expand(lstm1_out.shape[0], 4, 256)
lstm2_out, _ = self.lstm2(lstm1_out) # N 4 V N * 4 * 128
lstm2_out = lstm2_out.reshape(-1, 128) # 4N, V 4N * 128
fc2_out = self.fc2(lstm2_out) # 4N, V 4N * 10
fc2_out = fc2_out.reshape(-1, 4, 10) # N S V N * 4 * 10
return fc2_out
二、编码器和解码器分离
import torch
from torch import nn
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(180, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
)
self.lstm = nn.LSTM(128, 256, 2, batch_first=True) # V h num_layer
def forward(self, x): # N C H W N 3 60 120
x = x.reshape(-1, 180, 120) # N V S N 180 120
x = x.permute(0, 2, 1) # N S V N 120 180
x = x.reshape(-1, 180) # N V 120N 180
fc_out = self.fc(x) # N V 120N 128
fc_out = fc_out.reshape(-1, 120, 128) # N S V N 120 128
lstm_out, _ = self.lstm(fc_out) # N S V N 120 256
lstm_out = lstm_out[:, -1, :] # N V N 256
lstm_out = lstm_out.reshape(-1, 1, 256) # N 1 V N 1 256
lstm_out = lstm_out.expand(lstm_out.shape[0], 4, 256) # N 4 256
return lstm_out
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(256, 128, 2, batch_first=True)
self.fc = nn.Sequential(
nn.Linear(128, 10),
)
def forward(self, x):
lstm_out, _ = self.lstm(x) # N S V N 4 128
lstm_out = lstm_out.reshape(-1, 128) # N V 4N 128
fc_out = self.fc(lstm_out) # N V 4N 10
fc_out = fc_out.reshape(-1, 4, 10) # N S V N 4 10
return fc_out
class Net(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
encoder = self.encoder(x)
decoder = self.decoder(encoder)
return decoder
# 直接实例化Net()即可,优化也是直接优化Net()的权重即可
# self.net = Net().to(self.device)
# self.opt = torch.optim.Adam(self.net.parameters())
可以修改LSTM参数以改变模型识别率,代价是计算量的增减。
print('The End !')