Source code for pina._src.condition.base_condition

"""Module for the Base Condition class."""

from functools import partial
import torch
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from pina._src.condition.condition_interface import ConditionInterface
from pina._src.core.graph import LabelBatch
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.utils import check_consistency
from pina._src.data.single_batch_data_loader import _SingleBatchDataLoader
from pina._src.problem.problem_interface import ProblemInterface


[docs] class BaseCondition(ConditionInterface): """ Base class for all conditions, implementing common functionality. All specific condition types should inherit from this class and implement the abstract methods of :class:`~pina.condition.condition_interface.ConditionInterface`. This class is not meant to be instantiated directly. """ # Available collate functions for automatic batching collate_fn_dict = { "tensor": torch.stack, "label_tensor": LabelTensor.stack, "graph": LabelBatch.from_data_list, "data": Batch.from_data_list, } def __init__(self, **kwargs): """ Initialization of the :class:`BaseCondition` class. :param dict kwargs: The keyword arguments representing the data to be stored in the condition. """ super().__init__() self.data = self.store_data(**kwargs) self.has_custom_dataloader_fn = False def __len__(self): """ Return the number of data points in the condition. :return: The number of data points. :rtype: int """ return len(self.data) def __getitem__(self, idx): """ Return the data point at the specified index. :param int idx: The index of the data point to retrieve. :return: The data point at the specified index. :rtype: Any """ return self.data[idx]
[docs] def create_dataloader( self, dataset, batch_size, automatic_batching, **kwargs ): """ Create the DataLoader for the condition. :param _ConditionSubset dataset: The dataset for the DataLoader. :param int batch_size: The batch size for the DataLoader. :param bool automatic_batching: Whether to use automatic batching. :param dict kwargs: Additional keyword arguments for the DataLoader. :return: The DataLoader for the condition. :rtype: torch.utils.data.DataLoader """ # If batching the entire dataset, return a _SingleBatchDataLoader if batch_size == len(dataset): return _SingleBatchDataLoader(dataset) # Otherwise, return a regular DataLoader with the appropriate collate return DataLoader( dataset=dataset, collate_fn=( partial(self.collate_fn, condition=self) if not automatic_batching else self.automatic_batching_collate_fn ), batch_size=batch_size, **kwargs, )
[docs] def switch_dataloader_fn(self, create_dataloader_fn): """ Switch the dataloader function for the condition. :param Callable create_dataloader_fn: The new dataloader function to use for the condition. :return: The new dataloader function for the condition. :rtype: Callable """ self.has_custom_dataloader_fn = True self.create_dataloader = create_dataloader_fn
[docs] @classmethod def automatic_batching_collate_fn(cls, batch): """ Collate function for automatic batching to be used in the DataLoader. :param list batch: A list of items from the dataset. :return: A collated batch. :rtype: dict """ # If the batch is empty, return an empty dictionary if not batch: return {} # Otherwise, collate the batch using the appropriate collate function instance_class = batch[0].__class__ return instance_class.create_batch(batch)
[docs] @staticmethod def collate_fn(batch, condition): """ Collate function for custom batching to be used in the DataLoader. :param list batch: A list of items from the dataset. :param BaseCondition condition: The condition instance. :return: A collated batch. :rtype: dict """ return condition.data[batch].to_batch()
@property def problem(self): """ The problem associated with this condition. :return: The problem associated with this condition. :rtype: BaseProblem """ return self._problem @problem.setter def problem(self, value): """ Set the problem associated with this condition. :param BaseProblem value: The problem to associate with this condition. :raises ValueError: If the problem is not an instance of BaseProblem. """ check_consistency(value, ProblemInterface) self._problem = value