Source code for pina._src.solver.mixin.condition_aggregator_mixin
"""Module for the condition aggregator mixin class."""
import torch
[docs]
class ConditionAggregatorMixin:
"""
Mixin that logs per-condition scalar losses, weights them following the
provided weighting scheme, and aggregates them into the total loss.
Designed to be used in combination with any solver inheriting from
:class:`~pina._src.solver.base_solver.BaseSolver`.
"""
[docs]
def batch_evaluation_step(self, batch, batch_idx):
"""
Evaluate and aggregate the losses for all conditions in a batch.
For each condition in the batch, this method computes the corresponding
scalar loss, logs it using the condition name, and combines all
condition losses into a single scalar loss through the configured
weighting scheme.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:return: The aggregated scalar loss for the batch.
:rtype: torch.Tensor
"""
# Initialize a dictionary to hold the scalar losses for each condition
condition_losses = {}
# Loop through each condition in the batch and compute its scalar loss
for condition_name, data in batch:
# Compute the scalar loss for the current condition
condition_losses[condition_name] = self._compute_condition_loss(
condition=self.problem.conditions[condition_name],
data=dict(data),
batch_idx=batch_idx,
)
# Clamp parameters - null operation if problem is not InverseProblem
self._clamp_params()
# Log the individual condition losses
for name, value in condition_losses.items():
self.log(
name=f"{name}_loss",
value=value.item(),
batch_size=self.get_batch_size(batch),
**self.trainer.logging_kwargs,
)
# Aggregate into the total loss using the weighting scheme
aggregated_loss = self.weighting.aggregate(condition_losses)
return aggregated_loss.as_subclass(torch.Tensor)