Distributional Random Network Distillation (DRND)#

exploration distributional RL

Paper: Exploration and Anti-Exploration with Distributional Random Network Distillation

Pseudocode#

Configuration#

Specific configuration for the DRND algorithm (in config/model_configs/).#
# [start-bonus-config]
@dataclass
class DRNDBonusConfig:
    """
    Configuration for the DRND exploration bonus network.
    Implements the randomized target ensemble used for exploration.
    The default values follow Yang et al., 2024

    Attributes:
        depth (int): Number of hidden layers.
        width (int): Width (number of units) in each hidden layer.
        norm (bool): Whether to apply layer normalization.
        activation (str): Activation function to use ('relu' or 'crelu'). User should add other activation functions if needed.
        dim_out (int): Output dimensionality of the bonus network.
        scale_factor (float): Scaling factor between two bonus terms.
        n_members (int): Size of the target ensemble.
        learning_rate (float): Learning rate for training the predictor network.
    """

    depth: int = 4
    width: int = 256
    norm: bool = True
    activation: Literal["relu", "crelu"] = "relu"
    dim_out: int = 32
    scale_factor: float = 0.9
    n_members: int = 10
    learning_rate: float = 1e-4


# [end-bonus-config]


@dataclass
class DRNDActorConfig:
    """
    Configuration for the actor network used in DRND.

    Attributes:
        arch (type): The neural network architecture to use.
        actor_type (type): The actor class (typically DRNDActor).
        lambda_actor (float): Scaling coefficient for exploration bonus in the actor loss.
    """

    arch: type = ActorNetProbabilistic
    actor_type: type = DRNDActor
    lambda_actor: float = 1.0


@dataclass
class DRNDCriticConfig:
    """
    Configuration for the critic network used in DRND.

    Attributes:
        arch (type): The neural network architecture to use.
        critic_type (type): The critic class (typically DRNDCritics).
        lambda_critic (float): Scaling coefficient for exploration bonus in the critic target.
    """

    arch: type = CriticNet
    critic_type: type = DRNDCritics
    lambda_critic: float = 1.0


@dataclass
class DRNDConfig:
    """
    Full configuration for the DRND algorithm.

    Attributes:
        name (str): Name of the algorithm.
        bonus_conf (DRNDBonusConfig): Configuration for bonus (RND) component.
        target_entropy (float | None): Target entropy for entropy regularization.
        alpha (float): Entropy regularization coefficient.
        loss (str): Type of loss function ('MSELoss').
        policy_delay (int): Number of critic updates per actor update.
        tau (float): Soft update coefficient for Polyak averaging
        actor (DRNDActorConfig): Configuration for actor.
        critic (DRNDCriticConfig): Configuration for critic.
    """

    name: str = "drnd"
    bonus_conf: DRNDBonusConfig = field(default_factory=DRNDBonusConfig)
    target_entropy: float | None = None
    alpha: float = 1.0
    loss: str = "MSELoss"
    policy_delay: int = 1
    tau: float = 0.005
    actor: DRNDActorConfig = field(default_factory=DRNDActorConfig)
    critic: DRNDCriticConfig = field(default_factory=DRNDCriticConfig)

    def __post_init__(self):
        if isinstance(self.bonus_conf, dict):
            self.bonus_conf = DRNDBonusConfig(**self.bonus_conf)


UML Diagram#

UML diagram for the DRND algorithm.

UML diagram for the DRND algorithm.#

We use the UML diagram to illustrate the relationships between the classes in our DRND implementation.

The diagram shows how the DRNDActor and DRNDCritic classes inherit from the base classes Actor and CriticEnsemble, respectively. DistributionalRandomNetworkDistillation class also inherits from ActorCritic class which inherits from Agent.

We illustrate each class's crucial attributes and methods for DRND. Specifically:

get_bellman_target() method in DRNDCritic class is implemented to compute the Bellman target with exploration bonuses derived from the DRND module.

DRNDBonus class is introduced to compute exploration signals by measuring the discrepancy between a fixed and a predictor network, similar to RND but using a distributional target.

The actor uses DRNDActor class to sample actions based on the regular policy and is trained using entropy-regularized policy gradients.

Classes#

class objectrl.models.drnd.DRNDBonus(config: MainConfig, dim_state: int, dim_act: int)[source]#

Bases: Module

Distributional Random Network Distillation (DRND) bonus module. Provides an exploration bonus based on disagreement between an ensemble of target networks and a learned predictor network. Based on Yang et al. (2024).

Parameters:
  • config (MainConfig) – Main experiment/configuration object.

  • dim_state (int) – Observation space dimension.

  • dim_act (int) – Action space dimension.

target_ensemble#

Ensemble of target networks.

Type:

Ensemble

predictor#

Predictor network for state-action pairs.

Type:

nn.Module

optim_pred#

Optimizer for the predictor.

Type:

torch.optim.Optimizer

n_members#

Number of ensemble members.

Type:

int

device#

Device for computations.

Type:

torch.device

bonus_conf#

Configuration for the bonus module.

Type:

BonusConfig

__init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset() None[source]#
bonus(state: Tensor, action: Tensor) Tensor[source]#

Compute the DRND exploration bonus for a given (state, action) pair. Combines two terms: - Squared difference between predictor and mean of ensemble - Normalized difference in variances (distributional bonus)

