[Li Hongyi] Deep Learning——HW5-Machine Translation

Machine Translation

1. Goal

Given a piece of English, translate it into Traditional Chinese

2. Introduction

2.1 Dataset

training dataset

  • TED2020: TED talks with transcriptions translated by a global community of volunteers to more than 100 language.
  • we will use (en, zh-tw) aligned pairs.

Monolingual data

  • More TED talks in traditional chinese.

2.2. Evaluation

How do we evaluate the performance of our model?
Use BLUE
brevity penalty: penalizes short hypotheses
Insert image description here
c is the hypothesis length, r is the reference length
In other words, the sentences and labels translated by our model Compare the sentences in ,The more similar words between the two, the higher the accuracy of the translation

2.3. Workflow

Insert image description here
Preprocessing

  • download raw data
  • clean and normalize
  • remove bad data(too long/short)

Training

  • initialize a model
  • train it with training data

Testing

  • generate translation of data
  • evaluate the performance

2.4. Training tips

  • Tokenize data with sub_word units
    • For one, we can reduce the vocabulary size
    • For another, alleviate the open vocabulary problem
    • example: transportation => trans port ation
  • Lable smoothing regularization
    • When calculating loss, reserve some probability for incorrect labels
    • Avoids overfitting
  • Learning rate scheduling
    • Linearly increase le and then decay by inverse square root of steps
    • Stablilize training of transformers in early stage

2.5 Back-translation(BT)

It is very easy to obtain single-language data. For example, if you want Chinese data, you can directly climb down from the website, but not all English sentences can be translated into Chinese, so,Here we use the obtained Chinese (that is, the monolingual data in the data set) to translate it into English and make a BT to get another training data set. The data set becomes larger and the model responds accordingly. The more training you get, the better your performance will likely be. (But in the given data set, monolingual data=》test/test.zh are all periods)

3. Code

3.1 Data preprocessing

Dataset
Binocular Parallel Corpus 2020:
Original: 398066 (sentences)
Processed: 393980 (sentence)
Test data:
Size: 4000 sentences
The Chinese translation is not published, each line is "."< /span>

  1. Processed content
  • Download and unzip the file
  • rename file
# 下载档案并解压缩
data_dir = './DATA/rawdata'
dataset_name = 'ted2020'
urls = (
    '"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214989&authkey=AGgQ-DaR8eFSl1A"', 
    '"https://onedrive.live.com/download?cid=3E549F3B24B238B4&resid=3E549F3B24B238B4%214987&authkey=AA4qP_azsicwZZM"',
# # If the above links die, use the following instead. 
#     "https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/ted2020.tgz",
#     "https://www.csie.ntu.edu.tw/~r09922057/ML2021-hw5/test.tgz",
# # If the above links die, use the following instead. 
#     "https://mega.nz/#!vEcTCISJ!3Rw0eHTZWPpdHBTbQEqBDikDEdFPr7fI8WxaXK9yZ9U",
#     "https://mega.nz/#!zNcnGIoJ!oPJX9AvVVs11jc0SaK6vxP_lFUNTkEcK2WbxJpvjU5Y",
)
file_names = (
    'ted2020.tgz', # train & dev
    'test.tgz', # test
)
prefix = Path(data_dir).absolute() / dataset_name

prefix.mkdir(parents=True, exist_ok=True)
for u, f in zip(urls, file_names):
    path = prefix/f
    if not path.exists():
        if 'mega' in u:
            !megadl {
    
    u} --path {
    
    path}
        else:
            !wget {
    
    u} -O {
    
    path}
    if path.suffix == ".tgz":
        !tar -xvf {
    
    path} -C {
    
    prefix}
    elif path.suffix == ".zip":
        !unzip -o {
    
    path} -d {
    
    prefix}
# 重命名文件,加上前缀train_dev/test
!mv {
    
    prefix/'raw.en'} {
    
    prefix/'train_dev.raw.en'}
!mv {
    
    prefix/'raw.zh'} {
    
    prefix/'train_dev.raw.zh'}
!mv {
    
    prefix/'test.en'} {
    
    prefix/'test.raw.en'}
!mv {
    
    prefix/'test.zh'} {
    
    prefix/'test.raw.zh'}

#设定语言
src_lang = 'en'
tgt_lang = 'zh'

data_prefix = f'{
      
      prefix}/train_dev.raw'
test_prefix = f'{
      
      prefix}/test.raw'

!head {
    
    data_prefix+'.'+src_lang} -n 5
!head {
    
    data_prefix+'.'+tgt_lang} -n 5

Thank you so much, Chris.
And it's truly a great honor to have the opportunity to come to this stage twice; I'm extremely grateful.
I have been blown away by this conference, and I want to thank all of you for the many nice comments about what I had to say the other night.
And I say that sincerely, partly because I need that.
Put yourselves in my position.
Thank you very much, Chris. It is a great honor to have this opportunity to step onto this podium for the second time
. I am grateful.
I was extremely impressed by this seminar and I would like to thank everyone for their favorable comments on my previous presentations.
I sincerely want to say this, partly because - I really need it!
Please put yourself in my shoes!

  1. Processed content
  • Convert string from full shape to half shape
  • Separate the special characters and content of the string with ‘ ‘
  • Remove or replace some special characters
# 去掉或者替换掉一些特殊字符
def clean_s(s, lang):
    if lang == 'en':
        s = re.sub(r"\([^()]*\)", "", s)  # remove ([text])
        s = s.replace('-', '')  # remove '-'
        s = re.sub('([.,;!?()\"])', r' \1 ', s)  # keep punctuation
    elif lang == 'zh':
        s = strQ2B(s)  # Q2B
        s = re.sub(r"\([^()]*\)", "", s)  # remove ([text])
        s = s.replace(' ', '')
        s = s.replace('—', '')
        s = s.replace('“', '"')
        s = s.replace('”', '"')
        s = s.replace('_', '')
        s = re.sub('([。,;!?()\"~「」])', r' \1 ', s)  # keep punctuation
    s = ' '.join(s.strip().split()) # 将字符串以空格分割
    return s


