Resumen de aprendizaje profundo: ajuste fino de CLIP con su propio conjunto de datos

Resumen de CLIP

CLIP (Contrastive Language-Image Pretraining) es un modelo de aprendizaje profundo desarrollado por OpenAI para la codificación conjunta de imágenes y texto en lenguaje natural. Emplea un enfoque de aprendizaje multimodal que permite que el modelo comprenda la relación semántica entre imágenes y texto.

Su idea central es tratar las imágenes y los textos como insumos igualmente importantes y aprender las conexiones entre ellos a través del entrenamiento conjunto. El modelo CLIP utiliza un codificador compartido que mapea imágenes y texto por separado en un espacio de características compartidas. Al comparar los vectores codificados de imágenes y texto, el modelo puede juzgar la similitud y la relación entre ellos.

Utiliza una función de pérdida de contraste durante el entrenamiento para alentar al modelo a codificar pares de imágenes y texto relacionados más juntos y pares de imágenes y texto irrelevantes más separados. Esto permite que el modelo CLIP tenga una buena capacidad de generalización y aprenda habilidades generales de comprensión de imágenes y textos durante el entrenamiento.

Su proceso general es el siguiente:
inserte la descripción de la imagen aquí

Demuestra una sólida capacidad de disparo cero y se desempeña bien en muchas tareas de visión y lenguaje, como clasificación de imágenes, descripción de generación de imágenes, respuesta a preguntas de imágenes, etc. Su capacidad multimodal permite que el modelo CLIP establezca una fuerte conexión semántica entre las imágenes y el texto, lo que brinda una capacidad de comprensión y análisis más integral para diversos escenarios de aplicación.

Precisamente debido a su excelente capacidad de disparo cero, el modelo entrenado contiene una gran cantidad de conocimiento que se puede utilizar. Por lo tanto, en algunas tareas, como las tareas de clasificación y las tareas de subtítulos, puede intentar ajustar CLIP en su propio conjunto de datos, tal vez a través de Esta operación puede lograr un buen rendimiento. Sin embargo, no hay una introducción detallada sobre cómo ajustar CLIP en Internet, así que clasifiqué el conocimiento relevante y lo registré aquí.
Link de referencia

Afinar el código

biblioteca de terceros

  • clip-by-openai
  • antorcha

Tomemos la tarea de clasificación de imágenes que hice como ejemplo para presentar los pasos relevantes.

paso introducción

1. Cree el conjunto de datos

Cree su propio conjunto de datos, los datos devueltos por cada iteración incluyen: Imagen RGB y etiqueta de imagen (una foto de {etiqueta})
El ejemplo de código es el siguiente:

import os
from PIL import Image
import numpy as np
import clip
class YourDataset(Dataset):
    def __init__(self,img_root,meta_root,is_train,preprocess):
        # 1.根目录(根据自己的情况更改)
        self.img_root = img_root
        self.meta_root = meta_root
        # 2.训练图片和测试图片地址(根据自己的情况更改)
        self.train_set_file = os.path.join(meta_root,'train.txt')
        self.test_set_file = os.path.join(meta_root,'test.txt')
        # 3.训练 or 测试(根据自己的情况更改)
        self.is_train = is_train
        # 4.处理图像
        self.img_process = preprocess
        # 5.获得数据(根据自己的情况更改)
        self.samples = []
        self.sam_labels = []
        # 5.1 训练还是测试数据集
        self.read_file = ""
        if is_train:
            self.read_file = self.train_set_file
        else:
            self.read_file = self.test_set_file
		# 5.2 获得所有的样本(根据自己的情况更改)
        with open(self.read_file,'r') as f:
            for line in f:
                img_path = os.path.join(self.img_root,line.strip() + '.jpg')
                label = line.strip().split('/')[0]
                label = label.replace("_"," ")
                label = "a photo of " + label
                self.samples.append(img_path)
                self.sam_labels.append(label)
        # 转换为token
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        token = self.tokens[idx]
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        # 对图像进行转换
        image = self.img_process(image)
        return image,token

2. Cargue el modelo CLIP previamente entrenado y la configuración relacionada

En primer lugar, utilice una biblioteca de terceros para cargar el modelo CLIP previamente entrenado, que devolverá un modelo CLIP y un preprocesamiento de la función de preprocesamiento de imágenes, que se utilizará en el proceso de carga de datos posterior.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)

Luego inicialice el optimizador y la función de pérdida. Cabe señalar que si su pérdida es muy grande o anormal al principio, puede ajustar la tasa de aprendizaje y otros parámetros del optimizador para ajustar. Por lo general, un ajuste más pequeño tendrá un efecto.

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.1)

# 创建损失函数
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

3. Cargar datos

Este paso es principalmente para llamar a la clase creada en el primer paso y luego usar la función DataLoader para cargar su propio conjunto de datos.
el código se muestra a continuación:

