Manual implementation of focal loss


from torch import nn
import torch
from torch.nn import functional as F

class focal_loss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes=5, size_average=True):

        super(focal_loss, self).__init__()
        self.size_average = size_average
        if isinstance(alpha, (float, int)):    #仅仅设置第一类别的权重
            self.alpha = torch.zeros(num_classes)
            self.alpha[0] += alpha
            self.alpha[1:] += (1 - alpha)
        if isinstance(alpha, list):  #全部权重自己设置
            self.alpha = torch.Tensor(alpha)
        self.gamma = gamma


    def forward(self, inputs, targets):
        alpha = self.alpha
        print('aaaaaaa',alpha)
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs,dim=1)
        print('ppppppppppppppppppppp', P)
        # ---------one hot start--------------#
        class_mask = inputs.data.new(N, C).fill_(0)  # 生成和input一样shape的tensor
        print('依照input shape制作:class_mask\n', class_mask)
        class_mask = class_mask.requires_grad_()  # 需要更新, 所以加入梯度计算
        ids = targets.view(-1, 1)  # 取得目标的索引
        print('取得targets的索引\n', ids)
        class_mask.data.scatter_(1, ids.data, 1.)  # 利用scatter将索引丢给mask
        print('targets的one_hot形式\n', class_mask)  # one-hot target生成
        # ---------one hot end-------------------#
        probs = (P * class_mask).sum(1).view(-1, 1)
        print('留下targets的概率(1的部分),0的部分消除\n', probs)
        # 将softmax * one_hot 格式,0的部分被消除 留下1的概率, shape = (5, 1), 5就是每个target的概率

        log_p = probs.log()
        print('取得对数\n', log_p)
        # 取得对数
        loss = torch.pow((1 - probs), self.gamma) * log_p
        batch_loss = -alpha *loss.t()  # 對應下面公式
        print('每一个batch的loss\n', batch_loss)
        # batch_loss就是取每一个batch的loss值

        # 最终将每一个batch的loss加总后平均
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        print('loss值为\n', loss)
        return loss

torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
print('input值为\n', input)
targets = torch.randint(5, (5, ))
print('targets值为\n', targets)

criterion = focal_loss()
loss = criterion(input, targets)
loss.backward()

insert image description here

Guess you like

Origin blog.csdn.net/PETERPARKERRR/article/details/124977828