SolverInterface#

class SolverInterface(problem, weighting, use_lt)[source]#

Bases: LightningModule

Abstract base class for PINA solvers. All specific solvers should inherit from this interface. This class is a wrapper of LightningModule.

Initialization of the SolverInterface class.

Parameters:
  • problem (AbstractProblem) – The problem to be solved.

  • weighting (WeightingInterface) – The weighting schema to be used. If None, no weighting schema is used. Default is None.

  • use_lt (bool) – If True, the solver uses LabelTensors as input.

training_step(batch)[source]#

Solver training step.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The loss of the training step.

Return type:

LabelTensor

validation_step(batch)[source]#

Solver validation step.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

test_step(batch)[source]#

Solver test step.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

store_log(name, value, batch_size)[source]#

Store the log of the solver.

Parameters:
  • name (str) – The name of the log.

  • value (torch.Tensor) – The value of the log.

  • batch_size (int) – The size of the batch.

abstract forward(*args, **kwargs)[source]#

Abstract method for the forward pass implementation.

Parameters:
abstract optimization_cycle(batch)[source]#

The optimization cycle for the solvers.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The losses computed for all conditions in the batch, casted to a subclass of torch.Tensor. It should return a dict containing the condition name and the associated scalar loss.

Return type:

dict

property problem#

The problem instance.

Returns:

The problem instance.

Return type:

AbstractProblem

property use_lt#

Using LabelTensors as input during training.

Returns:

The use_lt attribute.

Return type:

bool

property weighting#

The weighting schema.

Returns:

The weighting schema.

Return type:

WeightingInterface

static get_batch_size(batch)[source]#

Get the batch size.

Parameters:

batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.

Returns:

The size of the batch.

Return type:

int

static default_torch_optimizer()[source]#

Set the default optimizer to torch.optim.Adam.

Returns:

The default optimizer.

Return type:

Optimizer

static default_torch_scheduler()[source]#

Set the default scheduler to torch.optim.lr_scheduler.ConstantLR.

Returns:

The default scheduler.

Return type:

Scheduler

on_train_start()[source]#

This method is called at the start of the training process to compile the model if the Trainer compile is True.

on_test_start()[source]#

This method is called at the start of the test process to compile the model if the Trainer compile is True.