Example 1: Adapting SAC to DRND

Example 1: Adapting SAC to DRND#

Modifying existing reinforcement learning algorithms to incorporate new exploration mechanisms or architectural changes is a common requirement in research and development. Our class structure makes such adaptations straightforward and maintainable. This example demonstrates how to convert Soft Actor-Critic (SAC) [1] to Distributional Random Network Distillation (DRND) [2] with minimal code changes.

DRND extends SAC by adding an exploration bonus based on uncertainty estimation through ensemble disagreement. The key insight is that areas of the state-action space where an ensemble of networks disagrees most likely represent unexplored regions worth investigating.

The conversion from SAC to DRND requires three main components: the exploration bonus mechanism, modified actor and critic networks, and an updated training procedure. Each component builds upon the existing SAC implementation with minimal modifications.

See Distributional Random Network Distillation (DRND) for the full model API. We will focus solely on the relevant changes compared to SAC in this use case.

DRND Bonus#

DRND’s exploration bonus consists of a learnable predictor network, \(f_\theta\), and a fixed ensemble of target networks, \(\bar f_1,\ldots, \bar f_N\). The ensemble provides uncertainty estimates through disagreement among its members.

The ensemble’s first two moments are computed as:

\[\begin{split}\mu(x) &= \mathbb E[X] = \frac1N \sum_{i=1}^N \bar f_i(x),\\ B_2(x) &= \mathbb E[X^2] = \frac1N \sum_{i=1}^N (\bar f_i(x))^2.\end{split}\]

The configuration dataclass encapsulates all hyperparameters needed for the bonus network:

@dataclass
class DRNDBonusConfig:
    """
    Configuration for the DRND exploration bonus network.
    Implements the randomized target ensemble used for exploration.
    The default values follow Yang et al., 2024

    Attributes:
        depth (int): Number of hidden layers.
        width (int): Width (number of units) in each hidden layer.
        norm (bool): Whether to apply layer normalization.
        activation (str): Activation function to use ('relu' or 'crelu'). User should add other activation functions if needed.
        dim_out (int): Output dimensionality of the bonus network.
        scale_factor (float): Scaling factor between two bonus terms.
        n_members (int): Size of the target ensemble.
        learning_rate (float): Learning rate for training the predictor network.
    """

    depth: int = 4
    width: int = 256
    norm: bool = True
    activation: Literal["relu", "crelu"] = "relu"
    dim_out: int = 32
    scale_factor: float = 0.9
    n_members: int = 10
    learning_rate: float = 1e-4


The network initialization leverages our existing utilities for consistent architecture across predictor and ensemble members:

from utils.net_utils import MLP
from models.basic.ensemble import Ensemble

gen_net = lambda: MLP(
    dim_obs[0] + dim_act[0],
    bonus_conf.dim_out,
    bonus_conf.depth,
    bonus_conf.width,
    act=bonus_conf.activation,
    has_norm=bonus_conf.norm
)

predictor = gen_net()
target_ensemble = Ensemble(
    bonus_conf.n_members,
    gen_net(),
    [gen_net() for _ in range(bonus_conf.n_members)]
)
optim_pred = torch.optim.Adam(predictor.parameters(), lr=bonus_conf.learning_rate)

The predictor network is trained to match randomly selected ensemble members, creating a learning signal that captures epistemic uncertainty. The training objective minimizes the mean squared error:

\[L(\theta) = ||f_\theta(x) - c(x)||^2,\]

where \(x=(s,a)\) represents the concatenated state-action input.

def update_predictor(state: torch.Tensor, action: torch.Tensor) -> None:
    sa = torch.cat((state, action), -1)
    optim_pred.zero_grad()
    c = torch.randint(bonus_conf.n_members, ())
    c_target = target_ensemble[c](sa)
    pred = predictor(sa)
    loss = (pred - c_target).pow(2).mean()
    loss.backward()
    optim_pred.step()

The exploration bonus combines two terms: disagreement between the predictor and the ensemble mean, and normalized variance across ensemble predictions:

\[b(x) = \lambda ||f_\theta(x) - \mu(x)||^2 + (1 - \lambda)\sqrt{\frac{(f_\theta(x)^2 - \mu(x)^2)}{B_2(x) - \mu(x)^2}},\]

