The Encoder-Decoer model shares the embedding matrix, and the parameter update problem of the embedding matrix

Recently, I have been doing generative question and answer and tried to use bert as the encoder and transformer-decoder as the decoder framework. I encountered a problem, that is, I want the decoder to share Bert's embedding matrix, but since the decoder and encoder learning rates are set to be different, I don't know how to update the embedding matrix parameters? Will it be affected by the decoder side, so I did the following experiment.

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, ):
        super(Encoder, self).__init__()
        self.embeddings = nn.Embedding(100, 50)
        self.fc = nn.Linear(50, 1)

    def forward(self, input):

        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.embeddings = None
        self.fc = nn.Linear(50, 1)

    def forward(self, input):
        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

        self.decoder.embeddings = self.encoder.embeddings

    def forward(self, enc_input, dec_input):
        enc_ = self.encoder(enc_input)
        dec_ = self.decoder(dec_input)

        return enc_.sum() + dec_.sum()


model = myModel()

enc_param = []
dec_param = []
for n,p in list(model.named_parameters()):
    if n.split('.')[0] == 'encoder':
        enc_param.append((n, p))
    else:
        dec_param.append((n, p))

optimizer_grouped_parameters = [
            # bert other module
            {
    
    "params": [p for n, p in enc_param],
             'lr': 0.01},
            {
    
    "params": [p for n, p in dec_param],
             'lr': 0.001},
        ]


optim = torch.optim.SGD(optimizer_grouped_parameters)

enc_input = torch.arange(0, 10).unsqueeze(0)
dec_input = torch.arange(5, 15).unsqueeze(0)

loss = model(enc_input, dec_input)

optim.zero_grad()
loss.backward()
optim.step()


print(id(model.encoder.embeddings))
print(id(model.decoder.embeddings))

print([n for (n, p) in dec_param])
print([n for (n, p) in enc_param])

'''输出
140206391178048
140206391178048
['decoder.fc.weight', 'decoder.fc.bias']
['encoder.embeddings.weight', 'encoder.fc.weight', 'encoder.fc.bias']

'''


According to the printing results, it is found that the embedding is only in the parameter group of the encoder, and the embedding of the decoder and the embedding of the encoder have the same address in the memory, indicating that they are shared, so my worries are unnecessary.

Guess you like

Origin blog.csdn.net/mch2869253130/article/details/123832565