自监督学习——对比学习SimCLR框架(原理+代码)

1原理

对比学习

通过比较不同实例之间的相似性和差异性来进行学习。在对比学习中,我们将输入数据分为不同的类别或组(正负样本对),并通过比较样本之间的差异来提取特征或进行分类。
在这里插入图片描述

样本相似度

对比学习有几种不同的方法,其中最常见的是基于距离度量的方法。这些方法使用距离函数来度量两个实例之间的相似性,例如欧氏距离余弦相似度。通过计算实例之间的距离,我们可以找到最相似或最不相似的实例,从而进行特征选择、相似性匹配或分类任务。(正样本对相似度越近越好,负样本对相似度越远越好)
在这里插入图片描述

SimCLR——对比学习通过度量学习,提供特征提取的能力
在这里插入图片描述

  1. 取一个输入图像:同1张图像进行2种数据增强,形成一个正样本对儿;不同图像之间是负样本对儿。
  2. 准备2个随机的图像增强:旋转,颜色/饱和度/亮度变化,缩放,裁剪等。文中详细讨论了增强的范围,并分析了哪些增广效果最好。(构造正样本:图像SimCLR-数据增强、文本SimCSE-Dropout、图文CLIP-图像文本对
  3. 特征提取:运行一个深度神经网络(最好是卷积神经网络,如ViT、Bert、ResNet50)来获得那些增强图像的图像特征表示(嵌入)
  4. 特征投影:运行一个小的全连接线性神经网络,将嵌入投影到另一个向量空间。
  5. 计算loss:计算对比损失并通过两个网络进行反向传播。当来自同一图像的投影相似时,对比损失减少。投影之间的相似度可以是任意的,这里我使用余弦相似度,和论文中一样。
  6. 下游任务:对比学习得到Encoder做为特征提取器,根据下游任务的数据集进行微调Finetuin。

在这里插入图片描述
数据要多,batch要大(batchsize=8192)
在这里插入图片描述
正负样本对的构建,不需要标注
在这里插入图片描述

损失loss函数怎么设计?

l i , j = − l o g e x p ( s i m ( z i , z j ) / t ) ∑ k = 1 2 N 1 [ k ! = i ] e x p ( s i m ( z i , z k ) / t l_{i,j}=-log{\frac{exp(sim(z_i,z_j)/t)}{\sum_{k=1}^{2N}1_{[k!=i]}exp(sim(z_i,z_k)/t}} li,j=logk=12N1[k!=i]exp(sim(zi,zk)/texp(sim(zi,zj)/t)
其中,分子是同类之间的相似度(正样本之间的距离),分母是不同类之间的相似度(负样本对之间的距离)。 t t t是temperature(尺度<1)参数,用于调整比列。
在这里插入图片描述

2 代码

DALLE2-pytorch 以 CLIP 为例,学习对比学习的过程,loss:=文本(MLM,Mask Language Model),图像(SimCLR对比损失),图文(图像文本对儿对比损失)
在这里插入图片描述

class SimCLR(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        channels = 3,
        hidden_layer = -2,
        project_hidden = True,
        project_dim = 128,
        augment_both = True,
        use_nt_xent_loss = False,
        augment_fn = None,
        temperature = 0.1
    ):
        super().__init__()
        self.net = NetWrapper(net, project_dim, layer = hidden_layer)
        self.augment = default(augment_fn, get_default_aug(image_size, channels))
        self.augment_both = augment_both
        self.temperature = temperature

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate parameters
        self.forward(torch.randn(1, channels, image_size, image_size))

    def forward(self, x):
        b, c, h, w, device = *x.shape, x.device
        transform_fn = self.augment if self.augment_both else noop
		# 把原图使用不同数据增强和ViT提取成两个不同的图像特征(正样本对queries、keys)
        queries, _ = self.net(transform_fn(x))  
        keys, _    = self.net(self.augment(x))

        queries, keys = map(flatten, (queries, keys))
        # 计算loss
        loss = nt_xent_loss(queries, keys, temperature = self.temperature) 
        return loss

loss

def nt_xent_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device

    n = b * 2  # 同一图片内部不同patch也是负样本
    projs = torch.cat((queries, keys))
    logits = projs @ projs.t()

    mask = torch.eye(n, device=device).bool()
    logits = logits[~mask].reshape(n, n - 1)  # 同一图片内部不同patch也是负样本,除了自己和自己
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
    loss = F.cross_entropy(logits, labels, reduction = 'sum')
    loss /= n
    return loss

猜你喜欢

转载自blog.csdn.net/weixin_54338498/article/details/131986565