def len_s(s, lang):
    if lang == 'zh':
        return len(s)
    return len(s.split())


# clean后文件名称是:train_dev.raw.clean.en, train_dev.raw.clean.zh, text.raw.clean.en. test.raw.clean.zh
def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):
    if Path(f'{
      
      prefix}.clean.{
      
      l1}').exists() and Path(f'{
      
      prefix}.clean.{
      
      l2}').exists():
        print(f'{
      
      prefix}.clean.{
      
      l1} & {
      
      l2} exists. skipping clean.')
        return
    with open(f'{
      
      prefix}.{
      
      l1}', 'r', encoding='utf-8') as l1_in_f:
        with open(f'{
      
      prefix}.{
      
      l2}', 'r', encoding='utf-8') as l2_in_f:
            with open(f'{
      
      prefix}.clean.{
      
      l1}', 'w', encoding='utf-8') as l1_out_f:
                with open(f'{
      
      prefix}.clean.{
      
      l2}', 'w', encoding='utf-8') as l2_out_f:
                    for s1 in l1_in_f:
                        s1 = s1.strip()
                        s2 = l2_in_f.readline().strip()
                        s1 = clean_s(s1, l1)
                        s2 = clean_s(s2, l2)
                        s1_len = len_s(s1, l1)
                        s2_len = len_s(s2, l2)
                        if min_len > 0:  # remove short sentence
                            if s1_len < min_len or s2_len < min_len:
                                continue
                        if max_len > 0:  # remove long sentence
                            if s1_len > max_len or s2_len > max_len:
                                continue
                        if ratio > 0:  # remove by ratio of length
                            if s1_len / s2_len > ratio or s2_len / s1_len > ratio:
                                continue
                        print(s1, file=l1_out_f)
                        print(s2, file=l2_out_f)


clean_corpus(data_prefix, src_lang, tgt_lang)
clean_corpus(test_prefix, src_lang, tgt_lang, ratio=-1, min_len=-1, max_len=-1)

Thank you so much, Chris .
And it's truly a great honor to have the opportunity to come to this stage twice; I'm extremely grateful .
I have been blown away by this conference , and I want to thank all of you for the many nice comments about what I had to say the other night .
And I say that sincerely , partly because I need that .
Put yourselves in my position .
Thank you very much, Chris. It is really a great honor to have this opportunity to step onto this podium for the second time
. I am grateful .
This seminar left a very deep impression on me, and I would like to thank everyone for their favorable comments on my previous lectures.
I really want to say this, partly because I really need it!
Please put yourself in my shoes!

3.2 Divide training set and validation set

The verification set only requires 3000~4000 sentences. Divide the training set processed above and place it in the train.clean.en/.zh and valid.clean.zn/.zh files respectively.

# 划分训练集和验证集
# 验证集300~400句即可
valid_ratio = 0.01
train_ratio = 1 - valid_ratio
data_dir = './data'
dataset_name = 'prefix'

# 最后划分为训练集和验证集 文件名称分别为 train.clean.en train.clean.zh valid.clean.en valid.clean.zh
if Path(f'{
      
      prefix}/train.clean.{
      
      src_lang}').exists() \
and Path(f'{
      
      prefix}/train.clean.{
      
      tgt_lang}').exists() \
and Path(f'{
      
      prefix}/valid.clean.{
      
      src_lang}').exists() \
and Path(f'{
      
      prefix}/valid.clean.{
      
      tgt_lang}').exists():
    print(f'train/valid splits exists. skipping split.')
else:
    line_num = sum(1 for line in open(f'{
      
      data_prefix}.clean.{
      
      src_lang}', encoding='utf-8'))
    labels = list(range(line_num))
    random.shuffle(labels)
    for lang in [src_lang, tgt_lang]:
        train_f = open(os.path.join(data_dir, dataset_name, f'train.clean.{
      
      lang}'), 'w', encoding='utf-8')
        valid_f = open(os.path.join(data_dir, dataset_name, f'valid.clean.{
      
      lang}'), 'w', encoding='utf-8')
        count = 0
        for line in open(f'{
      
      data_prefix}.clean.{
      
      lang}', 'r', encoding='utf-8'):
            if labels[count]/line_num < train_ratio:
                train_f.write(line)
            else:
                valid_f.write(line)
            count += 1
        train_f.close()
        valid_f.close()

The final result of the division is as follows
Insert image description here

3.3 Subword Units

A big problem in translation is out of vocabulary, which can be solved by using subword units as short word units.

  • Use sentencepieceMantle
  • forunigram or byte-pair encoding(BPE)
# Subword Units
# 分词
# 使用sentencepiece中的spm对训练集和验证集进行分词建模,模型名称是spm8000.model,同时产生词汇库spm8000.vocab
# 使用模型对训练集、验证集、测试集进行分词处理,得到文件train.en, train.zh, valid.en, valid.zh, test.en, test.zh
import sentencepiece as spm
vocab_size = 8000
if Path(f'{
      
      prefix}/spm{
      
      vocab_size}.model').exists():
    print(f'{
      
      prefix}/spm{
      
      vocab_size},model exits. skipping spm_train')
else:
    spm.SentencePieceTrainer.train(
        input=','.join([f'{
      
      prefix}/train.clean.{
      
      src_lang}',
                        f'{
      
      prefix}/valid.clean.{
      
      src_lang}',
                        f'{
      
      prefix}/train.clean.{
      
      tgt_lang}',
                        f'{
      
      prefix}/valid.clean.{
      
      tgt_lang}']),
        model_prefix=f'{
      
      prefix}/spm{
      
      vocab_size}',
        vocab_size=vocab_size,
        character_coverage=1,
        model_type='unigram', # 'bpe' 也可
        input_sentence_size=1e6,
        shuffle_input_sentence=True,
        normalization_rule_name='nmt_nfkc_cf',
    )
