GPT2-NewsTitle
GPT2 news headline generation project with hyper-detailed annotations
UpDate 02.19.2022
- Add a Streamlit page, and you can deploy a beautiful page without using Flask+HTML.
- Detailed documentation, see that the algorithm does not know the front end, and it can also make a good-looking interface
run code
streamlit run app.py
or
streamlit run app.py --server.port your_port
Specifically as shown in the figure below:
UpDate 01.02.2021
- Collect data from the Internet, organize and clean news data sets such as Tsinghua News Data, Sogou News Data, and some open-source summary data to build a relatively complete Chinese summary data set.
- When the data set is cleaned, only simple rule cleaning is performed. For example: cleaning html tags, removing redundant empty characters, removing image tags, etc.
- For details of the processed dataset, see the dataset description
data | Raw Data/Project Address | Processed file download address |
---|---|---|
Tsinghua News Data | address | Baidu cloud disk extraction code: vhol |
Sogou news data | address | Baidu cloud disk extraction code: ode6 |
nlpcc2017 summary data | address | Baidu cloud disk extraction code: e0zq |
csl summary data | address | Baidu cloud disk extraction code: 0qot |
Education and training industry summary data | address | Baidu cloud disk extraction code: kjz3 |
lcsts summary data | address | Baidu cloud disk extraction code: bzov |
Sensors Cup 2018 summary data | address | Baidu cloud disk extraction code: 6f4f |
Wanfang summary data | address | Baidu cloud disk extraction code: p69g |
WeChat official account summary data | address | Baidu cloud disk extraction code: 5has |
Weibo data | address | Baidu cloud disk extraction code: 85t5 |
news2016zh news data | address | Baidu cloud disk extraction code: qsj1 |
Data set collection: Baidu cloud disk Extraction code: 7am8
project description
- This project is a GPT2 model-based news headline generation project with super detailed Chinese annotations.
- This project refers to multiple GPT2 open source projects such as GPT2-Chinese , GPT2-chitchat , CDial-GPT , GPT2, etc., and according to my own understanding, refactor the code and add detailed comments, hoping to help people in need.
- This project uses HuggingFace's transformers to implement GPT2 model code writing, training and testing.
- This project builds a web service through the Flask framework, engineering the news summary generation model, and can experience the news headline generation effect visually through the page.
- The code of this project is explained in detail, you can read the code by yourself, or check the code comment introduction .
- The news headline model provided by this project is a 6-layer small model (only small models can be trained), and in the process of training the model, the pre-trained GPT2 model is not loaded but the parameters are randomly initialized, and the number of training rounds is relatively small. Less (5 rounds, not yet converged), so the effect is mediocre. If you want a better model, you can train a model according to your personal needs.
- The purpose of this project is to lead you through the entire process of training, testing and deploying the GPT2 generation model.
file structure
- config
- config.json The configuration information of the model, including n_ctx, n_embd, n_head, n_layer, etc.
- vocab
- vocab.txt dictionary file, the size of the dictionary is 13317, the "##Chinese" in the original dictionary is deleted, and "[Content]", "[Title]", "[Space]" and other tags are added.
- data_dir The folder where the data is stored
- templates Folder for storing html pages
- data_helper.py data preprocessing file, simple cleaning of data
- data_set.py data class file, which defines the data classes required by the model, which is convenient for model training
- model.py GPT2 model file, mainly rewrites the GPT2LMHeadModel in the transformers package , modifies the calculation loss part, and only calculates the loss of the predicted title part
- train.py The training file of the GPT2 model that generates news headlines from the news text
- generate_title.py According to the trained model, generate news titles and predict files
- http_server.py build web service file
operating environment
- vent == 1.3a1
- flask == 0.12.2
- transformers == 3.0.2
See the requirements.txt file for details
data set
Data comes from Sina Weibo, data link: https://www.jianshu.com/p/8f52352f0748?tdsourcetag=s_pcqq_aiomsg
data description | download link |
---|---|
Raw data | Baidu network disk , extraction code: nqzi |
processed data | Baidu network disk , extraction code: duba |
The original data is the news data downloaded directly from the Internet, and the processed data is the data processed using data_helper.py, which can be directly used for training.
Model parameters
See the config/config.json file for details
parameter | value |
---|---|
initializer_range | 0.02 |
layer_norm_epsilon | 1e-05 |
n_ctx | 512 |
n_embd | 768 |
n_head | 12 |
n_layer | 6 |
n_positions | 512 |
vocab_size | 13317 |
Note: In addition to the vector representation of each word, the model input also includes text paragraph vector representation and position vector representation.
Model file sharing
Model | download link |
---|---|
GPT2 model | Baidu network disk , extraction code: 165b |
model training
python3 train.py
或
python3 train.py --output_dir output_dir/(自定义保存模型路径)
The training parameters can be added by yourself, and the parameters are as follows:
parameter | type | Defaults | describe |
---|---|---|---|
device | str | "0" | Set the graphics card to use when training or testing |
config_path | str | "config/config.json" | Model parameter configuration information |
vocab_path | str | "vocab/vocab.txt" | Vocabulary, which is a small vocabulary and adds some new tokens |
train_file_path | str | "data_dir/train_data.json" | Training data generated from news headlines |
test_file_path | str | "data_dir/test_data.json" | Test data generated from news headlines |
pretrained_model_path | str | None | The path to the pretrained GPT2 model |
data_dir | str | "data_dir" | Generate the storage path of the cached data |
num_train_epochs | int | 5 | rounds of model training |
train_batch_size | int | 16 | The size of each batch during training |
test_batch_size | int | 8 | The size of each batch when testing |
learning_rate | float | 1e-4 | Learning rate during model training |
warmup_proportion | float | 0.1 | Warm up probability, that is, what percentage of the total training step size is used for warm up operation |
adam_epsilon | float | 1e-8 | The epsilon value of the Adam optimizer |
logging_steps | int | 20 | The number of steps to save the training log |
eval_steps | int | 4000 | During training, how many steps to perform a test |
gradient_accumulation_steps | int | 1 | gradient accumulation |
max_grad_norm | float | 1.0 | |
output_dir | str | "output_dir/" | Model output path |
seed | int | 2020 | random seed |
max_len | int | 512 | The maximum length of the input model is smaller than n_ctx in config |
Or modify the content of the set_args function in the train.py file to modify the default value.
The model provided by this project has been trained for 5 epochs. The model training loss and test set loss are as follows:
In fact, the model has not been fully trained. According to the loss trend, it can continue to train.
model testing
python3 generate_title.py
或
python3 generate_title.py --top_k 3 --top_p 0.9999 --generate_max_len 32
Parameters can be added by themselves, including parameters as follows:
parameter | type | Defaults | describe |
---|---|---|---|
device | str | "0" | Set the graphics card to use when training or testing |
model_path | str | "output_dir/checkpoint-139805" | Model file path |
vocab_path | str | "vocab/vocab.txt" | Vocabulary, which is a small vocabulary and adds some new tokens |
batch_size | int | 3 | The number of generated titles |
generate_max_len | int | 32 | The maximum length of generated headers |
repetition_penalty | float | 1.2 | Repeat Penalty Rate |
top_k | int | 5 | How many tokens with the highest probability to keep when decoding |
top_p | float | 0.95 | When decoding, keep the mark that the probability accumulation is greater than |
max_len | int | 512 | The maximum length of the input model is smaller than n_ctx in config |
The test results are as follows:
从测试集中抽一篇
content:
今日,中国三条重要高铁干线——兰新高铁、贵广铁路和南广铁路将开通运营。其中兰新高铁是中国首条高原高铁,全长1776公里,最高票价658元。贵广铁路最贵车票320元,南广铁路最贵车票206.5元,这两条线路大大缩短西南与各地的时空距离。出行更方便了!中国“高铁版图”再扩容 三条重要高铁今日开通
title:
生成的第1个标题为:中国“高铁版图”再扩容 三条重要高铁今日开通
生成的第2个标题为:贵广铁路最高铁版图
生成的第3个标题为:出行更方便了!中国“高铁版图”再扩容三条重要高铁今日开通
The decoding adopts the top_k and top_p decoding strategies, which have certain randomness and can be generated repeatedly.
Start the Flask service
python3 http_server.py
或
python3 http_server.py --http_id "0.0.0.0" --port 5555
For local testing, use "127.0.0.1:5555/news-title-generate" directly. If it is accessed by others, just replace "127.0.0.1" with the IP address of your computer.
Specifically as shown in the figure below:
future work
- In the later stage, news data sets such as Tsinghua University news data and Sogou news data may be sorted out to build a relatively complete news headline data set.
- Later, news data may be used to train a small GPT2 pre-training model.
- Later, the uploaded news headline model may be updated to train a model with better effect.
thank you
- Thanks to @JunkRoy for the web interface
reference
- GPT2-Chinese
- GPT2-chitchat
- CDial-GPT
- GPT2
- GPT2 news headline generation project with hyper-detailed annotations