Model-Agnostic Meta-Learning (MAML) model introduction and detailed algorithm

MAML is already a very important model in academia. The paper Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks has received 400+ citations since its publication in 2017. Since there are very few Chinese introductions about MAML on the Internet, many friends may not have a special understanding of it. So today I compiled my learning experience during this period of time, and shared my knowledge and understanding of MAML with you. MAML can be used for Supervised Regression and Classification and Reinforcement Learning. Since I don't know much about reinforcement learning, this article is based on the application of MAML in Supervised Regression and Classification.

 

1. Introduction of some related concepts

In the original paper, the author directly quoted many concepts related to meta-learning , such as  meta-learning, model-agnostic, N-way K-shot, tasks, etc. Some of these concepts have special meanings in MAML. Here, I try my best to introduce these concepts in an easy-to-understand way for everyone.

(1) meta-learning

Meta-learning means meta-learning, and it can also be called " learning to learn ". A common deep learning model aims to learn a mathematical model for prediction. The meta-learning is not the result of learning, but the process of learning. It is not learning a mathematical model directly used for prediction, but learning "how to learn a mathematical model faster and better."

Give a real life example. When we teach children to read English, we can directly let them imitate the pronunciation of apple and banana. But they will soon encounter new words, such as strawberry. This is because children need to listen to your pronunciation again to read this new word correctly. Let's change the way. This time we don't teach the pronunciation of every word, but the pronunciation of phonetic symbols. From then on, the children met a new word again, and they could read the word correctly only according to the phonetic transcription. The process of learning phonetic symbols is just a meta-learning process.

In deep learning, there are many meta-learning models that have been proposed, which can be roughly classified into learning good weight initializations, meta-models that generate the parameters of other models, and learning transferable optimizers. Among them, MAML belongs to the first category. MAML learns a good initialization weight to achieve fast adaptation on new tasks, that is, quickly converge on small-scale training samples and complete fine-tune.

 

(2) model-agnostic

model-agnostic That model has nothing to do . MAML is not so much a deep learning model as it is a framework that provides a meta-learner for training base-learner. The meta-learner here is the essence of MAML and is used for learning to learn; while the base-learner is the real mathematical model that is trained on the target data set and is actually used for the prediction task. Most deep learning models can be seamlessly embedded in MAML as a base-learner, and MAML can even be used in reinforcement learning. This is the meaning of model-agnostic in MAML.

 

(3) N-way K-shot

N-way K-shot is a common experimental setting in few-shot learning. Few-shot learning refers to the process of training mathematical models with very little labeled data, which is one of the problems that MAML is good at solving. N-way means that there are N categories in the training data, and K-shot means that there are K labeled data under each category.

(4) task

The noun task appears many times in MAML papers, and the training process of the model revolves around task, and the author did not give it a clear definition. To properly understand the task, we need to understand the concepts include D_{meta-train},   D_ {meta-test}Support the SETQuery the SETMeta-Train classesMeta-the Test classes and so on. Is it a bit dazzling? Don't worry, you can easily grasp these concepts with a simple example.

We assume such a scenario: we need to use MAML to train a mathematical model model M_{fine-tune}, the purpose is to classify images with unknown labels, the categories include   P_ {1} \ yes P_ {5}(5 labeled samples for each category for training. In addition, each category has 15 labeled samples for use For testing). In addition P_ {1} \ yes P_ {5}to the labeled samples in our training data , it also includes another 10 categories of images  C_{1}\sim C_{10}(30 labeled samples in each category) M_{meta}to help train the meta-learning model. Our experiment is set to 5-way 5-shot .

The specific training process will be introduced in the next section MAML algorithm detailed explanation. Here we only need to have a general understanding: MAML first uses  C_{1}\sim C_{10} the data set to train the meta-model M_{meta}, and then P_ {1} \ yes P_ {5}fine-tune on the data set to get the final model  M_{fine-tune}.

At this time, C_{1}\sim C_{10} i.e., Meta-classes Train , C_{1}\sim C_{10} comprising a total of 300 samples, i.e.,  D_{meta-train} for the training  M_{meta} dataset. Contrast P_ {1} \ yes P_ {5}, i.e. Meta-Test classes , P_ {1} \ yes P_ {5} comprising a total of 100 samples, i.e.  D_ {meta-test} , for training and test  M_{fine-tune} data sets.

The experimental set-5-way 5-shot, we are training  M_{meta} stage, from  C_{1}\sim C_{10} the randomly selected five categories, each of 20 randomly selected and then the labeled samples to form a Task \ huge \ tau  . 5 wherein the labeled samples called  \ huge \ tau  the Support SET , an additional 15 samples is called  \ huge \ tau  a Query SET . This task  \ huge \ tau is equivalent to a piece of training data in the training process of an ordinary deep learning model. Then we must form a batch to do stochastic gradient descent SGD, right? So we repeatedly sample several such tasks   from the training data distribution to \ huge \ tauform a batch. In the training  M_{fine-tune} phase, the meanings of task , support set , and query setM_{meta}  are the same as those in the training   phase.

2. Detailed explanation of MAML algorithm

The algorithm flow given by the author in the paper is as follows:

The algorithm is essentially an algorithm in the pre-training phase of MAML, and the purpose is to obtain a model  M_{meta} . Don't be scared by these mathematical symbols, the idea of ​​this algorithm is actually very simple. Next, let's analyze this algorithm line by line.

First look at the two Requires .

The first Require refers to the distribution D_{meta-train}  of tasks in the  middle . Combining the example we gave in the previous section, here is to randomly select tasks   repeatedly to \ huge \ tauform a task pool \ huge \ tau composed of several (eg, 1000)  as the training set of MAML. Some friends may be wondering. There are only so many training samples. If so many tasks are combined to form so many tasks, isn't there a duplication of samples between different tasks? Or the query set of some tasks will become the support set of other tasks ? That's right! That's it! We must remember that the purpose of MAML is fast adaptation , that is, through learning a large number of tasks, to obtain a sufficiently strong generalization ability, so that when facing new and never-seen tasks, you can quickly use fine-tune Fitting. As long as there are certain differences between tasks. Again, MAML training is based on tasks , and each task here is equivalent to a piece of training data in the training process of a normal deep learning model.

The second Require is easy to understand. Step size is actually the learning rate. Those who have read MAML papers will definitely have an impression of the word gradient by gradient . MAML is based on double gradients. Each iteration includes two parameter update processes, so there are two learning rates that can be adjusted.

Next, is the exciting algorithm flow.

Step 1. Initialize the parameters of the model randomly. There is nothing to say, there is this step before any model training.

Step 2, is a loop, which can be understood as an iterative process or an epoch. Of course, the pre-training process can have multiple epochs.

Step 3 is equivalent to the DataLoader in pytorch, that is, randomly sampling several (eg, 4) tasks to form a batch.

Steps 4 to 7 are the first gradient update process. Note that here we can understand that we have copied an original model, calculated new parameters, and used them in the second round of gradient calculation. We said that MAML is gradient by gradient, and there are two gradient update processes. In steps 4-7, using each task in the batch, we update the parameters of the model separately (4 tasks are updated 4 times). Note that this process can be repeated multiple times in the algorithm. The pseudo code does not reflect this level of looping, but the author's reanalysis part clearly mentions "using multiple gradient updates is a straightforward extension".

Step 5 : Use the support set in a task in the batch to calculate the gradient of each parameter. Under the setting of N-way K-shot, there should be NK in the support set . The author writes with respect to K examples in the algorithm, and calculates K samples under each class by default. In fact, there are NK samples in total involved in the calculation. The loss calculation method here is MSE in regression problems and cross-entropy in classification problems .

Step 6, the first update of the gradient.

Step 4 to step 7, after the end, MAML completed the first gradient update. The next thing we need to do is to calculate the second gradient update through gradient by gradient based on the parameters obtained from the first gradient update. The gradient calculated in the second gradient update is directly applied to the original model through SGD, which is the gradient that our model is really used to update its parameters.

Step 8 corresponds to the process of the second gradient update. The loss calculation method here is roughly the same as step 5, but there are two differences. One is that we no longer use the loss of each task to update the gradient, but like the common model training process, calculate the total loss of a batch, and perform random gradient descent SGD on the gradient. The other is the sample involved in the calculation here, which is the query set in the task . In our example, 5-way*15=75 samples. The purpose is to enhance the generalization ability of the model on the task and avoid overfitting. support set . After step 8 is over, the model ends the training in the batch, starts to return to step 3, and continues to sample the next batch.

Is M_{meta} the whole process of the above real-time MAML pre-training  very simple? In fact, MAML has quickly become popular in the field of meta-learning because of its simple ideas and amazing performance. Next, it should be the face of new Task , in  M_{meta} on the basis of fine-tuning to get  M_{fine-tune}  the method. There is no introduction to the fine-tune process in the original text . Here I will give a brief introduction to the friends.

The fine-tune process is roughly the same as the pre-training process. The main differences are the following:

  • In step 1, fine-tune does not need to initialize parameters randomly, but uses the trained  M_{meta}  initialization parameters.
  • In step 3, fine-tune only needs to extract one task for learning, and naturally there is no need to form a batch. Fine-tune uses the support set of this task to train the model and the query set to test the model. In actual operation, we will  D_ {meta-test} randomly select many tasks (eg, 500) from the above, fine-tune the model separately   M_{meta}, and average the final test results to avoid extreme situations.
  • There is no step 8 for fine-tune, because the query set of the task is used to test the model, and the label is unknown to the model. Therefore, the fine-tune process does not have the second gradient update, but directly uses the results of the first gradient calculation to update the parameters.

 

 

 

 

 

 

 

 

 

 

 

 

 

Guess you like

Origin blog.csdn.net/devil_son1234/article/details/115384884