GAROM#
- class GAROM(problem, generator, discriminator, loss=None, optimizer_generator=None, optimizer_discriminator=None, scheduler_generator=None, scheduler_discriminator=None, gamma=0.3, lambda_k=0.001, regularizer=False)[source]#
Bases:
MultiSolverInterface
GAROM solver class. This class implements Generative Adversarial Reduced Order Model solver, using user specified
models
to solve a specific order reductionproblem
.See also
Original reference: Coscia, D., Demo, N., & Rozza, G. (2023). Generative Adversarial Reduced Order Modelling. DOI: arXiv preprint arXiv:2305.15881..
Initialization of the
GAROM
class.- Parameters:
problem (AbstractProblem) – The formulation of the problem.
generator (torch.nn.Module) – The generator model.
discriminator (torch.nn.Module) – The discriminator model.
loss (torch.nn.Module) – The loss function to be minimized. If
None
,PowerLoss
withp=1
is used. Default isNone
.optimizer_generator (Optimizer) – The optimizer for the generator. If
None
, thetorch.optim.Adam
optimizer is used. Default isNone
.optimizer_discriminator (Optimizer) – The optimizer for the discriminator. If
None
, thetorch.optim.Adam
optimizer is used. Default isNone
.scheduler_generator (Scheduler) – The learning rate scheduler for the generator. If
None
, thetorch.optim.lr_scheduler.ConstantLR
scheduler is used. Default isNone
.scheduler_discriminator (Scheduler) – The learning rate scheduler for the discriminator. If
None
, thetorch.optim.lr_scheduler.ConstantLR
scheduler is used. Default isNone
.gamma (float) – Ratio of expected loss for generator and discriminator. Default is
0.3
.lambda_k (float) – Learning rate for control theory optimization. Default is
0.001
.regularizer (bool) – If
True
, uses a regularization term in the GAROM loss. Default isFalse
.
- accepted_conditions_types#
alias of
InputTargetCondition
- forward(x, mc_steps=20, variance=False)[source]#
Forward pass implementation.
- Parameters:
x (torch.Tensor) – The input tensor.
mc_steps (int) – Number of Montecarlo samples to approximate the expected value. Default is
20
.variance (bool) – If
True
, the method returns also the variance of the solution. Default isFalse
.
- Returns:
The expected value of the generator distribution. If
variance=True
, the method returns also the variance.- Return type:
- sample(x)[source]#
Sample from the generator distribution.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The generated sample.
- Return type:
- on_train_batch_end(outputs, batch, batch_idx)[source]#
This method is called at the end of each training batch and overrides the PyTorch Lightning implementation to log checkpoints.
- optimization_cycle(batch)[source]#
The optimization cycle for the GAROM solver.
- Parameters:
batch (list[tuple[str, dict]]) – A batch of data. Each element is a tuple containing a condition name and a dictionary of points.
- Returns:
The losses computed for all conditions in the batch, casted to a subclass of
torch.Tensor
. It should return a dict containing the condition name and the associated scalar loss.- Return type:
- validation_step(batch)[source]#
The validation step for the PINN solver.
- test_step(batch)[source]#
The test step for the PINN solver.
- property generator#
The generator model.
- Returns:
The generator model.
- Return type:
- property discriminator#
The discriminator model.
- Returns:
The discriminator model.
- Return type:
- property optimizer_generator#
The optimizer for the generator.
- Returns:
The optimizer for the generator.
- Return type:
- property optimizer_discriminator#
The optimizer for the discriminator.
- Returns:
The optimizer for the discriminator.
- Return type:
- property scheduler_generator#
The scheduler for the generator.
- Returns:
The scheduler for the generator.
- Return type: