回顾原型网络代码②计算模型的精度和损失

定义

few_shot.py中定义了函数模型。文件中有两个类和一个函数:

  • def load_protonet_conv(**kwargs):根据传进来的参数kwargs,建立模型
  • class Protonet(nn.Module):主要是计算episodelossacc
  • class Flatten(nn.Module):展平。但是torch不是有这个层吗,不明白为啥作者还要自己写呢

def load_protonet_conv(**kwargs)

在这里插入图片描述
当输入为(200,1,28,28)tensor时,模型的输出为(200,64),也就是每个28*28的图片样本转换为一个64维度的向量

使用

line123中使用了模型

engine.train(
   model = model,
   loader = train_loader,
   optim_method = getattr(optim, opt['train.optim_method']),  # Adam
   optim_config = {
    
     'lr': opt['train.learning_rate'], # 学习率0.001
                    'weight_decay': opt['train.weight_decay'] },  # 0.0
   max_epoch = opt['train.epochs']  # 10000
)

计算loss

10 w a y − 5 s h o t − 15 q u e r y 10way - 5shot - 15query 10way5shot15query为例,在line 40中调用Protonetloss计算了模型的lossacc

def loss(self, sample): 
	"""sample的格式为:
	{
		"class":长度为10的list,    
		"xs":torch.Size([10, 5, 1, 28, 28]),       
		"xq":torch.Size([10, 15, 1, 28, 28])
   }
	"""
    xs = Variable(sample['xs'])  # support  torch.Size([10, 5, 1, 28, 28])
    xq = Variable(sample['xq'])  # query  torch.Size([10, 15, 1, 28, 28])

    n_class = xs.size(0)  # 10
    assert xq.size(0) == n_class
    n_support = xs.size(1)   # 5
    n_query = xq.size(1)  # 15
    
    #  生成query的标签        torch.Size([10]) ->torch.Size([10, 1, 1])  -> torch.Size([10, 15, 1])
    target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
    
    target_inds = Variable(target_inds, requires_grad=False)
    if xq.is_cuda:
        target_inds = target_inds.cuda()
        
	# 把 support和query合并到一起进行特征提取,maybe并行化更快	
    x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),  # torch.Size([10, 5, 1, 28, 28]) -> torch.Size([50, 1, 28, 28])
                   xq.view(n_class * n_query, *xq.size()[2:])], 0)  # torch.Size([10, 15, 1, 28, 28]) -> torch.Size([150, 1, 28, 28])
    # x.Size([200, 1, 28, 28])
    
    z = self.encoder.forward(x)  # z.size([200, 64])
    z_dim = z.size(-1)
    
    # 求原型
    # torch.Size([50, 64]) -> torch.Size([10, 5, 64]) -> torch.Size([10, 64])
    z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)  
    
    # query的特征 torch.Size([150, 64])
    zq = z[n_class*n_support:]  
    
    # 计算query的特征到原型的距离,等下说这个函数
    dists = euclidean_dist(zq, z_proto)  # torch.Size([150, 10])
    
    # F.log_softmax作用:在softmax的结果上再做多一次log运算,不用softmax据说是为了防止溢出
    # 注意 -dist,下面要考
    log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)  
    #              torch.Size([150, 10]) -> torch.Size([10, 5, 10])

    # 计算损失,先在第2+1个维度上寻找对应标签的距离,例如类1的样本2标签是5,取出它距离原型1的距离,这就是这个样本的产生的loss,然后对所有样本求平均loss
    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()    
    #             torch.Size([10, 5, 1]) -> torch.Size([10, 5]) -> torch.Size([50]) -> tenser:()
    
    # 计算预测query的标签。根据距离结果,选最小的距离,但是之前-dist,所有这里就是max。
    _, y_hat = log_p_y.max(2)  # y_hat.size() ->([10, 15])
    
    # 根据预测的结果和标签是不是相等来计算精度
    # equal的结果是bool,先转换为float,才能计算平均值
    acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
    # y_hat和target_inds.squeeze()都是([10, 15]),这样eq的结果也是([10, 15]),求完平均值就是tenser:()
    
    # 返回结果
    return loss_val, {
    
    
        'loss': loss_val.item(),
        'acc': acc_val.item()
    }

计算query到原型的距离

line 48调用了line 3euclidean_dist(x, y)来计算query特征到原型特征的距离。
10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way5shot15query为例:

  • 由于10个类所以有10个原型,特征均为64维度,所以line 48z_proto、同时也是line 3的形参x t o r c h . S i z e ( [ 10 , 64 ] ) torch.Size([10, 64]) torch.Size([10,64]) t e n s o r tensor tensor
  • query集有 10 ∗ 15 = 150 10*15=150 1015=150个特征,于是line 48zq、同时也是line 3的形参x t o r c h . S i z e ( [ 150 , 64 ] ) torch.Size([150, 64]) torch.Size([150,64]) t e n s o r tensor tensor
def euclidean_dist(x, y):
    # x: N x D	 torch.Size([10, 64]
    # y: M x D	 torch.Size([150, 64]
    n = x.size(0) # 10 原型的个数
    m = y.size(0) # 150 query特征的个数
    d = x.size(1) # 64 特征的维度
    assert d == y.size(1)

	# 首先将x和y扩展到相同的size,这样才能做减法
    x = x.unsqueeze(1).expand(n, m, d)  # torch.Size([10, 64]) -> torch.Size([10,150,64])
    y = y.unsqueeze(0).expand(n, m, d)  # torch.Size([150, 64]) -> torch.Size([10,150,64])

    return torch.pow(x - y, 2).sum(2)

x - y 的Size 是([10,150,64])

关于torch.pow可以看torch.pow,就是对x-y的数值做平方

最后sum(2)代表在第2+1个维度上求和,也就是64维特征变成一个数值,返回的结果的Size 是([150,10]),代表150个query到10个原型的距离。
eg. 第0个query到第2个原型的距离就是torch.pow(x - y, 2).sum(2)[0][2]

Guess you like

Origin blog.csdn.net/qq_37252519/article/details/120874594