Directorio de artículos
Introducción del modelo
SimCSE
El modelo se divide principalmente en dos partes, una es la parte no supervisada y la otra es la parte supervisada. La estructura general se muestra en la siguiente figura:
Dirección del artículo: https://arxiv.org/pdf/2104.08821.pdf
SimCSE no supervisado
datos
Para la parte no supervisada, lo más ingenioso es utilizar Dropout
la mejora de datos para construir ejemplos positivos, construyendo así un par de muestras positivas, mientras que los pares de muestras negativas son batch
otras oraciones en la misma oración.
Entonces alguien puede preguntar, ¿por qué una oración obtiene dos vectores diferentes cuando se ingresa al modelo dos veces?
Esto se debe a que hay dropout
capas en el modelo y la desactivación aleatoria de neuronas hará que se ingrese la misma oración en el modelo durante la fase de entrenamiento para obtener diferentes resultados.
Mirando el código, es más intuitivo:
class TrainDataset(Dataset):
def __init__(self, data, tokenizer, model_type="unsup"):
self.data = data
self.tokenizer = tokenizer
self.model_type = model_type
def text2id(self, text):
if self.model_type == "unsup":
text_ids = self.tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
elif self.model_type == "sup":
text_ids = self.tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
return text_ids
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.text2id(self.data[index])
Se puede ver que si la misma oración se repite dos veces Bert
, Encoder
se generarán dos vectores de oraciones similares, que se consideran ejemplos positivos.
Modelo
class SimcseUnsupModel(nn.Module):
def __init__(self, pretrained_bert_path, drop_out) -> None:
super(SimcseUnsupModel, self).__init__()
self.pretrained_bert_path = pretrained_bert_path
config = BertConfig.from_pretrained(self.pretrained_bert_path)
config.attention_probs_dropout_prob = drop_out
config.hidden_dropout_prob = drop_out
self.bert = BertModel.from_pretrained(self.pretrained_bert_path, config=config)
def forward(self, input_ids, attention_mask, token_type_ids, pooling="cls"):
out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)
if pooling == "cls":
return out.last_hidden_state[:, 0]
if pooling == "pooler":
return out.pooler_output
if pooling == 'last-avg':
last = out.last_hidden_state.transpose(1, 2)
return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
if self.pooling == 'first-last-avg':
first = out.hidden_states[1].transpose(1, 2)
last = out.hidden_states[-1].transpose(1, 2)
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)
return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
# 有实验表明cls的pooling方式效果最好
Los estudiantes cuidadosos han descubierto que, obviamente simcse
, es lo bert
mismo.
Sí, en comparación con él Bert
solo Simcse
cambió drop_out
, se usó Bert
para mejorar los datos, pero en el cálculo Loss
, Simcse
se introdujo el contraste.Loss
def train(self, train_dataloader, dev_dataloader):
self.model.train()
for batch_idx, source in enumerate(tqdm(train_dataloader), start=1):
real_batch_num = source.get('input_ids').shape[0] # source.get('input_ids').shape [64, 2, 64]
input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(self.device) # shape[128, 64]
out = self.model(input_ids, attention_mask, token_type_ids) # out.shape [128, 768]
loss = self.simcse_unsup_loss(out)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if batch_idx % 10 == 0:
logger.info(f'loss: {
loss.item():.4f}')
corrcoef = self.eval(dev_dataloader)
self.model.train()
if self.best_loss > corrcoef:
self.best_loss = corrcoef
torch.save(self.model.state_dict(), self.model_save_path)
logger.info(f"higher corrcoef: {
self.best_loss:.4f} in batch: {
batch_idx}, save model")
def simcse_unsup_loss(self, y_pred):
y_true = torch.arange(y_pred.shape[0], device=self.device)
y_true = (y_true - y_true % 2 * 2) + 1
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
sim = sim / 0.05
loss = F.cross_entropy(sim, y_true)
return loss
train
La salida de source
for en la función contiene tres segmentos cortos: , , como se muestra en la siguiente figura: la primera dimensión de es , la segunda dimensión es el número de oraciones de entrada y se ingresan dos oraciones (la misma oración se ingresa dos veces) , entonces la primera Las dos dimensiones son 2 y la tercera dimensión son oraciones.Bert
tokenizer
input_ids
token_type_ids
attention_mask
input_ids
batch_size
bert
max_length
A continuación, veremos loss
el proceso de cálculo y analizaremos cada paso:
1. Dadas 128 oraciones, genere un índice de 0-127
y_true = torch.arange(y_pred.shape[0], device=self.device)
2. Genera etiquetas reales correspondientes a cada frase.
y_true = (y_true - y_true % 2 * 2) + 1
Preste atención a la diferencia entre este paso y_true
y el primer paso y_true
.
Aquí y_true
está en realidad el índice del ejemplo positivo correspondiente a cada frase de ésta batch
, por ejemplo:
与第0个句子相似的句子索引为1
与第1个句子相似的句子索引为0
与第2个句子相似的句子索引为3
与第2个句子相似的句子索引为2
Tenga en cuenta que comencé a contar desde la oración 0
3. Calcula la similitud entre dos pares.
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
y_pred
Las dimensiones son [128, 768]
sim
Las dimensiones son [128, 128]
Cada línea representa la similitud entre la oración actual y otras oraciones, en este momento el valor en la diagonal debe ser 1
4. Amplifique el valor de la diagonal a un número mayor para eliminar loss
la influencia de la diagonal misma (el infinito negativo es casi 0 al calcular la entropía cruzada)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
5. Multiplique por el coeficiente de temperatura del hiperparámetro. En cuanto a por qué es así 0.05
, solo se puede decir que el experimento muestra que 0.05
el efecto es bueno.
sim = sim / 0.05
6. Utilice la pérdida de entropía cruzada para representar la pérdida de comparación, trate oraciones similares como categorías, acorte la distancia de los ejemplos positivos y acorte la distancia de los ejemplos negativos. En la misma oración, excepto las oraciones ingresadas dos veces como ejemplos positivos, batch
otros bert
negativos ejemplos
loss = F.cross_entropy(sim, y_true)
Efecto
SimCSE supervisado
datos
A diferencia de lo no supervisado, la entrada de lo no supervisado es una sola text
oración, mientras que el conjunto de datos supervisados es [text, text+, text-]
un triplete de
Modelo
La parte del modelo es la misma que la supervisada y también se utiliza bert
para encode
codificación y cls
extracción de vectores de oraciones.
Centrémonos en las diferentes partes y loss
calculemos:
def simcse_sup_loss(self, y_pred):
y_true = torch.arange(y_pred.shape[0], device=self.device)
use_row = torch.where((y_true + 1) % 3 != 0)[0]
y_true = (use_row - use_row % 3 * 2) + 1
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
sim = torch.index_select(sim, 0, use_row)
sim = sim / 0.05
loss = F.cross_entropy(sim, y_true)
return loss
1. Generar un índice de 0-191
y_true = torch.arange(y_pred.shape[0], device=self.device)
2. Seleccione el índice a usar. Si no hay una tercera oración label
, la tercera oración será un ejemplo negativo. Si no se usa la tercera oración, batch
otras oraciones en la misma oración se considerarán ejemplos negativos.
use_row = torch.where((y_true + 1) % 3 != 0)[0]
3. Deseche la etiqueta real después de la tercera oración.
y_true = (use_row - use_row % 3 * 2) + 1
4. Calcule la similitud en pares, y sim
la dimensión en este momento es [192, 192]
, incluido el ejemplo negativo de la tercera oración.
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
5. Eliminar la influencia de la dimensión en la diagonal.
sim = sim - torch.eye(y_pred.shape[0], device=self.device) * 1e12
6. Elija filas útiles
sim = torch.index_select(sim, 0, use_row)
7. Calcule la pérdida de entropía cruzada, de acuerdo con el método no supervisado.
loss = F.cross_entropy(sim, y_true)
Efecto
Resumir
Del gran camino a lo simple
Todos los códigos se han subido al Github
enlace: https://github.com/seanzhang-zhichen/simcse-pytorch
Conjunto de datos: Código de extracción: hlva