Critic and CriticEnsemble#
This module defines the core critic network and its ensemble variant for Q-value estimation in reinforcement learning.
Key Points:
The
Criticclass estimates Q-values for state-action pairs with optional target networks for stabilized training.The
CriticEnsembleclass manages an ensemble of critics to improve robustness by aggregating multiple Q-value estimates.Both classes support soft target network updates via Polyak averaging.
The ensemble uses a generic
Ensemblecontainer for parallelizing multiple critics efficiently.Access individual critics from the ensemble with indexing (e.g., ensemble[0]).
Reduction methods (
minormean) for ensemble Q-values are configurable.
Usage Notes:
The input to critics is prepared by concatenating state and action tensors.
The target network, if enabled, must be initialized before use.
The
updatemethod applies the Bellman update to the ensemble.The
get_bellman_targetmethod is abstract and must be implemented in subclasses to provide the target for training.
Important
Always initialize the target network before training to stabilize updates.
Warning
Using an improper reduction method on Q-values from multiple critics may destabilize training.
Supported reductions are min and mean.
Here are the detailed methods and attributes for both classes.
Critic#
- class objectrl.models.basic.critic.Critic(config: MainConfig, dim_state: int, dim_act: int)[source]#
Bases:
ModuleA critic network that estimates Q-values for state-action pairs.
- device#
Device for computations.
- Type:
torch.device
- has_target#
Flag for presence of target network.
- Type:
bool
- _tau#
Polyak averaging factor for target update.
- Type:
float
- _gamma#
Discount factor for rewards.
- Type:
float
- model#
Main critic network.
- Type:
nn.Module
- target#
Target critic network.
- Type:
nn.Module, optional
- __init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#
Initialize the critic network.
- Parameters:
config – Configuration object containing model parameters
dim_state – Dimension of the state space
dim_act – Dimension of the action space
- Returns:
None
- model: Module#
- target: Module | None#
- reduce(q_val: Tensor) Tensor[source]#
Reduce Q-values if needed.
- Parameters:
q_val (torch.Tensor) – Q-values tensor
- Returns:
Reduced Q-values.
- Return type:
torch.Tensor
- Q(state: Tensor, action: Tensor | None = None) Tensor[source]#
Compute Q-values for given state-action pairs.
- Parameters:
state (torch.Tensor) – State tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Q-values for the state-action pairs.
- Return type:
torch.Tensor
- static _prepare_input(state: Tensor, action: Tensor) Tensor[source]#
Concatenate state and action tensors for critic input.
- Parameters:
state (torch.Tensor) – State tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Prepared input tensor for the critic.
- Return type:
torch.Tensor
- init_target() None[source]#
Initialize target network with weights from the main network.
- Parameters:
None
- Returns:
None
CriticEnsemble#
- class objectrl.models.basic.critic.CriticEnsemble(config: MainConfig, dim_state: int, dim_act: int)[source]#
Bases:
Module,ABCEnsemble of critic networks for robust Q-value estimation.
- n_members#
Number of critics in the ensemble.
- Type:
int
- config#
Configuration object.
- Type:
MainConfig
- dim_state#
Dimension of the state space.
- Type:
int
- dim_act#
Dimension of the action space.
- Type:
int
- has_target#
Flag for target networks.
- Type:
bool
- _reset#
Reset flag from config.
- Type:
bool
- device#
Device for computations.
- Type:
torch.device
- loss#
Loss function.
- Type:
callable
- _tau#
Polyak averaging factor.
- Type:
float
- _gamma#
Discount factor.
- Type:
float
- optim#
Optimizer for ensemble parameters.
- Type:
torch.optim.Optimizer
- iter#
Training iteration counter.
- Type:
int
- __init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#
Initialize the critic ensemble.
- Parameters:
config (MainConfig) – Configuration object with model parameters.
dim_state (int) – Dimension of the state space.
dim_act (int) – Dimension of the action space.
- Returns:
None
- reduce(q_val: Tensor, reduce_type: str) Tensor[source]#
Reduce Q-values from multiple critics according to the configured method. Currently supports ‘min’ or ‘mean’. User should add more methods if needed.
- Parameters:
q_val (torch.Tensor) – Q-values tensor from all critics.
reduce_type (str) – How to reduce the Q-values.
- Returns:
Reduced Q-values.
- Return type:
torch.Tensor
- _get_single_critic(index: int = 0) Critic[source]#
Get a single Critic instance from the ensemble.
- Parameters:
index (int) – Index of the critic to retrieve.
- Returns:
Critic instance.
- Return type:
- Q(state: Tensor, action: Tensor | None = None) Tensor[source]#
Compute Q-values for given state-action pairs using all critics.
- Parameters:
state (torch.Tensor) – State tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Q-values from all critics.
- Return type:
torch.Tensor
- Q_t(state: Tensor, action: Tensor | None = None) Tensor[source]#
Compute Q-values using the target networks.
- Parameters:
state (torch.Tensor) – State tensor.
action (torch.Tensor) – Action tensor.
- Returns:
Target Q-values from all critics.
- Return type:
torch.Tensor
- update(state: Tensor, action: Tensor, y: Tensor) None[source]#
Update critic networks using the provided Bellman targets.
- Parameters:
state (torch.Tensor) – State tensor.
action (torch.Tensor) – Action tensor.
y (torch.Tensor) – Bellman target values.
- Returns:
None