Commonly used loss functions for face recognition:
train_arc.py
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import matplotlib.pyplot as plt
from ArcLoss import Arc
from Net import Net
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = torchvision.datasets.MNIST(root="D:\data\MNIST_data",download=True,train=True,transform= torchvision.transforms.ToTensor())
train_loader = data.DataLoader(dataset= train_data,shuffle=True,batch_size=1024)
if __name__== "__main__":
net = Net().to(device)
arc = Arc(2,10).to(device)
opt_net = torch.optim.Adam(net.parameters())
opt_arc = torch.optim.Adam(arc.parameters())
loss_nll = nn.NLLLoss()
plt.ion()
epoch = 0
while True:
for i, (data,target) in enumerate(train_loader):
data,target = data.to(device), target.to(device)
layer = net(data)
out = arc(layer)
plt.clf()
c = ["#ff0000","#ffff00","#00ff00","#00ffff","#0000ff","#ff00ff","#990000","#999900","#009900","#009999"]
for j in range(10):
plt.plot(layer[target ==j,0].detach().cpu().numpy(),
layer[target ==j,1].detach().cpu().numpy(),".",c=c[j])
plt.legend(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],loc="upper right")
loss = loss_nll(out,target)
opt_net.zero_grad()
opt_arc.zero_grad()
loss.backward()
opt_net.step()
opt_arc.step()
print(loss.item())
plt.title("epoch=%d"%epoch)
plt.savefig("img/{0}.jpg".format(epoch))
epoch+=1
ArcLoss.py
import torch
from torch import nn
import torch.nn.functional as F
class Arc(nn.Module):
def __init__(self,feature_num,cls_num) -> None:
super().__init__()
self.w =nn.Parameter(torch.randn(feature_num,cls_num))
def forward(self,x,m=1,s=10):
x_norm = F.normalize(x,dim=1)
w_norm = F.normalize(self.w,dim=0)
#/10 为了防止梯度爆炸
cos = torch.matmul(x_norm,w_norm) /10
a = torch.arccos(cos)
top = torch.exp(s*torch.cos(a+m))
down = top + torch.sum(torch.exp(s*torch.cos(a)),dim=1,keepdim=True)-torch.exp(s*torch.cos(a))
arcsoftmax = torch.log(top/down)
return arcsoftmax
if __name__== "__main__":
arc = Arc(2,10)
data = torch.randn(1,2)
out = arc(data)
print(data)
print(out)
#和不再为1,arcloss破坏了softmax的归一话属性
print(torch.sum(out))