Example 2: Adapting SAC to REDQ

Example 2: Adapting SAC to REDQ#

This example demonstrates how to convert Soft Actor-Critic (SAC) [1] to Randomized Ensembled Double Q-learning (REDQ) [2] with minimal code changes.

REDQ extends SAC to improve the sample efficiency of model-free methods in continuous control tasks, where environment interactions are costly. The core idea is to increase the update-to-data (UTD) ratio, i.e., perform multiple gradient updates per environment step (\(\text{UTD} \gg 1\)). However, naively increasing UTD in SAC leads to instability. REDQ overcomes this by introducing three key components: a large critic ensemble, a randomized pessimistic critic target, and an averaged critic ensemble for policy learning.

The conversion from SAC to REDQ requires four modifications:

  1. Increase the UTD ratio (\(\text{UTD} \gg 1\)),

  2. Maintain an ensemble of \(N\) critics,

  3. Compute the critic target using the minimum over a random subset of size \(M < N\) from the critic ensemble,

  4. Use the mean of the critic ensemble for actor updates.

Each component builds upon the existing SAC implementation with minimal modifications.

See Randomized Ensembled Double Q-Learning (REDQ) for the full model API. We will focus solely on the relevant changes compared to SAC in this use case.

REDQ Critics#

REDQ contains \(N\) critics. To define the Bellman target for updating the critics, REDQ samples a set \(\mathcal{M}\) of \(M\) distinct indices from \(\{1, 2, \ldots, N\}\). Then, compute the Q target \(y\) (shared across all \(N\) Q-functions) using the minimum over these \(M\) critics:

\[y = r + \gamma \left( \min_{i \in \mathcal{M}} Q_{\phi_{\text{targ}, i}} \left( s', \tilde{a}' \right) - \alpha \log \pi_{\theta} \left( \tilde{a}' \mid s' \right) \right), \quad \tilde{a}' \sim \pi_{\theta} ( \cdot \mid s')\]

The policy \(\theta\), on the other hand, is updated using the mean of all \(N\) critics in the ensemble, with gradient ascent:

\[\nabla_{\theta} \frac{1}{|B|} \sum_{s \in B} \left( \frac{1}{N} \sum_{i=1}^{N} Q_{\phi_i} \left( s, \tilde{a}_{\theta}(s) \right) - \alpha \log \pi_{\theta} \left( \tilde{a}_{\theta}(s) \mid s \right) \right), \quad \tilde{a}_{\theta}(s) \sim \pi_{\theta}(\cdot \mid s)\]

This means we need to inherit a REDQCritic class from SACCritic with a modified reduce function:

    def reduce(self, q_val_list: torch.Tensor, reduce_type="min") -> torch.Tensor:
        """
        Randomly samples a subset of critics from the ensemble and reduces their Q-values.

        Args:
            q_val_list (torch.Tensor): List of Q-value tensors from each critic in the ensemble.
            reduce_type (str): Reduction method.

        Returns:
            torch.Tensor: Reduced Q-values obtained by taking the minimum over sampled critics.
        """
        if reduce_type == "min":
            if len(q_val_list) < self.config.model.n_in_target:
                raise ValueError(
                    f"Expected at least {self.config.model.n_in_target} critics, but got {len(q_val_list)}."
                )

            i_targets = torch.randperm(int(self.n_members))[
                : self.config.model.n_in_target
            ]

            return torch.stack([q_val_list[i] for i in i_targets], dim=-1).min(-1)[0]
        elif reduce_type == "mean":
            return q_val_list.mean(0)
        else:
            raise ValueError(
                f"Unsupported reduce type: {reduce_type}. Use 'min' or 'mean'."
            )

REDQ ActorCritic and REDQ Actor#

The REDQ actor is directly inherited from SAC without further modification:

class RandomizedEnsembledDoubleQLearning(ActorCritic):
    """
    REDQ agent combining REDQCritic and SACActor.
    Chen et al. (2021): Randomized Ensembled Double Q-Learning: Learning Fast Without a Model
    """

    _agent_name = "REDQ"

    def __init__(
        self,
        config: "MainConfig",
        critic_type: type = REDQCritic,
        actor_type: type = SACActor,
    ) -> None:
        """
        Initializes the REDQ agent.

        Args:
            config (MainConfig): Configuration dataclass instance.
            critic_type (type): Critic class type.
            actor_type (type): Actor class type.
        Returns:
            None
        """
        super().__init__(config, critic_type, actor_type)


REDQ Config#

Following the original REDQ implementation, we set the UTD ratio (corresponding to the hyperparameter policy_delay) to 20. REDQ maintains \(N = 10\) critics, and to compute the Bellman target, it takes \(M = 2\) random target critics. As described above, REDQ uses the minimum of \(M = 2\) critics for critic learning, and the mean of all critics for actor training. The relevant hyperparameters in the configuration dataclass can be modified as follows:

@dataclass
class REDQCriticConfig:
    """
    Configuration class for the REDQ critic ensemble.

    Attributes:
        arch (type): Neural network architecture for critics.
        critic_type (type): Critic class type.
        n_members (int): Number of critics in the ensemble.
        reduce (str): Reduction method during training.
        target_reduce (str): Reduction method for target Q-value computation.
    """

    arch: type = CriticNet
    critic_type: type = REDQCritic
    n_members: int = 10
    reduce: str = "mean"
    target_reduce: str = "min"


References