Source code for pina._src.loss.dual_loss_interface
"""Module for the Loss Interface."""
from abc import ABCMeta, abstractmethod
from torch.nn.modules.loss import _Loss
[docs]
class DualLossInterface(_Loss, metaclass=ABCMeta):
"""
Abstract interface for all losses requiring both an input and a target
tensor.
"""
[docs]
@abstractmethod
def forward(self, input, target):
"""
Forward method of the loss function.
:param torch.Tensor input: The input tensor.
:param torch.Tensor target: The target tensor.
:return: The computed loss.
:rtype: torch.Tensor
"""
@abstractmethod
def _reduction(self, loss):
"""
Apply the configured reduction operation to pointwise loss values.
:param torch.Tensor loss: The tensor of pointwise losses.
:return: The reduced loss tensor.
:rtype: torch.Tensor
"""