Source code for pina.trainer

"""Module for the Trainer."""

import sys
import torch
import lightning
from .utils import check_consistency
from .data import PinaDataModule
from .solver import SolverInterface, PINNInterface


[docs] class Trainer(lightning.pytorch.Trainer): """ PINA custom Trainer class to extend the standard Lightning functionality. This class enables specific features or behaviors required by the PINA framework. It modifies the standard :class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>` class to better support the training process in PINA. """ def __init__( self, solver, batch_size=None, train_size=1.0, test_size=0.0, val_size=0.0, compile=None, repeat=None, automatic_batching=None, num_workers=None, pin_memory=None, shuffle=None, **kwargs, ): """ Initialization of the :class:`Trainer` class. :param SolverInterface solver: A :class:`~pina.solver.solver.SolverInterface` solver used to solve a :class:`~pina.problem.abstract_problem.AbstractProblem`. :param int batch_size: The number of samples per batch to load. If ``None``, all samples are loaded and data is not batched. Default is ``None``. :param float train_size: The percentage of elements to include in the training dataset. Default is ``1.0``. :param float test_size: The percentage of elements to include in the test dataset. Default is ``0.0``. :param float val_size: The percentage of elements to include in the validation dataset. Default is ``0.0``. :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. :param bool repeat: Whether to repeat the dataset data in each condition during training. For further details, see the :class:`~pina.data.data_module.PinaDataModule` class. Default is ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see the :class:`~pina.data.data_module.PinaDataModule` class. Default is ``False``. :param int num_workers: The number of worker threads for data loading. Default is ``0`` (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data transfer to GPU. Default is ``False``. :param bool shuffle: Whether to shuffle the data during training. Default is ``True``. :param dict kwargs: Additional keyword arguments that specify the training setup. These can be selected from the `pytorch-lightning Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_. """ # check consistency for init types self._check_input_consistency( solver=solver, train_size=train_size, test_size=test_size, val_size=val_size, repeat=repeat, automatic_batching=automatic_batching, compile=compile, ) pin_memory, num_workers, shuffle, batch_size = ( self._check_consistency_and_set_defaults( pin_memory, num_workers, shuffle, batch_size ) ) # inference mode set to false when validating/testing PINNs otherwise # gradient is not tracked and optimization_cycle fails if isinstance(solver, PINNInterface): kwargs["inference_mode"] = False # Logging depends on the batch size, when batch_size is None then # log_every_n_steps should be zero if batch_size is None: kwargs["log_every_n_steps"] = 0 else: kwargs.setdefault("log_every_n_steps", 50) # default for lightning # Setting default kwargs, overriding lightning defaults kwargs.setdefault("enable_progress_bar", True) super().__init__(**kwargs) # checking compilation and automatic batching if compile is None or sys.platform == "win32": compile = False repeat = repeat if repeat is not None else False automatic_batching = ( automatic_batching if automatic_batching is not None else False ) # set attributes self.compile = compile self.solver = solver self.batch_size = batch_size self._move_to_device() self.data_module = None self._create_datamodule( train_size=train_size, test_size=test_size, val_size=val_size, batch_size=batch_size, repeat=repeat, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, shuffle=shuffle, ) # logging self.logging_kwargs = { "sync_dist": bool( len(self._accelerator_connector._parallel_devices) > 1 ), "on_step": bool(kwargs["log_every_n_steps"] > 0), "prog_bar": bool(kwargs["enable_progress_bar"]), "on_epoch": True, } def _move_to_device(self): """ Moves the ``unknown_parameters`` of an instance of :class:`~pina.problem.abstract_problem.AbstractProblem` to the :class:`Trainer` device. """ device = self._accelerator_connector._parallel_devices[0] # move parameters to device pb = self.solver.problem if hasattr(pb, "unknown_parameters"): for key in pb.unknown_parameters: pb.unknown_parameters[key] = torch.nn.Parameter( pb.unknown_parameters[key].data.to(device) ) def _create_datamodule( self, train_size, test_size, val_size, batch_size, repeat, automatic_batching, pin_memory, num_workers, shuffle, ): """ This method is designed to handle the creation of a data module when resampling is needed during training. Instead of manually defining and modifying the trainer's dataloaders, this method is called to automatically configure the data module. :param float train_size: The percentage of elements to include in the training dataset. :param float test_size: The percentage of elements to include in the test dataset. :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. :param bool repeat: Whether to repeat the dataset data in each condition during training. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data transfer to GPU. :param int num_workers: The number of worker threads for data loading. :param bool shuffle: Whether to shuffle the data during training. :raises RuntimeError: If not all conditions are sampled. """ if not self.solver.problem.are_all_domains_discretised: error_message = "\n".join( [ f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.solver.problem.discretised_domains else "not sampled"}""" for key in self.solver.problem.domains.keys() ] ) raise RuntimeError( "Cannot create Trainer if not all conditions " "are sampled. The Trainer got the following:\n" f"{error_message}" ) self.data_module = PinaDataModule( self.solver.problem, train_size=train_size, test_size=test_size, val_size=val_size, batch_size=batch_size, repeat=repeat, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, )
[docs] def train(self, **kwargs): """ Manage the training process of the solver. :param dict kwargs: Additional keyword arguments. See `pytorch-lightning Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_ for details. """ return super().fit(self.solver, datamodule=self.data_module, **kwargs)
[docs] def test(self, **kwargs): """ Manage the test process of the solver. :param dict kwargs: Additional keyword arguments. See `pytorch-lightning Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_ for details. """ return super().test(self.solver, datamodule=self.data_module, **kwargs)
@property def solver(self): """ Get the solver. :return: The solver. :rtype: SolverInterface """ return self._solver @solver.setter def solver(self, solver): """ Set the solver. :param SolverInterface solver: The solver to set. """ self._solver = solver @staticmethod def _check_input_consistency( solver, train_size, test_size, val_size, repeat, automatic_batching, compile, ): """ Verifies the consistency of the parameters for the solver configuration. :param SolverInterface solver: The solver. :param float train_size: The percentage of elements to include in the training dataset. :param float test_size: The percentage of elements to include in the test dataset. :param float val_size: The percentage of elements to include in the validation dataset. :param bool repeat: Whether to repeat the dataset data in each condition during training. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. """ check_consistency(solver, SolverInterface) check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) if repeat is not None: check_consistency(repeat, bool) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: check_consistency(compile, bool) @staticmethod def _check_consistency_and_set_defaults( pin_memory, num_workers, shuffle, batch_size ): """ Checks the consistency of input parameters and sets default values for missing or invalid parameters. :param bool pin_memory: Whether to use pinned memory for faster data transfer to GPU. :param int num_workers: The number of worker threads for data loading. :param bool shuffle: Whether to shuffle the data during training. :param int batch_size: The number of samples per batch to load. """ if pin_memory is not None: check_consistency(pin_memory, bool) else: pin_memory = False if num_workers is not None: check_consistency(pin_memory, int) else: num_workers = 0 if shuffle is not None: check_consistency(shuffle, bool) else: shuffle = True if batch_size is not None: check_consistency(batch_size, int) return pin_memory, num_workers, shuffle, batch_size