BaseWeighting#
Module for the Base Weighting class.
- class BaseWeighting(update_every_n_epochs=1, aggregator='sum')[source]#
Bases:
WeightingInterfaceBase class for all weighting schemas, implementing common functionality.
A weighting schema defines how scalar loss terms coming from different conditions are aggregated into a single scalar loss.
All weighting schemas must inherit from this class and implement the methods defined in
WeightingInterface.This class is not meant to be instantiated directly.
Initialization of the
BaseWeightingclass.- Parameters:
update_every_n_epochs (int) – The number of training epochs between weight updates. If set to 1, the weights are updated at every epoch. This parameter is ignored by static weighting schemes. Default is
1.aggregator (str | Callable) – The aggregation method. Available options include:
"sum"which sums the weighted losses,"mean"which averages the weighted losses, or a custom callable that takes an iterable of weighted losses and returns a single scalar. Default is"sum".
- Raises:
ValueError – If
update_every_n_epochsis not a positive integer.ValueError – If
aggregatoris invalid.
- final aggregate(losses)[source]#
Aggregate a collection of loss terms into a single scalar.
This method applies the current weighting scheme to the provided losses and returns the aggregated result. Implementations may internally update the weights (e.g., based on training state) before performing the aggregation.
- Parameters:
losses (dict) – The mapping from loss names to loss tensors.
- Returns:
The aggregated loss value.
- Return type:
- last_saved_weights()[source]#
Get the most recently computed weights.
- Returns:
The mapping from loss names to their corresponding weights.
- Return type:
- property solver#
Solver associated with this weighting strategy.
Provides access to the solver instance that uses this weighting scheme, enabling strategies that depend on training state or model information.
- Returns:
The solver instance.
- Return type:
BaseSolver