在CTC网络中我们可以训练出一个映射:
假如序列目标为字符串(词表大小为 n),则Nw输出为n维多项概率分布。
网络输出为:y=Nw,其中,表示t时刻输出是第k项的概率。
但是这个输出只是一组组概率,我们要由这个Nw得到我们预测的标签,这就涉及到一个解码的问题。
按照最大似然准则,最优的解码结果为:
但是上式并不存在已知的高效解法。故我们可以采用下面介绍的几种实用的近似破解码方法。
贪心搜索 (greedy search):
在实际过程中难以计算,但对于某个具体的字符串 π(去 blank 前),我们可以计算出:
因此,我们放弃寻找使 最大的字符串,而是寻找一个使最大的字符串,即:
简化后,解码过程(构造)变得非常简单(基于独立性假设): 在每个时刻t时输出概率最大的字符:
如:
假如y的分布如下图:
greedy search的结果为:
代码实现:
import numpy as np
# 求每一列(即每个时刻)中最大值对应的softmax值
def softmax(logits):
# 注意这里求e的次方时,次方数减去max_value其实不影响结果,因为最后可以化简成教科书上softmax的定义
# 次方数加入减max_value是因为e的x次方与x的极限(x趋于无穷)为无穷,很容易溢出,所以为了计算时不溢出,就加入减max_value项
# 次方数减去max_value后,e的该次方数总是在0到1范围内。
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def remove_blank(labels, blank=0):
new_labels = []
# 合并相同的标签
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# 删除blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def greedy_decode(y, blank=0):
# 按列取最大值,即每个时刻t上最大值对应的下标
raw_rs = np.argmax(y, axis=1)
# 移除blank,值为0的位置表示这个位置是blank
rs = remove_blank(raw_rs, blank)
return raw_rs, rs
np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
label_have_blank, label_no_blank = greedy_decode(y_test)
print(label_have_blank)
print(label_no_blank)
运行结果如下:
[1 3 5 5 5 5 1 5 3 4 4 3 0 4 5 0 3 1 3 3]
[1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3]
Process finished with exit code 0
束搜索(Beam Search):
贪心搜索的性能非常受限。如它不能给出除最优路径之外的其他其优路径。很多时候,如果我们能拿到nearbest的路径,后续可以利用其他信息来进一步优化搜索的结果。束搜索能近似找出 top 最优的若干条路径。
基本原理是通过 中 个序列,每个序列分别连接中个节点,得到 个新序列及对应的score,然后按照score从大到小的顺序选出前个序列,依次推进。
如:
假设为2,。
t=1时:
这个时候只会将两个概率最大的节点放进路径集合中,即有两条路径。
t=2时:
上面的两个路径每个路径都会和下一个时间点的每一项组成新的路径,因此一共有个新路径。
然后我们还是只保留概率最大的两条路径(次大的两个路径相等,这里舍弃掉一个)。
t=3时:
和t=2时类似,又组成了新的6条路径。我们还是取概率最大的两条路径。
实际使用该算法时,往往取前20,这里前2只是为了方便举例。
代码实现:
import numpy as np
# 求每一列(即每个时刻)中最大值对应的softmax值
def softmax(logits):
# 注意这里求e的次方时,次方数减去max_value其实不影响结果,因为最后可以化简成教科书上softmax的定义
# 次方数加入减max_value是因为e的x次方与x的极限(x趋于无穷)为无穷,很容易溢出,所以为了计算时不溢出,就加入减max_value项
# 次方数减去max_value后,e的该次方数总是在0到1范围内。
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def remove_blank(labels, blank=0):
new_labels = []
# 合并相同的标签
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# 删除blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def beam_decode(y, beam_size=10):
# y是个二维数组,记录了所有时刻的所有项的概率
T, V = y.shape
# 将所有的y中值改为log是为了防止溢出,因为最后得到的p是y1..yn连乘,且yi都在0到1之间,可能会导致下溢出
# 改成log(y)以后就变成连加了,这样就防止了下溢出
log_y = np.log(y)
# 初始的beam
beam = [([], 0)]
# 遍历所有时刻t
for t in range(T):
# 每个时刻先初始化一个new_beam
new_beam = []
# 遍历beam
for prefix, score in beam:
# 对于一个时刻中的每一项(一共V项)
for i in range(V):
# 记录添加的新项是这个时刻的第几项,对应的概率(log形式的)加上新的这项log形式的概率(本来是乘的,改成log就是加)
new_prefix = prefix + [i]
new_score = score + log_y[t, i]
# new_beam记录了对于beam中某一项,将这个项分别加上新的时刻中的每一项后的概率
new_beam.append((new_prefix, new_score))
# 给new_beam按score排序
new_beam.sort(key=lambda x: x[1], reverse=True)
# beam即为new_beam中概率最大的beam_size个路径
beam = new_beam[:beam_size]
return beam
np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
beam_chosen = beam_decode(y_test, beam_size=100)
for beam_string, beam_score in beam_chosen[:20]:
print(remove_blank(beam_string), beam_score)
运行结果如下:
[1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.261797539205567
[1, 3, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.279020152518033
[1, 3, 5, 1, 5, 3, 4, 2, 3, 4, 5, 3, 1, 3] -29.300726142201842
[1, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.310307014773972
[1, 3, 5, 1, 5, 3, 4, 2, 3, 3, 5, 3, 1, 3] -29.31794875551431
[1, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.327529628086438
[1, 3, 5, 1, 5, 4, 3, 4, 5, 3, 1, 3] -29.331572723457334
[1, 3, 5, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.33263180992451
[1, 3, 5, 4, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.334649090836038
[1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.33969505198154
[1, 3, 5, 2, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3] -29.339823066915415
[1, 3, 5, 1, 5, 4, 3, 3, 5, 3, 1, 3] -29.3487953367698
[1, 5, 1, 5, 3, 4, 2, 3, 4, 5, 3, 1, 3] -29.349235617770248
[1, 3, 5, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.349854423236977
[1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 3] -29.350803198551016
[1, 3, 5, 4, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.351871704148504
[1, 3, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.356917665294006
[1, 3, 5, 2, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3] -29.35704568022788
[1, 3, 5, 1, 5, 3, 4, 5, 4, 5, 3, 1, 3] -29.363802591012263
[1, 5, 1, 5, 3, 4, 2, 3, 3, 5, 3, 1, 3] -29.366458231082714
Process finished with exit code 0
可以看到log形式的score连加的结果都是负数,这是因为logx,当x属于0到1之间时logx为负的。
前缀束搜索(Prefix Beam Search):
束搜索(Beam Search)存在的一个问题是,在保存的 top N 条路径中,可能存在多条实际上是同一结果(经过去重复、去 blank 操作后的)。这减少了搜索结果的多样性。前缀束搜索(Prefix Beam Search)方法,可以在搜索过程中不断的合并相同的前缀。
probabilityWithBlank和probabilityNoBlank分别代表最后一个字符是空格和最后一个字符不是空格的概率。
如:
当t=2时,many-to-one map后为[1]的序列有三种:[1,0],[0,1],[1,1],其中[1,0]是尾部带blank的情况,[0,1]和[1,1]是尾部不带blank的情况,那么假设t=3时label为1,那么新的序列就有以下几种情况。
[1,0]+[1]=[1,0,1]->[1,0,1]
[0,1]+[1]=[0,1,1]->[1]
[1,1]+[1]=[1,1,1]->[1]
后两者是新的尾部不为blank的序列,可见尾部不为blank在新序列产生的时候是可以算作一种情况的,这就是为什么要分为blank和尾部不为blank的情况。
代码实现:
import numpy as np
from collections import defaultdict
ninf = float("-inf")
# 求每一列(即每个时刻)中最大值对应的softmax值
def softmax(logits):
# 注意这里求e的次方时,次方数减去max_value其实不影响结果,因为最后可以化简成教科书上softmax的定义
# 次方数加入减max_value是因为e的x次方与x的极限(x趋于无穷)为无穷,很容易溢出,所以为了计算时不溢出,就加入减max_value项
# 次方数减去max_value后,e的该次方数总是在0到1范围内。
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def remove_blank(labels, blank=0):
new_labels = []
# 合并相同的标签
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# 删除blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def _logsumexp(a, b):
'''
np.log(np.exp(a) + np.exp(b))
'''
if a < b:
a, b = b, a
if b == ninf:
return a
else:
return a + np.log(1 + np.exp(b - a))
def logsumexp(*args):
'''
from scipy.special import logsumexp
logsumexp(args)
'''
res = args[0]
for e in args[1:]:
res = _logsumexp(res, e)
return res
def prefix_beam_decode(y, beam_size=10, blank=0):
T, V = y.shape
log_y = np.log(y)
# 最后一个字符是blank与最后一个字符为non-blank两种情况
beam = [(tuple(), (0, ninf))]
# 对于每一个时刻t
for t in range(T):
# 当我使用普通的字典时,用法一般是dict={},添加元素的只需要dict[element] =value即可,调用的时候也是如此
# dict[element] = xxx,但前提是element字典里,如果不在字典里就会报错
# defaultdict的作用是在于,当字典里的key不存在但被查找时,返回的不是keyError而是一个默认值
# dict =defaultdict( factory_function)
# 这个factory_function可以是list、set、str等等,作用是当key不存在时,返回的是工厂函数的默认值
# 这里就是(ninf, ninf)是默认值
new_beam = defaultdict(lambda: (ninf, ninf))
# 对于beam中的每一项
for prefix, (p_b, p_nb) in beam:
for i in range(V):
# beam的每一项都加上时刻t中的每一项
p = log_y[t, i]
# 如果i中的这项是blank
if i == blank:
# 将这项直接加入路径中
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
continue
# 如果i中的这一项不是blank
else:
end_t = prefix[-1] if prefix else None
# 判断之前beam项中的最后一个元素和i的元素是不是一样
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
# 如果不一样,则将i这项加入路径中
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
else:
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
# 如果一样,保留现有的路径,但是概率上要加上新的这个i项的概率
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
# 给新的beam排序并取前beam_size个
beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam
np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
beam_test = prefix_beam_decode(y_test, beam_size=100)
for beam_string, beam_score in beam_test[:20]:
print(remove_blank(beam_string), beam_score)
运行结果如下:
[1, 5, 4, 1, 3, 4, 5, 2, 3] (-18.189863809114193, -17.613677981426175)
[1, 5, 4, 5, 3, 4, 5, 2, 3] (-18.19636512622969, -17.621013424585406)
[1, 5, 4, 1, 3, 4, 5, 1, 3] (-18.31701896033153, -17.666629973270073)
[1, 5, 4, 5, 3, 4, 5, 1, 3] (-18.323388267369936, -17.674125139073176)
[1, 5, 4, 1, 3, 4, 3, 2, 3] (-18.415808498759556, -17.862744326248826)
[1, 5, 4, 1, 3, 4, 3, 5, 3] (-18.36642276663863, -17.898463479112884)
[1, 5, 4, 5, 3, 4, 3, 2, 3] (-18.42224294936932, -17.870025672291458)
[1, 5, 4, 5, 3, 4, 3, 5, 3] (-18.37219911390019, -17.905130493229173)
[1, 5, 4, 1, 3, 4, 5, 4, 3] (-18.457066311773847, -17.880630315602037)
[1, 5, 4, 5, 3, 4, 5, 4, 3] (-18.462614293487096, -17.88759583852546)
[1, 5, 4, 1, 3, 4, 5, 3, 2] (-18.458941701567706, -17.951422824358747)
[1, 5, 4, 5, 3, 4, 5, 3, 2] (-18.464527031120184, -17.958629487208658)
[1, 5, 4, 1, 3, 4, 3, 1, 3] (-18.540857550725587, -17.92058991009369)
[1, 5, 4, 5, 3, 4, 3, 1, 3] (-18.547146092248852, -17.928030266681613)
[1, 5, 4, 1, 3, 4, 5, 3, 2, 3] (-19.325467801462263, -17.6892032244089)
[1, 5, 4, 5, 3, 4, 5, 3, 2, 3] (-19.328748799764973, -17.694105969982637)
[1, 5, 4, 1, 3, 4, 5, 3, 4] (-18.79699026165903, -17.945090229238392)
[1, 5, 4, 5, 3, 4, 5, 3, 4] (-18.80358553427324, -17.95258394264377)
[1, 5, 4, 3, 4, 3, 5, 2, 3] (-19.18153184608281, -17.859420073785095)
[1, 5, 4, 1, 3, 4, 5, 2, 3, 2] (-19.4393492963852, -17.884502168470895)
Process finished with exit code 0