Insightface人脸识别loss解读

(二)marginface的loss解读

margin face其实是arcface论文中融合了一下几个margin做了这么个实验,下面是截取的margin损失函数的代码,做了部分解读。

在这里插入图片描述

  elif args.loss_type==5:#margin face
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s#注意这里乘以s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    if args.margin_a!=1.0 or args.margin_m!=0.0 or args.margin_b!=0.0:
      if args.margin_a==1.0 and args.margin_m==0.0:
      #简单的cosface不需要求arcsin
        s_m = s*args.margin_b
        gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0)
        fc7 = fc7-gt_one_hot
      else:
		zy = mx.sym.pick(fc7, gt_label, axis=1)
		#pick 
		#[[1 2 3],[3,4,5]] 特征向量,按照label[0,1,0]进行pick,axis=1代表最后输出为一个列,也就是每一行pick一下,得到[1,4,3]
		#
		cos_t = zy/s
		#得到cos值
		t = mx.sym.arccos(cos_t)
		#反三角函数值
		if args.margin_a!=1.0:
		  t = t*args.margin_a#m1  sphereface
		if args.margin_m>0.0:
		  t = t+args.margin_m #m2  arcface
		body = mx.sym.cos(t)
		if args.margin_b>0.0:  #m3  cosface
		  body = body - args.margin_b
		new_zy = body*s
		diff = new_zy - zy
		#先剪一下fc7后面会加回来。这里只有label对应的变化。方便后面加回来。
		diff = mx.sym.expand_dims(diff, 1)
		#扩展维度,原来是[1,3,5]变为[[1],[3],[5]]
		gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
		#one hot  原[2,1,0]
		#[[0,0,1],[0,1,0],[1,0,0]]
		body = mx.sym.broadcast_mul(gt_one_hot, diff)
		#gt_one_hot[[0,0,1],[0,1,0],[1,0,0]]
		#diff[[1],[3],[5]]
		#结果:[[0,0,1],[0,3,0],[5,0,0]]
		fc7 = fc7+body
		#加上原fc7

(三)ArcFace解读

这里单独他有个arcface比较不好懂。简单说他在考虑cos函数是否单调的问题,我们给角度加一个margin,加完了得保持单调。他做了个margin
1、0<t+m<pi也就是cos-m<cost<cospi-m;完事他就搞了个threshold阈值,搞了个阈值,然后下面如果是简单的margin就不做了,如果是复杂的,我们就做一下处理,让cos_t - threshold会得到正负,后面会用他判断单调否。
2、接下来正常计算cost+m = costcost - sinmsinm
3、记下来到new_zy = mx.sym.where(cond, new_zy, zy_keep)关键语句;如果cond为真,就zy_keep否则new_zy
4、这个zy_keep不太懂,也不是cosface;??

elif args.loss_type==4:# arc face有一些小技巧...不太好懂
    s = args.margin_s
    m = args.margin_m
    assert s>0.0
    assert m>=0.0
    assert m<(math.pi/2)
    _weight = mx.symbol.L2Normalization(_weight, mode='instance')
    nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
    fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
    zy = mx.sym.pick(fc7, gt_label, axis=1)
    cos_t = zy/s
    cos_m = math.cos(m)
    sin_m = math.sin(m)
    mm = math.sin(math.pi-m)*m
    #threshold = 0.0
    threshold = math.cos(math.pi-m)
    #搞了个阈值,然后下面如果是简单的margin就不做了,如果是复杂的,我们就做一下处理,让cos_t - threshold会得到正负,后面会用他判断单调否
    if args.easy_margin:
      cond = mx.symbol.Activation(data=cos_t, act_type='relu')
    else:
      cond_v = cos_t - threshold
      cond = mx.symbol.Activation(data=cond_v, act_type='relu')
    #正常计算cost+m = costcost - sinmsinm
    body = cos_t*cos_t
    body = 1.0-body
    sin_t = mx.sym.sqrt(body)
    new_zy = cos_t*cos_m
    b = sin_t*sin_m
    new_zy = new_zy - b
    new_zy = new_zy*s
    if args.easy_margin:
      zy_keep = zy
    else:
      zy_keep = zy - s*mm
    #关键语句在下面,如果cond为真,就zy_keep否则new_zy
    new_zy = mx.sym.where(cond, new_zy, zy_keep)

    diff = new_zy - zy
    diff = mx.sym.expand_dims(diff, 1)
    gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    body = mx.sym.broadcast_mul(gt_one_hot, diff)
    fc7 = fc7+body
发布了140 篇原创文章 · 获赞 26 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/CLOUD_J/article/details/99717508