""" Module for CompetitivePINN """
import torch
import copy
try:
from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0
except ImportError:
from torch.optim.lr_scheduler import (
_LRScheduler as LRScheduler,
) # torch < 2.0
from torch.optim.lr_scheduler import ConstantLR
from .basepinn import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem
[docs]
class CompetitivePINN(PINNInterface):
r"""
Competitive Physics Informed Neural Network (PINN) solver class.
This class implements Competitive Physics Informed Neural
Network solvers, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
The Competitive Physics Informed Network aims to find
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
of the differential problem:
.. math::
\begin{cases}
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
\mathbf{x}\in\partial\Omega
\end{cases}
with a minimization (on ``model`` parameters) maximation (
on ``discriminator`` parameters) of the loss function
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
\frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
where :math:`D` is the discriminator network, which tries to find the points
where the network performs worst, and :math:`\mathcal{L}` is a specific loss
function, default Mean Square Error:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
.. seealso::
**Original reference**: Zeng, Qi, et al.
"Competitive physics informed networks." International Conference on
Learning Representations, ICLR 2022
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
.. warning::
This solver does not currently support the possibility to pass
``extra_feature``.
"""
def __init__(
self,
problem,
model,
discriminator=None,
loss=torch.nn.MSELoss(),
optimizer_model=torch.optim.Adam,
optimizer_model_kwargs={"lr": 0.001},
optimizer_discriminator=torch.optim.Adam,
optimizer_discriminator_kwargs={"lr": 0.001},
scheduler_model=ConstantLR,
scheduler_model_kwargs={"factor": 1, "total_iters": 0},
scheduler_discriminator=ConstantLR,
scheduler_discriminator_kwargs={"factor": 1, "total_iters": 0},
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use
for the model.
:param torch.nn.Module discriminator: The neural network model to use
for the discriminator. If ``None``, the discriminator network will
have the same architecture as the model network.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer_model: The neural
network optimizer to use for the model network
, default is `torch.optim.Adam`.
:param dict optimizer_model_kwargs: Optimizer constructor keyword
args. for the model.
:param torch.optim.Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param dict optimizer_discriminator_kwargs: Optimizer constructor
keyword args. for the discriminator.
:param torch.optim.LRScheduler scheduler_model: Learning
rate scheduler for the model.
:param dict scheduler_model_kwargs: LR scheduler constructor
keyword args.
:param torch.optim.LRScheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
"""
if discriminator is None:
discriminator = copy.deepcopy(model)
super().__init__(
models=[model, discriminator],
problem=problem,
optimizers=[optimizer_model, optimizer_discriminator],
optimizers_kwargs=[
optimizer_model_kwargs,
optimizer_discriminator_kwargs,
],
extra_features=None, # CompetitivePINN doesn't take extra features
loss=loss,
)
# set automatic optimization for GANs
self.automatic_optimization = False
# check consistency
check_consistency(scheduler_model, LRScheduler, subclass=True)
check_consistency(scheduler_model_kwargs, dict)
check_consistency(scheduler_discriminator, LRScheduler, subclass=True)
check_consistency(scheduler_discriminator_kwargs, dict)
# assign schedulers
self._schedulers = [
scheduler_model(self.optimizers[0], **scheduler_model_kwargs),
scheduler_discriminator(
self.optimizers[1], **scheduler_discriminator_kwargs
),
]
self._model = self.models[0]
self._discriminator = self.models[1]
[docs]
def forward(self, x):
r"""
Forward pass implementation for the PINN solver. It returns the function
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
:math:`\mathbf{x}`.
:param LabelTensor x: Input tensor for the PINN solver. It expects
a tensor :math:`N \times D`, where :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem,
:return: PINN solution evaluated at contro points.
:rtype: LabelTensor
"""
return self.neural_net(x)
[docs]
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the Competitive PINN solver based on given
samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor
"""
# train one step of the model
with torch.no_grad():
discriminator_bets = self.discriminator(samples)
loss_val = self._train_model(samples, equation, discriminator_bets)
self.store_log(loss_value=float(loss_val))
# detaching samples from the computational graph to erase it and setting
# the gradient to true to create a new computational graph.
# In alternative set `retain_graph=True`.
samples = samples.detach()
samples.requires_grad = True
# train one step of discriminator
discriminator_bets = self.discriminator(samples)
self._train_discriminator(samples, equation, discriminator_bets)
return loss_val
[docs]
def loss_data(self, input_tensor, output_tensor):
"""
The data loss for the PINN solver. It computes the loss between the
network output against the true solution.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
network solution.
:return: The computed data loss.
:rtype: torch.Tensor
"""
self.optimizer_model.zero_grad()
loss_val = (
super()
.loss_data(input_tensor, output_tensor)
.as_subclass(torch.Tensor)
)
loss_val.backward()
self.optimizer_model.step()
return loss_val
[docs]
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch, and ovverides
the PytorchLightining implementation for logging the checkpoints.
:param torch.Tensor outputs: The output from the model for the
current batch.
:param tuple batch: The current batch of data.
:param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += (
1
)
return super().on_train_batch_end(outputs, batch, batch_idx)
def _train_discriminator(self, samples, equation, discriminator_bets):
"""
Trains the discriminator network of the Competitive PINN.
:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator
network.
"""
# manual optimization
self.optimizer_discriminator.zero_grad()
# compute residual, we detach because the weights of the generator
# model are fixed
residual = self.compute_residual(
samples=samples, equation=equation
).detach()
# compute competitive residual, the minus is because we maximise
competitive_residual = residual * discriminator_bets
loss_val = -self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
).as_subclass(torch.Tensor)
# backprop
self.manual_backward(loss_val)
self.optimizer_discriminator.step()
return
def _train_model(self, samples, equation, discriminator_bets):
"""
Trains the model network of the Competitive PINN.
:param LabelTensor samples: Input samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation representing
the physics.
:param Tensor discriminator_bets: Predictions made by the discriminator.
network.
:return: The computed data loss.
:rtype: torch.Tensor
"""
# manual optimization
self.optimizer_model.zero_grad()
# compute residual (detached for discriminator) and log
residual = self.compute_residual(samples=samples, equation=equation)
# store logging
with torch.no_grad():
loss_residual = self.loss(torch.zeros_like(residual), residual)
# compute competitive residual, discriminator_bets are detached becase
# we optimize only the generator model
competitive_residual = residual * discriminator_bets.detach()
loss_val = self.loss(
torch.zeros_like(competitive_residual, requires_grad=True),
competitive_residual,
).as_subclass(torch.Tensor)
# backprop
self.manual_backward(loss_val)
self.optimizer_model.step()
return loss_residual
@property
def neural_net(self):
"""
Returns the neural network model.
:return: The neural network model.
:rtype: torch.nn.Module
"""
return self._model
@property
def discriminator(self):
"""
Returns the discriminator model (if applicable).
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self._discriminator
@property
def optimizer_model(self):
"""
Returns the optimizer associated with the neural network model.
:return: The optimizer for the neural network model.
:rtype: torch.optim.Optimizer
"""
return self.optimizers[0]
@property
def optimizer_discriminator(self):
"""
Returns the optimizer associated with the discriminator (if applicable).
:return: The optimizer for the discriminator.
:rtype: torch.optim.Optimizer
"""
return self.optimizers[1]
@property
def scheduler_model(self):
"""
Returns the scheduler associated with the neural network model.
:return: The scheduler for the neural network model.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._schedulers[0]
@property
def scheduler_discriminator(self):
"""
Returns the scheduler associated with the discriminator (if applicable).
:return: The scheduler for the discriminator.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self._schedulers[1]