Self-supervised learning - contrastive learning SimCLR framework (principle + code)

1 principle

Comparative study

pass 比较不同实例之间的相似性和差异性来进行学习. In contrastive learning, we classify the input data into different categories or groups ( 正负样本对) and extract features or perform classification by comparing the differences between samples.
insert image description here

Sample similarity

There are several different approaches to contrastive learning, the most common of which is 基于距离度量the . These methods use a distance function to measure the similarity between two instances, such as 欧氏距离or 余弦相似度. By computing the distance between instances, we can find the most or least similar instances for feature selection, similarity matching or classification tasks. (The closer the positive sample pair similarity is, the better, and the farther the negative sample pair similarity is, the better)
insert image description here

SimCLR - Contrastive learning provides feature extraction capabilities through metric learning
insert image description here

  1. Take an input image : perform two data enhancements on the same image to form a positive sample pair; different images are negative sample pairs.
  2. Prepare 2 random image enhancements : rotation, color/saturation/brightness changes, scaling, cropping, etc. The range of augmentations is discussed in detail and it is analyzed which augmentations work best. ( Construct positive samples: image SimCLR-data enhancement, text SimCSE-Dropout, graphic-text CLIP-image-text pair )
  3. Feature extraction : run a deep neural network (preferably a convolutional neural network such as ViT, Bert, ResNet50) to obtain those enhanced images 图像特征表示(嵌入).
  4. Feature Projection : Run a small fully connected linear neural network to project the embeddings into another vector space.
  5. Compute loss : Computes the contrastive loss and backpropagates through the two networks. Contrast loss is reduced when projections from the same image are similar. The similarity between projections can be arbitrary, here I use cosine similarity, as in the paper.
  6. Downstream tasks : Contrastive learning 得到Encoderis used as a feature extractor to fine-tune Finetuin according to the dataset of downstream tasks.

insert image description here
There is a lot of data, and the batch is large (batchsize=8192).
insert image description here
The construction of positive and negative sample pairs does not need to be labeled
insert image description here

How to design the loss loss function?

l i , j = − l o g e x p ( s i m ( z i , z j ) / t ) ∑ k = 1 2 N 1 [ k ! = i ] e x p ( s i m ( z i , z k ) / t l_{i,j}=-log{\frac{exp(sim(z_i,z_j)/t)}{\sum_{k=1}^{2N}1_{[k!=i]}exp(sim(z_i,z_k)/t}} li,j=logk=12 N1[k!=i]exp(sim(zi,zk)/texp(sim(zi,zj)/t)
Among them, the numerator is the similarity between the same class (the distance between positive samples), and the denominator is the similarity between different classes (the distance between negative sample pairs). ttt is the temperature (scale<1) parameter, which is used to adjust the ratio.
insert image description here

2 codes

DALLE2-pytorch takes CLIP as an example to learn the process of contrastive learning, loss: = text (MLM, Mask Language Model), image (SimCLR contrastive loss), graphic (image-text pair contrastive loss)
insert image description here

class SimCLR(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        channels = 3,
        hidden_layer = -2,
        project_hidden = True,
        project_dim = 128,
        augment_both = True,
        use_nt_xent_loss = False,
        augment_fn = None,
        temperature = 0.1
    ):
        super().__init__()
        self.net = NetWrapper(net, project_dim, layer = hidden_layer)
        self.augment = default(augment_fn, get_default_aug(image_size, channels))
        self.augment_both = augment_both
        self.temperature = temperature

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate parameters
        self.forward(torch.randn(1, channels, image_size, image_size))

    def forward(self, x):
        b, c, h, w, device = *x.shape, x.device
        transform_fn = self.augment if self.augment_both else noop
		# 把原图使用不同数据增强和ViT提取成两个不同的图像特征(正样本对queries、keys)
        queries, _ = self.net(transform_fn(x))  
        keys, _    = self.net(self.augment(x))

        queries, keys = map(flatten, (queries, keys))
        # 计算loss
        loss = nt_xent_loss(queries, keys, temperature = self.temperature) 
        return loss

loss

def nt_xent_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device

    n = b * 2  # 同一图片内部不同patch也是负样本
    projs = torch.cat((queries, keys))
    logits = projs @ projs.t()

    mask = torch.eye(n, device=device).bool()
    logits = logits[~mask].reshape(n, n - 1)  # 同一图片内部不同patch也是负样本,除了自己和自己
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
    loss = F.cross_entropy(logits, labels, reduction = 'sum')
    loss /= n
    return loss

Guess you like

Origin blog.csdn.net/weixin_54338498/article/details/131986565