Implementación de OpenAI CLIP en conjuntos de datos personalizados

En enero de 2021, OpenAI anunció dos nuevos modelos: DALL-E y CLIP, ambos modelos multimodales que conectan texto e imágenes de alguna manera. El nombre completo de CLIP es Preentrenamiento de imágenes y lenguaje contrastivo, un método de preentrenamiento basado en pares contrastantes de texto e imagen. ¿Por qué introducir CLIP? Porque la difusión estable, que es tan popular ahora, no es un modelo único, sino que se compone de varios modelos. Se utilizará un codificador de texto para codificar la entrada de texto del usuario. Este codificador de texto es el codificador de texto en el modelo CLIP.

Cuando se entrena el modelo CLIP, se le puede dar una oración de entrada y extraer las imágenes más relevantes para que coincida. CLIP aprende la relación entre una frase completa y la imagen que describe. Es decir, se entrena con oraciones completas, en lugar de categorías discretas como "coche", "perro", etc. Esto es crucial para la aplicación. Cuando se entrena en frases completas, el modelo puede aprender más y reconocer patrones entre fotografías y texto. También demostraron que el modelo funciona como clasificador cuando se entrena con un conjunto de datos considerable de fotografías y oraciones correspondientes. Cuando se lanzó CLIP, su rendimiento de clasificación en el conjunto de datos ImageNet excedió el de ResNets-50 después de un ajuste fino sin ningún ajuste fino (disparo cero), lo que significa que es muy útil.

Entonces, en este artículo, implementaremos el modelo CLIP desde cero usando PyTorch para que podamos comprender mejor CLIP.

Aquí se necesitan dos bibliotecas: timm y transformadores. Importemos el código primero.

 import os
 import cv2
 import gc
 import numpy as np
 import pandas as pd
 import itertools
 from tqdm.autonotebook import tqdm
 import albumentations as A
 import matplotlib.pyplot as plt
 
 import torch
 from torch import nn
 import torch.nn.functional as F
 import timm
 from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

El siguiente paso es preprocesar los datos y la configuración general. config es un archivo de Python normal en el que colocamos todos los hiperparámetros. Si usamos Jupyter Notebook, es una clase definida al principio de Notebook.

 class CFG:
     debug = False
     image_path = "../input/flickr-image-dataset/flickr30k_images/flickr30k_images"
     captions_path = "."
     batch_size = 32
     num_workers = 4
     head_lr = 1e-3
     image_encoder_lr = 1e-4
     text_encoder_lr = 1e-5
     weight_decay = 1e-3
     patience = 1
     factor = 0.8
     epochs = 2
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     model_name = 'resnet50'
     image_embedding = 2048
     text_encoder_model = "distilbert-base-uncased"
     text_embedding = 768
     text_tokenizer = "distilbert-base-uncased"
     max_length = 200
 
     pretrained = True # for both image encoder and text encoder
     trainable = True # for both image encoder and text encoder
     temperature = 1.0
 
     # image size
     size = 224
 
     # for projection head; used for both image and text encoders
     num_projection_layers = 1
     projection_dim = 256 
     dropout = 0.1

También hay algunas clases auxiliares para nuestros indicadores personalizados.

 class AvgMeter:
     def __init__(self, name="Metric"):
         self.name = name
         self.reset()
 
     def reset(self):
         self.avg, self.sum, self.count = [0] * 3
 
     def update(self, val, count=1):
         self.count += count
         self.sum += val * count
         self.avg = self.sum / self.count
 
     def __repr__(self):
         text = f"{self.name}: {self.avg:.4f}"
         return text
 
 def get_lr(optimizer):
     for param_group in optimizer.param_groups:
         return param_group["lr"]

Nuestro objetivo es describir imágenes y frases. Por tanto, el conjunto de datos debe devolver tanto frases como imágenes. Por lo tanto, debe usar el etiquetador DistilBERT para etiquetar la oración (título) y luego proporcionar la identificación de la etiqueta (input_ids) y la máscara de atención a DistilBERT. DistilBERT es más pequeño que el modelo BERT, pero los resultados de los modelos son similares, por lo que optamos por usarlo.

