메모리 오버플로 문제를 해결하기 위한 BERT-pytorch 소스 코드 구현
BERT 모델을 작업하시는 분들이 많을 거라 생각하는데, transfermer 모델에서 직접 데이터를 import하는 분들도 계시겠지만, 이 방법은 모델을 수정하는데 불편해서 pytorch를 통해 BERT를 자세하게 구현해 주시는 분들도 계시지만 Bo 저자는 BERT를 세부적으로 구현하기 위한 코드들에서 메모리 오버플로 문제가 있다는 것을 발견하고 이를 개선했는데, 다음 코드를 사용하면 중간 결과가 완전히 공개되지 않기 때문에 메모리 오버플로 문제를 해결할 수 있다. 참고
: 메모리 오버플로 문제를 해결하려면 del 문에 집중하세요.
'''
code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import matplotlib.pyplot as plt
from data_process import get_data
setences,label,setences_test,label_test=get_data()
device = torch.device('cpu')
sentences=setences
#text = (
# 'Hello, how are you? I am Romeo.\n' # R
# 'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
# 'Nice meet you too. How are you today?\n' # R
# 'Great. My baseball team won the competition.\n' # J
# 'Oh Congratulations, Juliet\n' # R
# 'Thank you Romeo\n' # J
# 'Where are you going today?\n' # R
# 'I am going shopping. What about you?\n' # J
# 'I am going to visit my grandmother. she is not very well' # R
#)
#sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filt
#print(sentences)
word_list = list(set(" ".join(setences).split())|set(" ".join(setences_test).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {
'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
word2idx[w] = i + 4
idx2word = {
i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)
token_list = list()
for sentence in setences:
arr = [word2idx[s] for s in sentence.split()]
token_list.append(arr)
#print(token_list)
'''
[[12, 7, 22, 5, 39, 21, 15],
[12, 15, 13, 35, 10, 27, 34, 14, 19, 5],
[34, 19, 5, 17, 7, 22, 5, 8],
[33, 13, 37, 32, 28, 11, 16],
[30, 23, 27],
[6, 5, 15],
[36, 22, 5, 31, 8],
[39, 21, 31, 18, 9, 20, 5],
[39, 21, 31, 14, 29, 13, 4, 25, 10, 26, 38, 24]]
'''
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64 # dimension of K(=Q), V
n_segments = 3
# sample IsNext and NotNext to be same in small batch size
def make_data():
batch = []
for i in range(len(setences)):
tokens_a_index = i
tokens_a = token_list[tokens_a_index]
input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1)
# MASK LM
n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
cand_maked_pos = [i for i, token in enumerate(input_ids)
if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8: # 80%
input_ids[pos] = word2idx['[MASK]'] # make mask
elif random() > 0.9: # 10%
index = randint(0, vocab_size - 1) # random index in vocabulary
while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
index = randint(0, vocab_size - 1)
input_ids[pos] = index # replace
# Zero Paddings
n_pad = maxlen - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
# Zero Padding (100% - 15%) tokens
if max_pred > n_pred:
n_pad = max_pred - n_pred
masked_tokens.extend([0] * n_pad)
masked_pos.extend([0] * n_pad)
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label[tokens_a_index]]) # IsNext
return batch
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped) # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
torch.LongTensor(masked_pos), torch.LongTensor(isNext)
class MyDataSet(Data.Dataset):
def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
self.input_ids = input_ids
self.segment_ids = segment_ids
self.masked_tokens = masked_tokens
self.masked_pos = masked_pos
self.isNext = isNext
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[
idx]
loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)
def get_attn_pad_mask(seq_q, seq_k):
batch_size, seq_len = seq_q.size() #[batch_size,maxlen]
# eq(zero) is PAD token
pad_attn_mask = seq_q.data.eq(0).unsqueeze(1) # [batch_size, 1, seq_len]
return pad_attn_mask.expand(batch_size, seq_len, seq_len) # [batch_size, seq_len, seq_len]
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class Embedding(nn.Module):
def __init__(self):
super(Embedding, self).__init__()
self.tok_embed = nn.Embedding(vocab_size, d_model) # token embedding
self.pos_embed = nn.Embedding(maxlen, d_model) # position embedding
self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding
self.norm = nn.LayerNorm(d_model)
def forward(self, x, seg):
seq_len = x.size(1)
pos = torch.arange(seq_len, dtype=torch.long)
# print("pos:",pos)
'''pos: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])'''
pos = pos.unsqueeze(0).expand_as(x).to(device) # [seq_len] -> [batch_size, seq_len]
# print("pos_batch:", pos)
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
del pos,x, seg
return self.norm(embedding)
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]
scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
del attn,scores,Q, K, V,attn_mask
return context
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
def forward(self, Q, K, V, attn_mask):
# q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]
residual, batch_size = Q, Q.size(0)
residual=residual.to(device)
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size, n_heads, seq_len, d_k]
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size, n_heads, seq_len, d_k]
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size, n_heads, seq_len, d_v]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]
# context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]
context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)# context: [batch_size, seq_len, n_heads, d_v]
output = nn.Linear(n_heads * d_v, d_model).to(device)(context)
del context,attn_mask,q_s,k_s,v_s
return nn.LayerNorm(d_model).to(device)(output + residual) # output: [batch_size, seq_len, d_model]
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
return self.fc2(gelu(self.fc1(x)))
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]
del enc_self_attn_mask,enc_inputs
return enc_outputs
class BERT(nn.Module):
def __init__(self):
super(BERT, self).__init__()
self.embedding = Embedding()
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
self.fc = nn.Sequential(
nn.Linear(d_model, d_model),
nn.Dropout(0.5),
nn.Tanh(),
)
self.classifier = nn.Linear(d_model, 3)
self.linear = nn.Linear(d_model, d_model)
self.activ2 = gelu
# fc2 is shared with embedding layer
embed_weight = self.embedding.tok_embed.weight
self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
self.fc2.weight = embed_weight
def forward(self, input_ids, segment_ids, masked_pos):
output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]
for layer in self.layers:
# output: [batch_size, max_len, d_model]
output = layer(output, enc_self_attn_mask)
# it will be decided by first token(CLS)
'''
(fc): Sequential(
(0): Linear(in_features=768, out_features=768, bias=True)
(1): Dropout(p=0.5, inplace=False)
(2): Tanh()
)
(classifier): Linear(in_features=768, out_features=2, bias=True)
(linear): Linear(in_features=768, out_features=768, bias=True)
(fc2): Linear(in_features=768, out_features=40, bias=False)
'''
# logits_clsf :根据[CLS]预测是否是连续的句子,[CLS]在第一维
h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]
logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext
masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]
#logits_lm:预测mask的token
logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]
del h_masked,h_pooled,output,enc_self_attn_mask,masked_pos,input_ids,segment_ids
return logits_lm, logits_clsf
model = BERT().to(device)
# print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000001)
#out = torch.gather(input, dim, index)
index = torch.from_numpy(np.array([[1, 2, 0], [2, 0, 1]])).type(torch.LongTensor)
index = index[:, :, None].expand(-1, -1, 10)
loss_list=[]
for epoch in range(10):
loss_sum=0
for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
#logits_lm:[batch_size,max_pred,vocab_size] -> [batch_size*max_pred,vocab_size],batch_size*max_pred个词。每个词都有vocab_size种可能。
loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
loss_lm = (loss_lm.float()).mean()
# isNext=isNext.to(device)
loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
loss = loss_lm + loss_clsf
loss_sum=loss_sum+loss
loss_list.append(float(loss))
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
del loss, logits_clsf, input_ids,segment_ids,masked_tokens,masked_pos,logits_lm,isNext,loss_clsf,loss_lm
# Predict mask tokens ans isNext
print('test')
token_list=[]
for sentence in setences_test:
arr = [word2idx[s] for s in sentence.split()]
token_list.append(arr)
def make_data_test():
batch = []
for i in range(len(setences_test)):
tokens_a_index = i
tokens_a = token_list[tokens_a_index]
input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1)
# MASK LM
n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
cand_maked_pos = [i for i, token in enumerate(input_ids)
if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8: # 80%
input_ids[pos] = word2idx['[MASK]'] # make mask
elif random() > 0.9: # 10%
index = randint(0, vocab_size - 1) # random index in vocabulary
while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
index = randint(0, vocab_size - 1)
input_ids[pos] = index # replace
# Zero Paddings
n_pad = maxlen - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
# Zero Padding (100% - 15%) tokens
if max_pred > n_pred:
n_pad = max_pred - n_pred
masked_tokens.extend([0] * n_pad)
masked_pos.extend([0] * n_pad)
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label_test[tokens_a_index]]) # IsNext
return batch
# Proprecessing Finished
batch = make_data_test()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped) # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
torch.LongTensor(masked_pos), torch.LongTensor(isNext)
predict_list=[]
for i in range(len(batch)):
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])
logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])
logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', isNext )
print('predict isNext : ', logits_clsf)
predict_list.append(logits_clsf)
test_loss = 0
correct = 0
total = 0
target_num =[0,0,0]
predict_num = [0,0,0]
p=0
acc_num =[0,0,0]
for i in label_test:
target_num[i]+=1
for i in predict_list:
# print(i.argmax())
index=int(i)
if index in [0,1,2]:
predict_num[index]+=1
# print(id2word[index],id2word[p])
if index==label_test[p]:
acc_num[index]+=1
p=p+1
#print(target_num)
#print(predict_num)
#print(acc_num)
recallz=0
precisionz=0
accuracyz=0
F1z=0
ps=0
rs=0
for i in range(3):
if target_num[i]!=0:
recallz=acc_num[i]/target_num[i]
else:
recallz=0
if predict_num[i]!=0:
precisionz=acc_num[i]/predict_num[i]
else:
precisionz=0
ps=ps+precisionz
rs=rs+recallz
if recallz+precisionz!=0:
F1z=2*recallz*precisionz/(recallz+precisionz)+F1z
#recall = [acc_num[i]/target_num[i] for i in range(3)]
#precision = [acc_num[i]/predict_num[i] for i in range(3)]
#F1 = [2*recall[i]*precision[i]/(recall[i]+precision[i]) for i in range(3)]
print()
accuracy = sum(acc_num)/sum(target_num)
# 打印格式方便复制
print('recall:',rs/3)
print('precision:',ps/3)
print('F1:',F1z/3)
print('accuracy',accuracy)
plt.plot(loss_list,label='BERT')
plt.legend()
plt.title('loss-epoch')
plt.show()