spm_model = spm.SentencePieceProcessor(model_file=str(f'{
      
      prefix}/spm{
      
      vocab_size}.model'))
in_tag = {
    
    
    'train': 'train.clean',
    'valid': 'valid.clean',
    'test': 'test.raw.clean',
}
for split in ['train', 'valid', 'test']:
    for lang in [src_lang, tgt_lang]:
        out_path = Path(f'{
      
      prefix}/{
      
      split}.{
      
      lang}')
        if out_path.exists():
            print(f"{
      
      out_path} exists. skipping spm_encode.")
        else:
            with open(f'{
      
      prefix}/{
      
      split}.{
      
      lang}', 'w', encoding='utf-8') as out_f:
                with open(f'{
      
      prefix}/{
      
      in_tag[split]}.{
      
      lang}', 'r', encoding='utf-8') as in_f:
                    for line in in_f:
                        line = line.strip()
                        tok = spm_model.encode(line, out_type=str)
                        print(' '.join(tok), file=out_f)

Part of the vocabulary spm8000.vocab obtained after word segmentation is as follows:
Insert image description here
Train.en and the corresponding train.zh file content after word segmentation processing:

thank you so much , chris .
and it ' s t ru ly a great ho n or to have the op port un ity to come to this st age t wi ce ; i ' m ex t re me ly gr ate ful .
i have been bl ow n away by this con fer ence , and i want to thank all of you for the many ni ce com ment s about what i had to say the other night .
and i say that since re ly , part ly because i need that .
put your s el ve s in my po sition .
Thank you very much, g Rees . It is really a great honor to have this opportunity to step onto this podium for the second time. I'm very grateful. This seminar left a very deep impression on me. I would like to thank everyone for their favorable comments on my previous speeches. I sincerely want to say this, partly because I really need it! Please put yourself in my shoes. Think!



3.4 Use fairseq to convert data into binary

The following program can be run in jupyter, or in the python interpreter

# 使用fairseq将数据二进制化 最终生成的文件在目录./data/data_bin下
binpath = Path('./data/data-bin')
if binpath.exists():
    print(binpath, "exists, will not overwrite!")
else:
    !python -m fairseq_cli.preprocess \
        --source-lang en\
        --target-lang zh\
        --trainpref ./data/prefix/train\
        --validpref ./data/prefix/valid\
        --testpref ./data/prefix/test\
        --destdir ./data/data_bin\
        --joined-dictionary\
        --workers 2

A series of files are generated in the data_bin directory

4. Experimental preparation

4.1 Experimental parameter setting

config = Namespace(
    datadir = "./data/data_bin",
    savedir = "./checkpoints/rnn",
    source_lang = "en",
    target_lang = "zh",
    
    # cpu threads when fetching & processing data.
    num_workers=2,  
    # batch size in terms of tokens. gradient accumulation increases the effective batchsize.
    max_tokens=8192,
    accum_steps=2,
    
    # the lr s calculated from Noam lr scheduler. you can tune the maximum lr by this factor.
    lr_factor=2.,
    lr_warmup=4000,
    
    # clipping gradient norm helps alleviate gradient exploding
    clip_norm=1.0,
    
    # maximum epochs for training
    max_epoch=30,
    start_epoch=1,
    
    # beam size for beam search
    beam=5, 
    # generate sequences of maximum length ax + b, where x is the source length
    max_len_a=1.2, 
    max_len_b=10,
    # when decoding, post process sentence by removing sentencepiece symbols.
    post_process = "sentencepiece",
    
    # checkpoints
    keep_last_epochs=5,
    resume=None, # if resume from checkpoint name (under config.savedir)
    
    # logging
    use_wandb=False,
)

4.2 logging

# logging套件记录一般讯息 wandb记录训练过程的loss, blue, model, weight等
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO", # "DEBUG" "WARNING" "ERROR"
    stream=sys.stdout,
)
proj = "hw5.seq2seq"
logger = logging.getLogger(proj)
if config.use_wandb:
    import wandb
    wandb.init(project=proj, name=Path(config.savedir).stem, config=config)

4.3 cuda environment

cuda_env = utils.CudaEnvironment()
utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

4.4 Reading the data set

Borrowed from fairsq’s TranslationTask

  • Used to load the binary data created above
  • Implement good data iterator (dataloader)
  • The dictionaries task.source_directionary and task.targrt_directionary are also useful
  • There is a practical implementation of beam search
from fairseq.tasks.translation import TranslationConfig, TranslationTask

## setup task
task_cfg = TranslationConfig(
    data=config.datadir,
    source_lang=config.source_lang,
    target_lang=config.target_lang,
    train_subset="train",
    required_seq_len_multiple=8,
    dataset_impl="mmap",
    upsample_primary=1,
)
task = TranslationTask.setup_task(task_cfg)
logger.info("loading data for epoch 1")
task.load_dataset(split="train", epoch=1, combine=True) # combine if you have back-translation data.
task.load_dataset(split="valid", epoch=1)

sample = task.dataset("valid")[1]
pprint.pprint(sample)
pprint.pprint(
    "Source: " + \
    task.source_dictionary.string(
        sample['source'],
        config.post_process,
    )
)
pprint.pprint(
    "Target: " + \
    task.target_dictionary.string(
        sample['target'],
        config.post_process,
    )
)

operation result:

{'id': 1,
'source': tensor([ 18, 14, 6, 2234, 60, 19, 80, 5, 256, 16, 405, 1407,
1706, 7, 2]),
'target': tensor([ 140, 690, 28, 270, 45, 151, 1142, 660 , 606, 369, 3114, 2434,
1434, 192, 2])}
“Source: that's exactly what i do optical mind control.” a>
'Target: This is actually what I do – optical manipulation ideas'

