"""Module for the GAROM solver."""
import torch
from torch.nn.modules.loss import _Loss
from .solver import MultiSolverInterface
from ..condition import InputTargetCondition
from ..utils import check_consistency
from ..loss import LossInterface, PowerLoss
[docs]
class GAROM(MultiSolverInterface):
"""
GAROM solver class. This class implements Generative Adversarial Reduced
Order Model solver, using user specified ``models`` to solve a specific
order reduction ``problem``.
.. seealso::
**Original reference**: Coscia, D., Demo, N., & Rozza, G. (2023).
*Generative Adversarial Reduced Order Modelling*.
DOI: `arXiv preprint arXiv:2305.15881.
<https://doi.org/10.48550/arXiv.2305.15881>`_.
"""
accepted_conditions_types = InputTargetCondition
def __init__(
self,
problem,
generator,
discriminator,
loss=None,
optimizer_generator=None,
optimizer_discriminator=None,
scheduler_generator=None,
scheduler_discriminator=None,
gamma=0.3,
lambda_k=0.001,
regularizer=False,
):
"""
Initialization of the :class:`GAROM` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module generator: The generator model.
:param torch.nn.Module discriminator: The discriminator model.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, :class:`~pina.loss.power_loss.PowerLoss` with ``p=1``
is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the :class:`torch.optim.Adam` optimizer is used.
Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the :class:`torch.optim.Adam` optimizer is
used. Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR`
scheduler is used. Default is ``None``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
:param float lambda_k: Learning rate for control theory optimization.
Default is ``0.001``.
:param bool regularizer: If ``True``, uses a regularization term in the
GAROM loss. Default is ``False``.
"""
# set loss
if loss is None:
loss = PowerLoss(p=1)
super().__init__(
models=[generator, discriminator],
problem=problem,
optimizers=[optimizer_generator, optimizer_discriminator],
schedulers=[
scheduler_generator,
scheduler_discriminator,
],
use_lt=False,
)
# check consistency
check_consistency(
loss, (LossInterface, _Loss, torch.nn.Module), subclass=False
)
self._loss = loss
# set automatic optimization for GANs
self.automatic_optimization = False
# check consistency
check_consistency(gamma, float)
check_consistency(lambda_k, float)
check_consistency(regularizer, bool)
# began hyperparameters
self.k = 0
self.gamma = gamma
self.lambda_k = lambda_k
self.regularizer = float(regularizer)
[docs]
def forward(self, x, mc_steps=20, variance=False):
"""
Forward pass implementation.
:param torch.Tensor x: The input tensor.
:param int mc_steps: Number of Montecarlo samples to approximate the
expected value. Default is ``20``.
:param bool variance: If ``True``, the method returns also the variance
of the solution. Default is ``False``.
:return: The expected value of the generator distribution. If
``variance=True``, the method returns also the variance.
:rtype: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
"""
# sampling
field_sample = [self.sample(x) for _ in range(mc_steps)]
field_sample = torch.stack(field_sample)
# extract mean
mean = field_sample.mean(dim=0)
if variance:
var = field_sample.var(dim=0)
return mean, var
return mean
[docs]
def sample(self, x):
"""
Sample from the generator distribution.
:param torch.Tensor x: The input tensor.
:return: The generated sample.
:rtype: torch.Tensor
"""
# sampling
return self.generator(x)
def _train_generator(self, parameters, snapshots):
"""
Train the generator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
generated_snapshots = self.sample(parameters)
# generator loss
r_loss = self._loss(snapshots, generated_snapshots)
d_fake = self.discriminator([generated_snapshots, parameters])
g_loss = (
self._loss(d_fake, generated_snapshots) + self.regularizer * r_loss
)
# backward step
g_loss.backward()
optimizer.step()
return r_loss, g_loss
[docs]
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
"""
# increase by one the counter of optimization to save loggers
(
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed
) += 1
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, parameters, snapshots):
"""
Train the discriminator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
# Generate a batch of images
generated_snapshots = self.sample(parameters)
# Discriminator pass
d_real = self.discriminator([snapshots, parameters])
d_fake = self.discriminator([generated_snapshots, parameters])
# evaluate loss
d_loss_real = self._loss(d_real, snapshots)
d_loss_fake = self._loss(d_fake, generated_snapshots.detach())
d_loss = d_loss_real - self.k * d_loss_fake
# backward step
d_loss.backward()
optimizer.step()
return d_loss_real, d_loss_fake, d_loss
def _update_weights(self, d_loss_real, d_loss_fake):
"""
Update the weights of the generator and discriminator models.
:param torch.Tensor d_loss_real: The discriminator loss computed on
dataset samples.
:param torch.Tensor d_loss_fake: The discriminator loss computed on
generated samples.
:return: The difference between the loss computed on the dataset samples
and the loss computed on the generated samples.
:rtype: torch.Tensor
"""
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
self.k += self.lambda_k * diff.item()
self.k = min(max(self.k, 0), 1) # Constraint to interval [0, 1]
return diff
[docs]
def optimization_cycle(self, batch):
"""
The optimization cycle for the GAROM solver.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input"],
points["target"],
)
d_loss_real, d_loss_fake, d_loss = self._train_discriminator(
parameters, snapshots
)
r_loss, g_loss = self._train_generator(parameters, snapshots)
diff = self._update_weights(d_loss_real, d_loss_fake)
condition_loss[condition_name] = r_loss
# some extra logging
self.store_log("d_loss", float(d_loss), self.get_batch_size(batch))
self.store_log("g_loss", float(g_loss), self.get_batch_size(batch))
self.store_log(
"stability_metric",
float(d_loss_real + torch.abs(diff)),
self.get_batch_size(batch),
)
return condition_loss
[docs]
def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss
[docs]
def test_step(self, batch):
"""
The test step for the PINN solver.
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
points["input"],
points["target"],
)
snapshots_gen = self.generator(parameters)
condition_loss[condition_name] = self._loss(
snapshots, snapshots_gen
)
loss = self.weighting.aggregate(condition_loss)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss
@property
def generator(self):
"""
The generator model.
:return: The generator model.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def discriminator(self):
"""
The discriminator model.
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self.models[1]
@property
def optimizer_generator(self):
"""
The optimizer for the generator.
:return: The optimizer for the generator.
:rtype: Optimizer
"""
return self.optimizers[0].instance
@property
def optimizer_discriminator(self):
"""
The optimizer for the discriminator.
:return: The optimizer for the discriminator.
:rtype: Optimizer
"""
return self.optimizers[1].instance
@property
def scheduler_generator(self):
"""
The scheduler for the generator.
:return: The scheduler for the generator.
:rtype: Scheduler
"""
return self.schedulers[0].instance
@property
def scheduler_discriminator(self):
"""
The scheduler for the discriminator.
:return: The scheduler for the discriminator.
:rtype: Scheduler
"""
return self.schedulers[1].instance