Sentence vector model SimCSE——Pytorch

Model Introduction

SimCSEThe model is mainly divided into two parts, one is the unsupervised part and the other is the supervised part. The overall structure is shown in the figure below:

Please add image description

Paper address: https://arxiv.org/pdf/2104.08821.pdf

Unsupervised SimCSE

data

For the unsupervised part, the most clever thing is to use Dropoutdata augmentation to construct positive examples, thereby constructing a positive sample pair, while the negative sample pairs are batchother sentences in the same sentence.

Then someone may ask, why does a sentence get two different vectors when it is input to the model twice?

This is because: there are dropoutlayers in the model, and random deactivation of neurons will cause the same sentence to be input into the model during the training phase to get different outputs.

Looking at the code, it’s more intuitive:

class TrainDataset(Dataset):
    def __init__(self, data, tokenizer, model_type="unsup"):
        self.data = data
        self.tokenizer = tokenizer
        self.model_type = model_type

    def text2id(self, text):
        if self.model_type == "unsup":
            text_ids = self.tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
        elif self.model_type == "sup":
            text_ids = self.tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')

        return text_ids

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.text2id(self.data[index])

It can be seen that if the same sentence is repeated twice Bert, Encodertwo similar sentence vectors will be generated, which are regarded as positive examples.

Model

class SimcseUnsupModel(nn.Module):
    def __init__(self, pretrained_bert_path, drop_out) -> None:
        super(SimcseUnsupModel, self).__init__()

        self.pretrained_bert_path = pretrained_bert_path
        config = BertConfig.from_pretrained(self.pretrained_bert_path)
        config.attention_probs_dropout_prob = drop_out
        config.hidden_dropout_prob = drop_out
        self.bert = BertModel.from_pretrained(self.pretrained_bert_path, config=config)
    
    def forward(self, input_ids, attention_mask, token_type_ids, pooling="cls"):
        out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)

        if pooling == "cls":
            return out.last_hidden_state[:, 0]
        if pooling == "pooler":
            return out.pooler_output
        if pooling == 'last-avg':
            last = out.last_hidden_state.transpose(1, 2)
            return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
        if self.pooling == 'first-last-avg':
            first = out.hidden_states[1].transpose(1, 2)
            last = out.hidden_states[-1].transpose(1, 2)
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)
            return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
        
        # 有实验表明cls的pooling方式效果最好

Careful students have discovered that, what simcse, it is obviously the bertsame.
Yes, compared with Bert, Simcseonly changed drop_out, used Bertfor data enhancement, but in the calculation Loss, Simcsecontrast was introducedLoss

    def train(self, train_dataloader, dev_dataloader):
        self.model.train()
        for batch_idx, source in enumerate(tqdm(train_dataloader), start=1):
            real_batch_num = source.get('input_ids').shape[0] # source.get('input_ids').shape [64, 2, 64]
            input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]

            out = self.model(input_ids, attention_mask, token_type_ids) # out.shape [128, 768]  
            loss = self.simcse_unsup_loss(out)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if batch_idx % 10 == 0:     
                logger.info(f'loss: {
      
      loss.item():.4f}')
                corrcoef = self.eval(dev_dataloader)
                self.model.train()
                if self.best_loss > corrcoef:
                    self.best_loss = corrcoef
                    torch.save(self.model.state_dict(), self.model_save_path)
                    logger.info(f"higher corrcoef: {
      
      self.best_loss:.4f} in batch: {
      
      batch_idx}, save model")


    def simcse_unsup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        y_true = (y_true - y_true % 2 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

trainsourceThe output of the function contains three short segments: , , as shown below: the first dimension is , and the second dimension is the number of sentences input. Two sentences are input (the same sentence is input twice), so the first dimension Bertis The two dimensions are 2, and the third dimension is sentences.tokenizerinput_idstoken_type_idsattention_mask

Please add image description
input_idsbatch_sizebertmax_length

Next we look lossat the calculation process, breaking down each step:

1. For 128 sentences, generate an index of 0-127

y_true = torch.arange(y_pred.shape[0], device=self.device)

Please add image description

2. Generate the real label corresponding to each sentence

y_true = (y_true - y_true % 2 * 2) + 1

Please add image description
Note the difference between this step y_trueand the first step y_true.

Here y_true, it is actually the index of the positive example corresponding to each sentence batch, such as:

与第0个句子相似的句子索引为1
与第1个句子相似的句子索引为0

与第2个句子相似的句子索引为3
与第2个句子相似的句子索引为2

Note that I start counting from the 0th sentence

3. Pairwise calculation of similarity

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

y_predThe dimensions are [128, 768]

simThe dimensions are [128, 128]

Each line represents the similarity between the current sentence and other sentences. At this time, the value on the diagonal should be 1Please add image description

4. Amplify the value on the diagonal to a larger number to eliminate its own lossinfluence on the diagonal (it is almost 0 when calculating cross entropy at negative infinity)

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

5. Multiply the temperature coefficient of the hyperparameter. As for why 0.05, we can only say that experiments show that 0.05the effect is good.

sim = sim / 0.05

6. Use cross-entropy loss to represent contrastive loss, treat similar sentences as classification, shorten the distance from positive examples, and widen the distance from negative examples. In the same sentence, except for the sentences entered twice, they are positive examples for each batchother bert. The sentences are all negative examples

loss = F.cross_entropy(sim, y_true)

Effect

Please add image description

Supervised SimCSE

data

Unlike unsupervised, the input of unsupervised is a single textsentence, while the supervised data set is [text, text+, text-]a triplet of
Please add image description

Model

The model part is the same as the supervised one. It is also used bertfor encodecoding and clssentence vector extraction.

Let's focus on the different parts and losscalculate:

    def simcse_sup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        use_row = torch.where((y_true + 1) % 3 != 0)[0]
        y_true = (use_row - use_row % 3 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = torch.index_select(sim, 0, use_row)
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

1. Generate an index of 0-191

y_true = torch.arange(y_pred.shape[0], device=self.device)

2. Select the index to use. If there is no third sentence label, the third sentence will be a negative example. If the third sentence is not used, batchother sentences in the same sentence will be regarded as negative examples.

use_row = torch.where((y_true + 1) % 3 != 0)[0]

3. Discard the real label after the third sentence

y_true = (use_row - use_row % 3 * 2) + 1

Please add image description

4. Calculate the similarity between two pairs. simThe dimension at this time is [192, 192], including the negative example of the third sentence.

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

Please add image description
5. Eliminate the influence of the dimension on the diagonal

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

6. Pick out useful lines

sim = torch.index_select(sim, 0, use_row)

Please add image description

7. Calculate cross-entropy loss, consistent with the unsupervised method

loss = F.cross_entropy(sim, y_true)

Effect

Please add image description

Summarize

avenue to simplicity

All codes have been uploaded to Github, link: https://github.com/seanzhang-zhichen/simcse-pytorch

Dataset: Extraction code: hlva

Guess you like

Origin blog.csdn.net/qq_44193969/article/details/126981581