Ensemble#

Generic neural network ensemble container designed to hold multiple model instances with shared architecture but independent parameters.

Key Points:

  • Provides batched parameter storage for efficient parallel forward passes using torch.func.

  • Supports expanding input tensors to match ensemble size.

  • Allows extraction of single ensemble members as standalone models.

  • Useful as a building block for model ensembles such as the CriticEnsemble.

Note

Increasing the number of ensemble members improves robustness but increases computational cost.

Warning

Incorrect parameter extraction when retrieving single ensemble members can lead to inconsistent behavior.

Warning

Set sequential=True if the underlying ensemble members are stateful, e.g., use batch normalization. The vectorized version using parallel forward passes will not update such layers.

Here are the detailed methods and attributes.

class objectrl.models.basic.ensemble.Ensemble(n_members: int, models: list[T], device: Literal['cpu', 'cuda'] = 'cpu', sequential: bool = False, compile: bool = False)[source]#

Bases: Module, ABC, Generic

A generic ensemble of neural networks This class allows for parallelizing the forward pass of multiple models while maintaining a consistent interface.

n_members#

Number of members in the ensemble.

Type:

int

prototype#

Prototype model used to create new members.

Type:

nn.Module

device#

Device type for the ensemble (e.g., “cpu”, “cuda”).

Type:

str

params#

Stacked parameters of the ensemble members.

Type:

dict[str, torch.Tensor]

buffers#

Stacked buffers of the ensemble members.

Type:

dict[str, torch.Tensor]

base_model#

Base model structure for functional calls.

Type:

nn.Module

forward_model#

Vectorized function to call the model.

Type:

torch.nn.functional

sequential#

Whether the ensemble is sequential (necessary for stateful layers)

Type:

bool

__init__(n_members: int, models: list[T], device: Literal['cpu', 'cuda'] = 'cpu', sequential: bool = False, compile: bool = False) None[source]#

Initialize the ensemble

Parameters:
  • n_members (int) – The number of members in the ensemble

  • models (list[nn.Module]) – List of models to parallelize

  • device (str) – The device to use

  • sequential (bool) – Whether the ensemble is sequential (necessary for stateful layers)

  • compile (bool) – Whether the ensemble is compiled or not

Returns:

None

forward(input: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

expand(x: Tensor | tuple | list, force: bool = False) Tensor | tuple | list[source]#
_get_single_member(index: int = 0) T[source]#

Extract a single member from the ensemble.

Parameters:

index – Index of the member to extract (default: 0)

Returns:

A single member of the ensemble with the specified index.

Return type:

T

_get_all_members() ModuleList[source]#

Extract all members from the ensemble.