[MetaLearning] Basic usage of Pytorch’s meta-learning library higher
Article directory
1. Basic introduction
higher.innerloop_ctx
is the context manager of the higher
library, which is used to create the context of the inner loop. The inner loop is usually used in meta-learning scenarios, where the inner loop of model parameter updates perform some additional operations.
This context manager mainly has five parameters: (For details, please refer toOfficial Library Description)
higher.innerloop_ctx(model, opt, device=None, copy_initial_weights=True, override=None, track_higher_grads=True)
- The first parameter
model
is the model that needs to be inner looped, usually your metamodel- The second parameter
opt
is the optimizer, which is the optimizer you use to update the model parameters- The third parameter
copy_initial_weights
is a Boolean value that specifies whether to copy the initial weights before each inner loop, if set toTrue
Indicates that the model's initial weights are copied before each inner loop to ensure that each inner loop starts with the same initial weights. If set toFalse
, all inner loops share the same weight model.- The fourth parameter
override
is a dictionary, e.g.override={'lr':lr_tensor, "momentum': momentum_tensor}
, used to specify parameters that override the optimizer during the inner loop, such as in the example here wherelr_tensor
andmomentum_tensor
are tensors specifying the learning rate and momentum covered during the inner loop.- The fifth parameter
track_higher_grads
is a Boolean value used to track the higher order gradient, if it isTrue
, it is calculated in the inner loop Gradient will be tracked to support higher-order gradient calculations. If set toFalse
, high-order gradients will not be tracked.
In thewith
statement block, obtain the context of the inner loop through(fmodel, diffopt)
. fmodel
represents the model in the inner loop, and diffopt
represents the optimizer in the inner loop. In this context, you can perform the calculations and parameter updates of the inner loop.
The following is a basic usage example to demonstrate how to usehigher.innerloop_ctx
. Using thehigher
library requires getting used to the following changes
From common usagepytorch
usage
model = MyModel()
opt = torch.optim.Adam(model.parameters())
for xs, ys in data:
opt.zero_grad()
logits = model(xs)
loss = loss_function(logits, ys)
loss.backward()
opt.step()
transition to
model = MyModel()
opt = torch.optim.Adam(model.parameters())
with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
for xs, ys in data:
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = loss_function(logits, ys) # no need to call loss.backwards()
diffopt.step(loss) # note that `step` must take `loss` as an argument!,这一步相当于使用了loss.backward()和opt.step()
# At the end of your inner loop you can obtain these e.g. ...
grad_of_grads = torch.autograd.grad(
meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
The difference between training the model and executing diffopt.step
to update fmodel
is that fmodel
will not behave as in the original part opt.step()
to update parameters in place. Instead, each call to diffopt.step
will create a new version of the parameters in such a way that fmodel
will use the new parameters in the next step, but all the previous parameters will remain.
What is the principle of operation of ? For example, fmodel
starts iterating from fmodel.parameters(time=0)
(here time=0
means the 0th iteration), when we call After a> to< /span>. This result is the same as before the iteration. Why is this? , and we can still access can get to access, among which diffopt.step
N times, we can use fmodel.parameters(time=i)
i
1
N
fmodel.parameters(time=0)
Because the creation of fmodel
depends on the parameter copy_initial_weights
, if copy_initial_weights=True
, then fmodel.parameters(time=0)
It is clone
'd from the original model and detach
'ed (that is, it is cloned from the original model and the calculation graph is separated), if'ed. 'd and not copy_initial_weights=False
, then it is just clone
detach
Here is a paragraph of the original text for everyone to understand
I.e. fmodel
starts with only fmodel.parameters(time=0)
available, but after you called diffopt.step
N times you can ask fmodel
to give you fmodel.parameters(time=i)
for any i up to N inclusive. Notice that fmodel.parameters(time=0)
doesn’t change in this process at all, just every time fmodel
is applied to some input it will use the latest version of parameters it currently has.
Now, what exactly is fmodel.parameters(time=0)
? It is created here and depends on copy_initial_weights
. If copy_initial_weights==True
then fmodel.parameters(time=0)
are clone
’d and detach
’ed parameters of model
. Otherwise they are only clone
’d, but not detach
’ed!
That means that when we do meta-optimization step, the original model
’s parameters will actually accumulate gradients if and only if copy_initial_weights==False
. And in MAML we want to optimize model
’s starting weights so we actually do need to get gradients from meta-optimization step.
2. Toy Example
import torch
import torch.nn as nn
import torch.optim as optim
import higher
import numpy as np
np.random.seed(1)
torch.manual_seed(3)
N = 100
actual_multiplier = 3.5
meta_lr = 0.00001
loops = 5 # how many iterations in the inner loop we want to do
x = torch.tensor(np.random.random((N,1)), dtype=torch.float64) # features for inner training loop
y = x * actual_multiplier # target for inner training loop
model = nn.Linear(1, 1, bias=False).double() # simplest possible model - multiple input x by weight w without bias
meta_opt = optim.SGD(model.parameters(), lr=meta_lr, momentum=0.)
def run_inner_loop_once(model, verbose, copy_initial_weights):
lr_tensor = torch.tensor([0.3], requires_grad=True)
momentum_tensor = torch.tensor([0.5], requires_grad=True)
opt = optim.SGD(model.parameters(), lr=0.3, momentum=0.5)
with higher.innerloop_ctx(model, opt, copy_initial_weights=copy_initial_weights, override={
'lr': lr_tensor, 'momentum': momentum_tensor}) as (fmodel, diffopt):
for j in range(loops):
if verbose:
print('Starting inner loop step j=={0}'.format(j))
print(' Representation of fmodel.parameters(time={0}): {1}'.format(j, str(list(fmodel.parameters(time=j)))))
print(' Notice that fmodel.parameters() is same as fmodel.parameters(time={0}): {1}'.format(j, (list(fmodel.parameters())[0] is list(fmodel.parameters(time=j))[0])))
out = fmodel(x)
if verbose:
print(' Notice how `out` is `x` multiplied by the latest version of weight: {0:.4} * {1:.4} == {2:.4}'.format(x[0,0].item(), list(fmodel.parameters())[0].item(), out[0].item()))
loss = ((out - y)**2).mean()
diffopt.step(loss)
if verbose:
# after all inner training let's see all steps' parameter tensors
print()
print("Let's print all intermediate parameters versions after inner loop is done:")
for j in range(loops+1):
print(' For j=={0} parameter is: {1}'.format(j, str(list(fmodel.parameters(time=j)))))
print()
# let's imagine now that our meta-learning optimization is trying to check how far we got in the end from the actual_multiplier
weight_learned_after_full_inner_loop = list(fmodel.parameters())[0]
meta_loss = (weight_learned_after_full_inner_loop - actual_multiplier)**2
print(' Final meta-loss: {0}'.format(meta_loss.item()))
meta_loss.backward() # will only propagate gradient to original model parameter's `grad` if copy_initial_weight=False
if verbose:
print(' Gradient of final loss we got for lr and momentum: {0} and {1}'.format(lr_tensor.grad, momentum_tensor.grad))
print(' If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smaller')
return meta_loss.item()
print('=================== Run Inner Loop First Time (copy_initial_weights=True) =================\n')
meta_loss_val1 = run_inner_loop_once(model, verbose=True, copy_initial_weights=True)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))
print('=================== Run Inner Loop Second Time (copy_initial_weights=False) =================\n')
meta_loss_val2 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)
print("\nLet's see if we got any gradient for initial model parameters: {0}\n".format(list(model.parameters())[0].grad))
print('=================== Run Inner Loop Third Time (copy_initial_weights=False) =================\n')
final_meta_gradient = list(model.parameters())[0].grad.item()
# Now let's double-check `higher` library is actually doing what it promised to do, not just giving us
# a bunch of hand-wavy statements and difficult to read code.
# We will do a simple SGD step using meta_opt changing initial weight for the training and see how meta loss changed
meta_opt.step()
meta_opt.zero_grad()
meta_step = - meta_lr * final_meta_gradient # how much meta_opt actually shifted inital weight value
# before we run inner loop third time, we update the meta parameter firstly.
meta_loss_val3 = run_inner_loop_once(model, verbose=False, copy_initial_weights=False)
meta_loss_gradient_approximation = (meta_loss_val3 - meta_loss_val2) / meta_step
print()
print('Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: {0:.4} VS {1:.4}'.format(meta_loss_gradient_approximation, final_meta_gradient))
The results are as follows
=================== Run Inner Loop First Time (copy_initial_weights=True) =================
Starting inner loop step j==0
Representation of fmodel.parameters(time=0): [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]
Notice that fmodel.parameters() is same as fmodel.parameters(time=0): True
Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.9915 == -0.4135
Starting inner loop step j==1
Representation of fmodel.parameters(time=1): [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]
Notice that fmodel.parameters() is same as fmodel.parameters(time=1): True
Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * -0.1217 == -0.05075
Starting inner loop step j==2
Representation of fmodel.parameters(time=2): [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]
Notice that fmodel.parameters() is same as fmodel.parameters(time=2): True
Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 1.015 == 0.4231
Starting inner loop step j==3
Representation of fmodel.parameters(time=3): [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]
Notice that fmodel.parameters() is same as fmodel.parameters(time=3): True
Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.064 == 0.8607
Starting inner loop step j==4
Representation of fmodel.parameters(time=4): [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]
Notice that fmodel.parameters() is same as fmodel.parameters(time=4): True
Notice how `out` is `x` multiplied by the latest version of weight: 0.417 * 2.867 == 1.196
Let's print all intermediate parameters versions after inner loop is done:
For j==0 parameter is: [tensor([[-0.9915]], dtype=torch.float64, requires_grad=True)]
For j==1 parameter is: [tensor([[-0.1217]], dtype=torch.float64, grad_fn=<AddBackward0>)]
For j==2 parameter is: [tensor([[1.0145]], dtype=torch.float64, grad_fn=<AddBackward0>)]
For j==3 parameter is: [tensor([[2.0640]], dtype=torch.float64, grad_fn=<AddBackward0>)]
For j==4 parameter is: [tensor([[2.8668]], dtype=torch.float64, grad_fn=<AddBackward0>)]
For j==5 parameter is: [tensor([[3.3908]], dtype=torch.float64, grad_fn=<AddBackward0>)]
Final meta-loss: 0.011927987982895929
Gradient of final loss we got for lr and momentum: tensor([-1.6295]) and tensor([-0.9496])
If you change number of iterations "loops" to much larger number final loss will be stable and the values above will be smaller
Let's see if we got any gradient for initial model parameters: None
=================== Run Inner Loop Second Time (copy_initial_weights=False) =================
Final meta-loss: 0.011927987982895929
Let's see if we got any gradient for initial model parameters: tensor([[-0.0053]], dtype=torch.float64)
=================== Run Inner Loop Third Time (copy_initial_weights=False) =================
Final meta-loss: 0.01192798770078706
Side-by-side meta_loss_gradient_approximation and gradient computed by `higher` lib: -0.005311 VS -0.005311
Reference
Parper: Generalized Inner Loop Meta-Learning
What does the copy_initial_weights documentation mean in the higher library for Pytorch?