SupervisedSolver#
- class SupervisedSolver(problem, model, loss=None, optimizer=None, scheduler=None, weighting=None, use_lt=True)[source]#
Bases:
SingleSolverInterface
Supervised Solver solver class. This class implements a Supervised Solver, using a user specified
model
to solve a specificproblem
.The Supervised Solver class aims to find a map between the input \(\mathbf{s}:\Omega\rightarrow\mathbb{R}^m\) and the output \(\mathbf{u}:\Omega\rightarrow\mathbb{R}^m\).
Given a model \(\mathcal{M}\), the following loss function is minimized during training:
\[\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),\]where \(\mathcal{L}\) is a specific loss function, typically the MSE:
\[\mathcal{L}(v) = \| v \|^2_2.\]In this context, \(\mathbf{u}_i\) and \(\mathbf{v}_i\) indicates the will to approximate multiple (discretised) functions given multiple (discretised) input functions.
Initialization of the
SupervisedSolver
class.- Parameters:
problem (AbstractProblem) – The problem to be solved.
model (torch.nn.Module) – The neural network model to be used.
loss (torch.nn.Module) – The loss function to be minimized. If
None
, thetorch.nn.MSELoss
loss is used. Default isNone
.optimizer (Optimizer) – The optimizer to be used. If
None
, thetorch.optim.Adam
optimizer is used. Default isNone
.scheduler (Scheduler) – Learning rate scheduler. If
None
, thetorch.optim.lr_scheduler.ConstantLR
scheduler is used. Default isNone
.weighting (WeightingInterface) – The weighting schema to be used. If
None
, no weighting schema is used. Default isNone
.use_lt (bool) – If
True
, the solver uses LabelTensors as input. Default isTrue
.
- accepted_conditions_types#
alias of
InputTargetCondition
- optimization_cycle(batch)[source]#
The optimization cycle for the solvers.
- Parameters:
batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.
- Returns:
The losses computed for all conditions in the batch, casted to a subclass of
torch.Tensor
. It should return a dict containing the condition name and the associated scalar loss.- Return type:
- loss_data(input_pts, output_pts)[source]#
Compute the data loss for the Supervised solver by evaluating the loss between the network’s output and the true solution. This method should not be overridden, if not intentionally.
- Parameters:
input_pts (LabelTensor | torch.Tensor) – The input points to the neural network.
output_pts (LabelTensor | torch.Tensor) – The true solution to compare with the network’s output.
- Returns:
The supervised loss, averaged over the number of observations.
- Return type:
- property loss#
The loss function to be minimized.
- Returns:
The loss function to be minimized.
- Return type: