Source code for pina._src.data.manager.data_manager
"""Module for the Data Manager factory class."""
import torch
from pina._src.core.label_tensor import LabelTensor
from pina._src.equation.base_equation import BaseEquation
from pina._src.data.manager.graph_data_manager import _GraphDataManager
from pina._src.data.manager.tensor_data_manager import _TensorDataManager
[docs]
class _DataManager:
"""
Factory class for data manager implementations.
This class dispatches object creation to either
:class:`~pina.data.manager.tensor_data_manager._TensorDataManager` or
:class:`~pina.data.manager.graph_data_manager._GraphDataManager` depending
on the types of the provided keyword arguments.
"""
def __new__(cls, **kwargs):
"""
Create the appropriate data manager implementation based on the provided
keyword arguments.
If all values in ``kwargs`` are instances of :class:`torch.Tensor`,
:class:`~pina.label_tensor.LabelTensor`, or
:class:`~pina.equation.base_equation.BaseEquation`, an instance of
:class:`~pina.data.manager.tensor_data_manager._TensorDataManager` is
created. Otherwise, an instance of
:class:`~pina.data.manager.graph_data_manager._GraphDataManager` is
created.
:param dict kwargs: The keyword arguments for the data manager.
:return: A concrete data manager instance.
:rtype: _TensorDataManager | _GraphDataManager
"""
# Guard subclass instantiation
if cls is not _DataManager:
return super().__new__(cls)
# Check if there are only tensors / equations
is_tensor_only = all(
isinstance(v, (torch.Tensor, LabelTensor, BaseEquation))
for v in kwargs.values()
)
# Choose the appropriate subclass
subclass = _TensorDataManager if is_tensor_only else _GraphDataManager
return subclass(**kwargs)