El siguiente paso es tokenizar utilizando el tokenizador HuggingFace. El objeto tokenizador obtenido en __init__ se cargará cuando se ejecute el modelo. El título se rellena y se trunca hasta una longitud máxima predeterminada. Antes de cargar la imagen relevante, cargaremos un título codificado en getitem, que es un diccionario con las claves input_ids ycare_mask, transformándolo y aumentándolo (si corresponde) . Luego conviértalo en un tensor y guárdelo en un diccionario con "imagen" como clave. Finalmente introducimos el texto original del título en el diccionario junto con la palabra clave "título".

 class CLIPDataset(torch.utils.data.Dataset):
     def __init__(self, image_filenames, captions, tokenizer, transforms):
         """
         image_filenames and cpations must have the same length; so, if there are
         multiple captions for each image, the image_filenames must have repetitive
         file names 
         """
 
         self.image_filenames = image_filenames
         self.captions = list(captions)
         self.encoded_captions = tokenizer(
             list(captions), padding=True, truncation=True, max_length=CFG.max_length
         )
         self.transforms = transforms
 
     def __getitem__(self, idx):
         item = {
             key: torch.tensor(values[idx])
             for key, values in self.encoded_captions.items()
         }
 
         image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         image = self.transforms(image=image)['image']
         item['image'] = torch.tensor(image).permute(2, 0, 1).float()
         item['caption'] = self.captions[idx]
 
         return item
 
 
     def __len__(self):
         return len(self.captions)
 
 
 
 def get_transforms(mode="train"):
     if mode == "train":
         return A.Compose(
             [
                 A.Resize(CFG.size, CFG.size, always_apply=True),
                 A.Normalize(max_pixel_value=255.0, always_apply=True),
             ]
         )
     else:
         return A.Compose(
             [
                 A.Resize(CFG.size, CFG.size, always_apply=True),
                 A.Normalize(max_pixel_value=255.0, always_apply=True),
             ]
         )

Codificador de imágenes y texto: utilizaremos ResNet50 como codificador de imágenes.

 class ImageEncoder(nn.Module):
     """
     Encode images to a fixed size vector
     """
 
     def __init__(
         self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
     ):
         super().__init__()
         self.model = timm.create_model(
             model_name, pretrained, num_classes=0, global_pool="avg"
         )
         for p in self.model.parameters():
             p.requires_grad = trainable
 
     def forward(self, x):
         return self.model(x)

Utilice DistilBERT como codificador de texto. Utilice la representación final de tokens CLS para obtener la representación completa de la oración.

 class TextEncoder(nn.Module):
     def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
         super().__init__()
         if pretrained:
             self.model = DistilBertModel.from_pretrained(model_name)
         else:
             self.model = DistilBertModel(config=DistilBertConfig())
             
         for p in self.model.parameters():
             p.requires_grad = trainable
 
         # we are using the CLS token hidden representation as the sentence's embedding
         self.target_token_idx = 0
 
     def forward(self, input_ids, attention_mask):
         output = self.model(input_ids=input_ids, attention_mask=attention_mask)
         last_hidden_state = output.last_hidden_state
         return last_hidden_state[:, self.target_token_idx, :]

El código anterior ha codificado la imagen y el texto en vectores de tamaño fijo (imagen 2048, texto 768), necesitamos que la imagen y el texto tengan dimensiones similares para poder compararlos, por lo que proyectamos los vectores de 2048 dimensiones y 768 dimensiones a 256 dimensional (proyección_dim), podemos compararlos solo si las dimensiones son las mismas.

 class ProjectionHead(nn.Module):
     def __init__(
         self,
         embedding_dim,
         projection_dim=CFG.projection_dim,
         dropout=CFG.dropout
     ):
         super().__init__()
         self.projection = nn.Linear(embedding_dim, projection_dim)
         self.gelu = nn.GELU()
         self.fc = nn.Linear(projection_dim, projection_dim)
         self.dropout = nn.Dropout(dropout)
         self.layer_norm = nn.LayerNorm(projection_dim)
     
     def forward(self, x):
         projected = self.projection(x)
         x = self.gelu(projected)
         x = self.fc(x)
         x = self.dropout(x)
         x = x + projected
         x = self.layer_norm(x)
         return x

