1 principle
Comparative study
pass 比较不同实例之间的相似性和差异性来进行学习
. In contrastive learning, we classify the input data into different categories or groups ( 正负样本对
) and extract features or perform classification by comparing the differences between samples.
Sample similarity
There are several different approaches to contrastive learning, the most common of which is 基于距离度量
the . These methods use a distance function to measure the similarity between two instances, such as 欧氏距离
or 余弦相似度
. By computing the distance between instances, we can find the most or least similar instances for feature selection, similarity matching or classification tasks. (The closer the positive sample pair similarity is, the better, and the farther the negative sample pair similarity is, the better)
SimCLR - Contrastive learning provides feature extraction capabilities through metric learning
- Take an input image : perform two data enhancements on the same image to form a positive sample pair; different images are negative sample pairs.
- Prepare 2 random image enhancements : rotation, color/saturation/brightness changes, scaling, cropping, etc. The range of augmentations is discussed in detail and it is analyzed which augmentations work best. ( Construct positive samples: image SimCLR-data enhancement, text SimCSE-Dropout, graphic-text CLIP-image-text pair )
- Feature extraction : run a deep neural network (preferably a convolutional neural network such as ViT, Bert, ResNet50) to obtain those enhanced images
图像特征表示(嵌入)
. - Feature Projection : Run a small fully connected linear neural network to project the embeddings into another vector space.
- Compute loss : Computes the contrastive loss and backpropagates through the two networks. Contrast loss is reduced when projections from the same image are similar. The similarity between projections can be arbitrary, here I use cosine similarity, as in the paper.
- Downstream tasks : Contrastive learning
得到Encoder
is used as a feature extractor to fine-tune Finetuin according to the dataset of downstream tasks.
There is a lot of data, and the batch is large (batchsize=8192).
The construction of positive and negative sample pairs does not need to be labeled
How to design the loss loss function?
l i , j = − l o g e x p ( s i m ( z i , z j ) / t ) ∑ k = 1 2 N 1 [ k ! = i ] e x p ( s i m ( z i , z k ) / t l_{i,j}=-log{\frac{exp(sim(z_i,z_j)/t)}{\sum_{k=1}^{2N}1_{[k!=i]}exp(sim(z_i,z_k)/t}} li,j=−log∑k=12 N1[k!=i]exp(sim(zi,zk)/texp(sim(zi,zj)/t)
Among them, the numerator is the similarity between the same class (the distance between positive samples), and the denominator is the similarity between different classes (the distance between negative sample pairs). ttt is the temperature (scale<1) parameter, which is used to adjust the ratio.
2 codes
DALLE2-pytorch takes CLIP as an example to learn the process of contrastive learning, loss: = text (MLM, Mask Language Model), image (SimCLR contrastive loss), graphic (image-text pair contrastive loss)
class SimCLR(nn.Module):
def __init__(
self,
net,
image_size,
channels = 3,
hidden_layer = -2,
project_hidden = True,
project_dim = 128,
augment_both = True,
use_nt_xent_loss = False,
augment_fn = None,
temperature = 0.1
):
super().__init__()
self.net = NetWrapper(net, project_dim, layer = hidden_layer)
self.augment = default(augment_fn, get_default_aug(image_size, channels))
self.augment_both = augment_both
self.temperature = temperature
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate parameters
self.forward(torch.randn(1, channels, image_size, image_size))
def forward(self, x):
b, c, h, w, device = *x.shape, x.device
transform_fn = self.augment if self.augment_both else noop
# 把原图使用不同数据增强和ViT提取成两个不同的图像特征(正样本对queries、keys)
queries, _ = self.net(transform_fn(x))
keys, _ = self.net(self.augment(x))
queries, keys = map(flatten, (queries, keys))
# 计算loss
loss = nt_xent_loss(queries, keys, temperature = self.temperature)
return loss
loss
def nt_xent_loss(queries, keys, temperature = 0.1):
b, device = queries.shape[0], queries.device
n = b * 2 # 同一图片内部不同patch也是负样本
projs = torch.cat((queries, keys))
logits = projs @ projs.t()
mask = torch.eye(n, device=device).bool()
logits = logits[~mask].reshape(n, n - 1) # 同一图片内部不同patch也是负样本,除了自己和自己
logits /= temperature
labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
loss = F.cross_entropy(logits, labels, reduction = 'sum')
loss /= n
return loss