SimCSE del modelo de vector de oración - Pytorch

Introducción del modelo

SimCSEEl modelo se divide principalmente en dos partes, una es la parte no supervisada y la otra es la parte supervisada. La estructura general se muestra en la siguiente figura:

Por favor agregue una descripción de la imagen.

Dirección del artículo: https://arxiv.org/pdf/2104.08821.pdf

SimCSE no supervisado

datos

Para la parte no supervisada, lo más ingenioso es utilizar Dropoutla mejora de datos para construir ejemplos positivos, construyendo así un par de muestras positivas, mientras que los pares de muestras negativas son batchotras oraciones en la misma oración.

Entonces alguien puede preguntar, ¿por qué una oración obtiene dos vectores diferentes cuando se ingresa al modelo dos veces?

Esto se debe a que hay dropoutcapas en el modelo y la desactivación aleatoria de neuronas hará que se ingrese la misma oración en el modelo durante la fase de entrenamiento para obtener diferentes resultados.

Mirando el código, es más intuitivo:

class TrainDataset(Dataset):
    def __init__(self, data, tokenizer, model_type="unsup"):
        self.data = data
        self.tokenizer = tokenizer
        self.model_type = model_type

    def text2id(self, text):
        if self.model_type == "unsup":
            text_ids = self.tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
        elif self.model_type == "sup":
            text_ids = self.tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')

        return text_ids

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.text2id(self.data[index])

Se puede ver que si la misma oración se repite dos veces Bert, Encoderse generarán dos vectores de oraciones similares, que se consideran ejemplos positivos.

Modelo

class SimcseUnsupModel(nn.Module):
    def __init__(self, pretrained_bert_path, drop_out) -> None:
        super(SimcseUnsupModel, self).__init__()

        self.pretrained_bert_path = pretrained_bert_path
        config = BertConfig.from_pretrained(self.pretrained_bert_path)
        config.attention_probs_dropout_prob = drop_out
        config.hidden_dropout_prob = drop_out
        self.bert = BertModel.from_pretrained(self.pretrained_bert_path, config=config)
    
    def forward(self, input_ids, attention_mask, token_type_ids, pooling="cls"):
        out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)

        if pooling == "cls":
            return out.last_hidden_state[:, 0]
        if pooling == "pooler":
            return out.pooler_output
        if pooling == 'last-avg':
            last = out.last_hidden_state.transpose(1, 2)
            return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
        if self.pooling == 'first-last-avg':
            first = out.hidden_states[1].transpose(1, 2)
            last = out.hidden_states[-1].transpose(1, 2)
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)
            return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
        
        # 有实验表明cls的pooling方式效果最好

