SupervisedSolver#
- class SupervisedSolver(problem, model, loss=None, optimizer=None, scheduler=None, weighting=None, use_lt=True)[source]#
Bases:
SupervisedSolverInterface
,SingleSolverInterface
Supervised Solver solver class. This class implements a Supervised Solver, using a user specified
model
to solve a specificproblem
.The Supervised Solver class aims to find a map between the input \(\mathbf{s}:\Omega\rightarrow\mathbb{R}^m\) and the output \(\mathbf{u}:\Omega\rightarrow\mathbb{R}^m\).
Given a model \(\mathcal{M}\), the following loss function is minimized during training:
\[\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{s}_i)),\]where \(\mathcal{L}\) is a specific loss function, typically the MSE:
\[\mathcal{L}(v) = \| v \|^2_2.\]In this context, \(\mathbf{u}_i\) and \(\mathbf{s}_i\) indicates the will to approximate multiple (discretised) functions given multiple (discretised) input functions.
Initialization of the
SupervisedSolver
class.- Parameters:
problem (AbstractProblem) – The problem to be solved.
model (torch.nn.Module) – The neural network model to be used.
loss (torch.nn.Module) – The loss function to be minimized. If
None
, thetorch.nn.MSELoss
loss is used. Default isNone
.optimizer (Optimizer) – The optimizer to be used. If
None
, thetorch.optim.Adam
optimizer is used. Default isNone
.scheduler (Scheduler) – Learning rate scheduler. If
None
, thetorch.optim.lr_scheduler.ConstantLR
scheduler is used. Default isNone
.weighting (WeightingInterface) – The weighting schema to be used. If
None
, no weighting schema is used. Default isNone
.use_lt (bool) – If
True
, the solver uses LabelTensors as input. Default isTrue
.
- loss_data(input, target)[source]#
Compute the data loss for the Supervised solver by evaluating the loss between the network’s output and the true solution. This method should not be overridden, if not intentionally.
- Parameters:
input (LabelTensor | torch.Tensor | Graph | Data) – The input to the neural network.
target (LabelTensor | torch.Tensor | Graph | Data) – The target to compare with the network’s output.
- Returns:
The supervised loss, averaged over the number of observations.
- Return type:
LabelTensor | torch.Tensor | Graph | Data