ScalarWeighting#

Module for the Scalar Weighting.

class ScalarWeighting(weights)[source]#

Bases: BaseWeighting

Weighting strategy based on fixed scalar coefficients.

This scheme assigns a constant multiplicative weight to each loss term, without adapting over time. The same weight can be applied to all terms, or distinct weights can be specified for individual conditions.

Initialization of the ScalarWeighting class.

Parameters:

weights (float | int | dict) – The scalar weights associated with each loss term. It can be provided either as a single numeric value or as a dictionary. If a scalar is given, the same weight is applied to all loss terms. If a dictionary is provided, its keys represent the loss identifiers (e.g., conditions) and its values specify the corresponding weights. Loss terms not explicitly defined in the dictionary are assigned a default weight of 1.

Raises:

ValueError – If the input weights are neither numeric nor a dictionary.

update_weights(losses)[source]#

Update the weights based on the current losses.

This method defines how the weighting strategy adapts over time. It is responsible for computing and storing updated weights that will be used during aggregation.

Parameters:

losses (dict) – The mapping from loss names to loss tensors.

Returns:

The updated weights.

Return type:

dict