BLIP2代码阅读

Q-Former核心

Image-text Contrastive

# 将image_feats扩展到所有GPU上,image_feats是图像特征,是一个tensor,维度为[batch_size*num_gpu, num_query_tokens, embed_dim]
image_feats_all = concat_all_gather(image_feats)  

# 将text_feat扩展到所有GPU上,text_feat是文本特征,是一个tensor,维度为[batch_size*num_gpu, embed_dim]
text_feat_all = concat_all_gather(text_feat)  

# 对每个查询标记计算图像到文本的相似度
sim_q2t = torch.matmul(
    image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
).squeeze()

# 对image-text相似度进行聚合,取所有查询标记的最大值
sim_i2t, _ = sim_q2t.max(-1)  
sim_i2t = sim_i2t / self.temp  # 对相似度进行缩放

# 对每个查询标记计算文本到图像的相似度
sim_t2q = torch.matmul(
    text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
).squeeze()

# 对text-image相似度进行聚合,取所有查询标记的最大值
sim_t2i, _ = sim_t2q.max(-1)  
sim_t2i = sim_t2i / self.temp  # 对相似度进行缩放

# 获取进程的rank和batch的大小
rank = dist.get_rank()  
bs = image.size(0)  

# 生成targets,用于计算损失
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
    image.device
)

# 如果samples中包含image_id,则表示在COCO检索微调训练中
if "image_id" in samples.keys():  
    # 获取image_ids
    image_ids = samples["image_id"].view(-1,1)
    # 将所有图片的image_id扩展到所有GPU上
    image_ids_all = concat_all_gather(image_ids)
    # 计算相似度目标,对匹配的图像进行惩罚
    pos_idx = torch.eq(image_ids, image_ids_all.t()).float()       
    sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)   
    sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)

    # 计算损失,在COCO检索微调训练中
    loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()
    loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()     
    loss_itc = (loss_t2i+loss_i2t)/2  
# 如果不是COCO检索微调训练
else:                     
    # 普通的图像-文本对比损失计算
    loss_itc = (
        F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
        + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
    ) / 2

Text-Image Match

# 将text_tokens中的输入ID和注意力掩码扩展到所有GPU上
text_input_ids_world = concat_all_gather(text_tokens.input_ids)  
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)

# 将image_embeds扩展到所有GPU上,并且允许梯度传播
image_embeds_world = all_gather_with_grad(image_embeds)  

# 禁止梯度传播,用于下面的操作
with torch.no_grad():
    # 如果samples中包含image_id,则在image_ids与image_ids_all.t()相等的位置填充-10000
    if "image_id" in samples.keys():
        mask = torch.eq(image_ids, image_ids_all.t())
        sim_t2i.masked_fill_(mask, -10000)
        sim_i2t.masked_fill_(mask, -10000)
    else:    
        # 否则在sim_t2i和sim_i2t的对角线位置填充-10000
        sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
        sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)            
        
    # 计算softmax权重,对文本到图像和图像到文本的相似性进行归一化
    weights_t2i = F.softmax(sim_t2i, dim=1)
    weights_i2t = F.softmax(sim_i2t, dim=1)

# 为每个文本选择一个负向图像
image_embeds_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_t2i[b], 1).item()
    image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

# 为每个图像选择一个负向文本
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
    neg_idx = torch.multinomial(weights_i2t[b], 1).item()
    text_ids_neg.append(text_input_ids_world[neg_idx])
    text_atts_neg.append(text_attention_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)

# 将所有的正向文本和负向文本组合成一个tensor
text_ids_all = torch.cat(
    [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
)  # pos, pos, neg
text_atts_all = torch.cat(
    [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
    dim=0,
)

# 生成用于查询任务的token和注意力掩码
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
    image.device
)
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)

# 将所有的正向图像和负向图像组合成一个tensor
image_embeds_all = torch.cat(
    [image_embeds, image_embeds_neg, image_embeds], dim=0
)  # pos, neg, pos
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
    image.device
)

# 执行查询任务,获取输出
output_itm = self.Qformer.bert(
    text_ids_all,
    query_embeds=query_tokens_itm,
    attention_mask=attention_mask_all,
    encoder_hidden_states=image_embeds_all,
    encoder_attention_mask=image_atts_all,
    return_dict=True,
)

# 提取视觉-语言交互层的输出
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]

# 对视觉-语言交互层的输出应用itm头
vl_output = self.itm_head(vl_embeddings)
logits = vl_output.mean(dim=1)

