Ensemble#
Generic neural network ensemble container designed to hold multiple model instances with shared architecture but independent parameters.
Key Points:
Provides batched parameter storage for efficient parallel forward passes using
torch.func.Supports expanding input tensors to match ensemble size.
Allows extraction of single ensemble members as standalone models.
Useful as a building block for model ensembles such as the
CriticEnsemble.
Note
Increasing the number of ensemble members improves robustness but increases computational cost.
Warning
Incorrect parameter extraction when retrieving single ensemble members can lead to inconsistent behavior.
Warning
Set sequential=True if the underlying ensemble members are stateful, e.g., use batch normalization. The vectorized version using parallel forward passes will not update such layers.
Here are the detailed methods and attributes.
- class objectrl.models.basic.ensemble.Ensemble(n_members: int, models: list[T], device: Literal['cpu', 'cuda'] = 'cpu', sequential: bool = False, compile: bool = False)[source]#
Bases:
Module,ABC,GenericA generic ensemble of neural networks This class allows for parallelizing the forward pass of multiple models while maintaining a consistent interface.
- n_members#
Number of members in the ensemble.
- Type:
int
- prototype#
Prototype model used to create new members.
- Type:
nn.Module
- device#
Device type for the ensemble (e.g., “cpu”, “cuda”).
- Type:
str
- params#
Stacked parameters of the ensemble members.
- Type:
dict[str, torch.Tensor]
- buffers#
Stacked buffers of the ensemble members.
- Type:
dict[str, torch.Tensor]
- base_model#
Base model structure for functional calls.
- Type:
nn.Module
- forward_model#
Vectorized function to call the model.
- Type:
torch.nn.functional
- sequential#
Whether the ensemble is sequential (necessary for stateful layers)
- Type:
bool
- __init__(n_members: int, models: list[T], device: Literal['cpu', 'cuda'] = 'cpu', sequential: bool = False, compile: bool = False) None[source]#
Initialize the ensemble
- Parameters:
n_members (int) – The number of members in the ensemble
models (list[nn.Module]) – List of models to parallelize
device (str) – The device to use
sequential (bool) – Whether the ensemble is sequential (necessary for stateful layers)
compile (bool) – Whether the ensemble is compiled or not
- Returns:
None
- forward(input: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.