「BLIP 微调指南」以 Image-Text Captioning 任务为例

前言:近日需要用到 BLIP 微调下游任务,搜索发觉如今并无 BLIP 微调教程,下面就以 Image-Text Captioning 任务为例,演示如何完成 BLIP 模型在自己数据集上的微调。


1. BLIP 介绍

相关论文BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (ICML, 2022)

演示地址https://huggingface.co/spaces/Salesforce/BLIP

开源代码https://github.com/salesforce/BLIP

在开源代码的 README.md 介绍中,可以看到

在这里插入图片描述
可知 BLIP 可以完成 Image-Text Captioning、VQA 以及 NLVR2 这几个下游任务。


2. 关键代码定位

下图是 github 仓库中的文件构成:

在这里插入图片描述
首先通过 https://github.com/salesforce/BLIP/blob/main/train_caption.py 文件,了解 BLIP 如何用于 captioning 场景中。

相关参数定义在 https://github.com/salesforce/BLIP/blob/main/configs/caption_coco.yaml 文件中。

发现模型通过如下方式定义:

在这里插入图片描述

根据 from models.blip import blip_decoder 的得知,blip_decoder 函数定义于 models.blip 文件中,于是转到https://github.com/salesforce/BLIP/blob/main/models/blip.py 文件。微调过程中,主要使用的是该文件中的 blip_decoder() 函数以及 BLIP_Decoder 类。

  • blip_decoder() 函数定义如下,参数列表包括 pretrained 模型的地址以及 BLIP_Decoder 类的参数列表。

    def blip_decoder(pretrained='',**kwargs):
        model = BLIP_Decoder(**kwargs)
        if pretrained:
            model,msg = load_checkpoint(model,pretrained)
            assert(len(msg.missing_keys)==0)
        return model    
    
  • BLIP_Decoder 类的初始化函数参数列表如下:

    class BLIP_Decoder(nn.Module):
        def __init__(self,                 
                     med_config = 'configs/med_config.json',  
                     image_size = 384,
                     vit = 'base',
                     vit_grad_ckpt = False,
                     vit_ckpt_layer = 0,
                     prompt = 'a picture of ',
                     ):
    

    其中,med_config 对应的 json 文件路径为 https://github.com/salesforce/BLIP/blob/main/configs/med_config.json
    image_size 为模型接收到的图像尺寸;
    vit 为 image encoder 的规模,可选 base 或 large;
    vit_grad_ckpt 是梯度检查点,如果设置为 True,训练时就会节省显存;反之,占用显存更多。详见此篇文章
    vit_ckpt_layer 为使用梯度检查点的层(个人理解);
    prompt 为 BLIP 使用的提示文本,以字符串形式的自然语言文本给出。

PyTorch的 gradient checkpoint 是通过torch.utils.checkpoint.checkpoint(function, *args, **kwargs)函数实现的。Gradient Checkpoint是通过以更长的计算时间为代价,换取更少的显存占用。相比于原本需要存储所有中间变量以供反向传播使用,使用 checkpoint 的部分不存储中间变量,而是在反向传播过程中重新计算这些中间变量。模型中的任何部分都可以使用 gradient checkpoint.

3. 关键参数赋值

  • blip_decoder() 的 pretrained 这一参数的值是在 https://github.com/salesforce/BLIP/blob/main/configs/caption_coco.yaml 文件中找到的;
  • image_size 的值要与自己的数据集中图像尺寸对应,我这里将其修改为 224;
  • prompt 修改为自己需要的自然语言提示文本,注意以字符串形式给出;
  • 其他参数保持不变即可。

4. 模型定义&使用

https://github.com/salesforce/BLIP/blob/main/models/ 目录下的 blip.py, vit.py 以及 med.py 文件下载到自己项目的同一目录下。

在主文件中,通过使用如下命令使用 BLIP model:

BLIPModel = blip_decoder(pretrained=args['pretrained'], image_size=args['image_size'], vit=args['vit'], vit_grad_ckpt=args['vit_grad_ckpt'], vit_ckpt_layer=args['vit_ckpt_layer'], prompt=args['prompt']).to(device)
  • 训练时,使用代码 loss = BLIPModel(imgs, texts) ,调用 Blip_Decoder 类中的 forward() 函数,会得到当前 batch 数据对应的 loss,然后按照关惯常操作进行反向传播;

  • 测试时,使用下述代码
    generated_texts = BLIPModel.generate(imgs, sample=True, num_beams=3, max_length=30, min_length=5, top_p=0.95, repetition_penalty=1.0)
    调用 Blip_Decoder 类中的 generate() 函数,会得到模型对当前 batch 数据生成的自然语言文本。


参考资料

  1. https://github.com/salesforce/BLIP
  2. Pytorch Gradient Checkpoint使用示例_森尼嫩豆腐的博客-CSDN博客
  3. https://pytorch.org/docs/stable/checkpoint.html

猜你喜欢

转载自blog.csdn.net/qq_36332660/article/details/131980723