[Self-supervised learning collection] 1: Intensive reading of moco code

written in front

I have just started self-supervised learning, and my understanding of self-supervised learning is still at the theoretical stage. Now I want to open a hole for myself, that is, this self-supervised learning code reading collection. On the one hand, it can deepen my understanding. On the other hand, I also hope to be able to Helping beginners like me, if there is anything wrong, I hope everyone will give me advice.

1. The main idea of ​​moco

Before talking about moco, you need to know what contrastive learning is. Contrastive learning is an important branch of self-supervised learning. Self-supervised learning is to mine supervised information on its own data set, and train the model through the supervised information generated by itself. For example, cut a picture into 9 grids , label 1-9 for each grid, and then scramble the 9 grids and labels, take the scrambled picture as input, and the scrambled label as ground truth, so that it can be completely automatic. Labeled datasets generate "labels" for the model to learn from. Therefore, the first key point of self-supervised learning is how to mine the supervised information of unlabeled data sets ? Since self-supervised learning uses a supervised method to train an unsupervised model, after mining the supervised information, it is necessary to consider how to use this information. This is the second key point of self-supervised learning: how to design a reasonable proxy task to mine latent features in the data.

Contrastive learning gives a solution to the first problem: After a picture has undergone different enhancements (cropping, adding noise, etc.), these enhanced pictures are regarded as positive sample pairs.
Compare the general process of learning
Moco is one of the most classic models in comparative learning. It is a comparative learning method based on positive and negative samples. In the comparative learning algorithm based on positive and negative samples, negative examples are generally other pictures in a sample library, although this sample Libraries are not the same, but these models are trying different methods to make the positive samples close enough in the mapping space, and the negative samples far enough in the mapping space, which is the proxy task of individual discrimination . There are as many classes as there are samples. MOCO summarizes the comparative learning based on positive and negative samples as a dictionary look -up problem, and the key point of this problem is how to generate a large and consistent dictionary .
Moco's dictionary query

First explain the dictionary query problem. The data in the dictionary is composed of keys (keys) and values ​​(values). In comparative learning, pictures can be imagined as keys, and the potential features corresponding to pictures are values. When comparing, we use enhanced After a certain picture is queried in this dictionary and it comes from the same picture but has undergone different enhancements, if the match is successful, their values ​​​​should be as close as possible. The key points of this problem are divided into two parts: big and consistent In previous studies, the first is large : this dictionary is either too small (a mini-batch) or too large (the entire data set), so a trade-off between the two is required. Moco's solution is the sample queue , a fixed number of mini-batches are saved in this queue, and each time a new mini-batch is added to the sample queue, the oldest one is dequeued; then consistency: the model of comparative learning is updated all the time , so After each sample passes through the model at different moments, the obtained features are not consistent, so in different epochs, the features used for training the same sample are also different. The solution of MOCO is the momentum encoder, that is, the momentum update dictionary The characteristics of the middle sample, so that a large part of it comes from the previous round of training (MOCO's experiment proves that 99.9% of the features come from the previous round, the effect is better), plus the sample queue eliminates the oldest sample every time , the oldest sample has the most momentum updates and the highest inconsistency. This method can ensure the consistency of features in the sample queue.
moco's momentum experiment

Two, code intensive reading

2.1 Code structure

insert image description here

The code is divided into two folders and several files. The folder detection is for the downstream tasks of target detection. The folder moco is the main part of the model. The builder is under the main folder. main_moco is the self-supervised training process of moco, and main_lincls is is to train a simple linear classifier for image classification tasks.
Our reading process is:
main_moco.py->moco folder->main_cls.py (->detection folder)

2.2 main_moco.py

2.2.1 Parameter setting

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

model_names are different visual backbone network names in torch, and the meanings of different parameters are as follows:

parameter name parameter meaning
data dataset path
arch Backbone network, choose one of model_names
workers Parameters in dataloader
epochs training eopch
start-epoch The initial epoch is generally 0. If the operation of a certain epoch is interrupted, you can enable this parameter and continue training
batch-size
learning-rate model learning rate
momentum Model Momentum
weight-decay weight decay
resume The path of the latest checkpoint
mucus-dim output dimension
moco-k Sample queue size (negative sample size)
moco m Momentum of dictionary updates
moco-t softmax temperature

def main():

main first processes the different parameters passed in by argparser, and finally calls the main_worker function at the end:

main_worker(args.gpu, ngpus_per_node, args)

Here if all default, args.gpu=None, ngpus_per_node=number of gpus.

def main_worker(gpu, ngpus_per_node, args)

The function first handles multiple processes, and then constructs the model according to the backbone network selected in the parameters and some special parameters of moco.

print("=> creating model '{}'".format(args.arch))
    model = moco.builder.MoCo(
        models.__dict__[args.arch],
        args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp)
    print(model)