your_dataset = YourDataset(img_root= '/images',
                                          meta_root= '/meta',
                                          is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)

4. Empieza a entrenar

El código de entrenamiento se puede escribir de acuerdo con la plantilla. Se necesita entrenar un total de épocas. Cada vez, todos los datos en un conjunto de datos se deben entrenar una vez, y luego el modelo se guarda cuando se completa cada entrenamiento. Hay dos tipos:

  • Guardar los parámetros del modelo.
  • Guardar parámetros del modelo, optimizador, número de iteraciones

El código de esta parte es el siguiente:

phase = "train"
model_name = "your model name"
ckt_gap = 4
epoches = 30
for epoch in range(epoches):
    scheduler.step()
    total_loss = 0
    batch_num = 0
    # 使用混合精度,占用显存更小
    with torch.cuda.amp.autocast(enabled=True):
        for images,label_tokens in your_dataloader:
            # 将图片和标签token转移到device设备
            images = images.to(device)
            label_tokens = label_tokens.to(device)
            batch_num += 1
            # 优化器梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                logits_per_image, logits_per_text = net(images, label_tokens)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                total_loss += cur_loss
                if phase == "train":
                    cur_loss.backward()
                    if device == "cpu":
                        optimizer.step()
                    else:
                        optimizer.step()
                        clip.model.convert_weights(net) 
            if batch_num % 4 == 0:
                logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))
        epoch_loss = total_loss / dataset_size_your
        torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")
        logger.info(f"weights_{epoch} saved")
        if epoch % ckt_gap == 0:
            checkpoint_path = f"{model_name}_ckt.pth"
            checkpoint = {
    
    
                'it': epoch,
                'network': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"checkpoint_{epoch} saved")
        logger.info('{} Loss: {:.4f}'.format(
            phase, epoch_loss))

todos los códigos

import os
from PIL import Image
import numpy as np
import clip
from loguru import logger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn

class YourDataset(Dataset):
    def __init__(self,img_root,meta_root,is_train,preprocess):
        # 1.根目录(根据自己的情况更改)
        self.img_root = img_root
        self.meta_root = meta_root
        # 2.训练图片和测试图片地址(根据自己的情况更改)
        self.train_set_file = os.path.join(meta_root,'train.txt')
        self.test_set_file = os.path.join(meta_root,'test.txt')
        # 3.训练 or 测试(根据自己的情况更改)
        self.is_train = is_train
        # 4.处理图像
        self.img_process = preprocess
        # 5.获得数据(根据自己的情况更改)
        self.samples = []
        self.sam_labels = []
        # 5.1 训练还是测试数据集
        self.read_file = ""
        if is_train:
            self.read_file = self.train_set_file
        else:
            self.read_file = self.test_set_file
		# 5.2 获得所有的样本(根据自己的情况更改)
        with open(self.read_file,'r') as f:
            for line in f:
                img_path = os.path.join(self.img_root,line.strip() + '.jpg')
                label = line.strip().split('/')[0]
                label = label.replace("_"," ")
                label = "photo if " + label
                self.samples.append(img_path)
                self.sam_labels.append(label)
        # 转换为token
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        token = self.tokens[idx]
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        # 对图像进行转换
        image = self.img_process(image)
        return image,token
# 创建模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.1)

# 创建损失函数
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# 加载数据集
your_dataset = YourDataset(img_root= '/images',
                                          meta_root= '/meta',
                                          is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)

phase = "train"
model_name = "your model name"
ckt_gap = 4
for epoch in range(st,args.epoches):
    scheduler.step()
    total_loss = 0
    batch_num = 0
    # 使用混合精度,占用显存更小
    with torch.cuda.amp.autocast(enabled=True):
        for images,label_tokens in your_dataloader:
            # 将图片和标签token转移到device设备
            images = images.to(device)
            label_tokens = label_tokens.to(device)
            batch_num += 1
            # 优化器梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                logits_per_image, logits_per_text = net(images, label_tokens)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                total_loss += cur_loss
                if phase == "train":
                    cur_loss.backward()
                    if device == "cpu":
                        optimizer.step()
                    else:
                        optimizer.step()
                        clip.model.convert_weights(net) 
            if batch_num % 4 == 0:
                logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))
        epoch_loss = total_loss / dataset_size_food101
        torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")
        logger.info(f"weights_{epoch} saved")
        if epoch % ckt_gap == 0:
            checkpoint_path = f"{model_name}_ckt.pth"
            checkpoint = {
    
    
                'it': epoch,
                'network': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"checkpoint_{epoch} saved")
        logger.info('{} Loss: {:.4f}'.format(
            phase, epoch_loss))

Supongo que te gusta

Origin blog.csdn.net/qq_41234663/article/details/131024876
Recomendado
Clasificación