MultiSolverInterface#
- class MultiSolverInterface(problem, models, optimizers=None, schedulers=None, weighting=None, use_lt=True)[source]#
Bases:
SolverInterfaceBase class for PINA solvers using multiple
torch.nn.Module.Initialization of the
MultiSolverInterfaceclass.- Parameters:
problem (AbstractProblem) – The problem to be solved.
models – The neural network models to be used.
optimizers (list[Optimizer]) – The optimizers to be used. If
None, thetorch.optim.Adamoptimizer is used for all models. Default isNone.schedulers (list[Scheduler]) – The schedulers to be used. If
None, thetorch.optim.lr_scheduler.ConstantLRscheduler is used for all the models. Default isNone.weighting (WeightingInterface) – The weighting schema to be used. If
None, no weighting schema is used. Default isNone.use_lt (bool) – If
True, the solver uses LabelTensors as input.
- Raises:
ValueError – If the models are not a list or tuple with length greater than one.
Warning
MultiSolverInterfaceuses manual optimization by settingautomatic_optimization=FalseinLightningModule. For more information on manual optimization please see here.- on_train_batch_end(outputs, batch, batch_idx)[source]#
This method is called at the end of each training batch and overrides the PyTorch Lightning implementation to log checkpoints.
- property models#
The models used for training.
- Returns:
The models used for training.
- Return type:
- property optimizers#
The optimizers used for training.