Trainer#

Public API for Trainer.

class Trainer(solver, batch_size=None, train_size=1.0, test_size=0.0, val_size=0.0, batching_mode='common_batch_size', automatic_batching=False, num_workers=0, pin_memory=False, shuffle=True, **kwargs)[source]#

Bases: Trainer

PINA-specific extension of lightning.pytorch.Trainer.

The trainer configures solver execution, dataset splitting, batching, logging, device placement for unknown parameters, and gradient tracking requirements for physics-informed solvers.

Initialization of the Trainer class.

Parameters:
  • solver (BaseSolver) – The solver used to train, validate, and test the associated problem.

  • batch_size (int) – The number of samples per batch. If None, the entire dataset is processed as a single batch. Default is None.

  • train_size (float) – The fraction of samples assigned to the training split. Must belong to the interval [0, 1]. Default is 1.0.

  • val_size (float) – The fraction of samples assigned to the validation split. Must belong to the interval [0, 1]. Default is 0.0.

  • test_size (float) – The fraction of samples assigned to the test split. Must belong to the interval [0, 1]. Default is 0.0.

  • batching_mode (str) – The strategy used to aggregate batches across dataloaders. Available options are "common_batch_size" for uniform batch sizes across conditions, "proportional" for batch sizes proportional to dataset sizes, and "separate_conditions" for iterating through each condition separately. Default is "common_batch_size".

  • automatic_batching (bool) – Whether PyTorch automatic batching should be enabled. If True, dataset elements are retrieved individually and collated into batches by the dataloader. If False, entire subsets are retrieved directly from the condition object. Default is False.

  • num_workers (int) – The number of worker processes used by dataloaders. Default is 0 for sequential loading.

  • pin_memory (bool) – Whether pinned memory should be enabled during data loading. Default is False.

  • shuffle (bool) – Whether condition samples should be shuffled before splitting. Default is True.

  • kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer.

Raises:
  • ValueError – If solver is not a PINA solver.

  • ValueError – If train_size, val_size, or test_size is not a float in the interval [0, 1].

  • ValueError – If the sum of train_size, val_size, and test_size is not equal to 1.

  • ValueError – If automatic_batching, pin_memory, or shuffle is not a boolean.

  • AssertionError – If num_workers is a negative integer.

  • ValueError – If batch_size, when provided, is not a positive integer.

  • ValueError – If batching_mode is not one of the available options.

  • UserWarning – If the provided batching_mode is incompatible with the batch_size.

  • RuntimeError – If any domain in the problem has not been discretised.

train(**kwargs)[source]#

Fit the solver using the trainer data module.

Parameters:

kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer fit method.

Returns:

Result returned by Lightning’s fit method.

Return type:

Any

test(**kwargs)[source]#

Test the solver using the trainer data module.

Parameters:

kwargs (dict) – Additional keyword arguments forwarded to the Lightning trainer test method.

Returns:

Result returned by Lightning’s test method.

Return type:

Any

property solver#

Return the solver attached to the trainer.

Returns:

The solver used by the trainer.

Return type:

BaseSolver