SolverInterface#

class SolverInterface(problem, weighting, use_lt)[source]#

Bases: LightningModule

Abstract base class for PINA solvers. All specific solvers must inherit from this interface. This class extends LightningModule, providing additional functionalities for defining and optimizing Deep Learning models.

By inheriting from this base class, solvers gain access to built-in training loops, logging utilities, and optimization techniques.

Initialization of the SolverInterface class.

Parameters:
  • problem (AbstractProblem) – The problem to be solved.

  • weighting (WeightingInterface) – The weighting schema to be used. If None, no weighting schema is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input.

abstract forward(*args, **kwargs)[source]#

Abstract method for the forward pass implementation.

Parameters:
abstract optimization_cycle(batch)[source]#

The optimization cycle for the solvers.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The losses computed for all conditions in the batch, casted to a subclass of torch.Tensor. It should return a dict containing the condition name and the associated scalar loss.

Return type:

dict

training_step(batch, **kwargs)[source]#

Solver training step. It computes the optimization cycle and aggregates the losses using the weighting attribute.

Parameters:
  • batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

  • kwargs (dict) – Additional keyword arguments passed to optimization_cycle.

Returns:

The loss of the training step.

Return type:

torch.Tensor

validation_step(batch, **kwargs)[source]#

Solver validation step. It computes the optimization cycle and averages the losses. No aggregation using the weighting attribute is performed.

Parameters:
  • batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

  • kwargs (dict) – Additional keyword arguments passed to optimization_cycle.

Returns:

The loss of the training step.

Return type:

torch.Tensor

test_step(batch, **kwargs)[source]#

Solver test step. It computes the optimization cycle and averages the losses. No aggregation using the weighting attribute is performed.

Parameters:
  • batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

  • kwargs (dict) – Additional keyword arguments passed to optimization_cycle.

Returns:

The loss of the training step.

Return type:

torch.Tensor

store_log(name, value, batch_size)[source]#

Store the log of the solver.

Parameters:
  • name (str) – The name of the log.

  • value (torch.Tensor) – The value of the log.

  • batch_size (int) – The size of the batch.

setup(stage)[source]#

This method is called at the start of the train and test process to compile the model if the Trainer compile is True.

static get_batch_size(batch)[source]#

Get the batch size.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The size of the batch.

Return type:

int

static default_torch_optimizer()[source]#

Set the default optimizer to torch.optim.Adam.

Returns:

The default optimizer.

Return type:

Optimizer

static default_torch_scheduler()[source]#

Set the default scheduler to torch.optim.lr_scheduler.ConstantLR.

Returns:

The default scheduler.

Return type:

Scheduler

property problem#

The problem instance.

Returns:

The problem instance.

Return type:

AbstractProblem

property use_lt#

Using LabelTensors as input during training.

Returns:

The use_lt attribute.

Return type:

bool

property weighting#

The weighting schema.

Returns:

The weighting schema.

Return type:

WeightingInterface