[ディープラーニング] 分類関連損失解析
記事ディレクトリ
1 はじめに
分類タスクでは、通常、さまざまな損失関数を使用して、モデルの出力と真のラベルの差を測定します。何を使用すればよいかわからない場合もありますが、分類に関連する一般的な損失関数、その分析、およびコード例をいくつか示します。
2. 分析
-
バイナリ クロス エントロピー ロス (BCELoss):
torch.nn.BCELoss() は、バイナリ分類に使用される損失関数です。モデル出力の確率と真のラベルのバイナリ値を比較し、バイナリのクロスエントロピー損失を計算します。BCELoss は、各サンプルが複数のカテゴリに属する場合を処理できます。BCELoss を使用する場合は、モデルの出力がシグモイド活性化関数によって [0, 1] の確率形式に変換されることに注意する必要があります。 -
ロジッツ損失のあるバイナリ クロス エントロピー (BCEWithLogitsLoss) with logits:
torch.nn.BCEWithLogitsLoss() は、BCELoss に似た損失関数で、シグモイド関数とバイナリ クロス エントロピー損失の両方を適用します。BCEWithLogitsLoss を使用する場合、シグモイド関数はすでに内部で自動的にこの操作を実行しているため、シグモイド関数をモデル出力に手動で適用する必要はありません。 -
マルチクラス相互エントロピー損失 (CrossEntropyLoss):
torch.nn.CrossEntropyLoss() は、マルチクラス分類タスクに使用される損失関数です。モデルによって出力された各カテゴリのスコアを真のラベルと比較し、クロスエントロピー損失を計算します。CrossEntropyLoss は、各サンプルが 1 つのカテゴリにのみ属する場合に適しています。CrossEntropyLoss を使用する前に、通常、モデル出力が Softmax 関数または log Softmax 関数を通過することを確認する必要があることに注意してください。 -
マルチラベル バイナリ クロス エントロピー ロス:
各サンプルが複数のカテゴリに属することができる場合、バイナリ クロス エントロピー ロスを使用してマルチラベル分類タスクを処理できます。各サンプルについて、モデルによって出力された確率が真のラベルと比較され、ラベルごとにバイナリのクロスエントロピー損失が計算されます。BCELoss は、ラベルごとに各ラベルに適用するか、torch.nn.BCEWithLogitsLoss() を使用してモデル出力の最後の次元をラベルの数に設定することで適用できます。
3. コード例
1) バイナリクロスエントロピー損失 (BCELoss):
import torch
import torch.nn as nn
# 模型输出经过 sigmoid 函数处理
model_output = torch.sigmoid(model(input))
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCELoss()
# 计算损失
loss = loss_fn(model_output, target)
2) ロジットを伴うバイナリクロスエントロピー損失 (BCEWithLogitsLoss):
import torch
import torch.nn as nn
# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_fn(model_output, target)
3) マルチカテゴリクロスエントロピー損失 (CrossEntropyLoss):
import torch
import torch.nn as nn
# 模型输出经过 softmax 函数处理
model_output = nn.functional.softmax(model(input), dim=1)
# 真实标签(每个样本只能属于一个类别)
target = torch.LongTensor([2, 1, 0])
# 创建损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(model_output, target)
4) マルチラベルバイナリクロスエントロピー損失:
import torch
import torch.nn as nn
# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([[0, 1], [1, 1], [1, 0], [0, 1]])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失,将模型输出的最后一个维度设置为标签数量
loss = loss_fn(model_output, target)