机器学习入门0005 tensorflow_NMT模型

机器学习入门0005 tensorflow_NMT模型


1.简介

nmt(Neural Machine Translation)是一个序列到序列的模型。可以用来做【聊天机器人】,【翻译】,【关键词提取】,【文章摘要】,【图像描述】等功能。用法简单,只需要安装Tensorflow1.4+ 版本即可运行。这个地址是Tensorflow 官方github https://github.com/tensorflow/nmt,里面内容很全面。




2.运行官网github的例子

下面内容是从斯坦福大学下载英语到越南语的平行语料库,然后通过nmt模型使用语料库训练一个 英语-越南语 或者 越南语-英语 的翻译模型。

2.1  下载平行语料库

在控制台输入这个命令nmt/scripts/download_iwslt15.sh /tmp/nmt_data,就可以下载小的平行语料库了。这个命令需要在nmt目录外边执行,会把数据下载到/tmp/nmt_data/下,其实是8个文件。在国内,由于网络原因,这个命令下载总是会中止,很蛋疼。可以借助其他下载工具来下载(比如某雷,某度云盘,某旋风等),这些文件的地址是:

扫描二维码关注公众号,回复: 154707 查看本文章

https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.vi

下载完成后放在/tmp/nmt_data/下,也可以下载到自己喜欢的位置,但是需要修改之后的命令,会造成不必要的麻烦。


2.2 训练模型(英语-越南语模型)

在控制台中输入下面命令:

mkdir /tmp/nmt_model
python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train 
--dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 
--steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu


命令共有两行,第一行创建一个文件夹,用于存放训练的模型(若干个矩阵)

第二行很长只有一行,使用来进行训练的。其中的参数用来指明数据的位置,模型存放的位置,训练时的参数:总共训练12000步 每次训练100组 rnn共2层128个单元等。

注意:这个训练时间和机器性能有关可能达到1周时间 python 3.X的用户执行第二条命令,要这样写python3 -m nmt.nmt --src=en --tgt=v...


2.3 使用训练好的模型

在/tmp/下创建一个文件,名字是my_infer_file.en 里面写上几行英语:


i am a student

how are you

....

i want to be a super programer


然后执行下面命令

python -m nmt.nmt --out_dir=/tmp/nmt_model --inference_input_file=/tmp/my_infer_file.en 
--inference_output_file=/tmp/nmt_model/output_infer

很快可以执行完毕,然后到/tmp/nmt_model/目录下看这个文件output_infer,里面是对应的越南语翻译。




3.怎么用nmt


3.1 数据格式

仿照着之前下载的8个文件,做好数据对应,其中三个文件是英文的一句一行单词之间通过空格分隔,还有三个是越南语,格式和英语一样。vocab.vi 是越南语的词汇表取了常用的5000个词,vocab.en 是英语词汇表取了最常用的前5000个词语,但是它们前三个词语是<unk> 代表不认识的词语 <s>开始 </s>结束,这三个词必须在词汇表中否则nmt模型不能工作,具体原因官方github上有解释。


3.2 模型参数

python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train --dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 --steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu

这条命令中只是使用了个别的参数,还有一些其他有用的参数,如下:

forget_bias=1.0 这个是lstm的记忆力参数,取值范围在[0.0,1.0]越大代表记性越好

batch_size=128 这个代表每次训练128条数据,取值是正整数,如果太大,需要的内存会增大

learning_rate=1 学习率,正常情况下设置成小于等于1的值,默认值 1

num_gpus=1 机器中gpu个数,默认值是1

eos='</s>' 结束符配置成</s>,参考3.1 数据格式

sos='<s>' 同上,这两个参数没有配置的必要

src_max_len=50 源输入最大长度,针对我们训练的英语-越南语模型中,意思是每行最长接受50个英语单词,其余忽略

tgt_max_len=50 目标输出最大长度,默认值50.这个和上面的参数有时很有用,假设我们要做文章摘要,参数可以这样写--src_max_len=800 --tgt_max_len=150,这两个参数都会影响训练和预测速度,他们越大,模型跑的越慢。

share_vocab=False 这个意思是是否公用词汇表,假设做文章摘要,把这个设置成True。因为不是做翻译,输入和输出是同一种语言。

还有一些其他参数,不再列举,可以去源代码中nmt.py文件中查看。


3.3 训练一个聊天机器人(汉语)

3.3.1 准备好训练数据,开发数据,测试数据,汉语常用汉字表(前5000个)即可

仿照3.1中的数据,来准备训练数据。这次不是翻译数据,而是对话数据。比如:

train.src

你 好 !

很 高 兴 认 识 你 。

当 然 很 激 动 了。

....




train.tgt

你 好 呀 !

我 也 是 呢 , 你 有 没 有 很 激 动 。

激 动 你 妹 啊 。

....



vocab.src

<unk>
<s>
</s>
,
的
。
<sp>
一
0
是
1
、
在
有
不
了
2
人
中
大
国
年

...



3.3.2 接下来进行训练

python -m nmt.nmt --src=src --tgt=tgt --vocab_prefix=/tmp/chat_data/vocab  --train_prefix=/tmp/chat_data/train --dev_prefix=/tmp/chat_data/dev  --test_prefix=/tmp/chat_data/test --out_dir=/tmp/nmt_model --num_train_steps=192000 --steps_per_stats=100 --num_layers=2 --num_units=256 --dropout=0.2 --metrics=bleu --src_max_len=80 --tgt_max_len=80 --share_vocab=True

经过漫长的训练,聊天模型训练完毕


3.3.3 集成到项目

有三种方案将训练的模型集成到项目中:

(1)对nmt进行部分修改,在项目代码中调用预测,使结果以文件形式展示,然后去文件中提取结果。优点:改动少,可以快速集成。 缺点:运行速度很慢

(2)对nmt进行部分修改,在项目代码中调用预测,只是要给nmt的源代码添加参数和返回值,返回值就是结果。 优点:改动少,可以快速集成。缺点:运行速度慢

(3)把nmt重构,写成一个对象,不要释放session,这样调用的速度会快一些。优点:运行速度快。 缺点:需要对nmt进行深入了解,开发周期长

前两种速度慢的原因是,每次运行都要加载大量的参数,加载词汇。第一种方案还多进行了两次io操作。







4.其他


4.1 数据问题

数据是比较难得到的,可以用自己qq聊天的数据,把聊天数据导出,然后做成nmt需要的数据格式。至于数据量,10万条以上吧,这个还没有详细的研究过。数据质量一定要好。很多公司是自己手动标注数据,这会耗费大量的时间,数据很难得。假设要做关键词抽取,可以通过爬虫爬取某浪新闻的带标签的文章。train_prefix dev_prefix这两个参数所指定的文件的数据量100-500条即可,不要太多了


4.2 内存问题

[src_max_len]  [gt_max_len]  [num_units]  [num_layers]  [batch_size] 这几个参数越大训练速度越慢,消耗内存越多。

词汇表越大,消耗的内存也越大,训练速度也会越慢。


4.3 训练问题

如果训练好久,控制台还没有一点反应,说明参数调的不好,机器性能跟不上,可以适当的降低参数。训练时间可能会很长,要有耐心。

nmt模型可以中断训练,下次输入和上次相同的参数,会接着继续执行。






猜你喜欢

转载自blog.csdn.net/moluth/article/details/79142689
今日推荐