CPM:A large-scale generative chinese pre-trained lanuage model

GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成) - GitHub - yangjianxin1/CPM: Easy-to-use CPM for Chinese text generation(基于CPM的中文文本生成)https://github.com/yangjianxin1/CPM论文《CPM: A Large-scale Generative Chinese Pre-trained Language Model》_陈欢伯的博客-CSDN博客1. IntroductionGPT-3含有175B参数使用了570GB的数据进行训练。但大多数语料是基于英文(93%),并且GPT-3的参数没有分布,所以提出了CPM(Chinese Pretrained language Model):包含2.6B参数,使用100GB中文训练数据。CPM可以对接下游任务:对话、文章生成、完形填空、语言理解。随着参数规模的增加,CPM在一些数据集上表现更好,表示大模型在语言生成和理解上面更有效。文章的主要贡献发布了一个CPM:2.6B参数,100GB中文训练https://blog.csdn.net/mark_technology/article/details/118680728文章本身写的非常简单,至于模型结构这块,可以看一下放出来的代码,还挺好用的,我跑一个电商场景的推荐文章生成模型,效果也不错。在生成模型上还是很建议尝试一下CPM,整体采用transformer中的代码实现,比较简洁。

中文版GPT-3来了?智源、清华发布清源 CPM——以中文为核心的大规模预训练模型

上面计算时间为使用单块NVIDIA V100 GPU训练的估计时间。

1.Approach

1.1 Chinese PLM(pretrained lanuage model)

上面是CPM的模型参数版本,其中small版本至少我是可以在gtx1080ti上训练的,后面我会添加我的具体训练参数。

稍微过一下CPM的模型结构,其实就是gpt2的模型:

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(30000, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=30000, bias=False)
)

