pytorch实现 chatbot聊天机器人

版权声明:版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DarrenXf/article/details/86750447

涉及的论文

Neural Conversational Model https://arxiv.org/abs/1506.05869
Luong attention mechanism(s) https://arxiv.org/abs/1508.04025
Sutskever et al. https://arxiv.org/abs/1409.3215
GRU Cho et al. https://arxiv.org/pdf/1406.1078v3.pdf
Bahdanau et al. https://arxiv.org/abs/1409.0473

使用的数据集

Corpus web https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
Corpus link http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

代码列表

chatbot_test.py
chatbot_train.py
corpus_dataset.py
vocabulary.py
graph.py  
model.py
etc.py             
main.py   

chatbot_test.py

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import corpus_dataset
import graph
import etc

def run_test():
    config = etc.config
    voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
    g = graph.CorpusGraph(config)
    train_model = g.create_train_model(voc, "test")
    g.evaluate_input(voc, train_model)

chatbot_train.py

# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import corpus_dataset
import graph
import etc

def run_train():
    config = etc.config
    voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
    g = graph.CorpusGraph(config)
    print("Create model")
    train_model = g.create_train_model(voc)
    print("Starting Training!")
    g.trainIters(voc, pairs, train_model)
#    print("Starting evaluate!")
#    g.evaluate_input(voc, train_model)

corpus_dataset.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : corpus_dataset.py
# Create date : 2019-01-16 11:16
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os
import re
import csv
import codecs
import unicodedata
import vocabulary

def _check_is_have_file(file_name):
    return os.path.exists(file_name)

def _filter_pair(p, max_length):
    return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length

def _filter_pairs(pairs, max_length):
    return [pair for pair in pairs if _filter_pair(pair, max_length)]

def _read_vocabulary(datafile, corpus_name):
    print("Reading lines...")
    lines = open(datafile, encoding='utf-8'). read().strip().split('\n')
    pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
    voc = vocabulary.Voc(corpus_name)
    return voc, pairs

def _unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def _get_delimiter(config):
    delimiter = config["delimiter"]
    delimiter = str(codecs.decode(delimiter, "unicode_escape"))
    return delimiter

def _get_object(line, fields):
    values = line.split(" +++$+++ ")
    obj = {}
    for i, field in enumerate(fields):
        obj[field] = values[i]
    return obj

def _load_lines(config):
    lines_file_name = config["lines_file_name"]
    corpus_path = config["corpus_path"]
    lines_file_full_path = "%s/%s" % (corpus_path, lines_file_name)
    fields = config["movie_lines_fields"]

    lines = {}
    f = open(lines_file_full_path, 'r', encoding='iso-8859-1')
    for line in f:
        line_obj = _get_object(line, fields)
        lines[line_obj['lineID']] = line_obj
    f.close()
    return lines

def _cellect_lines(conv_obj, lines):
    # Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]")
    line_ids = eval(conv_obj["utteranceIDs"])
    # Reassemble lines
    conv_obj["lines"] = []
    for line_id in line_ids:
        conv_obj["lines"].append(lines[line_id])
    return conv_obj

def _load_conversations(lines, config):
    conversations = []
    corpus_path = config["corpus_path"]
    conversation_file_name = config["conversation_file_name"]
    conversation_file_full_path = "%s/%s" % (corpus_path, conversation_file_name)
    fields = config["movie_conversations_fields"]
    f = open(conversation_file_full_path, 'r', encoding='iso-8859-1')
    for line in f:
        conv_obj = _get_object(line, fields)
        conv_obj = _cellect_lines(conv_obj, lines)
        conversations.append(conv_obj)
    f.close()
    return conversations

def _get_conversations(config):
    lines = {}
    conversations = []

    lines = _load_lines(config)
    print("lines count:", len(lines))
    conversations = _load_conversations(lines, config)
    print("conversations count:", len(conversations))
    return conversations

