Loss function Loss design and implementation

1 torch's custom loss function

Pytorch comes with some commonly used loss functions, which are all subclasses of torch.nn.Module. Therefore, the custom Loss function also needs to inherit this class.

Define the required hyperparameters in the __init__ function, and define the calculation method of loss in the forward function. The forward method is where the loss function is actually defined, and its return value is a scalar (Scalar) instead of a tensor (Tensor) and a vector (Vector). That is to say, if tensor or vector, you need to use related functions to convert it into a scalar, for example, use the torch.sum function for summation.

import torch
import torch.nn as nn
import torch.nn.functional as func

class myLoss(nn.Module

Guess you like

Origin blog.csdn.net/weixin_43838785/article/details/127160742