[docs]classSolverInterface(lightning.pytorch.LightningModule,metaclass=ABCMeta):""" Abstract base class for PINA solvers. All specific solvers should inherit from this interface. This class is a wrapper of :class:`~lightning.pytorch.LightningModule`. """def__init__(self,problem,weighting,use_lt):""" Initialization of the :class:`SolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param WeightingInterface weighting: The weighting schema to be used. If `None`, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """super().__init__()# check consistency of the problemcheck_consistency(problem,AbstractProblem)self._check_solver_consistency(problem)self._pina_problem=problem# check consistency of the weighting and hook the condition namesifweightingisNone:weighting=_NoWeighting()check_consistency(weighting,WeightingInterface)self._pina_weighting=weightingweighting.condition_names=list(self._pina_problem.conditions.keys())# check consistency use_ltcheck_consistency(use_lt,bool)self._use_lt=use_lt# if use_lt is true add extract operation in inputifuse_ltisTrue:self.forward=labelize_forward(forward=self.forward,input_variables=problem.input_variables,output_variables=problem.output_variables,)# PINA private attributes (some are overridden by derived classes)self._pina_problem=problemself._pina_models=Noneself._pina_optimizers=Noneself._pina_schedulers=Nonedef_check_solver_consistency(self,problem):""" Check the consistency of the solver with the problem formulation. :param AbstractProblem problem: The problem to be solved. """forconditioninproblem.conditions.values():check_consistency(condition,self.accepted_conditions_types)def_optimization_cycle(self,batch):""" Aggregate the loss for each condition in the batch. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The losses computed for all conditions in the batch, casted to a subclass of :class:`torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict """losses=self.optimization_cycle(batch)forname,valueinlosses.items():self.store_log(f"{name}_loss",value.item(),self.get_batch_size(batch))loss=self.weighting.aggregate(losses).as_subclass(torch.Tensor)returnloss
[docs]deftraining_step(self,batch):""" Solver training step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The loss of the training step. :rtype: LabelTensor """loss=self._optimization_cycle(batch=batch)self.store_log("train_loss",loss,self.get_batch_size(batch))returnloss
[docs]defvalidation_step(self,batch):""" Solver validation step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. """loss=self._optimization_cycle(batch=batch)self.store_log("val_loss",loss,self.get_batch_size(batch))
[docs]deftest_step(self,batch):""" Solver test step. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. """loss=self._optimization_cycle(batch=batch)self.store_log("test_loss",loss,self.get_batch_size(batch))
[docs]defstore_log(self,name,value,batch_size):""" Store the log of the solver. :param str name: The name of the log. :param torch.Tensor value: The value of the log. :param int batch_size: The size of the batch. """self.log(name=name,value=value,batch_size=batch_size,**self.trainer.logging_kwargs,)
[docs]@abstractmethoddefforward(self,*args,**kwargs):""" Abstract method for the forward pass implementation. :param args: The input tensor. :type args: torch.Tensor | LabelTensor :param dict kwargs: Additional keyword arguments. """
[docs]@abstractmethoddefoptimization_cycle(self,batch):""" The optimization cycle for the solvers. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The losses computed for all conditions in the batch, casted to a subclass of :class:`torch.Tensor`. It should return a dict containing the condition name and the associated scalar loss. :rtype: dict """
@propertydefproblem(self):""" The problem instance. :return: The problem instance. :rtype: :class:`~pina.problem.abstract_problem.AbstractProblem` """returnself._pina_problem@propertydefuse_lt(self):""" Using LabelTensors as input during training. :return: The use_lt attribute. :rtype: bool """returnself._use_lt@propertydefweighting(self):""" The weighting schema. :return: The weighting schema. :rtype: :class:`~pina.loss.weighting_interface.WeightingInterface` """returnself._pina_weighting
[docs]@staticmethoddefget_batch_size(batch):""" Get the batch size. :param list[tuple[str, dict]] batch: A batch of data. Each element is a tuple containing a condition name and a dictionary of points. :return: The size of the batch. :rtype: int """batch_size=0fordatainbatch:batch_size+=len(data[1]["input"])returnbatch_size
[docs]@staticmethoddefdefault_torch_optimizer():""" Set the default optimizer to :class:`torch.optim.Adam`. :return: The default optimizer. :rtype: Optimizer """returnTorchOptimizer(torch.optim.Adam,lr=0.001)
[docs]@staticmethoddefdefault_torch_scheduler():""" Set the default scheduler to :class:`torch.optim.lr_scheduler.ConstantLR`. :return: The default scheduler. :rtype: Scheduler """returnTorchScheduler(torch.optim.lr_scheduler.ConstantLR)
[docs]defon_train_start(self):""" This method is called at the start of the training process to compile the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. """super().on_train_start()ifself.trainer.compile:self._compile_model()
[docs]defon_test_start(self):""" This method is called at the start of the test process to compile the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. """super().on_train_start()ifself.trainer.compileandnotself._check_already_compiled():self._compile_model()
def_check_already_compiled(self):""" Check if the model is already compiled. :return: ``True`` if the model is already compiled, ``False`` otherwise. :rtype: bool """models=self._pina_modelsiflen(models)==1andisinstance(self._pina_models[0],torch.nn.ModuleDict):models=list(self._pina_models.values())formodelinmodels:ifnotisinstance(model,(OptimizedModule,torch.nn.ModuleDict)):returnFalsereturnTrue@staticmethoddef_perform_compilation(model):""" Perform the compilation of the model. :param torch.nn.Module model: The model to compile. :raises Exception: If the compilation fails. :return: The compiled model. :rtype: torch.nn.Module """model_device=next(model.parameters()).devicetry:ifmodel_device==torch.device("mps:0"):model=torch.compile(model,backend="eager")else:model=torch.compile(model,backend="inductor")exceptExceptionase:print("Compilation failed, running in normal mode.:\n",e)returnmodel
[docs]classSingleSolverInterface(SolverInterface,metaclass=ABCMeta):""" Base class for PINA solvers using a single :class:`torch.nn.Module`. """def__init__(self,problem,model,optimizer=None,scheduler=None,weighting=None,use_lt=True,):""" Initialization of the :class:`SingleSolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The neural network model to be used. :param Optimizer optimizer: The optimizer to be used. If `None`, the :class:`torch.optim.Adam` optimizer is used. Default is ``None``. :param Scheduler scheduler: The scheduler to be used. If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. If `None`, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ifoptimizerisNone:optimizer=self.default_torch_optimizer()ifschedulerisNone:scheduler=self.default_torch_scheduler()super().__init__(problem=problem,use_lt=use_lt,weighting=weighting)# check consistency of models argument and encapsulate in listcheck_consistency(model,torch.nn.Module)# check scheduler consistency and encapsulate in listcheck_consistency(scheduler,Scheduler)# check optimizer consistency and encapsulate in listcheck_consistency(optimizer,Optimizer)# initialize the model (needed by Lightining to go to different devices)self._pina_models=torch.nn.ModuleList([model])self._pina_optimizers=[optimizer]self._pina_schedulers=[scheduler]
[docs]defconfigure_optimizers(self):""" Optimizer configuration for the solver. :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """self.optimizer.hook(self.model.parameters())self.scheduler.hook(self.optimizer)return([self.optimizer.instance],[self.scheduler.instance])
def_compile_model(self):""" Compile the model. """ifisinstance(self._pina_models[0],torch.nn.ModuleDict):self._compile_module_dict()else:self._compile_single_model()def_compile_module_dict(self):""" Compile the model if it is a :class:`torch.nn.ModuleDict`. """forname,modelinself._pina_models[0].items():self._pina_models[0][name]=self._perform_compilation(model)def_compile_single_model(self):""" Compile the model if it is a single :class:`torch.nn.Module`. """self._pina_models[0]=self._perform_compilation(self._pina_models[0])@propertydefmodel(self):""" The model used for training. :return: The model used for training. :rtype: torch.nn.Module """returnself._pina_models[0]@propertydefscheduler(self):""" The scheduler used for training. :return: The scheduler used for training. :rtype: Scheduler """returnself._pina_schedulers[0]@propertydefoptimizer(self):""" The optimizer used for training. :return: The optimizer used for training. :rtype: Optimizer """returnself._pina_optimizers[0]
[docs]classMultiSolverInterface(SolverInterface,metaclass=ABCMeta):""" Base class for PINA solvers using multiple :class:`torch.nn.Module`. """def__init__(self,problem,models,optimizers=None,schedulers=None,weighting=None,use_lt=True,):""" Initialization of the :class:`MultiSolverInterface` class. :param AbstractProblem problem: The problem to be solved. :param models: The neural network models to be used. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] :param list[Optimizer] optimizers: The optimizers to be used. If `None`, the :class:`torch.optim.Adam` optimizer is used for all models. Default is ``None``. :param list[Scheduler] schedulers: The schedulers to be used. If `None`, the :class:`torch.optim.lr_scheduler.ConstantLR` scheduler is used for all the models. Default is ``None``. :param WeightingInterface weighting: The weighting schema to be used. If `None`, no weighting schema is used. Default is ``None``. :param bool use_lt: If ``True``, the solver uses LabelTensors as input. :raises ValueError: If the models are not a list or tuple with length greater than one. """ifnotisinstance(models,(list,tuple))orlen(models)<2:raiseValueError("models should be list[torch.nn.Module] or ""tuple[torch.nn.Module] with len greater than ""one.")ifany(optisNoneforoptinoptimizers):optimizers=[self.default_torch_optimizer()ifoptisNoneelseoptforoptinoptimizers]ifany(schedisNoneforschedinschedulers):schedulers=[self.default_torch_scheduler()ifschedisNoneelseschedforschedinschedulers]super().__init__(problem=problem,use_lt=use_lt,weighting=weighting)# check consistency of models argument and encapsulate in listcheck_consistency(models,torch.nn.Module)# check scheduler consistency and encapsulate in listcheck_consistency(schedulers,Scheduler)# check optimizer consistency and encapsulate in listcheck_consistency(optimizers,Optimizer)# check length consistency optimizersiflen(models)!=len(optimizers):raiseValueError("You must define one optimizer for each model."f"Got {len(models)} models, and {len(optimizers)}"" optimizers.")# initialize the modelself._pina_models=torch.nn.ModuleList(models)self._pina_optimizers=optimizersself._pina_schedulers=schedulers
[docs]defconfigure_optimizers(self):""" Optimizer configuration for the solver. :return: The optimizer and the scheduler :rtype: tuple[list[Optimizer], list[Scheduler]] """foroptimizer,scheduler,modelinzip(self.optimizers,self.schedulers,self.models):optimizer.hook(model.parameters())scheduler.hook(optimizer)return([optimizer.instanceforoptimizerinself.optimizers],[scheduler.instanceforschedulerinself.schedulers],)
def_compile_model(self):""" Compile the model. """fori,modelinenumerate(self._pina_models):ifnotisinstance(model,torch.nn.ModuleDict):self._pina_models[i]=self._perform_compilation(model)@propertydefmodels(self):""" The models used for training. :return: The models used for training. :rtype: torch.nn.ModuleList """returnself._pina_models@propertydefoptimizers(self):""" The optimizers used for training. :return: The optimizers used for training. :rtype: list[Optimizer] """returnself._pina_optimizers@propertydefschedulers(self):""" The schedulers used for training. :return: The schedulers used for training. :rtype: list[Scheduler] """returnself._pina_schedulers