athena.nll.NonlinearLevelSet.train

NonlinearLevelSet.train(inputs, gradients, outputs=None, interactive=False, target_loss=0.0001, optim_args=None, scheduler_args=None)[source]

Train the whole RevNet.

Parameters
  • inputs (torch.Tensor) – DoubleTensor n_samples-by-n_params containing the points in the full input space.

  • gradients (torch.Tensor) – DoubleTensor n_samples-by-n_params containing the gradient samples wrt the input parameters.

  • outputs (numpy.ndarray) – array n_samples-by-1 containing the corresponding function evaluations. Needed only for the interactive mode. Default is None.

  • interactive (bool) – if True a plot with the loss function decay, and the sufficient summary plot will be showed and updated every 10 epochs, and at the last epoch. Default is False.

  • target_loss (float) – loss threshold. Default is 0.0001.

  • optim_args (dict) – dictionary passed to the optimizer.

  • scheduler_args (dict) – dictionary passed to the scheduler.

Raises

ValueError: in interactive mode outputs must be provided for the sufficient summary plot.