ScalarWeighting#

Module for the Scalar Weighting.

class ScalarWeighting(weights)[source]#

Bases: WeightingInterface

Weighting scheme that assigns a scalar weight to each loss term.

Initialization of the ScalarWeighting class.

Parameters:

weights (float | int | dict) – The weights to be assigned to each loss term. If a single scalar value is provided, it is assigned to all loss terms. If a dictionary is provided, the keys are the conditions and the values are the weights. If a condition is not present in the dictionary, the default value is used.

aggregate(losses)[source]#

Aggregate the losses.

Parameters:

losses (dict) – The dictionary of losses.

Returns:

The aggregated losses.

Return type:

torch.Tensor