Source code for pina._src.weighting.base_weighting

"""Module for the Base Weighting class."""

from typing import final, Callable
import torch
from pina._src.weighting.weighting_interface import WeightingInterface
from pina._src.core.utils import check_positive_integer, check_consistency


[docs] class BaseWeighting(WeightingInterface): """ Base class for all weighting schemas, implementing common functionality. A weighting schema defines how scalar loss terms coming from different conditions are aggregated into a single scalar loss. All weighting schemas must inherit from this class and implement the methods defined in :class:`~pina.weighting.weighting_interface.WeightingInterface`. This class is not meant to be instantiated directly. """ # Supported aggregation methods _AGGREGATE_METHODS = {"sum": torch.sum, "mean": torch.mean} def __init__(self, update_every_n_epochs=1, aggregator="sum"): """ Initialization of the :class:`BaseWeighting` class. :param int update_every_n_epochs: The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. This parameter is ignored by static weighting schemes. Default is ``1``. :param aggregator: The aggregation method. Available options include: ``"sum"`` which sums the weighted losses, ``"mean"`` which averages the weighted losses, or a custom callable that takes an iterable of weighted losses and returns a single scalar. Default is ``"sum"``. :type aggregator: str | Callable :raises ValueError: If ``update_every_n_epochs`` is not a positive integer. :raises ValueError: If ``aggregator`` is invalid. """ # Check consistency check_positive_integer(value=update_every_n_epochs, strict=True) check_consistency(aggregator, (str, Callable)) # Validate aggregator if isinstance(aggregator, str): if aggregator not in self._AGGREGATE_METHODS: raise ValueError( f"Invalid aggregator '{aggregator}'. Available options: " f"{list(self._AGGREGATE_METHODS.keys())}. Got {aggregator}." ) aggregator = self._AGGREGATE_METHODS[aggregator] # Initialization self.update_every_n_epochs = update_every_n_epochs self.aggregator_fn = aggregator self._solver = None self._saved_weights = {}
[docs] @final def aggregate(self, losses): """ Aggregate a collection of loss terms into a single scalar. This method applies the current weighting scheme to the provided losses and returns the aggregated result. Implementations may internally update the weights (e.g., based on training state) before performing the aggregation. :param dict losses: The mapping from loss names to loss tensors. :return: The aggregated loss value. :rtype: torch.Tensor """ # Update weights when required if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0: self._saved_weights = self.update_weights(losses) # Apply weights to the corresponding losses weighted_losses = torch.stack( [ (self._saved_weights[condition] * loss).reshape(-1) for condition, loss in losses.items() ] ) return self.aggregator_fn(weighted_losses)
[docs] def last_saved_weights(self): """ Get the most recently computed weights. :return: The mapping from loss names to their corresponding weights. :rtype: dict """ return self._saved_weights
@property def solver(self): """ Solver associated with this weighting strategy. Provides access to the solver instance that uses this weighting scheme, enabling strategies that depend on training state or model information. :return: The solver instance. :rtype: :class:`~pina.solver.base_solver.BaseSolver` """ return self._solver