4.5 Dataset iterator

  • Control each batch to N tokens to make GPU memory more efficiently utilized
  • Let each training setepoch have different shuffling
  • Filter out sentences that are too long
  • Make the sentences** in each batch**padthe same length** so that the GPU can perform parallel operations
  • 加上eos并shift一格
    • teacher forcing: In order to train the model to generate the next word according to the prefix, the decoder's input will be the output target sequence shifted one space to the right.
    • Usually will be added at the beginning of the input bos token (as shown below)
      Insert image description here

fairseq directly moves eos to beginning, and the training effect is almost the same. For example:
Output target (target) and Decoder input (prev_output_tokens): eos = 2 target = 419, 711, 238, 888, 792, 60, 968, 8, 2 prev_output_tokens = 2, 419 , 711, 238, 888, 792, 60, 968, 8

def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True):
    batch_iterator = task.get_batch_iterator(
        dataset=task.dataset(split),
        max_tokens=max_tokens,
        max_sentences=None,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            max_tokens,
        ),
        ignore_invalid_inputs=True,
        seed=seed,
        num_workers=num_workers,
        epoch=epoch,
        disable_iterator_cache=not cached,
        # Set this to False to speed up. However, if set to False, changing max_tokens beyond
        # first call of this method has no effect.
    )
    return batch_iterator

if __name__=='__main__':
    demo_epoch_obj = load_data_iterator(task, "valid", epoch=1, max_tokens=20, num_workers=1, cached=False)
    demo_iter = demo_epoch_obj.next_epoch_itr(shuffle=True)
    sample = next(demo_iter)
    print(sample)

The output information and explanation are as follows:

