Trainer#

Module for the Trainer.

class Trainer(solver, batch_size=None, train_size=1.0, test_size=0.0, val_size=0.0, compile=None, repeat=None, automatic_batching=None, num_workers=None, pin_memory=None, shuffle=None, **kwargs)[source]#

Bases: Trainer

PINA custom Trainer class to extend the standard Lightning functionality.

This class enables specific features or behaviors required by the PINA framework. It modifies the standard lightning.pytorch.Trainer class to better support the training process in PINA.

Initialization of the Trainer class.

Parameters:
  • solver (SolverInterface) – A SolverInterface solver used to solve a AbstractProblem.

  • batch_size (int) – The number of samples per batch to load. If None, all samples are loaded and data is not batched. Default is None.

  • train_size (float) – The percentage of elements to include in the training dataset. Default is 1.0.

  • test_size (float) – The percentage of elements to include in the test dataset. Default is 0.0.

  • val_size (float) – The percentage of elements to include in the validation dataset. Default is 0.0.

  • compile (bool) – If True, the model is compiled before training. Default is False. For Windows users, it is always disabled.

  • repeat (bool) – Whether to repeat the dataset data in each condition during training. For further details, see the PinaDataModule class. Default is False.

  • automatic_batching (bool) – If True, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see the PinaDataModule class. Default is False.

  • num_workers (int) – The number of worker threads for data loading. Default is 0 (serial loading).

  • pin_memory (bool) – Whether to use pinned memory for faster data transfer to GPU. Default is False.

  • shuffle (bool) – Whether to shuffle the data during training. Default is True.

  • kwargs (dict) – Additional keyword arguments that specify the training setup. These can be selected from the pytorch-lightning Trainer API.

train(**kwargs)[source]#

Manage the training process of the solver.

Parameters:

kwargs (dict) –

Additional keyword arguments. See pytorch-lightning Trainer API for details.

test(**kwargs)[source]#

Manage the test process of the solver.

Parameters:

kwargs (dict) –

Additional keyword arguments. See pytorch-lightning Trainer API for details.

property solver#

Get the solver.

Returns:

The solver.

Return type:

SolverInterface