Com base na PyTorch função de perda

prefácio

Este blog fornece uma referência para alguma função perda comum, você pode facilmente importar para o código.
define a função perda como o modelo de rede neural calculado a partir do erro global residual por rodada, que por sua vez afeta o coeficiente de ajuste na realização do modo de back-propagação do mesmo, é só selecionar a função afeta diretamente o modelo de perda de desempenho.
Segmentação e classificação para outras tarefas, a opção padrão de função de perda é um cross-entropia binária (BCE). Quando um determinado métrica, tais como dados ou coeficientes IOU, é usado para o desempenho do modelo é determinada, por vezes concorrentes derivada destas teste mede a função de perda - geralmente sob a forma de um -. F (x), onde f (x) é medida problemática. Estas funções não pode simplesmente escrever NumPy, como eles são implementados no GPU, a extremidade traseira da função é necessária a partir da respectiva biblioteca modelo, que também compreende um gradiente para o algoritmo back-propagação. Esta não imaginava complexa.
Em muitos tipos de segmentação, cada classe para calcular a perda média é calculada a partir de todo o tensor de predição perda geralmente utilizado função perda vez. Este blog vai servir como um código básico modelo de referência, mas para muitos tipos de média e modificá-lo deve ser muito simples. Por exemplo, se o tensor compreende uma aulas achatadas contínuas, você pode dividi-los em quatro classes de igual comprimento, a perda de seus respectivos calculados e média.
Esperemos que este blog para sua ajuda, você é bem-vindo quaisquer alterações propostas.

perda de dados

coeficiente de Dice ou coeficiente de Sorensen-Dice, é uma tarefa de classificação binária padrão comum, tais como a divisão pixel, ele também pode ser modificado em função do resultado:
Aqui Insert Picture Descrição

#PyTorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

Perda BCE-Dice

Esta perda de ligação e dados de perda de perda de binário critérios de entropia cruzada (BCE), que é geralmente o modelo de segmentação padrão. Combinar estes dois métodos pode reduzir a perda de certa forma, enquanto se beneficia de estabilidade do BCE. Qualquer pessoa que aprendeu equações de regressão logística estão familiarizados com muitos tipos de AEC:
Aqui Insert Picture Descrição

#PyTorch
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

Jaccard / Cruzamento sobre a perda União (IOU)

indicadores IOU, ou índice de Jaccard, semelhante ao dice de índice, calculado como exemplos positivos de sobreposição entre os dois conjuntos da relação entre o valor da mesma em combinação um com o outro:
Aqui Insert Picture Descrição
o mesmo que dados métricos, é também um método comum de avaliação o desempenho do modelo divisão de pixel.

#PyTorch
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU

Perda focal

Artificial Intelligence Research Facebook Lin, que em 2017 apresentou uma perda focal, como um hedge contra conjuntos de dados desequilíbrio extremos, relativamente poucos destes dados definir exemplos positivos. Seu papel "Perda Focal para Detecção Dense objeto" pode ser encontrada aqui: https://arxiv.org/abs/1708.02002. Na prática, os pesquisadores usaram uma versão modificada da função alfa, então eu vou incluí-lo na implementação.

#PyTorch
ALPHA = 0.8
GAMMA = 2

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()

    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
                       
        return focal_loss
Publicado 33 artigos originais · ganhou elogios 3 · Vistas 5546

Acho que você gosta

Origin blog.csdn.net/weixin_42990464/article/details/104260043
Recomendado
Clasificación