# 生成itm标签,对正向样本标记为1,对负向样本标记为0
itm_labels = torch.cat(
    [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
    dim=0,
).to(image.device)

# 计算itm任务的交叉熵损失
loss_itm = F.cross_entropy(logits, itm_labels)

Image Caption

# 克隆text_tokens的输入ID,并将第一个位置设置为起始符号的ID
decoder_input_ids = text_tokens.input_ids.clone()
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id

# 将要预测的标签设置为decoder_input_ids,并使用-100填充pad位置
labels = decoder_input_ids.masked_fill(
    decoder_input_ids == self.tokenizer.pad_token_id, -100
)

# 生成注意力掩码,用于语言模型的输入
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
    image.device
)
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)

# 执行语言模型的训练,获取LM输出
lm_output = self.Qformer(
    decoder_input_ids,
    attention_mask=attention_mask,
    past_key_values=query_output.past_key_values,
    return_dict=True,
    labels=labels,
)

# 计算语言模型的损失
loss_lm = lm_output.loss

# 返回总的损失和各个部分的损失
return BlipOutput(
    loss=loss_itc + loss_itm + loss_lm,
    loss_itc=loss_itc,
    loss_itm=loss_itm,
    loss_lm=loss_lm,
)

Generate

def generate(
    self,
    samples,
    use_nucleus_sampling=False,
    num_beams=3,
    max_length=30,
    min_length=10,
    top_p=0.9,
    repetition_penalty=1.0,
):
    """
    Args:
    samples (dict): 包含以下键的字典:
        - image (torch.Tensor): 形状为(batch_size, 3, H, W)的张量
    use_nucleus_sampling (bool): 是否使用核采样。如果为False,则使用top-k采样。
    num_beams (int): 用于束搜索的束的数量。1表示不使用束搜索。
    max_length (int): 要生成的序列的最大长度。
    min_length (int): 要生成的序列的最小长度。
    top_p (float): 核采样的累积概率。
    repetition_penalty (float): 重复惩罚的参数。1.0表示没有惩罚。
    num_captions (int): 每个图像要生成的字幕数。
    Returns:
    captions (list): 长度为batch_size * num_captions的字符串列表。
    """
    # 获取图像并将其编码为图像嵌入
    image = samples["image"]
    image_embeds = self.ln_vision(self.visual_encoder(image))

    # 如果不使用核采样,则扩展图像嵌入,用于束搜索
    if not use_nucleus_sampling:
        image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
    else:
        num_beams = 1
    image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
        image.device
    )

    # 设置模型参数
    model_kwargs = {
    
    
        "encoder_hidden_states": image_embeds,
        "encoder_attention_mask": image_atts,
    }

    # 生成文本描述
    input_ids = (
        torch.LongTensor(image.size(0), 1)
        .fill_(self.tokenizer.bos_token_id)
        .to(image.device)
    )
    # image token编码
    query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
	# 调用Qformer进行生成
    outputs = self.Qformer.generate(
        input_ids=input_ids,
        query_embeds=query_tokens,
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        do_sample=use_nucleus_sampling,
        top_p=top_p,
        eos_token_id=self.tokenizer.sep_token_id,
        pad_token_id=self.tokenizer.pad_token_id,
        **model_kwargs
    )
    # 解码生成captions
    captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return captions

qformer的generate模块好复杂

generate 函数用于模型的条件生成。

首先,generate 函数通过各种参数调用生成。这些参数包括生成配置文件、标记处理列表、停止条件列表、前缀允许令牌 fn、委托人模型、流媒体设备、负提示 ids 和注意力掩码等。这些参数用来进一步控制生成的方式和结果。

generate 函数主要分为以下步骤:

  1. 处理 generation_config 和可能更新它的 kwargs,以及验证 .generate() 调用。

  2. 设置生成参数(如处理?)

  3. 定义模型输入

  4. 定义其他模型kwargs

  5. 准备自回归生成的input_ids

  6. 准备包含其他停止标准的 max_length

  7. 确定生成模式

有了这些步骤,generate 函数可以进入不同的生成模式和执行相应的生成方法,比如贪婪搜索,显示搜索等。

在选择相应的生成模式后,generate 函数根据模型的生成配置、输入和相应的参数调用相应的生成方法,包括模型的贪婪搜索、显示搜索、样本生成等。

接下来,通过选择合适的方法和参数对模型进行生成,并返回生成的输出。
最后,如果选择了与专辑模式,则可以使用 assistant_model 对象来进行助攻生成。在这种情况下,生成方法中的一些参数和逻辑都将有所不同。目的是加速生成。具体来说,在助攻生成过程中,特定的助攻模型将返回模型的生成 output 或torch.FloatTensor

总的来说,generate 函数负责执行不同的生成方法和逻辑以生成模型的输出。它允许用户根据实际需要执行不同的生成方法,并支持其他参数的进一步控制。这个函数给了人们灵活的选择,以获得满足需求的生成输出。

猜你喜欢

转载自blog.csdn.net/RandyHan/article/details/134804888
今日推荐