"""Module for the Loss Interface."""fromabcimportABCMeta,abstractmethodfromtorch.nn.modules.lossimport_Lossimporttorch
[docs]classLossInterface(_Loss,metaclass=ABCMeta):""" Abstract base class for all losses. All classes defining a loss function should inherit from this interface. """def__init__(self,reduction="mean"):""" Initialization of the :class:`LossInterface` class. :param str reduction: The reduction method for the loss. Available options: ``none``, ``mean``, ``sum``. If ``none``, no reduction is applied. If ``mean``, the sum of the loss values is divided by the number of values. If ``sum``, the loss values are summed. Default is ``mean``. """super().__init__(reduction=reduction,size_average=None,reduce=None)
[docs]@abstractmethoddefforward(self,input,target):""" Forward method of the loss function. :param torch.Tensor input: Input tensor from real data. :param torch.Tensor target: Model tensor output. """
def_reduction(self,loss):""" Apply the reduction to the loss. :param torch.Tensor loss: The tensor containing the pointwise losses. :raises ValueError: If the reduction method is not valid. :return: Reduced loss. :rtype: torch.Tensor """ifself.reduction=="none":ret=losselifself.reduction=="mean":ret=torch.mean(loss,keepdim=True,dim=-1)elifself.reduction=="sum":ret=torch.sum(loss,keepdim=True,dim=-1)else:raiseValueError(self.reduction+" is not valid")returnret