Source code for pina._src.weighting.weighting_interface
"""Module for the Weighting Interface."""
from abc import ABCMeta, abstractmethod
[docs]
class WeightingInterface(metaclass=ABCMeta):
"""
Abstract interface for all weighting schemas.
"""
[docs]
@abstractmethod
def aggregate(self, losses):
"""
Aggregate a collection of loss terms into a single scalar.
This method applies the current weighting scheme to the provided losses
and returns the aggregated result. Implementations may internally update
the weights (e.g., based on training state) before performing the
aggregation.
:param dict losses: The mapping from loss names to loss tensors.
:return: The aggregated loss value.
:rtype: torch.Tensor
"""
[docs]
@abstractmethod
def update_weights(self, losses):
"""
Update the weights based on the current losses.
This method defines how the weighting strategy adapts over time. It is
responsible for computing and storing updated weights that will be used
during aggregation.
:param dict losses: The mapping from loss names to loss tensors.
:return: The updated weights.
:rtype: dict
"""
[docs]
@abstractmethod
def last_saved_weights(self):
"""
Get the most recently computed weights.
:return: The mapping from loss names to their corresponding weights.
:rtype: dict
"""
@property
@abstractmethod
def solver(self):
"""
Solver associated with this weighting strategy.
Provides access to the solver instance that uses this weighting scheme,
enabling strategies that depend on training state or model information.
:return: The solver instance.
:rtype: :class:`~pina.solver.base_solver.BaseSolver`
"""