GPT2 news headline generation project with hyper-detailed annotations

GPT2-NewsTitle

GPT2 news headline generation project with hyper-detailed annotations

UpDate 02.19.2022

run code

streamlit run app.py
or
streamlit run app.py --server.port your_port

Specifically as shown in the figure below:

UpDate 01.02.2021

  • Collect data from the Internet, organize and clean news data sets such as Tsinghua News Data, Sogou News Data, and some open-source summary data to build a relatively complete Chinese summary data set.
  • When the data set is cleaned, only simple rule cleaning is performed. For example: cleaning html tags, removing redundant empty characters, removing image tags, etc.
  • For details of the processed dataset, see the dataset description
data Raw Data/Project Address Processed file download address
Tsinghua News Data address Baidu cloud disk  extraction code: vhol
Sogou news data address Baidu cloud disk  extraction code: ode6
nlpcc2017 summary data address Baidu cloud disk  extraction code: e0zq
csl summary data address Baidu cloud disk  extraction code: 0qot
Education and training industry summary data address Baidu cloud disk  extraction code: kjz3
lcsts summary data address Baidu cloud disk  extraction code: bzov
Sensors Cup 2018 summary data address Baidu cloud disk  extraction code: 6f4f
Wanfang summary data address Baidu cloud disk  extraction code: p69g
WeChat official account summary data address Baidu cloud disk  extraction code: 5has
Weibo data address Baidu cloud disk  extraction code: 85t5
news2016zh news data address Baidu cloud disk  extraction code: qsj1

Data set collection: Baidu cloud disk  Extraction code: 7am8

project description

  • This project is a GPT2 model-based news headline generation project with super detailed Chinese annotations.
  • This project refers to multiple GPT2 open source projects such as GPT2-Chinese , GPT2-chitchat , CDial-GPT , GPT2, etc., and according to my own understanding, refactor the code and add detailed comments, hoping to help people in need.
  • This project uses HuggingFace's transformers to implement GPT2 model code writing, training and testing.
  • This project builds a web service through the Flask framework, engineering the news summary generation model, and can experience the news headline generation effect visually through the page.
  • The code of this project is explained in detail, you can read the code by yourself, or check the code comment introduction .
  • The news headline model provided by this project is a 6-layer small model (only small models can be trained), and in the process of training the model, the pre-trained GPT2 model is not loaded but the parameters are randomly initialized, and the number of training rounds is relatively small. Less (5 rounds, not yet converged), so the effect is mediocre. If you want a better model, you can train a model according to your personal needs.
  • The purpose of this project is to lead you through the entire process of training, testing and deploying the GPT2 generation model.

file structure

  • config
    • config.json The configuration information of the model, including n_ctx, n_embd, n_head, n_layer, etc.
  • vocab
    • vocab.txt dictionary file, the size of the dictionary is 13317, the "##Chinese" in the original dictionary is deleted, and "[Content]", "[Title]", "[Space]" and other tags are added.
  • data_dir The folder where the data is stored
  • templates Folder for storing html pages
  • data_helper.py data preprocessing file, simple cleaning of data
  • data_set.py data class file, which defines the data classes required by the model, which is convenient for model training
  • model.py GPT2 model file, mainly rewrites the GPT2LMHeadModel in the transformers package , modifies the calculation loss part, and only calculates the loss of the predicted title part
  • train.py The training file of the GPT2 model that generates news headlines from the news text
  • generate_title.py According to the trained model, generate news titles and predict files
  • http_server.py build web service file

operating environment

  • vent == 1.3a1
  • flask == 0.12.2
  • transformers == 3.0.2

See the requirements.txt file for details

data set

Data comes from Sina Weibo, data link: https://www.jianshu.com/p/8f52352f0748?tdsourcetag=s_pcqq_aiomsg

data description download link
Raw data Baidu network disk , extraction code: nqzi
processed data Baidu network disk , extraction code: duba

The original data is the news data downloaded directly from the Internet, and the processed data is the data processed using data_helper.py, which can be directly used for training.

Model parameters

See the config/config.json file for details

parameter value
initializer_range 0.02
layer_norm_epsilon 1e-05
n_ctx 512
n_embd 768
n_head 12
n_layer 6
n_positions 512
vocab_size 13317