{'id': tensor([723]), # The id of each example
'nsentences': 1, # batch size Number of sentences
'ntokens': 18, # batch size word count
'net_input': { 'src_tokens': tensor([[ 1, 1, 1, 1, 1, 18, 26, 82, 8, 480, 15, 651, 1361, 38, 6, 176, 2696, 39, 5, 822, 92, 260, 7, 2]]), # Sequence of source language 'src_lengths': tensor([19] ), # The length of each sentence without pad 'prev_output_tokens': tensor([[ 2, 140, 296, 318, 1560, 51, 568, 316, 225, 1952, 254, 78, # The target sequence after one shift mentioned above 151, 2691, 9, 215, 1680, 10, 1, 1, 1, 1, 1, 1]]) } 2691, 9, 215, 1680, 10, 2, 1, 1, 1, 1, 1, 1]]) # Target sequence 'target': tensor([[ 140, 296, 318, 1560, 51, 568, 316, 225, 1952, 254, 78, 151, },








5. Define model architecture

  • Inheritfairseq's Encoder, decoder, and model, so that the beam search function written by it can be directly used during the test phase

5.1 Encoder

The encoder of seq2seq model is RNN or Transformer Encoder. The following explanation takes RNN as an example.
Corresponding to each input, Encoder will output a vector and a hidden state, and will use the hidden state Enter at next. In other words, the Encoder reads the input sequence step by step and outputs a single vector at each timestep, and the hidden state (content vector) at the last timestep.

Let’s explain the GRU in this experiment

This experiment usesGRU. The input and output parameters of GRU are as follows:
There are two input parameters, respectively input and h_0.
Inputs: input, h_0
input的shape
The shape of input: (seq_len, batch, input_size)
< a i=9> : tensor containing the feature of the input sequence. The input can also be a packed variable length sequence. See functorch.nn.utils.rnn.pack_padded_sequencefor details. ②h_0的shape
It can also be seen from the explanation below that this parameter does not need to be provided, then it defaults is 0.
The shape of h_0:(num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional num_directions should be 2, else it should be 1.
There are two outputs, respectivelyThe shape of h_n is: (num_layers * num_directions, batch, hidden_size) Similarly, the directions can be separated in the packed case. If a class:torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using output.view(seq_len, batch, num_directions, hidden_size), with forward and backward being direction 0 and 1 respectively.: tensor containing the output features h_t from the last layer of the GRU, for each t.The shape of output is: (seq_len, batch, num_directions * hidden_size)output and Like output, the layers can be separated using h_n.view(num_layers, num_directions, batch , hidden_size).When bidirectional = True, h_n will contain the concatenation of the final forward and reverse hidden states respectively. It is roughly as shown in the figure below: h_n
output



h_n





Insert image description here

# 定义模型架构
# 使用fairsq的Encoder,decoder and model
class RNNEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        '''
        :param args:
            encoder_embed_dim 是embedding的维度,主要将one-hot vect的单词向量压缩到指定的维度
            encoder_ffn_embed_dim 是RNN输出和隐藏状态的维度(hidden dimension)
            encoder_layers 是RNN要叠多少层
            dropout 是决定有大欧少的几率会将某个节点变为0,主要是为了防止overfitting,一般来说训练时用
        :param dictionary: fairseq帮我们做好的dictionary 再次用来得到padding index,好用来得到encoder padding mask
        :param embed_tokens: 事先做好的词嵌入(nn.Embedding)
        '''
        super().__init__(dictionary)
        self.embed_tokens = embed_tokens
        self.embed_dim = args.encoder_embed_dim
        self.hidden_dim = args.encoder_ffn_embed_dim
        self.num_layers = args.encoder_layers

        self.dropout_in_module = nn.Dropout(args.dropout)
        self.rnn = nn.GRU(
            self.embed_dim,
            self.hidden_dim,
            self.num_layers,
            dropout=args.dropout,
            batch_first=False,
            bidirectional=True,
        )
        self.dropout_out_module = nn.Dropout(args.dropout)

        self.padding_idx = dictionary.pad()

    def combine_bidir(self, outs, bsz:int):
        out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
        return out.view(self.num_layers, bsz, -1)

    def forward(self, src_tokens, **unused):
        '''
        :param src_tokens: 英文的整数序列
        :param unused:
        :return:
            outputs: 最上层RNN每个timestep的输出,最后可以用Attention再进行处理
            final_hiddens: 每层最终timestep的隐藏状态,将传递到Decoder进行解码
            encoder_padding_mask: 告诉我们那些事位置的资讯不重要
        '''
        bsz, seqlen = src_tokens.size()

        # get embeddings
        x = self.embed_tokens(src_tokens)
        x = self.dropout_in_module(x)

        # B x T x C => T x B x C
        x = x.transpose(0, 1)

        # 过双向RNN
        h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
        x, final_hiddens = self.rnn(x, h0)
        outputs = self.dropout_out_module(x)
        # outputs = [sequence len, batch size, hid dim * directions] 是最上面RNN的输出
        # hidden =  [num_layers * directions, batch size  , hid dim]

        # 因为Encoder是双向的,我们需要链接两个方向的隐藏状态
        final_hiddens = self.combine_bidir(final_hiddens, bsz)
        # hidden =  [num_layers x batch x num_directions*hidden]

        encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
        return tuple(
            (
                outputs,  # seq_len x batch x hidden
                final_hiddens,  # num_layers x batch x num_directions*hidden
                encoder_padding_mask,  # seq_len x batch
            )
        )

    def reorder_encoder_out(self, encoder_out, new_order):
        return tuple(
            (
                encoder_out[0].index_select(1, new_order),
                encoder_out[1].index_select(1, new_order),
                encoder_out[2].index_select(1, new_order),
            )
        )

5.2 Attention

  • When the input is too long, or the meaning of the entire input cannot be obtained by relying on "content vector" alone, use Attention Mechanism to provide the Decoder with more information
  • Based on the current situationDecoder embeddings, calculate those that have a higher relationship with Encoder outputs >, based on the value of the relationship average the Encoder outputs as the input of the Decoder RNN
  • The common function of Attention is to use Neural Network / Dot Product to calculate the relationship between query (decoder embedding) and key (Encoder outputs). Then do softmax on all the calculated values ​​to get the distribution, and finally do values (Encoder outputs) based on this distributionweight sum
# Attention
class AttentionLayer(nn.Module):
    def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):
        '''
        :param input_embed_dim: key 的维度,应是 decoder 要做 attend 时的向量的维度
        :param source_embed_dim: query 的维度,应是要被 attend 的向量(encoder outputs)的维度
        :param output_embed_dim: value 的维度,应是做完 attention 后,下一层预期的向量维度
        :param bias:
        '''
        super().__init__()
        self.input_proj = nn.Linear(input_embed_dim, source_embed_dim, bias=bias)
        self.output_proj = nn.Linear(
            input_embed_dim + source_embed_dim, output_embed_dim, bias=bias
        )

    def forward(self, inputs, encoder_outputs, encoder_padding_mask):
        '''
        :param inputs: 就是key,要attend别人的向量
        :param encoder_outputs: 是query/value,被attend的向量
        :param encoder_padding_mask: 告诉我们哪些是位置的资讯不重要
        :return:
            output: 做完attention后的context vector
            attention score: attention的分布
        '''
        # inputs: T, B, dim
        # encoder_outputs: S x B x dim
        # padding mask: S x B

        # convert all to batch first
        inputs = inputs.transpose(1, 0) # B, T, dim
        encoder_outputs = encoder_outputs.transpose(1, 0) #B, S, dim
        encoder_padding_mask = encoder_padding_mask.transpose(1, 0) # B, S

        # 投影到encoder_outputs的维度
        # (B, T, dm) x (B, dim, S) = (B, T, S)
        attn_scores = torch.bmm(x, encoder_outputs.transpose(1, 2))

        # 挡住padding位置的attention
        if encoder_padding_mask is not None:
            # 利用broadcast B, S -> (B, 1, S)
            encoder_padding_mask = encoder_padding_mask.unsqueeze(1)
            attn_scores = (
                attn_scores.float()
                .masked_dill_(encoder_padding_mask, float("-inf"))# 用来mask掉当前时刻后面时刻的序列信息
                .type_as(attn_scores)# 按照给定的tensor进行类型转换
            )

        # 在source对应维度softmax
        attn_scores = F.softmax(attn_scores, dim=-1)

        # 形状(B, T, S) x (B, S, dim) = (B, T, dim)加权平均
        x = torch.bmm(attn_scores, encoder_outputs)

        # (B, T, dim)
        x = torch.cat((x, inputs), dim=-1)
        x = torch.tanh(self.output_proj(x)) # output + linear + tanh

        # 回复形状(B, T, dim) -> (T, B, dim)
        return x.transpose(1, 0), attn_scores

5.3 Decoder decoder

  • The hidden states of the decoder are initialized with the final hidden state of the encoder.
  • The decoder also changes hidden states based on the input of the current timestep (that is, the output of the previous timestep) and outputs the result.
  • If attention is added, it can perform better
  • We write the seq2seq steps in the decoder so that the seq2seq class can be used with RNN and Transformer without having to rewrite it.
# Decoder
class RNNDecoder(FairseqIncrementalDecoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.embed_tokens = embed_tokens

        assert args.decoder_layers == args.encoder_layers, f"""seq2seq rnn requires that encoder 
                and decoder have same layers of rnn. got: {
      
      args.encoder_layers, args.decoder_layers}"""
        assert args.decoder_ffn_embed_dim == args.encoder_ffn_embed_dim * 2, f"""seq2seq-rnn requires 
                that decoder hidden to be 2*encoder hidden dim. got: {
      
      args.decoder_ffn_embed_dim, args.encoder_ffn_embed_dim * 2}"""

        self.embed_dim = args.decoder_embed_dim
        self.hidden_dim = args.decoder_ffn_embed_dim
        self.num_layers = args.decoder_layers

        self.dropout_in_module = nn.Dropout(args.dropout)
        self.rnn = nn.GRU(
            self.embed_dim,
            self.hidden_dim,
            self.num_layers,
            dropout=args.dropout,
            batch_first=False,
            bidirectional=False,
        )
        self.attention = AttentionLayer(
            self.embed_dim, self.hidden_dim, self.embed_dim, bias=False
        )
        # self.attention = None
        self.dropout_out_module = nn.Dropout(args.dropout)

        if self.hidden_dim != self.embed_dim:
            self.project_out_dim = nn.Linear(self.hidden_dim, self.embed_dim)
        else:
            self.project_out_dim = None

        if args.share_decoder_input_output_embed:
            self.output_projection = nn.Linear(
                self.embed_tokens.weight.shape[1],
                self.embed_tokens.weight.shape[0],
                bias=False,
            )
            self.output_projection.weight = self.embed_tokens.weight
        else:
            self.output_projection = nn.Linear(
                self.output_embed_dim, len(dictionary), bias=False
            )
            nn.init.normal_(
                self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
            )

    def forward(self, prev_output_tokens, encoder_out, incremental_state=None, **unused):
        # 取出encoder的输出
        encoder_outputs, encoder_hiddens, encoder_padding_mask = encoder_out
        # outputs:          seq_len x batch x num_directions*hidden
        # encoder_hiddens:  num_layers x batch x num_directions*encoder_hidden
        # padding_mask:     seq_len x batch

        if incremental_state is not None and len(incremental_state
                                                 )>0:
            # 如果保留了上一个timestep留下的资讯,我们可以从那里进来,而不是从bos开始
            prev_output_tokens = prev_output_tokens[:, -1:]
            cache_state = self.get_incremental_state(incremental_state, "cache_state")
            prev_hiddens = cache_state["prev_hiddens"]
        else:
            # 沒有incremental state代表这是training或者是test time时的第一步
            # 准备seq2seq: 把encoder_hiddens pass进去decoder的hidden states
            prev_hiddens = encoder_hiddens

        bsz, seqlen = prev_output_tokens.size()

        # embed tokens
        x = self.embed_tokens(prev_output_tokens)
        x = self.dropout_in_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # 做decoder-to-encoder attention
        if self.attention is not None:
            x, attn = self.attention(x, encoder_outputs, encoder_padding_mask)

        # 过单向RNN
        x, final_hiddens = self.rnn(x, prev_hiddens)
        # outputs = [sequence len, batch size, hid dim]
        # hidden =  [num_layers * directions, batch size  , hid dim]
        x = self.dropout_out_module(x)

        # 投影到embedding size (如果hidden 和embed size不一样,然后share_embedding又变成True,需要额外project一次)
        if self.project_out_dim != None:
            x = self.project_out_dim(x)

        # 投影到vocab size 的分布
        x = self.output_projection(x)

        # T x B x C -> B x T x C
        x = x.transpose(1, 0)

        # 如果是Incremental, 记录这个timestep的hidden states, 下个timestep读回来
        cache_state = {
    
    
            "prev_hiddens": final_hiddens,
        }
        self.set_incremental_state(incremental_state, "cached_state", cache_state)

        return x, None

    def reorder_incremental_state(
            self,
            incremental_state,
            new_order,
    ):
        # 这个beam search时会用到,意思并不是很重要
        cache_state = self.get_incremental_state(incremental_state, "cached_state")
        prev_hiddens = cache_state["prev_hiddens"]
        prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]
        cache_state = {
    
    
            "prev_hiddens": torch.stack(prev_hiddens),
        }
        self.set_incremental_state(incremental_state, "cached_state", cache_state)
        return

5.4Seq2Seq

  • YuEncoderJapaneseDecoderComposition
  • Receive input and pass it toEncoder
  • pass the output of Encoder toDecoder
  • DecoderDecode based on the output of the previous timestep and Encoder output
  • When decoding is completed, return the output of Decoder
# Seq2Seq
class Seq2Seq(FairseqEncoderDecoderModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(encoder, decoder)
        self.args = args

    def forward(self, src_tokens, src_lengths, prev_output_tikens, return_all_hiddens: bool = True):
        encoder_out = self.encoder(
            src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
        )
        logits, extra = self.decoder(
            prev_output_tikens,
            encoder_out=encoder_out,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
        )
        return logits, extra

5.5 Model initialization

# 模型初始化
def build_model(args, task):
    src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

    # 词嵌入
    encoder_embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, src_dict.pad())
    decoder_embed_tokens = nn.Embedding(len(tgt_dict), args.decoder_embed_dim, tgt_dict.pad())

    # 编码器和解码器
    encoder = RNNEncoder(args, src_dict, encoder_embed_tokens)
    decoder = RNNDecoder(args, tgt_dict, decoder_embed_tokens)

    # 序列到序列模型
    model = Seq2Seq(args, encoder, decoder)

    # 序列到序列模型的初始化很重要 需要特别处理
    def init_params(module):
        from fairseq.modules import MultiheadAttention
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        if isinstance(module, MultiheadAttention):
            module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
            module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
            module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.RNNBase):
            for name, param in module.named_parameters():
                if "weight" in name or "bias" in name:
                    param.data.uniform_(-0.1, 0.1)

    # 初始化模型
    model.apply(init_params)
    return model

5.6 Set model related parameters

arch_args = Namespace(
    encoder_embed_dim=256,
    encoder_ffn_embed_dim=512,
    encoder_layers=1,
    decoder_embed_dim=256,
    decoder_ffn_embed_dim=1024,
    decoder_layers=1,
    share_decoder_input_output_embed=True,
    dropout=0.3,
)

model = build_model(arch_args, task)
logger.info(model)

Seq2Seq(
(encoder): RNNEncoder(
(embed_tokens): Embedding(8000, 256, padding_idx=1)
(dropout_in_module): Dropout(p=0.3, inplace=False)
(rnn): GRU(256, 512, dropout=0.3, bidirectional=True)
(dropout_out_module): Dropout(p=0.3, inplace=False)
)
(decoder): RNNDecoder(
(embed_tokens): Embedding(8000, 256, padding_idx=1)
(dropout_in_module): Dropout(p=0.3, inplace=False)
(rnn): GRU(256, 1024, dropout=0.3)
(attention): AttentionLayer(
(input_proj): Linear(in_features=256, out_features=1024, bias=False)
(output_proj): Linear(in_features=1280, out_features=256, bias=False)
)
(dropout_out_module): Dropout(p=0.3, inplace=False)
(project_out_dim): Linear(in_features=1024, out_features=256, bias=True)
(output_projection): Linear(in_features=256, out_features=8000, bias=False)
)
)

5.7 Optimization Optimization

Loss : Label Smoothing Regularization

  • Let the model learn to output a less concentrated distribution to prevent the model from being overconfident.
  • Sometimes the round truth is not the only answer, so when calculating loss, we will reserve a part of the probability for labels other than the correct answer.
  • Can effectively prevent overfitting
class LabelSmoothedCrossEntropyCriterion(nn.Module):
    def __init__(self, smoothing, ignore_index=None, reduce=True):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.reduce = reduce
    
    def forward(self, lprobs, target):
        if target.dim() == lprobs.dim() - 1:
            target = target.unsqueeze(-1)
        # nll: Negative log likelihood,當目標是one-hot時的cross-entropy loss. 以下同 F.nll_loss
        nll_loss = -lprobs.gather(dim=-1, index=target)
        # 將一部分正確答案的機率分配給其他label 所以當計算cross-entropy時等於把所有label的log prob加起來
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
        if self.ignore_index is not None:
            pad_mask = target.eq(self.ignore_index)
            nll_loss.masked_fill_(pad_mask, 0.0)
            smooth_loss.masked_fill_(pad_mask, 0.0)
        else:
            nll_loss = nll_loss.squeeze(-1)
            smooth_loss = smooth_loss.squeeze(-1)
        if self.reduce:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()
        # 計算cross-entropy時 加入分配給其他label的loss
        eps_i = self.smoothing / lprobs.size(-1)
        loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss
        return loss

# 一般都用0.1效果就很好了
criterion = LabelSmoothedCrossEntropyCriterion(
    smoothing=0.1,
    ignore_index=task.target_dictionary.pad(),
)

5.8 Optimizer: Adam + lr scheduling

Inverse square root scheduling is important for stability when training Transformer, and was later used on RNN. The learning rate is updated according to the following formula, linearly increases in the early stage, and decreases according to the reciprocal of the square root of the update step number in the later stage.
Insert image description here

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
    
    @property
    def param_groups(self):
        return self.optimizer.param_groups
        
    def multiply_grads(self, c):
        """Multiplies grads by a constant *c*."""                
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.data.mul_(c)
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return 0 if not step else self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

Scheduling visualization

optimizer = NoamOpt(
    model_size=arch_args.encoder_embed_dim, 
    factor=config.lr_factor, 
    warmup=config.lr_warmup, 
    optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))
