Critic and CriticEnsemble#

This module defines the core critic network and its ensemble variant for Q-value estimation in reinforcement learning.

Key Points:

  • The Critic class estimates Q-values for state-action pairs with optional target networks for stabilized training.

  • The CriticEnsemble class manages an ensemble of critics to improve robustness by aggregating multiple Q-value estimates.

  • Both classes support soft target network updates via Polyak averaging.

  • The ensemble uses a generic Ensemble container for parallelizing multiple critics efficiently.

  • Access individual critics from the ensemble with indexing (e.g., ensemble[0]).

  • Reduction methods (min or mean) for ensemble Q-values are configurable.

Usage Notes:

  • The input to critics is prepared by concatenating state and action tensors.

  • The target network, if enabled, must be initialized before use.

  • The update method applies the Bellman update to the ensemble.

  • The get_bellman_target method is abstract and must be implemented in subclasses to provide the target for training.

Important

Always initialize the target network before training to stabilize updates.

Warning

Using an improper reduction method on Q-values from multiple critics may destabilize training. Supported reductions are min and mean.

Here are the detailed methods and attributes for both classes.

Critic#

class objectrl.models.basic.critic.Critic(config: MainConfig, dim_state: int, dim_act: int)[source]#

Bases: Module

A critic network that estimates Q-values for state-action pairs.

device#

Device for computations.

Type:

torch.device

has_target#

Flag for presence of target network.

Type:

bool

_tau#

Polyak averaging factor for target update.

Type:

float

_gamma#

Discount factor for rewards.

Type:

float

model#

Main critic network.

Type:

nn.Module

target#

Target critic network.

Type:

nn.Module, optional

__init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#

Initialize the critic network.

Parameters:
  • config – Configuration object containing model parameters

  • dim_state – Dimension of the state space

  • dim_act – Dimension of the action space

Returns:

None

model: Module#
target: Module | None#
reduce(q_val: Tensor) Tensor[source]#

Reduce Q-values if needed.

Parameters:

q_val (torch.Tensor) – Q-values tensor

Returns:

Reduced Q-values.

Return type:

torch.Tensor

Q(state: Tensor, action: Tensor | None = None) Tensor[source]#

Compute Q-values for given state-action pairs.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

Returns:

Q-values for the state-action pairs.

Return type:

torch.Tensor

static _prepare_input(state: Tensor, action: Tensor) Tensor[source]#

Concatenate state and action tensors for critic input.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

Returns:

Prepared input tensor for the critic.

Return type:

torch.Tensor

init_target() None[source]#

Initialize target network with weights from the main network.

Parameters:

None

Returns:

None

update_target() None[source]#

Update target network parameters using soft update.

Parameters:

None

Returns:

None

Q_t(state: Tensor, action: Tensor | None = None) Tensor[source]#

Compute Q-values using the target network.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

Returns:

Target Q-values.

Return type:

torch.Tensor

CriticEnsemble#

class objectrl.models.basic.critic.CriticEnsemble(config: MainConfig, dim_state: int, dim_act: int)[source]#

Bases: Module, ABC

Ensemble of critic networks for robust Q-value estimation.

n_members#

Number of critics in the ensemble.

Type:

int

config#

Configuration object.

Type:

MainConfig

dim_state#

Dimension of the state space.

Type:

int

dim_act#

Dimension of the action space.

Type:

int

has_target#

Flag for target networks.

Type:

bool

_reset#

Reset flag from config.

Type:

bool

device#

Device for computations.

Type:

torch.device

loss#

Loss function.

Type:

callable

_tau#

Polyak averaging factor.

Type:

float

_gamma#

Discount factor.

Type:

float

model_ensemble#

Ensemble of critic models.

Type:

Ensemble[Critic]

optim#

Optimizer for ensemble parameters.

Type:

torch.optim.Optimizer

target_ensemble#

Target ensemble.

Type:

Ensemble[Critic], optional

iter#

Training iteration counter.

Type:

int

__init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#

Initialize the critic ensemble.

Parameters:
  • config (MainConfig) – Configuration object with model parameters.

  • dim_state (int) – Dimension of the state space.

  • dim_act (int) – Dimension of the action space.

Returns:

None

reset() None[source]#

Reset the ensemble models and optimizer.

Parameters:

None

Returns:

None

reduce(q_val: Tensor, reduce_type: str) Tensor[source]#

Reduce Q-values from multiple critics according to the configured method. Currently supports ‘min’ or ‘mean’. User should add more methods if needed.

Parameters:
  • q_val (torch.Tensor) – Q-values tensor from all critics.

  • reduce_type (str) – How to reduce the Q-values.

Returns:

Reduced Q-values.

Return type:

torch.Tensor

_get_single_critic(index: int = 0) Critic[source]#

Get a single Critic instance from the ensemble.

Parameters:

index (int) – Index of the critic to retrieve.

Returns:

Critic instance.

Return type:

Critic

Q(state: Tensor, action: Tensor | None = None) Tensor[source]#

Compute Q-values for given state-action pairs using all critics.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

Returns:

Q-values from all critics.

Return type:

torch.Tensor

Q_t(state: Tensor, action: Tensor | None = None) Tensor[source]#

Compute Q-values using the target networks.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

Returns:

Target Q-values from all critics.

Return type:

torch.Tensor

update(state: Tensor, action: Tensor, y: Tensor) None[source]#

Update critic networks using the provided Bellman targets.

Parameters:
  • state (torch.Tensor) – State tensor.

  • action (torch.Tensor) – Action tensor.

  • y (torch.Tensor) – Bellman target values.

Returns:

None

update_target() None[source]#

Update target network parameters using soft update.

Parameters:

None

Returns:

None

abstractmethod get_bellman_target(*args: Any, **kwargs: Any) Tensor[source]#

Calculate the Bellman target for training the critic.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Bellman target tensor.

Return type:

torch.Tensor