Entonces, al final nuestro modelo CLIP se ve así:

 class CLIPModel(nn.Module):
     def __init__(
         self,
         temperature=CFG.temperature,
         image_embedding=CFG.image_embedding,
         text_embedding=CFG.text_embedding,
     ):
         super().__init__()
         self.image_encoder = ImageEncoder()
         self.text_encoder = TextEncoder()
         self.image_projection = ProjectionHead(embedding_dim=image_embedding)
         self.text_projection = ProjectionHead(embedding_dim=text_embedding)
         self.temperature = temperature
 
     def forward(self, batch):
         # Getting Image and Text Features
         image_features = self.image_encoder(batch["image"])
         text_features = self.text_encoder(
             input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
         )
         # Getting Image and Text Embeddings (with same dimension)
         image_embeddings = self.image_projection(image_features)
         text_embeddings = self.text_projection(text_features)
 
         # Calculating the Loss
         logits = (text_embeddings @ image_embeddings.T) / self.temperature
         images_similarity = image_embeddings @ image_embeddings.T
         texts_similarity = text_embeddings @ text_embeddings.T
         targets = F.softmax(
             (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
         )
         texts_loss = cross_entropy(logits, targets, reduction='none')
         images_loss = cross_entropy(logits.T, targets.T, reduction='none')
         loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
         return loss.mean()
 
 #这里还加了一个交叉熵函数
 def cross_entropy(preds, targets, reduction='none'):
     log_softmax = nn.LogSoftmax(dim=-1)
     loss = (-targets * log_softmax(preds)).sum(1)
     if reduction == "none":
         return loss
     elif reduction == "mean":
         return loss.mean()

Cabe señalar aquí que CLIP utiliza entropía cruzada simétrica como función de pérdida, lo que puede reducir el impacto del ruido y mejorar la robustez del modelo. Aquí solo utilizamos entropía cruzada por simplicidad.

Podemos probar:

 # A simple Example
 
 batch_size = 4
 dim = 256
 embeddings = torch.randn(batch_size, dim)
 out = embeddings @ embeddings.T
 print(F.softmax(out, dim=-1))

El siguiente paso es el entrenamiento, hay algunas funciones que pueden ayudarnos a cargar el cargador de datos de entrenamiento y verificación.

 def make_train_valid_dfs():
     dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
     max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
     image_ids = np.arange(0, max_id)
     np.random.seed(42)
     valid_ids = np.random.choice(
         image_ids, size=int(0.2 * len(image_ids)), replace=False
     )
     train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
     train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
     valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
     return train_dataframe, valid_dataframe
 
 
 def build_loaders(dataframe, tokenizer, mode):
     transforms = get_transforms(mode=mode)
     dataset = CLIPDataset(
         dataframe["image"].values,
         dataframe["caption"].values,
         tokenizer=tokenizer,
         transforms=transforms,
     )
     dataloader = torch.utils.data.DataLoader(
         dataset,
         batch_size=CFG.batch_size,
         num_workers=CFG.num_workers,
         shuffle=True if mode == "train" else False,
     )
     return dataloader

Luego viene la formación y la evaluación.

 def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
     loss_meter = AvgMeter()
     tqdm_object = tqdm(train_loader, total=len(train_loader))
     for batch in tqdm_object:
         batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
         loss = model(batch)
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         if step == "batch":
             lr_scheduler.step()
 
         count = batch["image"].size(0)
         loss_meter.update(loss.item(), count)
 
         tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
     return loss_meter
 
 
 def valid_epoch(model, valid_loader):
     loss_meter = AvgMeter()
 
     tqdm_object = tqdm(valid_loader, total=len(valid_loader))
     for batch in tqdm_object:
         batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
         loss = model(batch)
 
         count = batch["image"].size(0)
         loss_meter.update(loss.item(), count)
 
         tqdm_object.set_postfix(valid_loss=loss_meter.avg)
     return loss_meter

Finalmente, se integra todo el proceso.

 def main():
     train_df, valid_df = make_train_valid_dfs()
     tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
     train_loader = build_loaders(train_df, tokenizer, mode="train")
     valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
 
 
     model = CLIPModel().to(CFG.device)
     params = [
         {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
         {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
         {"params": itertools.chain(
             model.image_projection.parameters(), model.text_projection.parameters()
         ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
     ]
     optimizer = torch.optim.AdamW(params, weight_decay=0.)
     lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
         optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
     )
     step = "epoch"
 
     best_loss = float('inf')
     for epoch in range(CFG.epochs):
         print(f"Epoch: {epoch + 1}")
         model.train()
         train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
         model.eval()
         with torch.no_grad():
             valid_loss = valid_epoch(model, valid_loader)
         
         if valid_loss.avg < best_loss:
             best_loss = valid_loss.avg
             torch.save(model.state_dict(), "best.pt")
             print("Saved Best Model!")
         
         lr_scheduler.step(valid_loss.avg)

Aplicación: obtenga incrustaciones de imágenes y encuentre coincidencias.

¿Cómo lo aplicamos realmente después de completar la capacitación? Necesitamos escribir una función que cargue el modelo entrenado, le proporcione imágenes del conjunto de validación y devuelva la forma (valid_set_size, 256) y las incrustaciones de imágenes del modelo en sí.

 def get_image_embeddings(valid_df, model_path):
     tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
     valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
     
     model = CLIPModel().to(CFG.device)
     model.load_state_dict(torch.load(model_path, map_location=CFG.device))
     model.eval()
     
     valid_image_embeddings = []
     with torch.no_grad():
         for batch in tqdm(valid_loader):
             image_features = model.image_encoder(batch["image"].to(CFG.device))
             image_embeddings = model.image_projection(image_features)
             valid_image_embeddings.append(image_embeddings)
     return model, torch.cat(valid_image_embeddings)
 _, valid_df = make_train_valid_dfs()
 model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
 
 def find_matches(model, image_embeddings, query, image_filenames, n=9):
     tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
     encoded_query = tokenizer([query])
     batch = {
         key: torch.tensor(values).to(CFG.device)
         for key, values in encoded_query.items()
     }
     with torch.no_grad():
         text_features = model.text_encoder(
             input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
         )
         text_embeddings = model.text_projection(text_features)
     
     image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
     text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
     dot_similarity = text_embeddings_n @ image_embeddings_n.T
     
     values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
     matches = [image_filenames[idx] for idx in indices[::5]]
     
     _, axes = plt.subplots(3, 3, figsize=(10, 10))
     for match, ax in zip(matches, axes.flatten()):
         image = cv2.imread(f"{CFG.image_path}/{match}")
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         ax.imshow(image)
         ax.axis("off")
     
     plt.show()

El método de llamada es el siguiente:

 find_matches(model, 
              image_embeddings,
              query="one dog sitting on the grass",
              image_filenames=valid_df['image'].values,
              n=9)

Puedes ver que nuestro efecto de personalización es bastante bueno (pero hay un gato en la imagen, ja). En otras palabras, el método CLIP también se puede personalizar en conjuntos de datos pequeños.

El siguiente es el código y el conjunto de datos de este artículo:

https://avoid.overfit.cn/post/25295aa8daee45fc8336b2e86a29106a

Autor: Jyoti Dabass, Ph.D.

Supongo que te gusta

Origin blog.csdn.net/m0_46510245/article/details/132800715
Recomendado
Clasificación