athena.nll.NonlinearLevelSet.train¶
-
NonlinearLevelSet.
train
(inputs, gradients, outputs=None, interactive=False, target_loss=0.0001, optim_args=None, scheduler_args=None)[source]¶ Train the whole RevNet.
- Parameters
inputs (torch.Tensor) – DoubleTensor n_samples-by-n_params containing the points in the full input space.
gradients (torch.Tensor) – DoubleTensor n_samples-by-n_params containing the gradient samples wrt the input parameters.
outputs (numpy.ndarray) – array n_samples-by-1 containing the corresponding function evaluations. Needed only for the interactive mode. Default is None.
interactive (bool) – if True a plot with the loss function decay, and the sufficient summary plot will be showed and updated every 10 epochs, and at the last epoch. Default is False.
target_loss (float) – loss threshold. Default is 0.0001.
optim_args (dict) – dictionary passed to the optimizer.
scheduler_args (dict) – dictionary passed to the scheduler.
- Raises
ValueError: in interactive mode outputs must be provided for the sufficient summary plot.