def _extract_sentence_pairs(conversations):
    pairs = []
    for conversation in conversations:
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                pairs.append([inputLine, targetLine])
    return pairs

def _load_formatted_data(config):
    max_length = config["max_length"]
    corpus_name = config["corpus_name"]

    formatted_file_full_path = get_formatted_file_full_path(config)
    print("Start preparing training data ...")
    voc, pairs = _read_vocabulary(formatted_file_full_path, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = _filter_pairs(pairs, max_length)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

def _trim_rare_words(voc, pairs, min_count):
    voc.trim(min_count)
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

def _write_newly_formatted_file(config):
    formatted_file_full_path = get_formatted_file_full_path(config)
    if not _check_is_have_file(formatted_file_full_path):
        delimiter = _get_delimiter(config)
        conversations = _get_conversations(config)
        outputfile = open(formatted_file_full_path, 'w', encoding='utf-8')
        pairs = _extract_sentence_pairs(conversations)
        print("pairs count:", len(pairs))
        writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
        print("\nWriting newly formatted file...")
        for pair in pairs:
            writer.writerow(pair)
    else:
        print("%s already has the formatted file,so we do not write" % formatted_file_full_path)

def load_vocabulary_and_pairs(config):
    _write_newly_formatted_file(config)
    voc, pairs = _load_formatted_data(config)
    pairs = _trim_rare_words(voc, pairs, config["min_count"])
    return voc, pairs

def get_formatted_file_full_path(config):
    formatted_file_name = config["formatted_file_name"]
    corpus_path = config["corpus_path"]
    formatted_file_full_path = "%s/%s" % (corpus_path, formatted_file_name)
    return formatted_file_full_path

def normalize_string(s):
    s = _unicode_to_ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

vocabulary.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : vocabulary.py
# Create date : 2019-01-16 11:21
# Modified date : 2019-02-02 13:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

graph.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-16 11:44
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import os
import itertools
import random
import torch
import torch.nn as nn
from torch import optim

import vocabulary
import model
import corpus_dataset

def _get_training_batches(voc, pairs, batch_size, n_iteration):
    training_batches = []
    for i in range(n_iteration):
        lt = [random.choice(pairs) for _ in range(batch_size)]
        batch = _batch2TrainData(voc, lt)
        training_batches.append(batch)
    return training_batches

def _zero_padding(l, fillvalue=vocabulary.PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def _binary_matrix(lt):
    m = []
    for i, seq in enumerate(lt):
        m.append([])
        for token in seq:
            if token == vocabulary.PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

def _get_indexes_batch(lt, voc):
    indexes_batch = [_indexes_from_sentence(voc, sentence) for sentence in lt]
    return indexes_batch

def _input_var(batch, voc):
    indexes_batch = _get_indexes_batch(batch, voc)
    padList = _zero_padding(indexes_batch)
    variable = torch.LongTensor(padList)
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    return variable, lengths

def _output_var(batch, voc):
    indexes_batch = _get_indexes_batch(batch, voc)
    padList = _zero_padding(indexes_batch)
    variable = torch.LongTensor(padList)

    max_target_len = max([len(indexes) for indexes in indexes_batch])
    mask = _binary_matrix(padList)
    mask = torch.ByteTensor(mask)

    return variable, mask, max_target_len

def _indexes_from_sentence(voc, sentence):
    #return [voc.word2index[word] for word in sentence.split(' ')] + [vocabulary.EOS_token]
    index_lt = []
    for word in sentence.split(' '):
        i = voc.word2index[word]
        index_lt.append(i)
    index_lt.append(vocabulary.EOS_token)
    return index_lt

def _batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    input_variable, lengths = _input_var(input_batch, voc)
    target_variable, mask, max_target_len = _output_var(output_batch, voc)
    return input_variable, lengths, target_variable, mask, max_target_len

def _maskNLLLoss(inp, target, mask, device):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

class CorpusGraph(nn.Module):
    def __init__(self, config):
        super(CorpusGraph, self).__init__()
        self.model_name = config["model_name"]
        self.save_dir = config["save_dir"]
        self.corpus_name = config["corpus_name"]
        self.encoder_n_layers = config["encoder_n_layers"]
        self.decoder_n_layers = config["decoder_n_layers"]
        self.hidden_size = config["hidden_size"]
        self.checkpoint_iter = config["checkpoint_iter"]
        self.learning_rate = config["learning_rate"]
        self.decoder_learning_ratio = config["decoder_learning_ratio"]
        self.dropout = config["dropout"]
        self.attn_model = config["attn_model"]
        self.device = config["device"]
        self.print_every = config["print_every"]
        self.save_every = config["save_every"]
        self.n_iteration = config["n_iteration"]
        self.batch_size = config["batch_size"]
        self.clip = config["clip"]
        self.max_length = config["max_length"]
        self.teacher_forcing_ratio = config["teacher_forcing_ratio"]
        self.train_load_checkpoint_file = config["train_load_checkpoint_file"]

    def _evaluate(self, voc, sentence, train_model):
        encoder = train_model["encoder"]
        decoder = train_model["decoder"]
        # Set dropout layers to eval mode
        encoder.eval()
        decoder.eval()

        searcher = model.GreedySearchDecoder(encoder, decoder)
        indexes_batch = [_indexes_from_sentence(voc, sentence)]
        lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
        input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
        input_batch = input_batch.to(self.device)
        lengths = lengths.to(self.device)
        tokens, scores = searcher(input_batch, lengths, self.max_length, self.device)
        decoded_words = [voc.index2word[token.item()] for token in tokens]
        return decoded_words

    def _choose_use_teacher_forcing(self):
        return True if random.random() < self.teacher_forcing_ratio else False

    def _train_step(self, decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len):
        loss = 0
        print_losses = []
        n_totals = 0

        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            if self._choose_use_teacher_forcing():
                decoder_input = target_variable[t].view(1, -1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = torch.LongTensor([[topi[i][0] for i in range(self.batch_size)]])
                decoder_input = decoder_input.to(self.device)

            mask_loss, nTotal = _maskNLLLoss(decoder_output, target_variable[t], mask[t], self.device)
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

        return loss, print_losses, n_totals

    def _train_init(self, input_variable, lengths, target_variable, mask, train_model):
        encoder = train_model["encoder"]
        decoder = train_model["decoder"]
        encoder_optimizer = train_model["encoder_optimizer"]
        decoder_optimizer = train_model["decoder_optimizer"]
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        input_variable = input_variable.to(self.device)
        lengths = lengths.to(self.device)
        target_variable = target_variable.to(self.device)
        mask = mask.to(self.device)
        encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
        decoder_input = torch.LongTensor([[vocabulary.SOS_token for _ in range(self.batch_size)]])
        decoder_input = decoder_input.to(self.device)
        decoder_hidden = encoder_hidden[:decoder.n_layers]

        return decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask

    def _train_backward(self, loss, train_model):
        encoder = train_model["encoder"]
        decoder = train_model["decoder"]
        encoder_optimizer = train_model["encoder_optimizer"]
        decoder_optimizer = train_model["decoder_optimizer"]
        loss.backward()
        _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), self.clip)
        _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), self.clip)
        encoder_optimizer.step()
        decoder_optimizer.step()

    def _train(self, input_variable, lengths, target_variable, mask, max_target_len, train_model):
        decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask = self._train_init(input_variable, lengths, target_variable, mask, train_model)
        loss, print_losses, n_totals = self._train_step(decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len)
        self._train_backward(loss, train_model)
        return sum(print_losses) / n_totals

    def _save_model_dict(self, train_model, iteration, voc, loss):
        model_dict = self._get_model_dict(train_model, iteration, voc, loss)
        checkpoint_file_full_path = self._get_checkpoint_file_full_name()
        torch.save(model_dict, checkpoint_file_full_path)

    def _show_batches(self, batches):
        input_variable, lengths, target_variable, mask, max_target_len = batches
        print("input_variable:", input_variable)
        print("lengths:", lengths)
        print("target_variable:", target_variable)
        print("mask:", mask)
        print("max_target_len:", max_target_len)

    def _show_train_state(self, print_loss, iteration):
        print_loss_avg = print_loss / self.print_every
        print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / self.n_iteration * 100, print_loss_avg))
        print_loss = 0
        return print_loss

    def _get_model_dict(self, train_model, iteration, voc, loss):
        model_dict = {}
        model_dict["en"] = train_model["encoder"].state_dict()
        model_dict["de"] = train_model["decoder"].state_dict()
        model_dict["en_opt"] = train_model["encoder_optimizer"].state_dict()
        model_dict["de_opt"] = train_model["decoder_optimizer"].state_dict()
        model_dict["embedding"] = train_model["embedding"].state_dict()
        model_dict["iteration"] = iteration
        model_dict["loss"] = loss
        model_dict["voc_dict"] = voc.__dict__
        return model_dict

    def _load_checkpoint(self, train_model, voc, checkpoint):
        train_model["encoder"].load_state_dict(checkpoint['en'])
        train_model["decoder"].load_state_dict(checkpoint['de'])
        train_model["encoder_optimizer"].load_state_dict(checkpoint['en_opt'])
        train_model["decoder_optimizer"].load_state_dict(checkpoint['de_opt'])
        train_model["embedding"].load_state_dict(checkpoint['embedding'])
        voc.__dict__ = checkpoint['voc_dict']
        train_model["iteration"] = checkpoint["iteration"]
        return train_model

    def _train_load_checkpoint(self, train_model, voc):
        loadFilename = self._get_checkpoint_file_full_name()
        if os.path.exists(loadFilename) and self.train_load_checkpoint_file:
            checkpoint = torch.load(loadFilename)
            train_model = self._load_checkpoint(train_model, voc, checkpoint)
        return train_model

    def _test_load_checkpoint(self, train_model, voc):
        loadFilename = self._get_checkpoint_file_full_name()
        if os.path.exists(loadFilename) and self.train_load_checkpoint_file:
            checkpoint = torch.load(loadFilename)
            # If loading a model trained on GPU to CPU
            checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
            train_model = self._load_checkpoint(train_model, voc, checkpoint)
        return train_model

    def _get_save_directory(self):
        directory = os.path.join(self.save_dir,
                                 self.model_name,
                                 self.corpus_name,
                                 '{}-{}_{}'.format(self.encoder_n_layers,
                                                   self.decoder_n_layers,
                                                   self.hidden_size))
        if not os.path.exists(directory):
            os.makedirs(directory)
        return directory

    def _get_checkpoint_file_full_name(self):
        directory = self._get_save_directory()
        checkpoint_file_name = "checkpoint.tar"
        checkpoint_file_full_name = "%s/%s" % (directory, checkpoint_file_name)
        return checkpoint_file_full_name

    def create_train_model(self, voc, status="train"):
        embedding = nn.Embedding(voc.num_words, self.hidden_size)
        encoder = model.EncoderRNN(self.hidden_size, embedding, self.encoder_n_layers, self.dropout)
        encoder = encoder.to(self.device)
        decoder = model.LuongAttnDecoderRNN(self.attn_model, embedding, self.hidden_size, voc.num_words, self.decoder_n_layers, self.dropout)
        decoder = decoder.to(self.device)
        #Ensure dropout layers are in train mode
        encoder.train()
        decoder.train()
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(), self.learning_rate*self.decoder_learning_ratio)

        train_model = {}
        train_model["encoder"] = encoder
        train_model["decoder"] = decoder
        train_model["encoder_optimizer"] = encoder_optimizer
        train_model["decoder_optimizer"] = decoder_optimizer
        train_model["embedding"] = embedding
        train_model["iteration"] = 0

        if status == "train":
            train_model = self._train_load_checkpoint(train_model, voc)
        else:
            train_model = self._test_load_checkpoint(train_model, voc)
        return train_model

    def trainIters(self, voc, pairs, train_model):
        training_batches = _get_training_batches(voc, pairs, self.batch_size, self.n_iteration)
        print_loss = 0
        base_iteration = train_model['iteration'] + 1
        start_iteration = 1

        for iteration in range(start_iteration, self.n_iteration + 1):
            training_batch = training_batches[iteration - 1]
            #self._show_batches(training_batch)
            input_variable, lengths, target_variable, mask, max_target_len = training_batch
            loss = self._train(input_variable, lengths, target_variable, mask, max_target_len, train_model)
            print_loss += loss
            cur_iteration = base_iteration + iteration

            if iteration % self.print_every == 0:
                print_loss = self._show_train_state(print_loss, cur_iteration)

            if iteration % self.save_every == 0:
                self._save_model_dict(train_model, cur_iteration, voc, loss)

    def evaluate_input(self, voc, train_model):
        input_sentence = ''
        while(1):
            try:
                input_sentence = input('> ')
                if input_sentence == 'q' or input_sentence == 'quit': break
                input_sentence = corpus_dataset.normalize_string(input_sentence)
                output_words = self._evaluate(voc, input_sentence, train_model)
                output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
                print('Bot:', ' '.join(output_words))
            except KeyError:
                print("Error: Encountered unknown word.")

