Arcface 解析

简单来说,

1.feature转换为512维特征,

2.再用矩阵乘法转换为分类数。

https://github.com/TreB1eN/InsightFace_Pytorch

先把InsightFace中ArcFace代码贴出来: 

class Arcface(Module):
    # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599    
    def __init__(self, embedding_size=512, classnum=51332,  s=64., m = 0.5):
        super(Arcface, self).__init__()
        self.classnum = classnum
        self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) #uniform_(-1, 1)服从均匀分布,mul_对应点相乘
        self.m = m # the margin value, default is 0.5
        self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m  # issue 1
        self.threshold = math.cos(math.pi - m)
    def forward(self, embbedings, label):
        # weights norm
        nB = len(embbedings)
        kernel_norm = l2_norm(self.kernel, axis=0)
        # cos(theta+m)
        cos_theta = torch.mm(embbedings, kernel_norm)#进行矩阵乘法
#         output = torch.mm(embbedings,kernel_norm)
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        output *= self.s # scale up in order to make softmax work, first introduced in normface
        return output


对于torch中一些函数的理解

1)对于self.kernel = Parameter(torch.Tensor(embedding_size, classnum))中,Parameter的作用:

首先可以把Parameter理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.kernel变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。(摘自:原文链接:https://blog.csdn.net/qq_36955294/article/details/88117170)

看了torch官网的解释:
"Variable的一种,常被用于模块参数(module parameter)。

Parameters 是 Variable 的子类。Paramenters和Modules一起使用的时候会有一些特殊的属性,即:当Paramenters赋值给Module的属性的时候,他会自动的被加到 Module的 参数列表中(即:会出现在 parameters() 迭代器中)。将Varibale赋值给Module属性则不会有这样的影响。 这样做的原因是:我们有时候会需要缓存一些临时的状态(state)"

这句话中,embedding_size = 512,classnum是人脸识别的ID数,先使用orch.Tensor,生成一个512×classnum的张量,然后通过Parameter将这个张量转化为可以训练的模型;

2)对于self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)  的理解:

#uniform_(from=-1, to=1) → Tensor    将tensor用从均匀分布中抽样得到的值填充。

# renorm_返回一个张量,包含规范化后的各个子张量,使得沿着2维划分的各子张量的1范数小于1e-5

# mul_用标量值1e5乘以输入input的每个元素,并返回一个新的结果张量;

以上是对pytorch中一些函数的理解;

对于arcface公式的代码实现

对于arcFace的实现实际应该是包括两部分,第一部分是cosin函数部分;第二部分就是常规的softmax部分;

在pytorch代码中,第二部分直接有函数实现,是可以直接使用的;所以重点是cosin函数部分的实现;

下面就重点讲解记录一下怎样一步步的实现第一部分代码:

1)对Feature进行了l2 norm,对参数也进行了l2 norm.所以权值参数×feature = cos theta

2)将cos theta夹逼到【-1, 1】之间,因为cos theta的定义域在【0,pi】值域实在【-1,1】之间;

3)计算cos(theta + m)使用到余弦定理;

4)计算完成后,要判断theta是否超出范围,进行数据调整,这一块的判读原理在下图:

(不知道这样理解是否有错?望大佬赐教)

判断后得出一个值为0或1的mask,通过使用cos_theta_m[cond_mask] = keep_val[cond_mask],将超出范围的值使用keep_val表示,加入[cond_mask],是将mask为1(True)位置的元素取出,进行修改;

https://github.com/foamliu/InsightFace-v2

https://github.com/foamliu/InsightFace-PyTorch

ArcMarginmodel:

这个使用线性回归来替代矩阵相乘,区别就是看看好不好收敛吧

class ArcMarginModel(nn.Module):
    def __init__(self, args):
        super(ArcMarginModel, self).__init__()

        self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = args.easy_margin
        self.m = args.margin_m
        self.s = args.margin_s

        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        self.th = math.cos(math.pi - self.m)
        self.mm = math.sin(math.pi - self.m) * self.m

    def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        cosine = F.linear(x, W)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m  # cos(theta + m)
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output


 

发布了2732 篇原创文章 · 获赞 1011 · 访问量 538万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/104790999
今日推荐