机器学习入门0005 tensorflow_NMT模型
1.简介
nmt(Neural Machine Translation)是一个序列到序列的模型。可以用来做【聊天机器人】,【翻译】,【关键词提取】,【文章摘要】,【图像描述】等功能。用法简单,只需要安装Tensorflow1.4+ 版本即可运行。这个地址是Tensorflow 官方github https://github.com/tensorflow/nmt,里面内容很全面。
2.运行官网github的例子
下面内容是从斯坦福大学下载英语到越南语的平行语料库,然后通过nmt模型使用语料库训练一个 英语-越南语 或者 越南语-英语 的翻译模型。
2.1 下载平行语料库
在控制台输入这个命令nmt/scripts/download_iwslt15.sh /tmp/nmt_data,就可以下载小的平行语料库了。这个命令需要在nmt目录外边执行,会把数据下载到/tmp/nmt_data/下,其实是8个文件。在国内,由于网络原因,这个命令下载总是会中止,很蛋疼。可以借助其他下载工具来下载(比如某雷,某度云盘,某旋风等),这些文件的地址是:
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.vi
下载完成后放在/tmp/nmt_data/下,也可以下载到自己喜欢的位置,但是需要修改之后的命令,会造成不必要的麻烦。
2.2 训练模型(英语-越南语模型)
在控制台中输入下面命令:
mkdir /tmp/nmt_model python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab --train_prefix=/tmp/nmt_data/train
--dev_prefix=/tmp/nmt_data/tst2012 --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000
--steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu
命令共有两行,第一行创建一个文件夹,用于存放训练的模型(若干个矩阵)
第二行很长只有一行,使用来进行训练的。其中的参数用来指明数据的位置,模型存放的位置,训练时的参数:总共训练12000步 每次训练100组 rnn共2层128个单元等。
注意:这个训练时间和机器性能有关可能达到1周时间 python 3.X的用户执行第二条命令,要这样写python3 -m nmt.nmt --src=en --tgt=v...
2.3 使用训练好的模型
在/tmp/下创建一个文件,名字是my_infer_file.en 里面写上几行英语:
i am a student how are you .... i want to be a super programer
然后执行下面命令
python -m nmt.nmt --out_dir=/tmp/nmt_model --inference_input_file=/tmp/my_infer_file.en --inference_output_file=/tmp/nmt_model/output_infer
很快可以执行完毕,然后到/tmp/nmt_model/目录下看这个文件output_infer,里面是对应的越南语翻译。
3.怎么用nmt
3.1 数据格式
仿照着之前下载的8个文件,做好数据对应,其中三个文件是英文的一句一行单词之间通过空格分隔,还有三个是越南语,格式和英语一样。vocab.vi 是越南语的词汇表取了常用的5000个词,vocab.en 是英语词汇表取了最常用的前5000个词语,但是它们前三个词语是<unk> 代表不认识的词语 <s>开始 </s>结束,这三个词必须在词汇表中否则nmt模型不能工作,具体原因官方github上有解释。
3.2 模型参数
python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab --train_prefix=/tmp/nmt_data/train --dev_prefix=/tmp/nmt_data/tst2012 --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 --steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu
这条命令中只是使用了个别的参数,还有一些其他有用的参数,如下:
forget_bias=1.0 这个是lstm的记忆力参数,取值范围在[0.0,1.0]越大代表记性越好
batch_size=128 这个代表每次训练128条数据,取值是正整数,如果太大,需要的内存会增大
learning_rate=1 学习率,正常情况下设置成小于等于1的值,默认值 1
num_gpus=1 机器中gpu个数,默认值是1
eos='</s>' 结束符配置成</s>,参考3.1 数据格式
sos='<s>' 同上,这两个参数没有配置的必要
src_max_len=50 源输入最大长度,针对我们训练的英语-越南语模型中,意思是每行最长接受50个英语单词,其余忽略
tgt_max_len=50 目标输出最大长度,默认值50.这个和上面的参数有时很有用,假设我们要做文章摘要,参数可以这样写--src_max_len=800 --tgt_max_len=150,这两个参数都会影响训练和预测速度,他们越大,模型跑的越慢。
share_vocab=False 这个意思是是否公用词汇表,假设做文章摘要,把这个设置成True。因为不是做翻译,输入和输出是同一种语言。
还有一些其他参数,不再列举,可以去源代码中nmt.py文件中查看。
3.3.1 准备好训练数据,开发数据,测试数据,汉语常用汉字表(前5000个)即可
仿照3.1中的数据,来准备训练数据。这次不是翻译数据,而是对话数据。比如:
train.src
你 好 ! 很 高 兴 认 识 你 。 当 然 很 激 动 了。 ....
train.tgt
你 好 呀 ! 我 也 是 呢 , 你 有 没 有 很 激 动 。 激 动 你 妹 啊 。 ....
vocab.src
<unk> <s> </s> , 的 。 <sp> 一 0 是 1 、 在 有 不 了 2 人 中 大 国 年 ...
3.3.2 接下来进行训练
python -m nmt.nmt --src=src --tgt=tgt --vocab_prefix=/tmp/chat_data/vocab --train_prefix=/tmp/chat_data/train --dev_prefix=/tmp/chat_data/dev --test_prefix=/tmp/chat_data/test --out_dir=/tmp/nmt_model --num_train_steps=192000 --steps_per_stats=100 --num_layers=2 --num_units=256 --dropout=0.2 --metrics=bleu --src_max_len=80 --tgt_max_len=80 --share_vocab=True
经过漫长的训练,聊天模型训练完毕
3.3.3 集成到项目
有三种方案将训练的模型集成到项目中:
(1)对nmt进行部分修改,在项目代码中调用预测,使结果以文件形式展示,然后去文件中提取结果。优点:改动少,可以快速集成。 缺点:运行速度很慢
(2)对nmt进行部分修改,在项目代码中调用预测,只是要给nmt的源代码添加参数和返回值,返回值就是结果。 优点:改动少,可以快速集成。缺点:运行速度慢
(3)把nmt重构,写成一个对象,不要释放session,这样调用的速度会快一些。优点:运行速度快。 缺点:需要对nmt进行深入了解,开发周期长
前两种速度慢的原因是,每次运行都要加载大量的参数,加载词汇。第一种方案还多进行了两次io操作。
4.1 数据问题
数据是比较难得到的,可以用自己qq聊天的数据,把聊天数据导出,然后做成nmt需要的数据格式。至于数据量,10万条以上吧,这个还没有详细的研究过。数据质量一定要好。很多公司是自己手动标注数据,这会耗费大量的时间,数据很难得。假设要做关键词抽取,可以通过爬虫爬取某浪新闻的带标签的文章。train_prefix dev_prefix这两个参数所指定的文件的数据量100-500条即可,不要太多了
4.2 内存问题
[src_max_len] [gt_max_len] [num_units] [num_layers] [batch_size] 这几个参数越大训练速度越慢,消耗内存越多。
词汇表越大,消耗的内存也越大,训练速度也会越慢。
4.3 训练问题
如果训练好久,控制台还没有一点反应,说明参数调的不好,机器性能跟不上,可以适当的降低参数。训练时间可能会很长,要有耐心。
nmt模型可以中断训练,下次输入和上次相同的参数,会接着继续执行。