Condition Aggregator Mixin#

Module for the condition aggregator mixin class.

class ConditionAggregatorMixin[source]

Bases: object

Mixin that logs per-condition scalar losses, weights them following the provided weighting scheme, and aggregates them into the total loss.

Designed to be used in combination with any solver inheriting from BaseSolver.

batch_evaluation_step(batch, batch_idx)[source]

Evaluate and aggregate the losses for all conditions in a batch.

For each condition in the batch, this method computes the corresponding scalar loss, logs it using the condition name, and combines all condition losses into a single scalar loss through the configured weighting scheme.

Parameters:
  • batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

  • batch_idx (int) – The index of the current batch.

Returns:

The aggregated scalar loss for the batch.

Return type:

torch.Tensor