Trainer#
Public API for Trainer.
- class Trainer(solver, batch_size=None, train_size=1.0, test_size=0.0, val_size=0.0, batching_mode='common_batch_size', automatic_batching=False, num_workers=0, pin_memory=False, shuffle=True, **kwargs)[source]#
Bases:
TrainerPINA-specific extension of
lightning.pytorch.Trainer.The trainer configures solver execution, dataset splitting, batching, logging, device placement for unknown parameters, and gradient tracking requirements for physics-informed solvers.
Initialization of the
Trainerclass.- Parameters:
solver (BaseSolver) – The solver used to train, validate, and test the associated problem.
batch_size (int) – The number of samples per batch. If
None, the entire dataset is processed as a single batch. Default isNone.train_size (float) – The fraction of samples assigned to the training split. Must belong to the interval
[0, 1]. Default is1.0.val_size (float) – The fraction of samples assigned to the validation split. Must belong to the interval
[0, 1]. Default is0.0.test_size (float) – The fraction of samples assigned to the test split. Must belong to the interval
[0, 1]. Default is0.0.batching_mode (str) – The strategy used to aggregate batches across dataloaders. Available options are
"common_batch_size"for uniform batch sizes across conditions,"proportional"for batch sizes proportional to dataset sizes, and"separate_conditions"for iterating through each condition separately. Default is"common_batch_size".automatic_batching (bool) – Whether PyTorch automatic batching should be enabled. If
True, dataset elements are retrieved individually and collated into batches by the dataloader. IfFalse, entire subsets are retrieved directly from the condition object. Default isFalse.num_workers (int) – The number of worker processes used by dataloaders. Default is
0for sequential loading.pin_memory (bool) – Whether pinned memory should be enabled during data loading. Default is
False.shuffle (bool) – Whether condition samples should be shuffled before splitting. Default is
True.kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer.
- Raises:
ValueError – If
solveris not a PINA solver.ValueError – If
train_size,val_size, ortest_sizeis not a float in the interval[0, 1].ValueError – If the sum of
train_size,val_size, andtest_sizeis not equal to 1.ValueError – If
automatic_batching,pin_memory, orshuffleis not a boolean.AssertionError – If
num_workersis a negative integer.ValueError – If
batch_size, when provided, is not a positive integer.ValueError – If
batching_modeis not one of the available options.UserWarning – If the provided
batching_modeis incompatible with thebatch_size.RuntimeError – If any domain in the problem has not been discretised.
- train(**kwargs)[source]#
Fit the solver using the trainer data module.
- Parameters:
kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer
fitmethod.- Returns:
Result returned by Lightning’s
fitmethod.- Return type:
Any
- test(**kwargs)[source]#
Test the solver using the trainer data module.
- Parameters:
kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer
testmethod.- Returns:
Result returned by Lightning’s
testmethod.- Return type:
Any
- property solver#
Return the solver attached to the trainer.
- Returns:
The solver used by the trainer.
- Return type:
BaseSolver