Example 2: Adapting SAC to REDQ#
This example demonstrates how to convert Soft Actor-Critic (SAC) [1] to Randomized Ensembled Double Q-learning (REDQ) [2] with minimal code changes.
REDQ extends SAC to improve the sample efficiency of model-free methods in continuous control tasks, where environment interactions are costly. The core idea is to increase the update-to-data (UTD) ratio, i.e., perform multiple gradient updates per environment step (\(\text{UTD} \gg 1\)). However, naively increasing UTD in SAC leads to instability. REDQ overcomes this by introducing three key components: a large critic ensemble, a randomized pessimistic critic target, and an averaged critic ensemble for policy learning.
The conversion from SAC to REDQ requires four modifications:
Increase the UTD ratio (\(\text{UTD} \gg 1\)),
Maintain an ensemble of \(N\) critics,
Compute the critic target using the minimum over a random subset of size \(M < N\) from the critic ensemble,
Use the mean of the critic ensemble for actor updates.
Each component builds upon the existing SAC implementation with minimal modifications.
See Randomized Ensembled Double Q-Learning (REDQ) for the full model API. We will focus solely on the relevant changes compared to SAC in this use case.
REDQ Critics#
REDQ contains \(N\) critics. To define the Bellman target for updating the critics, REDQ samples a set \(\mathcal{M}\) of \(M\) distinct indices from \(\{1, 2, \ldots, N\}\). Then, compute the Q target \(y\) (shared across all \(N\) Q-functions) using the minimum over these \(M\) critics:
The policy \(\theta\), on the other hand, is updated using the mean of all \(N\) critics in the ensemble, with gradient ascent:
This means we need to inherit a REDQCritic class from SACCritic with a modified reduce function:
def reduce(self, q_val_list: torch.Tensor, reduce_type="min") -> torch.Tensor:
"""
Randomly samples a subset of critics from the ensemble and reduces their Q-values.
Args:
q_val_list (torch.Tensor): List of Q-value tensors from each critic in the ensemble.
reduce_type (str): Reduction method.
Returns:
torch.Tensor: Reduced Q-values obtained by taking the minimum over sampled critics.
"""
if reduce_type == "min":
if len(q_val_list) < self.config.model.n_in_target:
raise ValueError(
f"Expected at least {self.config.model.n_in_target} critics, but got {len(q_val_list)}."
)
i_targets = torch.randperm(int(self.n_members))[
: self.config.model.n_in_target
]
return torch.stack([q_val_list[i] for i in i_targets], dim=-1).min(-1)[0]
elif reduce_type == "mean":
return q_val_list.mean(0)
else:
raise ValueError(
f"Unsupported reduce type: {reduce_type}. Use 'min' or 'mean'."
)
REDQ ActorCritic and REDQ Actor#
The REDQ actor is directly inherited from SAC without further modification:
class RandomizedEnsembledDoubleQLearning(ActorCritic):
"""
REDQ agent combining REDQCritic and SACActor.
Chen et al. (2021): Randomized Ensembled Double Q-Learning: Learning Fast Without a Model
"""
_agent_name = "REDQ"
def __init__(
self,
config: "MainConfig",
critic_type: type = REDQCritic,
actor_type: type = SACActor,
) -> None:
"""
Initializes the REDQ agent.
Args:
config (MainConfig): Configuration dataclass instance.
critic_type (type): Critic class type.
actor_type (type): Actor class type.
Returns:
None
"""
super().__init__(config, critic_type, actor_type)
REDQ Config#
Following the original REDQ implementation, we set the UTD ratio (corresponding to the hyperparameter policy_delay) to 20.
REDQ maintains \(N = 10\) critics, and to compute the Bellman target, it takes \(M = 2\) random target critics.
As described above, REDQ uses the minimum of \(M = 2\) critics for critic learning,
and the mean of all critics for actor training.
The relevant hyperparameters in the configuration dataclass can be modified as follows:
@dataclass
class REDQCriticConfig:
"""
Configuration class for the REDQ critic ensemble.
Attributes:
arch (type): Neural network architecture for critics.
critic_type (type): Critic class type.
n_members (int): Number of critics in the ensemble.
reduce (str): Reduction method during training.
target_reduce (str): Reduction method for target Q-value computation.
"""
arch: type = CriticNet
critic_type: type = REDQCritic
n_members: int = 10
reduce: str = "mean"
target_reduce: str = "min"
References