model.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : model.py
# Create date : 2019-01-16 11:38
# Modified date : 2019-02-02 14:50
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import vocabulary

class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        self.gru = nn.GRU(hidden_size,
                          hidden_size,
                          n_layers,
                          dropout=(0 if n_layers == 1 else dropout),
                          bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[ :, :, :self.hidden_size] + outputs[ :, :, self.hidden_size:]
        return outputs, hidden


class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)
        attn_energies = attn_energies.t()
        return F.softmax(attn_energies, dim=1).unsqueeze(1)


class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden


class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length,device):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        decoder_hidden = encoder_hidden[:self.decoder.n_layers]
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * vocabulary.SOS_token
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

etc.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-17 22:50
# Modified date : 2019-02-02 14:10
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import torch

config = {}
config["corpus_name"] = "cornell movie-dialogs corpus"
config["corpus_path"] = "./data/%s" % config["corpus_name"]
config["delimiter"] = '\t'

config["formatted_file_name"] = "formatted_movie_lines.txt"
config["conversation_file_name"] = "movie_conversations.txt"
config["lines_file_name"] = "movie_lines.txt"

config["movie_lines_fields"] = ["lineID", "characterID", "movieID", "character", "text"]
config["movie_conversations_fields"] = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

config["model_name"] = 'cb_model'
config["attn_model"] = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
config["hidden_size"] = 500
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["dropout"] = 0.1
config["print_every"] = 20
config["save_every"] = 500
config["n_iteration"] = 1000
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["clip"] = 50.0
config["learning_rate"] = 0.0001
config["decoder_learning_ratio"] = 5.0
config["batch_size"] = 64
config["save_dir"] = "./data/save"
config["checkpoint_iter"] = 4000
config["min_count"] = 3  # Minimum word count threshold for trimming
config["max_length"] = 10
config["teacher_forcing_ratio"] = 1.0
config["train_load_checkpoint_file"] = True

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
config["device"] = device

main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-02-02 13:44
# Modified date : 2019-02-02 13:45
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

from chatbot_train import run_train
from chatbot_test import run_test

def run():
    run_train()
    run_test()

run()

github:

https://github.com/darr/chatbot

猜你喜欢

转载自blog.csdn.net/DarrenXf/article/details/86750447
今日推荐