最近做生成式问答,尝试用bert做encoder,transformer-decoder做decoder框架来做。遇到一个问题,就是我想让decoder共享bert的embedding矩阵,但是由于设置了decoder和encoder学习速率不同,因此,我不知道embedding矩阵参数如何更新?会不会收到decoder端的影响,于是做了下面的实验。
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, ):
super(Encoder, self).__init__()
self.embeddings = nn.Embedding(100, 50)
self.fc = nn.Linear(50, 1)
def forward(self, input):
feature = self.embeddings(input)
feature = self.fc(feature)
return feature
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.embeddings = None
self.fc = nn.Linear(50, 1)
def forward(self, input):
feature = self.embeddings(input)
feature = self.fc(feature)
return feature
class myModel(nn.Module):
def __init__(self):
super(myModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.decoder.embeddings = self.encoder.embeddings
def forward(self, enc_input, dec_input):
enc_ = self.encoder(enc_input)
dec_ = self.decoder(dec_input)
return enc_.sum() + dec_.sum()
model = myModel()
enc_param = []
dec_param = []
for n,p in list(model.named_parameters()):
if n.split('.')[0] == 'encoder':
enc_param.append((n, p))
else:
dec_param.append((n, p))
optimizer_grouped_parameters = [
# bert other module
{
"params": [p for n, p in enc_param],
'lr': 0.01},
{
"params": [p for n, p in dec_param],
'lr': 0.001},
]
optim = torch.optim.SGD(optimizer_grouped_parameters)
enc_input = torch.arange(0, 10).unsqueeze(0)
dec_input = torch.arange(5, 15).unsqueeze(0)
loss = model(enc_input, dec_input)
optim.zero_grad()
loss.backward()
optim.step()
print(id(model.encoder.embeddings))
print(id(model.decoder.embeddings))
print([n for (n, p) in dec_param])
print([n for (n, p) in enc_param])
'''输出
140206391178048
140206391178048
['decoder.fc.weight', 'decoder.fc.bias']
['encoder.embeddings.weight', 'encoder.fc.weight', 'encoder.fc.bias']
'''
根据打印结果,发现embedding只在encoder的参数组中,而且decoder的embedding与encoder的embedding在内存中地址一样,说明是共享的,所以我的担心是多虑的。