Process finished with exit code 0
transformer.wte.weight   [30000, 768]
transformer.wpe.weight   [1024, 768]
transformer.h.0.ln_1.weight   [768]
transformer.h.0.ln_1.bias   [768]
transformer.h.0.attn.bias   [1, 1, 1024, 1024]
transformer.h.0.attn.masked_bias   []
transformer.h.0.attn.c_attn.weight   [768, 2304]
transformer.h.0.attn.c_attn.bias   [2304]
transformer.h.0.attn.c_proj.weight   [768, 768]
transformer.h.0.attn.c_proj.bias   [768]
transformer.h.0.ln_2.weight   [768]
transformer.h.0.ln_2.bias   [768]
transformer.h.0.mlp.c_fc.weight   [768, 3072]
transformer.h.0.mlp.c_fc.bias   [3072]
transformer.h.0.mlp.c_proj.weight   [3072, 768]
transformer.h.0.mlp.c_proj.bias   [768]
transformer.h.1.ln_1.weight   [768]
transformer.h.1.ln_1.bias   [768]
transformer.h.1.attn.bias   [1, 1, 1024, 1024]
transformer.h.1.attn.masked_bias   []
transformer.h.1.attn.c_attn.weight   [768, 2304]
transformer.h.1.attn.c_attn.bias   [2304]
transformer.h.1.attn.c_proj.weight   [768, 768]
transformer.h.1.attn.c_proj.bias   [768]
transformer.h.1.ln_2.weight   [768]
transformer.h.1.ln_2.bias   [768]
transformer.h.1.mlp.c_fc.weight   [768, 3072]
transformer.h.1.mlp.c_fc.bias   [3072]
transformer.h.1.mlp.c_proj.weight   [3072, 768]
transformer.h.1.mlp.c_proj.bias   [768]
transformer.h.2.ln_1.weight   [768]
transformer.h.2.ln_1.bias   [768]
transformer.h.2.attn.bias   [1, 1, 1024, 1024]
transformer.h.2.attn.masked_bias   []
transformer.h.2.attn.c_attn.weight   [768, 2304]
transformer.h.2.attn.c_attn.bias   [2304]
transformer.h.2.attn.c_proj.weight   [768, 768]
transformer.h.2.attn.c_proj.bias   [768]
transformer.h.2.ln_2.weight   [768]
transformer.h.2.ln_2.bias   [768]
transformer.h.2.mlp.c_fc.weight   [768, 3072]
transformer.h.2.mlp.c_fc.bias   [3072]
transformer.h.2.mlp.c_proj.weight   [3072, 768]
transformer.h.2.mlp.c_proj.bias   [768]
transformer.h.3.ln_1.weight   [768]
transformer.h.3.ln_1.bias   [768]
transformer.h.3.attn.bias   [1, 1, 1024, 1024]
transformer.h.3.attn.masked_bias   []
transformer.h.3.attn.c_attn.weight   [768, 2304]
transformer.h.3.attn.c_attn.bias   [2304]
transformer.h.3.attn.c_proj.weight   [768, 768]
transformer.h.3.attn.c_proj.bias   [768]
transformer.h.3.ln_2.weight   [768]
transformer.h.3.ln_2.bias   [768]
transformer.h.3.mlp.c_fc.weight   [768, 3072]
transformer.h.3.mlp.c_fc.bias   [3072]
transformer.h.3.mlp.c_proj.weight   [3072, 768]
transformer.h.3.mlp.c_proj.bias   [768]
transformer.h.4.ln_1.weight   [768]
transformer.h.4.ln_1.bias   [768]
transformer.h.4.attn.bias   [1, 1, 1024, 1024]
transformer.h.4.attn.masked_bias   []
transformer.h.4.attn.c_attn.weight   [768, 2304]
transformer.h.4.attn.c_attn.bias   [2304]
transformer.h.4.attn.c_proj.weight   [768, 768]
transformer.h.4.attn.c_proj.bias   [768]
transformer.h.4.ln_2.weight   [768]
transformer.h.4.ln_2.bias   [768]
transformer.h.4.mlp.c_fc.weight   [768, 3072]
transformer.h.4.mlp.c_fc.bias   [3072]
transformer.h.4.mlp.c_proj.weight   [3072, 768]
transformer.h.4.mlp.c_proj.bias   [768]
transformer.h.5.ln_1.weight   [768]
transformer.h.5.ln_1.bias   [768]
transformer.h.5.attn.bias   [1, 1, 1024, 1024]
transformer.h.5.attn.masked_bias   []
transformer.h.5.attn.c_attn.weight   [768, 2304]
transformer.h.5.attn.c_attn.bias   [2304]
transformer.h.5.attn.c_proj.weight   [768, 768]
transformer.h.5.attn.c_proj.bias   [768]
transformer.h.5.ln_2.weight   [768]
transformer.h.5.ln_2.bias   [768]
transformer.h.5.mlp.c_fc.weight   [768, 3072]
transformer.h.5.mlp.c_fc.bias   [3072]
transformer.h.5.mlp.c_proj.weight   [3072, 768]
transformer.h.5.mlp.c_proj.bias   [768]
transformer.h.6.ln_1.weight   [768]
transformer.h.6.ln_1.bias   [768]
transformer.h.6.attn.bias   [1, 1, 1024, 1024]
transformer.h.6.attn.masked_bias   []
transformer.h.6.attn.c_attn.weight   [768, 2304]
transformer.h.6.attn.c_attn.bias   [2304]
transformer.h.6.attn.c_proj.weight   [768, 768]
transformer.h.6.attn.c_proj.bias   [768]
transformer.h.6.ln_2.weight   [768]
transformer.h.6.ln_2.bias   [768]
transformer.h.6.mlp.c_fc.weight   [768, 3072]
transformer.h.6.mlp.c_fc.bias   [3072]
transformer.h.6.mlp.c_proj.weight   [3072, 768]
transformer.h.6.mlp.c_proj.bias   [768]
transformer.h.7.ln_1.weight   [768]
transformer.h.7.ln_1.bias   [768]
transformer.h.7.attn.bias   [1, 1, 1024, 1024]
transformer.h.7.attn.masked_bias   []
transformer.h.7.attn.c_attn.weight   [768, 2304]
transformer.h.7.attn.c_attn.bias   [2304]
transformer.h.7.attn.c_proj.weight   [768, 768]
transformer.h.7.attn.c_proj.bias   [768]
transformer.h.7.ln_2.weight   [768]
transformer.h.7.ln_2.bias   [768]
transformer.h.7.mlp.c_fc.weight   [768, 3072]
transformer.h.7.mlp.c_fc.bias   [3072]
transformer.h.7.mlp.c_proj.weight   [3072, 768]
transformer.h.7.mlp.c_proj.bias   [768]
transformer.h.8.ln_1.weight   [768]
transformer.h.8.ln_1.bias   [768]
transformer.h.8.attn.bias   [1, 1, 1024, 1024]
transformer.h.8.attn.masked_bias   []
transformer.h.8.attn.c_attn.weight   [768, 2304]
transformer.h.8.attn.c_attn.bias   [2304]
transformer.h.8.attn.c_proj.weight   [768, 768]
transformer.h.8.attn.c_proj.bias   [768]
transformer.h.8.ln_2.weight   [768]
transformer.h.8.ln_2.bias   [768]
transformer.h.8.mlp.c_fc.weight   [768, 3072]
transformer.h.8.mlp.c_fc.bias   [3072]
transformer.h.8.mlp.c_proj.weight   [3072, 768]
transformer.h.8.mlp.c_proj.bias   [768]
transformer.h.9.ln_1.weight   [768]
transformer.h.9.ln_1.bias   [768]
transformer.h.9.attn.bias   [1, 1, 1024, 1024]
transformer.h.9.attn.masked_bias   []
transformer.h.9.attn.c_attn.weight   [768, 2304]
transformer.h.9.attn.c_attn.bias   [2304]
transformer.h.9.attn.c_proj.weight   [768, 768]
transformer.h.9.attn.c_proj.bias   [768]
transformer.h.9.ln_2.weight   [768]
transformer.h.9.ln_2.bias   [768]
transformer.h.9.mlp.c_fc.weight   [768, 3072]
transformer.h.9.mlp.c_fc.bias   [3072]
transformer.h.9.mlp.c_proj.weight   [3072, 768]
transformer.h.9.mlp.c_proj.bias   [768]
transformer.h.10.ln_1.weight   [768]
transformer.h.10.ln_1.bias   [768]
transformer.h.10.attn.bias   [1, 1, 1024, 1024]
transformer.h.10.attn.masked_bias   []
transformer.h.10.attn.c_attn.weight   [768, 2304]
transformer.h.10.attn.c_attn.bias   [2304]
transformer.h.10.attn.c_proj.weight   [768, 768]
transformer.h.10.attn.c_proj.bias   [768]
transformer.h.10.ln_2.weight   [768]
transformer.h.10.ln_2.bias   [768]
transformer.h.10.mlp.c_fc.weight   [768, 3072]
transformer.h.10.mlp.c_fc.bias   [3072]
transformer.h.10.mlp.c_proj.weight   [3072, 768]
transformer.h.10.mlp.c_proj.bias   [768]
transformer.h.11.ln_1.weight   [768]
transformer.h.11.ln_1.bias   [768]
transformer.h.11.attn.bias   [1, 1, 1024, 1024]
transformer.h.11.attn.masked_bias   []
transformer.h.11.attn.c_attn.weight   [768, 2304]
transformer.h.11.attn.c_attn.bias   [2304]
transformer.h.11.attn.c_proj.weight   [768, 768]
transformer.h.11.attn.c_proj.bias   [768]
transformer.h.11.ln_2.weight   [768]
transformer.h.11.ln_2.bias   [768]
transformer.h.11.mlp.c_fc.weight   [768, 3072]
transformer.h.11.mlp.c_fc.bias   [3072]
transformer.h.11.mlp.c_proj.weight   [3072, 768]
transformer.h.11.mlp.c_proj.bias   [768]
transformer.ln_f.weight   [768]
transformer.ln_f.bias   [768]
lm_head.weight   [30000, 768]

1.2 data processing

CPM的词汇表有3w个。丰富的中文训练数据,中文数据其实比较好搞,直接网上爬就可以,git上作为提供了一个作文预训练的模型,在这个预训练模型上finetune效果也不错,我的训练数据大概有7-8w的标题-文本对数据。

1.3 pr-training details

 lr=1.5x10-4,batch_size=3072,max_len:1024(训练时,输入数据的最大长度),steps=2000(前500轮warmup),optimizer=adam,64*v100训了2周。

2x1080ti:cpm-small版本,max_len:200,lr=0.00015,batch_size:16,steps:100,adamw。

transformer=4.6.0

2.后面是cpm在一些任务上的实验。

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/126040727