plt.plot(np.arange(1, 100000), [optimizer.rate(i) for i in range(1, 100000)])
plt.legend([f"{
      
      optimizer.model_size}:{
      
      optimizer.warmup}"])
None

6. train

def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):
    itr = epoch_itr.next_epoch_itr(shuffle=True)
    itr = iterators.GroupedIterator(itr, accum_steps) # 梯度累积: 每 accum_steps 个 sample 更新一次

    stats = {
    
    "loss":[]}
    scaler = GradScaler()

    model.train()
    progress = tqdm.tqdm(itr, desc=f"train epoch {
      
      epoch_itr}", leave=False)
    for samples in progress:
        model.zero_grad()
        accum_loss = 0
        sample_size = 0
        # 梯度累积:没accum_steps个sample更新一次
        for i, sample in enumerate(samples):
            if i == 1:
                torch.cuda.empty_cache()

            sample = utils.move_to_cuda(sample, device=device)
            traget = sample["traget"]
            sample_size_i = sample["ntokens"]
            sample_size += sample_size_i

            # 混合精度训练
            with autocast():
                net_output = model.forward(**sample["net_input"])
                lprobs = F.log_softmax(net_output[0], -1)
                loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))

                # logging
                accum_loss += loss.item()
                # back-prop
                scaler.scale(loss).backward()

            scaler.unscale_(optimizer)
            optimizer.multiply_grads(
                1 / (sample_size or 1.0))  # (sample_size or 1.0) handles the case of a zero gradient
            gnorm = nn.utils.clip_grad_norm_(model.parameters(), config.clip_norm)  # 梯度裁剪 防止梯度爆炸

            scaler.step(optimizer)
            scaler.update()

            # logging
            loss_print = accum_loss / sample_size
            stats["loss"].append(loss_print)
            progress.set_postfix(loss=loss_print)
            # if config.use_wandb:
            #     wandb.log({
    
    
            #         "train/loss": loss_print,
            #         "train/grad_norm": gnorm.item(),
            #         "train/lr": optimizer.rate(),
            #         "train/sample_size": sample_size,
            #     })

        loss_print = np.mean(stats["loss"])
        logger.info(f"training loss: {
      
      loss_print:.4f}")
        return stats

