Multi-Model Mixin#

Module for the multi-model mixin class.

class MultiModelMixin[source]

Bases: object

Mixin that defines the forward pass and optimizer configuration for solvers backed by multiple models. Provides properties to access the models, optimizers, and schedulers.

Designed to be used in combination with any solver inheriting from BaseSolver.

forward(x)[source]

The forward pass implementation that evaluates all models and returns a stacked tensor of their outputs.

Parameters:

x (torch.Tensor | LabelTensor | Data | Graph) – The input data.

Returns:

The output of all models stacked together.

Return type:

torch.Tensor | LabelTensor | Data | Graph

configure_optimizers()[source]

Configure the optimizers and schedulers for all models.

Returns:

The optimizer and the scheduler

Return type:

tuple[list[TorchOptimizer], list[TorchScheduler]]

property models

The models used by the solver.

Returns:

The models used by the solver.

Return type:

list[torch.nn.Module]

property optimizers

The optimizers used by the solver.

Returns:

The optimizers used by the solver.

Return type:

list[TorchOptimizer]

property schedulers

The schedulers used by the solver.

Returns:

The schedulers used by the solver.

Return type:

list[TorchScheduler]

property num_models

The number of models used by the solver.

Returns:

The number of models used by the solver.

Return type:

int