MultiSolverInterface#
- class MultiSolverInterface(problem, models, optimizers=None, schedulers=None, weighting=None, use_lt=True)[source]#
Bases:
SolverInterface
Base class for PINA solvers using multiple
torch.nn.Module
.Initialization of the
MultiSolverInterface
class.- Parameters:
problem (AbstractProblem) – The problem to be solved.
models – The neural network models to be used.
optimizers (list[Optimizer]) – The optimizers to be used. If
None
, thetorch.optim.Adam
optimizer is used for all models. Default isNone
.schedulers (list[Scheduler]) – The schedulers to be used. If
None
, thetorch.optim.lr_scheduler.ConstantLR
scheduler is used for all the models. Default isNone
.weighting (WeightingInterface) – The weighting schema to be used. If
None
, no weighting schema is used. Default isNone
.use_lt (bool) – If
True
, the solver uses LabelTensors as input.
- Raises:
ValueError – If the models are not a list or tuple with length greater than one.
Warning
MultiSolverInterface
uses manual optimization by settingautomatic_optimization=False
inLightningModule
. For more information on manual optimization please see here.- on_train_batch_end(outputs, batch, batch_idx)[source]#
This method is called at the end of each training batch and overrides the PyTorch Lightning implementation to log checkpoints.
- property models#
The models used for training.
- Returns:
The models used for training.
- Return type:
- property optimizers#
The optimizers used for training.