7. Validation & Inference

In order to prevent overfitting, each epoch must be verified to calculate the performance of the model on unseen data.

  • The process is basically the same as training, plus inference
  • After verification, we can save the model weights
    Only verification loss cannot describe the true performance of the model
  • Directly use the current model to generate the translation result (hypothesis), and then calculate the BLUE score with the correct answer (reference)
  • We use the sequence generator written by fairseq to perform beam search to generate translation results.
# fairseq 的 beam search generator
# 給定模型和輸入序列,用 beam search 生成翻譯結果
sequence_generator = task.build_generator([model], config)

def decode(toks, dictionary):
    # 從 Tensor 轉成人看得懂的句子
    s = dictionary.string(
        toks.int().cpu(),
        config.post_process,
    )
    return s if s else "<unk>"

def inference_step(sample, model):
    gen_out = sequence_generator.generate([model], sample)
    srcs = []
    hyps = []
    refs = []
    for i in range(len(gen_out)):
        # 對於每個 sample, 收集輸入,輸出和參考答案,稍後計算 BLEU
        srcs.append(decode(
            utils.strip_pad(sample["net_input"]["src_tokens"][i], task.source_dictionary.pad()), 
            task.source_dictionary,
        ))
        hyps.append(decode(
            gen_out[i][0]["tokens"], # 0 代表取出 beam 內分數第一的輸出結果
            task.target_dictionary,
        ))
        refs.append(decode(
            utils.strip_pad(sample["target"][i], task.target_dictionary.pad()), 
            task.target_dictionary,
        ))
    return srcs, hyps, refs
