机器学习HW15元学习


一、简介

任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。

Task: Few-shot Classification

The Omniglot dataset
在这里插入图片描述
Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
问题设置: 5-way 1-shot classification
在这里插入图片描述
Training MAML on Omniglot classification task.
在这里插入图片描述
Training / validation set:30 alphabets

  • multiple characters in one alphabet
  • 20 images for one character
    在这里插入图片描述
    Testing set:
    640 support and query pairs
  • 5 support images
  • 5 query images
    在这里插入图片描述

实验

1、simple

简单的迁移学习模型
训练:对随机选择的5个任务进行正常的分类训练
验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
在这里插入图片描述

在这里插入图片描述

2、medium

完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。

# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights = OrderedDict((name, param - inner_lr*grad)
      for ((name, param), grad) in zip(fast_weights.items(), grads)
      )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数


# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
#raise NotimplementedError

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

3、strong

使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。

# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict((name, param - inner_lr*grad)      
    for ((name, param), grad) in zip(fast_weights.items(), grads)
    )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数

在这里插入图片描述
在这里插入图片描述

4、boss

任务增强(通过元学习)-什么是合理的方法来创建新任务?
使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
在这里插入图片描述

#MetaSolver函数中修改
for meta_batch in x:
    # Get data
    if torch.rand(1).item() > 0.6:
        times = 1 if torch.rand(1).item() > 0.5 else 3
        meta_batch = torch.rot90(meta_batch, times, [-1, -2])

在这里插入图片描述
在这里插入图片描述

三、代码

模型构建准备工作

由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。

def functional_forward(self, x, params):
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

创建labels for 5-way 2-shot

def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()


# Try to create labels for 5-way 2-shot setting
create_label(5, 2)

计算精度

def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。

def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

元学习

def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        if torch.rand(1).item() > 0.6:
            times = 1 if torch.rand(1).item() > 0.5 else 3
            meta_batch = torch.rot90(meta_batch, times, [-1, -2])#  B = rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
            """ Inner Loop Update """
            # TODO: Finish the inner loop update rule
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights = OrderedDict((name, param - inner_lr*grad)
                                        for ((name, param), grad) in zip(fast_weights.items(), grads)
                                        )
            
            #raise NotImplementedError
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)

            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # TODO: Finish the outer loop update
        meta_batch_loss.backward()
        optimizer.step()
        #raise NotimplementedError

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

猜你喜欢

转载自blog.csdn.net/Raphael9900/article/details/128646394