元学习-Prototypical Network的解析

        元学习顾名思义是使网络具备自主学习能力,能像人一样具备学习能力,通过看到少量的样本就可以去区分识别更多的样本,对于现在很多较难获取样本的场景具有优越性。

        元学习网络的训练与评估使用过程,首先,元学习网络现在较大的数据集如minimagesnet数据集上进行5way5shot和5way1shot训练,获得模型,加载训练好的模型再在新的类别中进行测试评估。个人理解跟迁移学习十分相似。同时在训练过程中学习率随着损失的变化而变化,当学习率不发生变化是则停止训练。

        具体的来说元学习就是由feature encoder和其距离计算组成。

        其中prototypical network是其中较为简单的元学习网络,主要由encoder(Protonet)如下所示。

Protonet(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): Flatten()
  )
)

distance metric(距离计算)组成。

def euclidean(x, y):
    '''
    Compute euclidean distance between two tensors
    '''
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    if d != y.size(1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)

网络的损失计算

class PrototypicalLoss(Module):
    '''
    Loss class deriving from Module for the prototypical loss function defined below
    '''
    def __init__(self, n_support, dist_func, reg):
        super(PrototypicalLoss, self).__init__()
        self.n_support = n_support
        if dist_func == "cosine":
            self.dist_func = cosine
        elif dist_func == "euclidean":
            self.dist_func = euclidean
        else:
            self.dist_func = None
        self.reg = reg

    def forward(self, input, target, weights):
        return prototypical_loss(input, target, self.n_support, weights=weights, dist_func=self.dist_func, lambda_reg=self.reg)


def prototypical_loss(input, target, n_support, weights, dist_func=euclidean, lambda_reg=0.05):
    '''
    Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py

    Compute the barycentres by averaging the features of n_support
    samples for each class in target, computes then the distances from each
    samples' features to each one of the barycentres, computes the
    log_probability for each n_query samples for each one of the current
    classes, of appertaining to a class c, loss and accuracy are then computed
    and returned
    Args:
    - input: the model output for a batch of samples
    - target: ground truth for the above batch of samples
    - n_support: number of samples to keep in account when computing
      barycentres, for each one of the current classes
    '''
    target_cpu = target.to('cpu')
    input_cpu = input.to('cpu')

    def supp_idxs(c):
        # FIXME when torch will support where as np
        return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)

    # FIXME when torch.unique will be available on cuda too
    classes = torch.unique(target_cpu)  # non-repeated classes (i.e. types of ground truth)
    n_classes = len(classes)
    # FIXME when torch will support where as np
    # assuming n_query, n_target constants
    n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support

    support_idxs = list(map(supp_idxs, classes))

    prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs])  # 每一个class的类似centroid?
    # FIXME when torch will support where as np
    query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[n_support:], classes))).view(-1)

    query_samples = input.to('cpu')[query_idxs]
    dists = dist_func(query_samples, prototypes)

    log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)

    target_inds = torch.arange(0, n_classes)
    target_inds = target_inds.view(n_classes, 1, 1)
    target_inds = target_inds.expand(n_classes, n_query, 1).long()

    # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
    # --------------------------
    reg = 0
    for param in weights:
        param = param.to('cpu')
        reg += torch.sum(0.5*(param**2))  # L2 regularization
        # reg += torch.sum(torch.abs(param))  # L1 regularization
        
    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() + lambda_reg*reg
    # --------------------------

    _, y_hat = log_p_y.max(2)
    acc_val = y_hat.eq(target_inds.squeeze()).float().mean()

    return loss_val,  acc_val

后记

元学习中matching network使用的是cosine distance。DeepEMD使用的是EMD距离。有所不同,理论上DeepEMD的效果更好。有时间继续更新............

猜你喜欢

转载自blog.csdn.net/weixin_45994963/article/details/127675687