ps:このステップでは正と負のサンプルの数に大きなギャップがあるためです。1,500以上の陽性サンプルと750,000以上の陰性サンプルがあります。FocalLossを使用してこの問題を解決します。
まず、理論を要約してくれたCode_Martのブログhttps://blog.csdn.net/Code_Mart/article/details/89736187に感謝します。そして、フォーカルロスの2つのカテゴリーとマルチカテゴリーのコードを実現し、説明しました。同時に、xwmwanjy666との彼の議論はいくつかの詳細を明らかにしました。
しかし、コードがpytorch 0.4.1バージョンに準拠していないように感じ、それらの間の会話で見つかりましたhttps://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/focal_loss.py変更しました私の考えによると少し。
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, gamma=0,alpha=1):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.ce = nn.CrossEntropyLoss()
self.alpha=alpha
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
loss = self.alpha*loss
return loss.mean()
それはとても簡単です、これは私が望むものですが、それから私は彼らの議論を読み続けてそれを比較し、それが行われていないことを発見しました(グラウンドトゥルースは1、alpha = a;グラウンドトゥルースが0の場合、alpha = 1-a) 。それは私を不快にさせました、そしてそれから私はgithubで検索するのに4時間以上費やしました、そして基本的にこの問題を考慮しなかったか、コードが非常に複雑であった(理解できない、入力が私の要件を満たさなかった)ことを発見しました分類には適用されません問題については、使用される関数pytorchのバージョンは通常0.4未満です。
https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.pyが見つかるまで
import torch
import torch.nn.functional as F
class BCEFocalLoss(torch.nn.Module):
def __init__(self, gamma=2, alpha=None, reduction='elementwise_mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, _input, target):
pt = torch.sigmoid(_input)
loss = - (1 - pt) ** self.gamma * target * torch.log(pt) - \
pt ** self.gamma * (1 - target) * torch.log(1 - pt)
if self.alpha:
loss = loss * self.alpha
if self.reduction == 'elementwise_mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
損失=-(1-pt)** self.gamma * target * torch.log(pt)-pt ** self.gamma *(1-target)* torch.log(1-pt)。これにより、上記の問題も解決できます。著者が対応するものを変更しなかったことを私は完全には理解していません(グラウンドトゥルースは1、alpha = a、グラウンドトゥルースが0の場合、alpha = 1-a)。ただし、このアイデアを参照用に使用すると、最初のコードは次のように変更されます。
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, gamma=2,alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.ce = nn.CrossEntropyLoss()
self.alpha=alpha
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = self.alpha*(1 - p) ** self.gamma * logp * target.long() + \
(1-self.alpha)*(p) ** self.gamma * logp * (1-target.long())
return loss.mean()
このように書くことに問題があるかどうかはわかりません。
追加されたps:2019年6月14日:
Focal Loss分類問題pytorchを使用して、コードのテストコードテストを実装します。このアイデアを利用して、最初のコード(上記の最後のコード)のコードを変更します。
中間結果は上記の損失です:
input=torch.Tensor([[ 0.0543, 0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557, 1.4724]])
target= torch.Tensor([1,0,1,1])
tensor([[0.3752, 0.6248],
[0.8547, 0.1453],
[0.3451, 0.6549],
[0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0344)
tensor(0.2966)
target= torch.Tensor([0,1,0,0])
tensor([[0.3752, 0.6248],
[0.8547, 0.1453],
[0.3451, 0.6549],
[0.1270, 0.8730]])
tensor(0.5403)
tensor(0.0987)
tensor(1.5092)
上記の結果から、効果が良くないことがわかります。1番目のラベルは予測確率に対応しているため損失は小さく、2番目のラベルは予測確率と反対であるため損失が大きいはずです。傾向は一貫していますが。しかし、相対的な倍数はそれぞれ約70倍、3倍、5倍です。組み込みの損失関数を使用することもできます。したがって、最後に、Focal Loss分類問題pytorchを使用して、コード継続3の結論コードを実現することを選択します。
import torch
import torch.nn as nn
#二分类
class FocalLoss(nn.Module):
def __init__(self, gamma=2,alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha=alpha
def forward(self, input, target):
# input:size is M*2. M is the batch number
# target:size is M.
pt=torch.softmax(input,dim=1)
p=pt[:,1]
loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
(1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
return loss.mean()
ps2020.11.12アップデート
以前の2クラスモデルで使用されたフォーカルロス:
class FocalLossV1(nn.Module):
def __init__(self,
alpha=0.25,
gamma=2,
reduction='mean',):
super(FocalLossV1, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.crit = nn.BCEWithLogitsLoss(reduction='none')
self.celoss = torch.nn.CrossEntropyLoss(reduction='none')
def forward(self, logits, label):
'''
args:
logits: tensor of shape (N, ...)
label: tensor of shape(N, ...)
'''
# compute loss
logits = logits.float() # use fp32 if logits is fp16
with torch.no_grad():
alpha = torch.empty_like(logits).fill_(1 - self.alpha)
alpha[label == 1] = self.alpha
ce_loss=(-(label * torch.log(logits)) - (
(1 - label) * torch.log(1 - logits)))
# ce_loss=(-(label * torch.log(torch.softmax(logits, dim=1))) - (
# (1 - label) * torch.log(1 - torch.softmax(logits, dim=1))))
pt = torch.where(label == 1, logits, 1 - logits)
# ce_loss = self.crit(logits, label)
loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss)
if self.reduction == 'mean':
loss = loss.mean()
if self.reduction == 'sum':
loss = loss.sum()
return loss