Note: In addition to the vector representation of each word, the model input also includes text paragraph vector representation and position vector representation. 

Model file sharing

Model download link
GPT2 model Baidu network disk , extraction code: 165b

model training

python3 train.py
或
python3 train.py --output_dir output_dir/(自定义保存模型路径) 

The training parameters can be added by yourself, and the parameters are as follows:

parameter type Defaults describe
device str "0" Set the graphics card to use when training or testing
config_path str "config/config.json" Model parameter configuration information
vocab_path str "vocab/vocab.txt" Vocabulary, which is a small vocabulary and adds some new tokens
train_file_path str "data_dir/train_data.json" Training data generated from news headlines
test_file_path str "data_dir/test_data.json" Test data generated from news headlines
pretrained_model_path str None The path to the pretrained GPT2 model
data_dir str "data_dir" Generate the storage path of the cached data
num_train_epochs int 5 rounds of model training
train_batch_size int 16 The size of each batch during training
test_batch_size int 8 The size of each batch when testing
learning_rate float 1e-4 Learning rate during model training
warmup_proportion float 0.1 Warm up probability, that is, what percentage of the total training step size is used for warm up operation
adam_epsilon float 1e-8 The epsilon value of the Adam optimizer
logging_steps int 20 The number of steps to save the training log
eval_steps int 4000 During training, how many steps to perform a test
gradient_accumulation_steps int 1 gradient accumulation
max_grad_norm float 1.0
output_dir str "output_dir/" Model output path
seed int 2020 random seed
max_len int 512 The maximum length of the input model is smaller than n_ctx in config

Or modify the content of the set_args function in the train.py file to modify the default value.

The model provided by this project has been trained for 5 epochs. The model training loss and test set loss are as follows: 

In fact, the model has not been fully trained. According to the loss trend, it can continue to train.

model testing

python3 generate_title.py
或
python3 generate_title.py --top_k 3 --top_p 0.9999 --generate_max_len 32

Parameters can be added by themselves, including parameters as follows:

parameter type Defaults describe
device str "0" Set the graphics card to use when training or testing
model_path str "output_dir/checkpoint-139805" Model file path
vocab_path str "vocab/vocab.txt" Vocabulary, which is a small vocabulary and adds some new tokens
batch_size int 3 The number of generated titles
generate_max_len int 32 The maximum length of generated headers
repetition_penalty float 1.2 Repeat Penalty Rate
top_k int 5 How many tokens with the highest probability to keep when decoding
top_p float 0.95 When decoding, keep the mark that the probability accumulation is greater than
max_len int 512 The maximum length of the input model is smaller than n_ctx in config

The test results are as follows:

从测试集中抽一篇
content:
今日,中国三条重要高铁干线——兰新高铁、贵广铁路和南广铁路将开通运营。其中兰新高铁是中国首条高原高铁,全长1776公里,最高票价658元。贵广铁路最贵车票320元,南广铁路最贵车票206.5元,这两条线路大大缩短西南与各地的时空距离。出行更方便了!中国“高铁版图”再扩容 三条重要高铁今日开通
title:
生成的第1个标题为:中国“高铁版图”再扩容 三条重要高铁今日开通
生成的第2个标题为:贵广铁路最高铁版图
生成的第3个标题为:出行更方便了!中国“高铁版图”再扩容三条重要高铁今日开通


The decoding adopts the top_k and top_p decoding strategies, which have certain randomness and can be generated repeatedly.

Start the Flask service

python3 http_server.py
或
python3 http_server.py --http_id "0.0.0.0" --port 5555

For local testing, use "127.0.0.1:5555/news-title-generate" directly. If it is accessed by others, just replace "127.0.0.1" with the IP address of your computer.

Specifically as shown in the figure below:

future work

  •  In the later stage, news data sets such as Tsinghua University news data and Sogou news data may be sorted out to build a relatively complete news headline data set.
  •  Later, news data may be used to train a small GPT2 pre-training model.
  •  Later, the uploaded news headline model may be updated to train a model with better effect.

thank you

  • Thanks to @JunkRoy for the web interface

reference

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131735317