ArcFaceLoss (face detection)

Commonly used loss functions for face recognition:

train_arc.py

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import matplotlib.pyplot as plt
from ArcLoss import Arc
from Net import Net

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = torchvision.datasets.MNIST(root="D:\data\MNIST_data",download=True,train=True,transform= torchvision.transforms.ToTensor())
train_loader = data.DataLoader(dataset= train_data,shuffle=True,batch_size=1024)

if __name__== "__main__":
    net = Net().to(device)
    arc = Arc(2,10).to(device)

    opt_net = torch.optim.Adam(net.parameters())
    opt_arc = torch.optim.Adam(arc.parameters())

    loss_nll = nn.NLLLoss()
    plt.ion()
    epoch = 0
    while True:
        for i, (data,target) in enumerate(train_loader):
            data,target = data.to(device), target.to(device)
            layer = net(data)
            out = arc(layer)

            plt.clf()
            c = ["#ff0000","#ffff00","#00ff00","#00ffff","#0000ff","#ff00ff","#990000","#999900","#009900","#009999"]
            for j in range(10):
                plt.plot(layer[target ==j,0].detach().cpu().numpy(),
                         layer[target ==j,1].detach().cpu().numpy(),".",c=c[j])
            plt.legend(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],loc="upper right")
            
            loss = loss_nll(out,target)
            opt_net.zero_grad()
            opt_arc.zero_grad()
            loss.backward()
            opt_net.step()
            opt_arc.step()
        
        print(loss.item())
        plt.title("epoch=%d"%epoch)
        plt.savefig("img/{0}.jpg".format(epoch))
        epoch+=1

ArcLoss.py

import torch 
from torch import nn
import torch.nn.functional as F

class Arc(nn.Module):
    def __init__(self,feature_num,cls_num) -> None:
        super().__init__()
        self.w =nn.Parameter(torch.randn(feature_num,cls_num))
    
    def forward(self,x,m=1,s=10):
        x_norm = F.normalize(x,dim=1)
        w_norm = F.normalize(self.w,dim=0)

        #/10 为了防止梯度爆炸
        cos = torch.matmul(x_norm,w_norm) /10
        a = torch.arccos(cos)

        top = torch.exp(s*torch.cos(a+m))
        down = top + torch.sum(torch.exp(s*torch.cos(a)),dim=1,keepdim=True)-torch.exp(s*torch.cos(a))
        arcsoftmax = torch.log(top/down)
        return arcsoftmax

if __name__== "__main__":
    arc = Arc(2,10)
    data = torch.randn(1,2)
    out = arc(data)
    print(data)
    print(out)
    #和不再为1,arcloss破坏了softmax的归一话属性
    print(torch.sum(out))

Guess you like

Origin blog.csdn.net/weixin_44659309/article/details/131170365