版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
一、Loss
1、Triplet Loss
代码转载自:丨github丨
from __future__ import absolute_import
import torch
from torch import nn
from torch.autograd import Variable
class TripletLoss(nn.Module):
def __init__(self, margin=0):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
n = inputs.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t()) # 解释:https://blog.csdn.net/qq_36556893/article/details/90638449
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # reture bool value: [True, True, False, False]
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().view(1))
dist_an.append(dist[i][mask[i] == 0].min().view(1))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = dist_an.data.new()
y.resize_as_(dist_an.data)
y.fill_(1)
y = Variable(y)
loss = self.ranking_loss(dist_an, dist_ap, y)
prec = (dist_an.data > dist_ap).data.float().mean()
return loss, prec
######## 用下面代码去理解上面的原理 #########
import torch
t1 = Variable(torch.FloatTensor(torch.rand([4,128])))
t1 = t1.cuda()
n = t1.size(0)
dist = torch.pow(t1, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(0, 1, t1, t2.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
print('dist:\t',dist)
targets = Variable(torch.FloatTensor([1,1,3,4]))
n = targets.size(0)
mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # reture bool value
print('mask:\t',mask)
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().view(1))
dist_an.append(dist[i][mask[i] == 0].min().view(1))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
print('dist_an:\t',dist_an)
# print('print:\t',dist[0][mask[0]].max().view(1))
y = dist_an.data.new()
print(y)
y.resize_as_(dist_an.data)
print(y)
y.fill_(1)
print(y)