Article directory
Model Introduction
SimCSE
The model is mainly divided into two parts, one is the unsupervised part and the other is the supervised part. The overall structure is shown in the figure below:
Paper address: https://arxiv.org/pdf/2104.08821.pdf
Unsupervised SimCSE
data
For the unsupervised part, the most clever thing is to use Dropout
data augmentation to construct positive examples, thereby constructing a positive sample pair, while the negative sample pairs are batch
other sentences in the same sentence.
Then someone may ask, why does a sentence get two different vectors when it is input to the model twice?
This is because: there are dropout
layers in the model, and random deactivation of neurons will cause the same sentence to be input into the model during the training phase to get different outputs.
Looking at the code, it’s more intuitive:
class TrainDataset(Dataset):
def __init__(self, data, tokenizer, model_type="unsup"):
self.data = data
self.tokenizer = tokenizer
self.model_type = model_type
def text2id(self, text):
if self.model_type == "unsup":
text_ids = self.tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
elif self.model_type == "sup":
text_ids = self.tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
return text_ids
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.text2id(self.data[index])
It can be seen that if the same sentence is repeated twice Bert
, Encoder
two similar sentence vectors will be generated, which are regarded as positive examples.
Model
class SimcseUnsupModel(nn.Module):
def __init__(self, pretrained_bert_path, drop_out) -> None:
super(SimcseUnsupModel, self).__init__()
self.pretrained_bert_path = pretrained_bert_path
config = BertConfig.from_pretrained(self.pretrained_bert_path)
config.attention_probs_dropout_prob = drop_out
config.hidden_dropout_prob = drop_out
self.bert = BertModel.from_pretrained(self.pretrained_bert_path, config=config)
def forward(self, input_ids, attention_mask, token_type_ids, pooling="cls"):
out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)
if pooling == "cls":
return out.last_hidden_state[:, 0]
if pooling == "pooler":
return out.pooler_output
if pooling == 'last-avg':
last = out.last_hidden_state.transpose(1, 2)
return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
if self.pooling == 'first-last-avg':
first = out.hidden_states[1].transpose(1, 2)
last = out.hidden_states[-1].transpose(1, 2)
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)
return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
# 有实验表明cls的pooling方式效果最好
Careful students have discovered that, what simcse
, it is obviously the bert
same.
Yes, compared with Bert
, Simcse
only changed drop_out
, used Bert
for data enhancement, but in the calculation Loss
, Simcse
contrast was introducedLoss
def train(self, train_dataloader, dev_dataloader):
self.model.train()
for batch_idx, source in enumerate(tqdm(train_dataloader), start=1):
real_batch_num = source.get('input_ids').shape[0] # source.get('input_ids').shape [64, 2, 64]
input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
out = self.model(input_ids, attention_mask, token_type_ids) # out.shape [128, 768]
loss = self.simcse_unsup_loss(out)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if batch_idx % 10 == 0:
logger.info(f'loss: {
loss.item():.4f}')
corrcoef = self.eval(dev_dataloader)
self.model.train()
if self.best_loss > corrcoef:
self.best_loss = corrcoef
torch.save(self.model.state_dict(), self.model_save_path)
logger.info(f"higher corrcoef: {
self.best_loss:.4f} in batch: {
batch_idx}, save model")
def simcse_unsup_loss(self, y_pred):
y_true = torch.arange(y_pred.shape[0], device=self.device)
y_true = (y_true - y_true % 2 * 2) + 1
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
sim = sim / 0.05
loss = F.cross_entropy(sim, y_true)
return loss
train
source
The output of the function contains three short segments: , , as shown below: the first dimension is , and the second dimension is the number of sentences input. Two sentences are input (the same sentence is input twice), so the first dimension Bert
is The two dimensions are 2, and the third dimension is sentences.tokenizer
input_ids
token_type_ids
attention_mask
input_ids
batch_size
bert
max_length
Next we look loss
at the calculation process, breaking down each step:
1. For 128 sentences, generate an index of 0-127
y_true = torch.arange(y_pred.shape[0], device=self.device)
2. Generate the real label corresponding to each sentence
y_true = (y_true - y_true % 2 * 2) + 1
Note the difference between this step y_true
and the first step y_true
.
Here y_true
, it is actually the index of the positive example corresponding to each sentence batch
, such as:
与第0个句子相似的句子索引为1
与第1个句子相似的句子索引为0
与第2个句子相似的句子索引为3
与第2个句子相似的句子索引为2
Note that I start counting from the 0th sentence
3. Pairwise calculation of similarity
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
y_pred
The dimensions are [128, 768]
sim
The dimensions are [128, 128]
Each line represents the similarity between the current sentence and other sentences. At this time, the value on the diagonal should be 1
4. Amplify the value on the diagonal to a larger number to eliminate its own loss
influence on the diagonal (it is almost 0 when calculating cross entropy at negative infinity)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
5. Multiply the temperature coefficient of the hyperparameter. As for why 0.05
, we can only say that experiments show that 0.05
the effect is good.
sim = sim / 0.05
6. Use cross-entropy loss to represent contrastive loss, treat similar sentences as classification, shorten the distance from positive examples, and widen the distance from negative examples. In the same sentence, except for the sentences entered twice, they are positive examples for each batch
other bert
. The sentences are all negative examples
loss = F.cross_entropy(sim, y_true)
Effect
Supervised SimCSE
data
Unlike unsupervised, the input of unsupervised is a single text
sentence, while the supervised data set is [text, text+, text-]
a triplet of
Model
The model part is the same as the supervised one. It is also used bert
for encode
coding and cls
sentence vector extraction.
Let's focus on the different parts and loss
calculate:
def simcse_sup_loss(self, y_pred):
y_true = torch.arange(y_pred.shape[0], device=self.device)
use_row = torch.where((y_true + 1) % 3 != 0)[0]
y_true = (use_row - use_row % 3 * 2) + 1
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
sim = torch.index_select(sim, 0, use_row)
sim = sim / 0.05
loss = F.cross_entropy(sim, y_true)
return loss
1. Generate an index of 0-191
y_true = torch.arange(y_pred.shape[0], device=self.device)
2. Select the index to use. If there is no third sentence label
, the third sentence will be a negative example. If the third sentence is not used, batch
other sentences in the same sentence will be regarded as negative examples.
use_row = torch.where((y_true + 1) % 3 != 0)[0]
3. Discard the real label after the third sentence
y_true = (use_row - use_row % 3 * 2) + 1
4. Calculate the similarity between two pairs. sim
The dimension at this time is [192, 192]
, including the negative example of the third sentence.
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
5. Eliminate the influence of the dimension on the diagonal
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
6. Pick out useful lines
sim = torch.index_select(sim, 0, use_row)
7. Calculate cross-entropy loss, consistent with the unsupervised method
loss = F.cross_entropy(sim, y_true)
Effect
Summarize
avenue to simplicity
All codes have been uploaded to Github
, link: https://github.com/seanzhang-zhichen/simcse-pytorch
Dataset: Extraction code: hlva