Introduction to machine learning 0005 tensorflow_NMT model

Introduction to machine learning 0005 tensorflow_NMT model


1 Introduction

nmt (Neural Machine Translation) is a sequence-to-sequence model. It can be used to do [chat robot], [translation], [keyword extraction], [article abstract], [image description] and other functions. The usage is simple, you only need to install Tensorflow1.4+ version to run. This address is the official github of Tensorflow https://github.com/tensorflow/nmt, which is very comprehensive.




2. Run the example of the official website github

The following content is to download a parallel corpus from English to Vietnamese from Stanford University, and then use the corpus to train an English-Vietnamese or Vietnamese-English translation model through the nmt model.

2.1 Download parallel corpus

Enter the command nmt/scripts/download_iwslt15.sh /tmp/nmt_data in the console to download a small parallel corpus. This command needs to be executed outside the nmt directory, and the data will be downloaded to /tmp/nmt_data/, which is actually 8 files. In China, due to network reasons, the download of this command will always be terminated, which is very painful. You can use other download tools to download (such as a thunder, a cloud disk, a whirlwind, etc.), the addresses of these files are:

https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.vi

After the download is complete, put it under /tmp/nmt_data/, or you can download it to your favorite location, but you need to modify the command afterward, which will cause unnecessary trouble.


2.2 Training model (English-Vietnamese model)

Enter the following command in the console:

mkdir /tmp/nmt_model
python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train
--dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000
--steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu


The command has two lines. The first line creates a folder for storing the trained model (several matrices)

The second line is very long and has only one line , which is used for training. The parameters are used to indicate the location of the data, the location where the model is stored, and the parameters during training: a total of 12,000 steps are trained, and 100 groups of rnn are trained each time, a total of 2 layers and 128 units, etc.

Note: This training time is related to machine performance and may reach 1 week. Python 3.X users execute the second command, to write python3 -m nmt.nmt --src=en --tgt=v...


2.3 Using the trained model

Create a file under /tmp/ with the name my_infer_file.en and write a few lines of English in it:


i am a student

how are you

....

i want to be a super programer


Then execute the following command

python -m nmt.nmt --out_dir=/tmp/nmt_model --inference_input_file=/tmp/my_infer_file.en
--inference_output_file=/tmp/nmt_model/output_infer

It can be executed soon, and then go to the /tmp/nmt_model/ directory to see the file output_infer, which is the corresponding Vietnamese translation.




3. How to use nmt


3.1 Data Format

Follow the 8 files downloaded before, and do a good job of data correspondence. Three of the files are English sentences and lines of words are separated by spaces, and the other three are Vietnamese. The format is the same as that of English. vocab.vi is a Vietnamese vocabulary with 5000 commonly used words, vocab.en is an English vocabulary with the most commonly used top 5000 words, but their first three words are representing unknown words< s>start</s>end , these three words must be in the vocabulary otherwise the nmt model will not work, the specific reasons are explained on the official github.


3.2 Model parameters

python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train --dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 --steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu

Only individual parameters are used in this command, and there are some other useful parameters, as follows:

forget_bias=1.0 This is the memory parameter of lstm, the value range is [0.0, 1.0], the larger the memory, the better the memory

batch_size=128 This represents 128 pieces of data for each training, and the value is a positive integer. If it is too large, the required memory will increase

learning_rate=1 The learning rate, normally set to a value less than or equal to 1, the default value is 1

num_gpus=1 The number of gpus in the machine, the default value is 1

eos='</s>' end character is configured as </s>, refer to 3.1 Data Format

sos='<s>' Same as above, these two parameters do not need to be configured

src_max_len=50 The maximum length of the source input, for the English-Vietnamese model we trained, it means that each line accepts up to 50 English words, and the rest are ignored

tgt_max_len=50 The maximum length of the target output, the default value is 50. This and the above parameters are sometimes useful. Suppose we want to do an article summary, the parameters can be written like this --src_max_len=800 --tgt_max_len=150, these two parameters will affect the training and prediction speed, the bigger they are, the slower the model runs.

share_vocab=False This means whether the vocabulary is public or not, assuming an article summary, set this to True. Because it is not doing translation, the input and output are in the same language.

There are some other parameters, which are no longer listed, and can be viewed in the nmt.py file in the source code.


3.3 Train a chatbot (Chinese)

3.3.1 Prepare the training data, development data, test data, and table of common Chinese characters (the first 5000)

Prepare the training data according to the data in 3.1. This time it's not translation data, but conversation data. for example:

train.src

Hi!

Nice to meet you.

Of course excited.

....




train.tgt

Hello!

Me too, are you excited?

Excited for your sister.

....



vocab.src

<unk>
<s>
</s>
of
<sp>
one
0
Yes
1
exist
have
Do not
span
2
people
middle
Big
country
year

...



3.3.2 Next training

python -m nmt.nmt --src=src --tgt=tgt --vocab_prefix=/tmp/chat_data/vocab  --train_prefix=/tmp/chat_data/train --dev_prefix=/tmp/chat_data/dev  --test_prefix=/tmp/chat_data/test --out_dir=/tmp/nmt_model --num_train_steps=192000 --steps_per_stats=100 --num_layers=2 --num_units=256 --dropout=0.2 --metrics=bleu --src_max_len=80 --tgt_max_len=80 --share_vocab=True

After a long training, the chat model is trained


3.3.3 Integration into the project

There are three options for integrating the trained model into the project:

(1) Partially modify nmt, call the prediction in the project code, display the result in the form of a file, and then extract the result from the file. Advantages: few changes, can be quickly integrated. Cons: very slow

(2) Partially modify nmt, call prediction in the project code, just add parameters and return values ​​to the source code of nmt, and the return value is the result. Advantages: few changes, can be quickly integrated. Cons: slow running

(3) Refactor nmt, write it as an object, do not release the session, so that the call speed will be faster. Advantages: run fast. Disadvantages: need to have an in-depth understanding of nmt, long development cycle

The reason for the slowness of the first two is that each run has to load a large number of parameters and load vocabulary. The first scheme also performs two more io operations.







4. Other


4.1 Data issues

The data is relatively difficult to obtain, you can use your own QQ chat data, export the chat data, and then make the data format required by nmt. As for the amount of data, more than 100,000 pieces, this has not been studied in detail. The data quality must be good. Many companies manually label data by themselves, which takes a lot of time and data is hard to come by. Assuming that keyword extraction is to be done, the tagged articles of a certain wave of news can be crawled through a crawler. The data volume of the files specified by the two parameters of train_prefix  dev_prefix can be 100-500, not too much


4.2 Memory problems

[src_max_len] [gt_max_len] [num_units] [num_layers] [batch_size] The larger the parameters, the slower the training speed and the more memory consumption.

The larger the vocabulary, the larger the memory consumption and the slower the training speed.


4.3 Training problems

If there is no response from the console after training for a long time, it means that the parameters are not adjusted well and the performance of the machine cannot keep up. You can reduce the parameters appropriately. Training time can be long, be patient.

The nmt model can interrupt training, and the next time you enter the same parameters as the last time, it will continue to execute.






Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325645359&siteId=291194637