Source code for pina._src.loss.dual_loss_interface

"""Module for the Loss Interface."""

from abc import ABCMeta, abstractmethod
from torch.nn.modules.loss import _Loss


[docs] class DualLossInterface(_Loss, metaclass=ABCMeta): """ Abstract interface for all losses requiring both an input and a target tensor. """
[docs] @abstractmethod def forward(self, input, target): """ Forward method of the loss function. :param torch.Tensor input: The input tensor. :param torch.Tensor target: The target tensor. :return: The computed loss. :rtype: torch.Tensor """
@abstractmethod def _reduction(self, loss): """ Apply the configured reduction operation to pointwise loss values. :param torch.Tensor loss: The tensor of pointwise losses. :return: The reduced loss tensor. :rtype: torch.Tensor """