"""Trainer utilities built on top of the PyTorch Lightning Trainer class."""
import warnings
import torch
import lightning
from pina._src.solver.mixin.physics_informed_mixin import PhysicsInformedMixin
from pina._src.solver.base_solver import BaseSolver
from pina._src.data.data_module import DataModule
from pina._src.core.utils import (
check_consistency,
custom_warning_format,
check_positive_integer,
)
# Set custom warning format and filter warnings
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)
[docs]
class Trainer(lightning.pytorch.Trainer):
"""
PINA-specific extension of :class:`lightning.pytorch.Trainer`.
The trainer configures solver execution, dataset splitting, batching,
logging, device placement for unknown parameters, and gradient tracking
requirements for physics-informed solvers.
"""
# Available batching modes
_AVAIL_BATCHING_MODES = {
"common_batch_size",
"proportional",
"separate_conditions",
}
def __init__(
self,
solver,
batch_size=None,
train_size=1.0,
test_size=0.0,
val_size=0.0,
batching_mode="common_batch_size",
automatic_batching=False,
num_workers=0,
pin_memory=False,
shuffle=True,
**kwargs,
):
"""
Initialization of the :class:`Trainer` class.
:param BaseSolver solver: The solver used to train, validate, and test
the associated problem.
:param int batch_size: The number of samples per batch. If ``None``, the
entire dataset is processed as a single batch. Default is ``None``.
:param float train_size: The fraction of samples assigned to the
training split. Must belong to the interval ``[0, 1]``.
Default is ``1.0``.
:param float val_size: The fraction of samples assigned to the
validation split. Must belong to the interval ``[0, 1]``.
Default is ``0.0``.
:param float test_size: The fraction of samples assigned to the test
split. Must belong to the interval ``[0, 1]``. Default is ``0.0``.
:param str batching_mode: The strategy used to aggregate batches across
dataloaders. Available options are ``"common_batch_size"`` for
uniform batch sizes across conditions, ``"proportional"`` for batch
sizes proportional to dataset sizes, and ``"separate_conditions"``
for iterating through each condition separately.
Default is ``"common_batch_size"``.
:param bool automatic_batching: Whether PyTorch automatic batching
should be enabled. If ``True``, dataset elements are retrieved
individually and collated into batches by the dataloader.
If ``False``, entire subsets are retrieved directly from the
condition object. Default is ``False``.
:param int num_workers: The number of worker processes used by
dataloaders. Default is ``0`` for sequential loading.
:param bool pin_memory: Whether pinned memory should be enabled during
data loading. Default is ``False``.
:param bool shuffle: Whether condition samples should be shuffled before
splitting. Default is ``True``.
:param dict kwargs: Additional keyword arguments forwarded to the
Lightning trainer.
:raises ValueError: If ``solver`` is not a PINA solver.
:raises ValueError: If ``train_size``, ``val_size``, or ``test_size`` is
not a float in the interval ``[0, 1]``.
:raises ValueError: If the sum of ``train_size``, ``val_size``, and
``test_size`` is not equal to 1.
:raises ValueError: If ``automatic_batching``, ``pin_memory``, or
``shuffle`` is not a boolean.
:raises AssertionError: If ``num_workers`` is a negative integer.
:raises ValueError: If ``batch_size``, when provided, is not a positive
integer.
:raises ValueError: If ``batching_mode`` is not one of the available
options.
:raises UserWarning: If the provided ``batching_mode`` is incompatible
with the ``batch_size``.
:raises RuntimeError: If any domain in the problem has not been
discretised.
"""
# Backward compatibility: compile has been removed
if "compile" in kwargs:
warnings.warn(
"`compile` is deprecated and no longer used. Compilation is "
"now disabled and the argument will be ignored.",
DeprecationWarning,
stacklevel=2,
)
kwargs.pop("compile")
# Check consistency
check_consistency(solver, BaseSolver)
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
check_consistency(automatic_batching, bool)
check_consistency(pin_memory, bool)
check_consistency(shuffle, bool)
check_positive_integer(num_workers, strict=False)
if batch_size is not None:
check_positive_integer(batch_size, strict=True)
# Check that train_size, test_size and val_size sum to 1
total = train_size + val_size + test_size
if not torch.isclose(torch.tensor(total), torch.tensor(1.0)):
raise ValueError(
"`train_size`, `val_size`, and `test_size` must sum to 1."
)
# Check consistency
if batching_mode not in self._AVAIL_BATCHING_MODES:
raise ValueError(
f"Invalid batching mode '{batching_mode}'. "
f"Expected one of: {sorted(self._AVAIL_BATCHING_MODES)}."
)
# Set inference mode to false when usiing physics-informed mixin
if isinstance(solver, PhysicsInformedMixin):
kwargs["inference_mode"] = False
# Set log_every_n_steps to 0 if batch_size is None, otherwise default
kwargs["log_every_n_steps"] = (
0 if batch_size is None else kwargs.get("log_every_n_steps", 50)
)
# Set default value for enable_progress_bar to True if not provided
kwargs.setdefault("enable_progress_bar", True)
# Initialize the parent class with the provided keyword arguments
super().__init__(**kwargs)
# Raise warning if batch size and batching mode are incompatible
if batch_size is None and batching_mode != "common_batch_size":
warnings.warn(
f"Batching mode '{batching_mode}' is ignored when the batch "
"size is None. Setting batching_mode to 'common_batch_size'.",
UserWarning,
)
# Set batching mode to common_batch_size if incompatible
batching_mode = "common_batch_size"
# Raise warning if batch size and batching mode are incompatible
if (
batch_size is not None
and batching_mode == "proportional"
and batch_size <= len(solver.problem.conditions)
):
warnings.warn(
"Batching mode 'proportional' requires the batch size to be "
"larger than the number of conditions. Setting batching_mode "
"to 'common_batch_size'.",
UserWarning,
)
# Set batching mode to common_batch_size if incompatible
batching_mode = "common_batch_size"
# Initialize the class attributes
self.solver = solver
self.batch_size = batch_size
# Move the unknown parameters to the correct device
self._move_to_device()
# Check that all domains are discretised, otherwise raise an error
if not self.solver.problem.are_all_domains_discretised:
# Get the list of sampled domains from the problem
sampled_domains = self.solver.problem.discretised_domains
# Create a status message for each domain
status = "\n".join(
f" - Domain '{name}': "
f"{'sampled' if name in sampled_domains else 'not sampled'}"
for name in self.solver.problem.domains
)
# Raise an error with the status of each domain
raise RuntimeError(
"Cannot create the Trainer because some domains have not been "
f"sampled. Domain status:\n{status}"
)
# Create the data module
self.data_module = DataModule(
problem=self.solver.problem,
train_size=train_size,
test_size=test_size,
val_size=val_size,
batch_size=self.batch_size,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
)
# Set logging kwargs
self.logging_kwargs = {
"sync_dist": bool(
len(self._accelerator_connector._parallel_devices) > 1
),
"on_step": bool(kwargs["log_every_n_steps"] > 0),
"prog_bar": bool(kwargs["enable_progress_bar"]),
"on_epoch": True,
}
def _move_to_device(self):
"""
Move problem unknown parameters to the trainer device.
If the associated problem defines ``unknown_parameters``, each parameter
is moved to the first device configured by the Lightning accelerator
connector.
"""
# Get the device from the accelerator connector
device = self._accelerator_connector._parallel_devices[0]
# Get the problem instance from the solver
problem = self.solver.problem
# Move the unknown parameters to the correct device if they exist
if hasattr(problem, "unknown_parameters"):
for key in problem.unknown_parameters:
problem.unknown_parameters[key] = torch.nn.Parameter(
problem.unknown_parameters[key].data.to(device)
)
[docs]
def train(self, **kwargs):
"""
Fit the solver using the trainer data module.
:param dict kwargs: Additional keyword arguments forwarded to the
Lightning trainer ``fit`` method.
:return: Result returned by Lightning's ``fit`` method.
:rtype: Any
"""
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
[docs]
def test(self, **kwargs):
"""
Test the solver using the trainer data module.
:param dict kwargs: Additional keyword arguments forwarded to the
Lightning trainer ``test`` method.
:return: Result returned by Lightning's ``test`` method.
:rtype: Any
"""
return super().test(self.solver, datamodule=self.data_module, **kwargs)
@property
def solver(self):
"""
Return the solver attached to the trainer.
:return: The solver used by the trainer.
:rtype: BaseSolver
"""
return self._solver
@solver.setter
def solver(self, solver):
"""
Set the solver attached to the trainer.
:param BaseSolver solver: The solver instance to attach.
"""
self._solver = solver