Source code for pina._src.condition.condition_interface
"""Module for the Condition interface."""
from abc import ABCMeta, abstractmethod
[docs]
class ConditionInterface(metaclass=ABCMeta):
"""
Abstract interface for all conditions.
Refer to :class:`pina.condition.condition.Condition` for a thorough
description of all available conditions and how to instantiate them.
"""
@abstractmethod
def __len__(self):
"""
Return the number of data points in the condition.
:return: The number of data points.
:rtype: int
"""
@abstractmethod
def __getitem__(self, idx):
"""
Return the data point at the specified index.
:param int idx: The index of the data point to retrieve.
:return: The data point at the specified index.
:rtype: Any
"""
[docs]
@abstractmethod
def store_data(self, **kwargs):
"""
Store the data for the condition in a suitable format.
:param dict kwargs: The keyword arguments containing the data to be
stored.
:return: The stored data in a suitable format.
:rtype: Any
"""
[docs]
@abstractmethod
def create_dataloader(
self, dataset, batch_size, automatic_batching, **kwargs
):
"""
Create the DataLoader for the condition.
:param _ConditionSubset dataset: The dataset for the DataLoader.
:param int batch_size: The batch size for the DataLoader.
:param bool automatic_batching: Whether to use automatic batching.
:param dict kwargs: Additional keyword arguments for the DataLoader.
:return: The DataLoader for the condition.
:rtype: torch.utils.data.DataLoader
"""
[docs]
@abstractmethod
def evaluate(self, batch, solver):
"""
Evaluate the residual of the condition on the given batch using the
solver.
This method computes the non-aggregated, element-wise residual of the
condition. A forward pass of the solver's model is performed on the
input samples, and the condition residual is evaluated accordingly.
The returned tensor is not reduced, preserving the per-sample residual
values.
:param dict batch: The batch containing the data required by the
condition evaluation.
:param BaseSolver solver: The solver used to perform the forward pass
and compute the residual. The solver provides access to the model
and its parameters, which may be necessary for evaluating the
condition residual.
:return: The non-aggregated residual tensor.
:rtype: torch.Tensor | LabelTensor
"""
[docs]
@abstractmethod
def switch_dataloader_fn(self, create_dataloader_fn):
"""
Switch the dataloader function for the condition.
:param Callable create_dataloader_fn: The new dataloader function to use
for the condition.
:return: The new dataloader function for the condition.
:rtype: Callable
"""
[docs]
@classmethod
@abstractmethod
def automatic_batching_collate_fn(cls, batch):
"""
Collate function for automatic batching to be used in the DataLoader.
:param list batch: A list of items from the dataset.
:return: A collated batch.
:rtype: dict
"""
[docs]
@staticmethod
@abstractmethod
def collate_fn(batch, condition):
"""
Collate function for custom batching to be used in the DataLoader.
:param list batch: A list of items from the dataset.
:param BaseCondition condition: The condition instance.
:return: A collated batch.
:rtype: dict
"""
@property
@abstractmethod
def problem(self):
"""
The problem associated with this condition.
:return: The problem associated with this condition.
:rtype: BaseProblem
"""
@problem.setter
@abstractmethod
def problem(self, value):
"""
Set the problem associated with this condition.
:param BaseProblem value: The problem to associate with this condition.
"""