Parameters:
  • state (torch.Tensor) – Input state tensor.

  • action (torch.Tensor) – Input action tensor.

Returns:

Exploration bonus.

Return type:

torch.Tensor

mu(x: Tensor) Tensor[source]#
B2(x: Tensor) Tensor[source]#
update_predictor(state: Tensor, action: Tensor) None[source]#

Updates the predictor network using a randomly selected ensemble member as the regression target.

Parameters:
  • state (torch.Tensor) – Input state tensor.

  • action (torch.Tensor) – Input action tensor.

Returns:

None

class objectrl.models.drnd.DRNDActor(config: MainConfig, dim_state: tuple[int, ...], dim_act: tuple[int, ...])[source]#

Bases: SACActor

Actor network for DRND, based on SAC but augmented with an exploration bonus.

Parameters:
  • config (MainConfig) – Experiment configuration.

  • dim_state (tuple) – Observation space dimension.

  • dim_act (tuple) – Action space dimension.

lambda_actor#

Regularization coefficient for the actor loss.:

Type:

float

__init__(config: MainConfig, dim_state: tuple[int, ...], dim_act: tuple[int, ...]) None[source]#

Initializes the Actor.

Parameters:
  • config (MainConfig) – Configuration dataclass instance.

  • dim_state (int) – Dimension of observation space.

  • dim_act (int) – Dimension of action space.

Returns:

None

loss(state: Tensor, critics: DRNDCritics, bonus_ensemble: DRNDBonus) tuple[Tensor, dict][source]#

Compute actor loss including the entropy term and the DRND exploration bonus.

Parameters:
  • state (torch.Tensor) – Batch of input states.

  • critics (DRNDCritics) – Critic network(s).

  • bonus_ensemble (DRNDBonus) – Bonus ensemble for exploration.

Returns:

Total actor loss. act_dict (dict): Output of actor network.

Return type:

loss (Tensor)

update(state: Tensor, critics: DRNDCritics, bonus_ensemble: DRNDBonus) None[source]#

Perform a gradient step for the actor.

Parameters:
  • state (torch.Tensor) – Batch of input states.

  • critics (DRNDCritics) – Critic network(s).

  • bonus_ensemble (DRNDBonus) – Bonus ensemble for exploration.

Returns:

None

class objectrl.models.drnd.DRNDCritics(config: MainConfig, dim_state: int, dim_act: int)[source]#

Bases: CriticEnsemble

Critic module for DRND that incorporates exploration bonus into target computation.

Parameters:
  • config (MainConfig) – Experiment configuration.

  • dim_state (tuple) – Observation space dimension.

  • dim_act (tuple) – Action space dimension.

lambda_critic#

Regularization coefficient for the critic loss.

Type:

float

_gamma#

Discount factor for future rewards.

Type:

float

_agent_name#

Name of the agent.

Type:

str

__init__(config: MainConfig, dim_state: int, dim_act: int)[source]#

Initialize the critic ensemble.

Parameters:
  • config (MainConfig) – Configuration object with model parameters.

  • dim_state (int) – Dimension of the state space.

  • dim_act (int) – Dimension of the action space.

Returns:

None

get_bellman_target(reward: Tensor, next_state: Tensor, done: Tensor, actor: DRNDActor, bonus_ensemble: DRNDBonus) Tensor[source]#

Computes the Bellman target including entropy regularization and exploration penalty.

Parameters:
  • reward (torch.Tensor) – Reward signal.

  • next_state (torch.Tensor) – Next state.

  • done (torch.Tensor) – Done flag (1 if terminal, else 0).

  • actor (DRNDActor) – Actor network (used for target action).

  • bonus_ensemble (DRNDBonus) – Bonus ensemble for exploration.

Returns:

Bellman target.

Return type:

y (Tensor)

class objectrl.models.drnd.DRND(config: MainConfig, critic_type: type = <class 'objectrl.models.drnd.DRNDCritics'>, actor_type: type = <class 'objectrl.models.drnd.DRNDActor'>, bonus_type: type = <class 'objectrl.models.drnd.DRNDBonus'>)[source]#

Bases: ActorCritic

DRND agent integrating exploration through Distributional Random Network Distillation.

Implements actor-critic logic where: - Actor loss is regularized by an exploration bonus - Critic targets include bonus penalties - Bonus predictor is trained online

Yang et al. (2024): Exploration and Anti-Exploration with Distributional Random Network Distillation

_agent_name = 'DRND'#
__init__(config: MainConfig, critic_type: type = <class 'objectrl.models.drnd.DRNDCritics'>, actor_type: type = <class 'objectrl.models.drnd.DRNDActor'>, bonus_type: type = <class 'objectrl.models.drnd.DRNDBonus'>) None[source]#

Initializes DRND agent.

Parameters:
  • config (MainConfig) – Configuration dataclass instance.

  • critic_type (type) – Critic class type.

  • actor_type (type) – Actor class type.

Returns:

None

learn(max_iter: int = 1, n_epochs: int = 0) None[source]#

Perform the learning process for the agent.

Parameters:
  • max_iter (int) – Maximum number of iterations for learning.

  • n_epochs (int) – Number of epochs for training. If 0, random sampling is used.

Returns:

None