where \(\lambda\) is a scaling factor:

@torch.no_grad()
def bonus(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
    sa = torch.cat((state, action), -1)
    target_pred = target_ensemble(sa)
    mu = target_pred.mean(0)
    mu2 = mu.pow(2)
    B2 = target_pred.pow(2).mean(0)
    pred = predictor(sa)
    dim_check(pred, mu)
    fst = (pred - mu).pow(2).sum(1, keepdim=True)
    snd = torch.sqrt(((pred.pow(2) - mu2).abs() / (B2 - mu2))).mean(1, keepdim=True)
    return bonus_conf.scale_factor * fst + (1 - bonus_conf.scale_factor) * snd

DRND Actor#

The DRND actor extends the SAC actor by incorporating the exploration bonus into its loss function. This modification encourages the policy to explore regions where the ensemble exhibits high disagreement. The required changes are highlighted in the code block below.

class DRNDActor(SACActor):
    """
    Actor network for DRND, based on SAC but augmented with an exploration bonus.

    Args:
        config (MainConfig): Experiment configuration.
        dim_state (tuple): Observation space dimension.
        dim_act (tuple): Action space dimension.
    Attributes:
        lambda_actor (float): Regularization coefficient for the actor loss.:
    """

    def __init__(
        self, config: "MainConfig", dim_state: tuple[int, ...], dim_act: tuple[int, ...]
    ) -> None:
        super().__init__(config, dim_state, dim_act)

        self.lambda_actor = config.model.actor.lambda_actor

    def loss(
        self, state: torch.Tensor, critics: "DRNDCritics", bonus_ensemble: DRNDBonus
    ) -> tuple[torch.Tensor, dict]:
        """
        Compute actor loss including the entropy term and the DRND exploration bonus.

        Args:
            state (torch.Tensor): Batch of input states.
            critics (DRNDCritics): Critic network(s).
            bonus_ensemble (DRNDBonus): Bonus ensemble for exploration.
        Returns:
            loss (Tensor): Total actor loss.
            act_dict (dict): Output of actor network.
        """
        loss, act_dict = super().loss(state, critics)
        bonus = bonus_ensemble.bonus(state, act_dict["action"]).mean()
        return loss + self.lambda_actor * bonus, act_dict

    def update(
        self, state: torch.Tensor, critics: "DRNDCritics", bonus_ensemble: DRNDBonus
    ) -> None:
        """
        Perform a gradient step for the actor.

        Args:
            state (torch.Tensor): Batch of input states.
            critics (DRNDCritics): Critic network(s).
            bonus_ensemble (DRNDBonus): Bonus ensemble for exploration.
        Returns:
            None
        """
        self.optim.zero_grad()
        loss, act_dict = self.loss(state, critics, bonus_ensemble)
        loss.backward()
        self.optim.step()
        self.update_alpha(act_dict)

        self.iter += 1  # Increment iteration counter


DRND Critics#

The DRND critics modify the Bellman target computation to include the exploration bonus as an intrinsic reward. This ensures that the value function accounts for the exploration benefit of different state-action pairs:

class DRNDCritics(CriticEnsemble):
    """
    Critic module for DRND that incorporates exploration bonus into target computation.

    Args:
        config (MainConfig): Experiment configuration.
        dim_state (tuple): Observation space dimension.
        dim_act (tuple): Action space dimension.
    Attributes:
        lambda_critic (float): Regularization coefficient for the critic loss.
        _gamma (float): Discount factor for future rewards.
        _agent_name (str): Name of the agent.
    """

    def __init__(self, config: "MainConfig", dim_state: int, dim_act: int):
        super().__init__(config, dim_state, dim_act)
        self.lambda_critic = config.model.critic.lambda_critic

    @torch.no_grad()
    def get_bellman_target(
        self,
        reward: torch.Tensor,
        next_state: torch.Tensor,
        done: torch.Tensor,
        actor: DRNDActor,
        bonus_ensemble: DRNDBonus,
    ) -> torch.Tensor:
        """
        Computes the Bellman target including entropy regularization and exploration penalty.

        Args:
            reward (torch.Tensor): Reward signal.
            next_state (torch.Tensor): Next state.
            done (torch.Tensor): Done flag (1 if terminal, else 0).
            actor (DRNDActor): Actor network (used for target action).
            bonus_ensemble (DRNDBonus): Bonus ensemble for exploration.
        Returns:
            y (Tensor): Bellman target.
        """
        alpha = actor.log_alpha.exp().detach()

        act_dict = actor.act(next_state)
        next_action, log_prob = act_dict["action"], act_dict["action_logprob"]

        target_values = self.Q_t(next_state, next_action)
        target_reduced = self.reduce(
            target_values, reduce_type=self.config.model.critic.target_reduce
        )
        bonus = bonus_ensemble.bonus(next_state, next_action)
        q_target = target_reduced - alpha * log_prob - self.lambda_critic * bonus
        reward = reward.unsqueeze(-1)
        dim_check(q_target, reward)
        y = reward + (self._gamma * q_target * (1 - done.unsqueeze(-1)))
        return y


DRND ActorCritic#

The main training loop requires minimal modifications to accommodate the predictor network updates. The key changes involve passing the bonus ensemble to actor and critic updates and adding predictor training:

class DRND(ActorCritic):
    """
    DRND agent integrating exploration through Distributional Random Network Distillation.

    Implements actor-critic logic where:
    - Actor loss is regularized by an exploration bonus
    - Critic targets include bonus penalties
    - Bonus predictor is trained online

    Yang et al. (2024): Exploration and Anti-Exploration with Distributional Random Network Distillation
    """

    _agent_name = "DRND"

    def __init__(
        self,
        config: "MainConfig",
        critic_type: type = DRNDCritics,
        actor_type: type = DRNDActor,
        bonus_type: type = DRNDBonus,
    ) -> None:
        """
        Initializes DRND agent.

        Args:
            config (MainConfig): Configuration dataclass instance.
            critic_type (type): Critic class type.
            actor_type (type): Actor class type.
        Returns:
            None
        """
        super().__init__(config, critic_type, actor_type)

        self.bonus_ensemble = bonus_type(config, self.dim_state, self.dim_act)

    def learn(self, max_iter: int = 1, n_epochs: int = 0) -> None:
        """
        Perform the learning process for the agent.

        Args:
            max_iter (int): Maximum number of iterations for learning.
            n_epochs (int): Number of epochs for training. If 0, random sampling is used.
        Returns:
            None
        """
        # Check if there is enough data in memory to sample a batch
        if self.config_train.batch_size > len(self.experience_memory):
            return None

        # Determine the number of steps and initialize the iterator
        n_steps = self.experience_memory.get_steps_and_iterator(
            n_epochs, max_iter, self.config_train.batch_size
        )

        for _ in range(n_steps):
            # Get batch using the internal iterator
            batch = self.experience_memory.get_next_batch(self.config_train.batch_size)

            bellman_target = self.critic.get_bellman_target(
                batch["reward"],
                batch["next_state"],
                batch["terminated"],
                self.actor,
                self.bonus_ensemble,
            )
            self.critic.update(batch["state"], batch["action"], bellman_target)

            # Update the actor network periodically
            if self.n_iter % self.policy_delay == 0:
                self.actor.update(batch["state"], self.critic, self.bonus_ensemble)
                if self.actor.has_target:
                    self.actor.update_target()

            # Update target networks
            if self.critic.has_target:
                self.critic.update_target()
            self.bonus_ensemble.update_predictor(batch["state"], batch["action"])
            self.n_iter += 1
        return None


Summary#

Converting SAC to DRND demonstrates the flexibility of ObjectRL’s architecture. The transformation required only:

  1. Adding the exploration bonus mechanism with ensemble-based uncertainty estimation

  2. Extending actor and critic classes to incorporate bonus terms in their loss functions

  3. Minimal training loop modifications to update the predictor network

The inheritance-based design allows us to reuse the majority of SAC’s implementation while cleanly extending functionality. This pattern generalizes to other algorithmic modifications, making it straightforward to experiment with novel exploration strategies, different network architectures, or alternative learning objectives.

The approach ensures that each component remains testable and maintainable while facilitating rapid prototyping of new ideas. This design philosophy significantly reduces development time for implementing state-of-the-art reinforcement learning algorithms.

References