Trainer#
Module for the Trainer.
- class Trainer(solver, batch_size=None, train_size=1.0, test_size=0.0, val_size=0.0, compile=None, repeat=None, automatic_batching=None, num_workers=None, pin_memory=None, shuffle=None, **kwargs)[source]#
Bases:
Trainer
PINA custom Trainer class to extend the standard Lightning functionality.
This class enables specific features or behaviors required by the PINA framework. It modifies the standard
lightning.pytorch.Trainer
class to better support the training process in PINA.Initialization of the
Trainer
class.- Parameters:
solver (SolverInterface) – A
SolverInterface
solver used to solve aAbstractProblem
.batch_size (int) – The number of samples per batch to load. If
None
, all samples are loaded and data is not batched. Default isNone
.train_size (float) – The percentage of elements to include in the training dataset. Default is
1.0
.test_size (float) – The percentage of elements to include in the test dataset. Default is
0.0
.val_size (float) – The percentage of elements to include in the validation dataset. Default is
0.0
.compile (bool) – If
True
, the model is compiled before training. Default isFalse
. For Windows users, it is always disabled.repeat (bool) – Whether to repeat the dataset data in each condition during training. For further details, see the
PinaDataModule
class. Default isFalse
.automatic_batching (bool) – If
True
, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see thePinaDataModule
class. Default isFalse
.num_workers (int) – The number of worker threads for data loading. Default is
0
(serial loading).pin_memory (bool) – Whether to use pinned memory for faster data transfer to GPU. Default is
False
.shuffle (bool) – Whether to shuffle the data during training. Default is
True
.kwargs (dict) – Additional keyword arguments that specify the training setup. These can be selected from the pytorch-lightning Trainer API.
- train(**kwargs)[source]#
Manage the training process of the solver.
- Parameters:
kwargs (dict) –
Additional keyword arguments. See pytorch-lightning Trainer API for details.
- test(**kwargs)[source]#
Manage the test process of the solver.
- Parameters:
kwargs (dict) –
Additional keyword arguments. See pytorch-lightning Trainer API for details.
- property solver#
Get the solver.
- Returns:
The solver.
- Return type: