beam search实现

beam search的代码,根据理解加了注释。
from __future__ import division
from __future__ import print_function
import numpy as np


class BeamEntry:
    "information about one single beam at specific time-step"

    def __init__(self):
        self.prTotal = 0  # blank and non-blank
        self.prNonBlank = 0  # non-blank
        self.prBlank = 0  # blank
        self.prText = 1  # LM score
        self.lmApplied = False  # flag if LM was already applied to this beam
        self.labeling = ()  # beam-labeling


class BeamState:
    "information about the beams at specific time-step"

    def __init__(self):
        self.entries = {}

    def norm(self):
        "length-normalise LM score"
        for (k, _) in self.entries.items():
            labelingLen = len(self.entries[k].labeling)
            print('what is k:', k)
            self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))

    def sort(self):
        "return beam-labelings, sorted by probability"
        beams = [v for (_, v) in self.entries.items()]
        #for beam in beams:
            #print('sort func beams:',(beam.labeling,beam.prTotal,beam.prText))
        #for k,v in self.entries.items():
            #print(k,v)
        sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal * x.prText)
        return [x.labeling for x in sortedBeams]


def applyLM(parentBeam, childBeam, classes, lm):
    "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
    if lm and not childBeam.lmApplied:
        c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')]  # first char
        c2 = classes[childBeam.labeling[-1]]  # second char
        lmFactor = 0.01  # influence of language model
        bigramProb = lm.getCharBigram(c1,
                                      c2) ** lmFactor  # probability of seeing first and second char next to each other
        childBeam.prText = parentBeam.prText * bigramProb  # probability of char sequence
        childBeam.lmApplied = True  # only apply LM once per beam entry


def addBeam(beamState, labeling):
    "add beam if it does not yet exist"
    if labeling not in beamState.entries:
        beamState.entries[labeling] = BeamEntry()


def ctcBeamSearch(mat, classes, lm):
    "beam search as described by the paper of Hwang et al. and the paper of Graves et al."

    blankIdx = len(classes)
    maxT, maxC = mat.shape
    beamWidth = 2

    # initialise beam state
    last = BeamState()
    labeling = ()
    last.entries[labeling] = BeamEntry()
    last.entries[labeling].prBlank = 1
    last.entries[labeling].prTotal = 1

    # go over all time-steps
    for t in range(maxT):
        curr = BeamState()

        # get beam-labelings of best beams
        bestLabelings = last.sort()[0:beamWidth]
        print('bestLabelings', bestLabelings,' length ',len(bestLabelings))

        # go over best beams
        # t位置+1,更新目前最优的beamwidth个路径
        for labeling in bestLabelings:

            # probability of paths ending with a non-blank
            prNonBlank = 0
            # in case of non-empty beam
            if labeling:
                # probability of paths with repeated last char at the end
                # label不更新单第二种情况,该t位置不是Blank,但是与该路径最后一位相同,并且在该路径不是以Blank结尾的情况,
                # 则更新该路径不以空格结尾的概率
                prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
                print('NoneBlak probability:', last.entries[labeling].prNonBlank , mat[t, labeling[-1]])

            # probability of paths ending with a blank
            # label不更新的第一种情况,该t位置是Blank。则更新该路径空格结尾的概率
            prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]

            # add beam at current time-step if needed
            addBeam(curr, labeling)

            # fill in data
            curr.entries[labeling].labeling = labeling
            curr.entries[labeling].prNonBlank += prNonBlank
            curr.entries[labeling].prBlank += prBlank
            curr.entries[labeling].prTotal += prBlank + prNonBlank
            curr.entries[labeling].prText = last.entries[
                labeling].prText  # beam-labeling not changed, therefore also LM score unchanged from
            curr.entries[labeling].lmApplied = True  # LM already applied at previous time-step for this beam-labeling

            # label更新,更新的情况t位置不能是空格
            # extend current beam-labeling
            for c in range(maxC - 1):
                # add new char to current beam-labeling
                newLabeling = labeling + (c,)
                print('add labeling:',labeling,newLabeling)

                # if new labeling contains duplicate char at the end, only consider paths ending with a blank
                if labeling and labeling[-1] == c:
                    # label更新的第二种情况,该t位置(非空格)与路径最后一位相同,但是路径以空格结尾的情况,直接加上。
                    # 更新新路径不以空格结尾的概率
                    prNonBlank = mat[t, c] * last.entries[labeling].prBlank
                    print('duplicate: ',labeling,last.entries[labeling].prBlank)
                else:
                    # label更新的第一种情况,该t位置(非空格)与路径最后一位不同,则直接加上,更新新路经不以空格结尾的概率
                    prNonBlank = mat[t, c] * last.entries[labeling].prTotal

                # add beam at current time-step if needed
                addBeam(curr, newLabeling)

                # fill in data
                curr.entries[newLabeling].labeling = newLabeling
                curr.entries[newLabeling].prNonBlank += prNonBlank
                curr.entries[newLabeling].prTotal += prNonBlank

                # apply LM
                applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)

        # set new beam state
        last = curr

    # normalise LM scores according to beam-labeling-length
    last.norm()

    # sort by probability
    bestLabeling = last.sort()[0]  # get most probable labeling
    print('bestLabeling', bestLabeling)
    # map labels to chars
    res = ''
    for l in bestLabeling:
        res += classes[l]

    return res


def testBeamSearch():
    "test decoder"
    classes = 'ab'
    mat = np.array([[0.4, 0.5, 0.1], [0.1, 0.3, 0.6], [0.1, 0.9, 0]])
    #mat = np.array([[0.1, 0, 0.9], [0.1, 0, 0.9]])
    print('Test beam search')
    expected = 'a'
    actual = ctcBeamSearch(mat, classes, None)
    print('Expected: "' + expected + '"')
    print('Actual: "' + actual + '"')
    print('OK' if expected == actual else 'ERROR')


if __name__ == '__main__':
    testBeamSearch()

猜你喜欢

转载自blog.csdn.net/henyaoyuancc/article/details/85317072