import shutil
import sacrebleu

def validate(model, task, criterion, log_to_wandb=True):
    logger.info('begin validation')
    itr = load_data_iterator(task, "valid", 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)
    
    stats = {
    
    "loss":[], "bleu": 0, "srcs":[], "hyps":[], "refs":[]}
    srcs = []
    hyps = []
    refs = []
    
    model.eval()
    progress = tqdm.tqdm(itr, desc=f"validation", leave=False)
    with torch.no_grad():
        for i, sample in enumerate(progress):
            # validation loss
            sample = utils.move_to_cuda(sample, device=device)
            net_output = model.forward(**sample["net_input"])

            lprobs = F.log_softmax(net_output[0], -1)
            target = sample["target"]
            sample_size = sample["ntokens"]
            loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1)) / sample_size
            progress.set_postfix(valid_loss=loss.item())
            stats["loss"].append(loss)
            
            # 進行推論
            s, h, r = inference_step(sample, model)
            srcs.extend(s)
            hyps.extend(h)
            refs.extend(r)
            
    tok = 'zh' if task.cfg.target_lang == 'zh' else '13a'
    stats["loss"] = torch.stack(stats["loss"]).mean().item()
    stats["bleu"] = sacrebleu.corpus_bleu(hyps, [refs], tokenize=tok) # 計算BLEU score
    stats["srcs"] = srcs
    stats["hyps"] = hyps
    stats["refs"] = refs
    
    if config.use_wandb and log_to_wandb:
        wandb.log({
    
    
            "valid/loss": stats["loss"],
            "valid/bleu": stats["bleu"].score,
        }, commit=False)
    
    showid = np.random.randint(len(hyps))
    logger.info("example source: " + srcs[showid])
    logger.info("example hypothesis: " + hyps[showid])
    logger.info("example reference: " + refs[showid])
    
    # show bleu results
    logger.info(f"validation loss:\t{
      
      stats['loss']:.4f}")
    logger.info(stats["bleu"].format())
    return stats

8. Save and load model parameters

def validate_and_save(model, task, criterion, optimizer, epoch, save=True):   
    stats = validate(model, task, criterion)
    bleu = stats['bleu']
    loss = stats['loss']
    if save:
        # save epoch checkpoints
        savedir = Path(config.savedir).absolute()
        savedir.mkdir(parents=True, exist_ok=True)
        
        check = {
    
    
            "model": model.state_dict(),
            "stats": {
    
    "bleu": bleu.score, "loss": loss},
            "optim": {
    
    "step": optimizer._step}
        }
        torch.save(check, savedir/f"checkpoint{
      
      epoch}.pt")
        shutil.copy(savedir/f"checkpoint{
      
      epoch}.pt", savedir/f"checkpoint_last.pt")
        logger.info(f"saved epoch checkpoint: {
      
      savedir}/checkpoint{
      
      epoch}.pt")
    
        # save epoch samples
        with open(savedir/f"samples{
      
      epoch}.{
      
      config.source_lang}-{
      
      config.target_lang}.txt", "w") as f:
            for s, h in zip(stats["srcs"], stats["hyps"]):
                f.write(f"{
      
      s}\t{
      
      h}\n")

        # get best valid bleu    
        if getattr(validate_and_save, "best_bleu", 0) < bleu.score:
            validate_and_save.best_bleu = bleu.score
            torch.save(check, savedir/f"checkpoint_best.pt")
            
        del_file = savedir / f"checkpoint{
      
      epoch - config.keep_last_epochs}.pt"
        if del_file.exists():
            del_file.unlink()
    
    return stats

def try_load_checkpoint(model, optimizer=None, name=None):
    name = name if name else "checkpoint_last.pt"
    checkpath = Path(config.savedir)/name
    if checkpath.exists():
        check = torch.load(checkpath)
        model.load_state_dict(check["model"])
        stats = check["stats"]
        step = "unknown"
        if optimizer != None:
            optimizer._step = step = check["optim"]["step"]
        logger.info(f"loaded checkpoint {
      
      checkpath}: step={
      
      step} loss={
      
      stats['loss']} bleu={
      
      stats['bleu']}")
    else:
        logger.info(f"no checkpoints found at {
      
      checkpath}!")

9. Start training

model = model.to(device=device)
criterion = criterion.to(device=device)

logger.info("task: {}".format(task.__class__.__name__))
logger.info("encoder: {}".format(model.encoder.__class__.__name__))
logger.info("decoder: {}".format(model.decoder.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info("optimizer: {}".format(optimizer.__class__.__name__))
logger.info(
    "num. model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    )
)
logger.info(f"max tokens per batch = {
      
      config.max_tokens}, accumulate steps = {
      
      config.accum_steps}")

epoch_itr = load_data_iterator(task, "train", config.start_epoch, config.max_tokens, config.num_workers)
try_load_checkpoint(model, optimizer, name=config.resume)
while epoch_itr.next_epoch_idx <= config.max_epoch:
    # train for one epoch
    train_one_epoch(epoch_itr, model, task, criterion, optimizer, config.accum_steps)
    stats = validate_and_save(model, task, criterion, optimizer, epoch=epoch_itr.epoch)
    logger.info("end of epoch {}".format(epoch_itr.epoch))    
    epoch_itr = load_data_iterator(task, "train", epoch_itr.next_epoch_idx, config.max_tokens, config.num_workers)

Finally running! ! !

Although the translation is slightly outrageous hhh
Insert image description here

Guess you like

Origin blog.csdn.net/m0_51474171/article/details/130018611