Los estudiantes cuidadosos han descubierto que, obviamente simcse, es lo bertmismo.
Sí, en comparación con él Bertsolo Simcsecambió drop_out, se usó Bertpara mejorar los datos, pero en el cálculo Loss, Simcsese introdujo el contraste.Loss

    def train(self, train_dataloader, dev_dataloader):
        self.model.train()
        for batch_idx, source in enumerate(tqdm(train_dataloader), start=1):
            real_batch_num = source.get('input_ids').shape[0] # source.get('input_ids').shape [64, 2, 64]
            input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
            token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]

            out = self.model(input_ids, attention_mask, token_type_ids) # out.shape [128, 768]  
            loss = self.simcse_unsup_loss(out)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if batch_idx % 10 == 0:     
                logger.info(f'loss: {
      
      loss.item():.4f}')
                corrcoef = self.eval(dev_dataloader)
                self.model.train()
                if self.best_loss > corrcoef:
                    self.best_loss = corrcoef
                    torch.save(self.model.state_dict(), self.model_save_path)
                    logger.info(f"higher corrcoef: {
      
      self.best_loss:.4f} in batch: {
      
      batch_idx}, save model")


    def simcse_unsup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        y_true = (y_true - y_true % 2 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

trainLa salida de sourcefor en la función contiene tres segmentos cortos: , , como se muestra en la siguiente figura: la primera dimensión de es , la segunda dimensión es el número de oraciones de entrada y se ingresan dos oraciones (la misma oración se ingresa dos veces) , entonces la primera Las dos dimensiones son 2 y la tercera dimensión son oraciones.Berttokenizerinput_idstoken_type_idsattention_mask

Por favor agregue una descripción de la imagen.
input_idsbatch_sizebertmax_length

A continuación, veremos lossel proceso de cálculo y analizaremos cada paso:

1. Dadas 128 oraciones, genere un índice de 0-127

y_true = torch.arange(y_pred.shape[0], device=self.device)

Por favor agregue una descripción de la imagen.

2. Genera etiquetas reales correspondientes a cada frase.

y_true = (y_true - y_true % 2 * 2) + 1

Por favor agregue una descripción de la imagen.
Preste atención a la diferencia entre este paso y_truey el primer paso y_true.

Aquí y_trueestá en realidad el índice del ejemplo positivo correspondiente a cada frase de ésta batch, por ejemplo:

与第0个句子相似的句子索引为1
与第1个句子相似的句子索引为0

与第2个句子相似的句子索引为3
与第2个句子相似的句子索引为2

Tenga en cuenta que comencé a contar desde la oración 0

3. Calcula la similitud entre dos pares.

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

y_predLas dimensiones son [128, 768]

simLas dimensiones son [128, 128]

Cada línea representa la similitud entre la oración actual y otras oraciones, en este momento el valor en la diagonal debe ser 1Por favor agregue una descripción de la imagen.

4. Amplifique el valor de la diagonal a un número mayor para eliminar lossla influencia de la diagonal misma (el infinito negativo es casi 0 al calcular la entropía cruzada)

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

5. Multiplique por el coeficiente de temperatura del hiperparámetro. En cuanto a por qué es así 0.05, solo se puede decir que el experimento muestra que 0.05el efecto es bueno.

sim = sim / 0.05

6. Utilice la pérdida de entropía cruzada para representar la pérdida de comparación, trate oraciones similares como categorías, acorte la distancia de los ejemplos positivos y acorte la distancia de los ejemplos negativos. En la misma oración, excepto las oraciones ingresadas dos veces como ejemplos positivos, batchotros bertnegativos ejemplos

loss = F.cross_entropy(sim, y_true)

Efecto

Por favor agregue una descripción de la imagen.

SimCSE supervisado

datos

A diferencia de lo no supervisado, la entrada de lo no supervisado es una sola textoración, mientras que el conjunto de datos supervisados ​​es [text, text+, text-]un triplete de
Por favor agregue una descripción de la imagen.

Modelo

La parte del modelo es la misma que la supervisada y también se utiliza bertpara encodecodificación y clsextracción de vectores de oraciones.

Centrémonos en las diferentes partes y losscalculemos:

    def simcse_sup_loss(self, y_pred):
        y_true = torch.arange(y_pred.shape[0], device=self.device)
        use_row = torch.where((y_true + 1) % 3 != 0)[0]
        y_true = (use_row - use_row % 3 * 2) + 1
        sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
        sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
        sim = torch.index_select(sim, 0, use_row)
        sim = sim / 0.05
        loss = F.cross_entropy(sim, y_true)
        return loss

1. Generar un índice de 0-191

y_true = torch.arange(y_pred.shape[0], device=self.device)

2. Seleccione el índice a usar. Si no hay una tercera oración label, la tercera oración será un ejemplo negativo. Si no se usa la tercera oración, batchotras oraciones en la misma oración se considerarán ejemplos negativos.

use_row = torch.where((y_true + 1) % 3 != 0)[0]

3. Deseche la etiqueta real después de la tercera oración.

y_true = (use_row - use_row % 3 * 2) + 1

Por favor agregue una descripción de la imagen.

4. Calcule la similitud en pares, y simla dimensión en este momento es [192, 192], incluido el ejemplo negativo de la tercera oración.

sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)

Por favor agregue una descripción de la imagen.
5. Eliminar la influencia de la dimensión en la diagonal.

sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12

6. Elija filas útiles

sim = torch.index_select(sim, 0, use_row)

Por favor agregue una descripción de la imagen.

7. Calcule la pérdida de entropía cruzada, de acuerdo con el método no supervisado.

loss = F.cross_entropy(sim, y_true)

Efecto

Por favor agregue una descripción de la imagen.

Resumir

Del gran camino a lo simple

Todos los códigos se han subido al Githubenlace: https://github.com/seanzhang-zhichen/simcse-pytorch

Conjunto de datos: Código de extracción: hlva

Supongo que te gusta

Origin blog.csdn.net/qq_44193969/article/details/126981581
Recomendado
Clasificación