beam_search

import torch
import torch.nn.functional as F
from src.utils import generate_square_subsequent_mask
import math


class Translator():
    def __init__(self, bos_idx, eos_idx, device, max_steps=64, beam_size=3, length_norm_coefficient=0.6):
        '''
        length_norm_coefficient: co-efficient for normalizing decoded sequences' scores by their lengths
        '''
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.beam_size = beam_size
        self.device = device
        self.max_steps = max_steps
        self.length_norm_coefficient = length_norm_coefficient
        

    def beam_translate(self, old_model, enc_inputs, tokenizer):
        """
        Translates a source language sequence to the target language, with beam search decoding.

        :param enc_inputs: [1, src_len]
        :param enc_mask: [1, src_len]
        :return: the best hypothesis, and all candidate hypotheses
        """
       
        
        if hasattr(old_model, "module"):
            model = old_model.module
        else:
            model = old_model
        
        with torch.no_grad():
            # Beam size
            k = self.beam_size

            # Minimum number of hypotheses to complete
            n_completed_hypotheses = min(k, 10)

            # Vocab size
            vocab_size = len(tokenizer)

            # Encode
            memory = model.encode(**enc_inputs)  # (1, source_sequence_length, d_model)
            # Our hypothesis to begin with is just <BOS>

             # Our hypothesis to begin with is just <BOS>
            hypotheses = torch.ones(k,1).fill_(self.bos_idx).long().to(self.device)  # (k, 1)

            # Tensor to store hypotheses' scores; now it's just 0
            hypotheses_scores = torch.zeros(k).to(self.device)  # (k)

            # Lists to store completed hypotheses and their scores
            completed_hypotheses = list()
            completed_hypotheses_scores = list()

            # Start decoding
            step = 1
            # Assume "s" is the number of incomplete hypotheses currently in the bag; a number less than or equal to "k"
            # At this point, s is 1, because we only have 1 hypothesis to work with, i.e. "<BOS>"
            while True:
                num_hyp = hypotheses.size(0)  # 相当于batch_size
                hyp_mask = (generate_square_subsequent_mask(hypotheses.size(1))
                    .type(torch.bool)).to(self.device)
                # (num_hyp,  step,  d_model)
                decoder_sequences = model.decode(tgt=hypotheses,
                                                 src=enc_inputs['input_ids'],
                                                 memory=memory.repeat(num_hyp, 1, 1),
                                                 tgt_mask=hyp_mask,
                                                  )  
                
                # Scores at this step
                scores = decoder_sequences[:, -1, :]   # (num_hyp, vocab_size)
                scores = torch.log(scores)   # (num_hyp, vocab_size)

                # Add hypotheses' scores from last step to scores at this step to get scores for all possible new hypotheses
                scores = hypotheses_scores.unsqueeze(1) + scores  # (num_hyp, vocab_size)

                # Unroll and find top k scores, and their unrolled indices
                if step == 1:   # step=1的时候只需要取其中一个条数据取top-k
                    top_k_hypotheses_scores, unrolled_indices = scores[0].topk(num_hyp, 0, True, True)  # (k)
                else:
                    top_k_hypotheses_scores, unrolled_indices = scores.view(-1).topk(num_hyp, 0, True, True)  # (k)

                # Convert unrolled indices to actual indices of the scores tensor which yielded the best scores
                prev_word_indices = unrolled_indices // vocab_size  # (num_hyp)
                next_word_indices = unrolled_indices % vocab_size  # (num_hyp)

                # Construct the the new top k hypotheses from these indices
                top_k_hypotheses = torch.cat([hypotheses[prev_word_indices], next_word_indices.unsqueeze(1)],
                                             dim=1)  # (num_hyp, step + 1)

                # Which of these new hypotheses are complete (reached <EOS>)?
                complete = next_word_indices == self.eos_idx  # (num_hyp), bool

                # Set aside completed hypotheses and their scores normalized by their lengths
                # For the length normalization formula, see
                # "Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation"
                completed_hypotheses.extend(top_k_hypotheses[complete].tolist())
#                 norm = math.pow(((5 + step) / (5 + 1)), self.length_norm_coefficient)
                norm = 1.0
                completed_hypotheses_scores.extend((top_k_hypotheses_scores[complete] / norm).tolist())

                # Stop if we have completed enough hypotheses
                if len(completed_hypotheses) >= n_completed_hypotheses:
                    break

                # Else, continue with incomplete hypotheses
                hypotheses = top_k_hypotheses[~complete]  # (s, step + 1)
                hypotheses_scores = top_k_hypotheses_scores[~complete]  # (s)

                # Stop if things have been going on for too long
                if step > self.max_steps:
                    break
                step += 1

            # If there is not a single completed hypothesis, use partial hypotheses
            if len(completed_hypotheses) == 0:
                completed_hypotheses = hypotheses.tolist()
                completed_hypotheses_scores = hypotheses_scores.tolist()

            # Decode the hypotheses
            all_hypotheses = list()
            for i, com_hyp in enumerate(completed_hypotheses):
                predict_seq = tokenizer.decode(com_hyp, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                predict_seq = predict_seq.replace("<S>", "").replace("<\S>", "").strip()
                all_hypotheses.append({
    
    "hypothesis": predict_seq, "score": completed_hypotheses_scores[i]})

            # Find the best scoring completed hypothesis
            i = completed_hypotheses_scores.index(max(completed_hypotheses_scores))
            best_hypothesis = all_hypotheses[i]["hypothesis"]

            return best_hypothesis, all_hypotheses
        
         # function to generate output sequence using greedy algorithm
    def greedy_translate(self, old_model, enc_inputs, tokenizer):
        
      
        
        if hasattr(old_model, "module"):
            model = old_model.module
        else:
            model = old_model
            
        with torch.no_grad():
            
            memory = model.encode(**enc_inputs)
            ys = torch.ones(1, 1).fill_(self.bos_idx).type(torch.long).to(self.device)
            for i in range(self.max_steps):

                tgt_mask = (generate_square_subsequent_mask(ys.size(1)).bool()).to(self.device)  # [1, tgt_len, tgt_len]

                out = model.decode(tgt=ys, src=enc_inputs['input_ids'], memory=memory, tgt_mask=tgt_mask)  # [1, tgt_len, d_model]
                
                prob = out[:,-1]
                # print(prob.shape)
                _, next_word = torch.max(prob, dim=1)
                next_word = next_word.item()
                # print(next_word)
                ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(ys)], dim=-1)
                if next_word == self.eos_idx:
                    break

            predict_seq = tokenizer.decode(ys.squeeze().cpu().tolist(), 
                                              skip_special_tokens=True, 
                                              clean_up_tokenization_spaces=True)
            predict_seq = predict_seq.replace("<S>", "").replace("<\S>", "").strip()
            return predict_seq

猜你喜欢

转载自blog.csdn.net/mch2869253130/article/details/123871309