You can see that the model uses the cross-entropy loss function and optimizes the model using the stochastic gradient descent method.

# 模型
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
# 损失函数
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
# 优化器
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

Then there is data processing, which defines the way of data enhancement and the way of data standardization

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]                                   	 

Then, datasets and dataloader are defined to facilitate subsequent training. Each piece of data in train_dataset is a sample pair of the same picture after different enhancements.

train_dataset = datasets.ImageFolder(
        traindir,
        moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

Finally, start the training process according to the basic components defined above:

train(train_loader, model, criterion, optimizer, epoch, args)

def train(train_loader, model, criterion, optimizer, epoch, args)

Here I will clear the processes that are not related to model training, such as multi-process/time control

    for i, (images, _) in enumerate(train_loader):

        # compute output
        output, target = model(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)
        losses.update(loss.item(), images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

As shown in the code, the model has two data inputs, one corresponding to the queue of the dictionary query, and the other corresponding to the key in the dictionary that matches the queue
insert image description here

2.3 moco folder

The model structure and data output of moco in the moco folder

2.3.1 loader.py

loader.py is very simple. It implements different enhancements to the two branches of the same image according to the data enhancement method defined in main_worker in Section 2.2.1.

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

Among them, base_transform is the augmentation defined in main_worker

2.3.2 builder.py

Model initialization

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()
        
		# 基本参数
        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

What is saved here is the model of moco. First of all, it can be seen that the model of the queue and the model of the key are completely consistent in structure, but the key model does not have a gradient return, but is directly copied from the queue model, which is different from the design in the paper. be consistent.
insert image description here

sample queue

At the end of the initialization, the model also defines a sample queue , where the queue is a circular queue maintained by the queue head and the list , and the enqueue/exit operation of the queue is shown in the following function:

    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

Each time, the queue will replace a batch pointed by the queue head with a new batch, and then point the queue head to the tail of the newly added batch, which is equivalent to enqueuing the oldest batch and adding the new batch to the queue tail.

Momentum encoder

    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

param_q is the updated parameter passed back through the gradient, and param_k is the previous feature before performing the above loop operation, self.m is the momentum, the best is 0.999, that is, 99.9% of the features in the sample dictionary come from the previous features , only 0.1% comes from the current update, which ensures a high consistency of the dictionary.

model forward process

The forward process is relatively conventional. First, the image of the queue is passed through the queue encoder to obtain the features:

# compute query features
q = self.encoder_q(im_q)  # queries: NxC
q = nn.functional.normalize(q, dim=1)

Then the momentum updates the key encoder and gets the features of the key:

self._momentum_update_key_encoder()  # update the key encoder
k = self.encoder_k(im_k)  # keys: NxC
k = nn.functional.normalize(k, dim=1)

Then get the loss of the model according to the queue and key, where q and k are positive examples of each other, so the error is called l_pos, and the new sample has not yet entered the queue at this time, so q and all samples in the sample queue are negative examples of each other. The error between q and the sample queue is therefore called l_neg:

l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

Then use torch.cat to splice the positive error and negative error together. Note that l_pos comes first here, so the positive sample error is always the 0th element of each row.

# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T

Since the positive sample is in the 0th element of each row, when calculating the cross-entropy, the input label represents the position of the positive sample, so the label is all 0.

# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

Finally, update the sample queue, dequeue old samples, and enqueue new samples:

# dequeue and enqueue
self._dequeue_and_enqueue(k)

After returning the error and label, the cross-entropy loss will be calculated in the main_worker function of main_moco.py, and the parameters of the queue encoder will be updated.

return logits, labels

The above constitutes a complete pre-training process.

2.4 main_cls.py

This part is mainly to use the pre-trained model, fine-tune and test it on downstream tasks, the main process is similar to main_moco, the differences are as follows:

The model structure is different:
the pre-trained queue encoder is directly extracted in main_cls

'''
首先,构建骨干网络实例
'''
model = models.__dict__[args.arch]()

'''
然后,加载预训练模型,并保留queue编码器部分
'''
checkpoint = torch.load(args.pretrained, map_location="cpu")
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
       # retain only encoder_q up to before the embedding layer
       if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
               state_dict[k[len("module.encoder_q."):]] = state_dict[k]
      # delete renamed or unused k
       del state_dict[k]
       
'''
最后,将保留的queue编码器加载到构建的实例中
'''
msg = model.load_state_dict(state_dict, strict=False)

The usage part is different:
this part is mainly for the fine-tuning of the model, so it does not perform gradient return to all layers of the model like training, but only updates the linear layer of the model, so other layers such as the cnn layer need to be frozen:

# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
# init the fc layer
model.fc.weight.data.normal_(mean=0.0, std=0.01)
model.fc.bias.data.zero_()

Guess you like

Origin blog.csdn.net/Kevinxgl/article/details/128360609