SolverInterface#
- class SolverInterface(problem, weighting, use_lt)[source]#
Bases:
LightningModule
Abstract base class for PINA solvers. All specific solvers should inherit from this interface. This class is a wrapper of
LightningModule
.Initialization of the
SolverInterface
class.- Parameters:
problem (AbstractProblem) – The problem to be solved.
weighting (WeightingInterface) – The weighting schema to be used. If
None
, no weighting schema is used. Default isNone
.use_lt (bool) – If
True
, the solver uses LabelTensors as input.
- training_step(batch)[source]#
Solver training step.
- store_log(name, value, batch_size)[source]#
Store the log of the solver.
- Parameters:
name (str) – The name of the log.
value (torch.Tensor) – The value of the log.
batch_size (int) – The size of the batch.
- abstract forward(*args, **kwargs)[source]#
Abstract method for the forward pass implementation.
- Parameters:
args (torch.Tensor | LabelTensor) – The input tensor.
kwargs (dict) – Additional keyword arguments.
- abstract optimization_cycle(batch)[source]#
The optimization cycle for the solvers.
- Parameters:
batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.
- Returns:
The losses computed for all conditions in the batch, casted to a subclass of
torch.Tensor
. It should return a dict containing the condition name and the associated scalar loss.- Return type:
- property problem#
The problem instance.
- Returns:
The problem instance.
- Return type:
- property use_lt#
Using LabelTensors as input during training.
- Returns:
The use_lt attribute.
- Return type:
- property weighting#
The weighting schema.
- Returns:
The weighting schema.
- Return type:
- static default_torch_optimizer()[source]#
Set the default optimizer to
torch.optim.Adam
.- Returns:
The default optimizer.
- Return type:
- static default_torch_scheduler()[source]#
Set the default scheduler to
torch.optim.lr_scheduler.ConstantLR
.- Returns:
The default scheduler.
- Return type: