Distributional Random Network Distillation (DRND)#
exploration distributional RLPaper: Exploration and Anti-Exploration with Distributional Random Network Distillation
Pseudocode#
Configuration#
# [start-bonus-config]
@dataclass
class DRNDBonusConfig:
"""
Configuration for the DRND exploration bonus network.
Implements the randomized target ensemble used for exploration.
The default values follow Yang et al., 2024
Attributes:
depth (int): Number of hidden layers.
width (int): Width (number of units) in each hidden layer.
norm (bool): Whether to apply layer normalization.
activation (str): Activation function to use ('relu' or 'crelu'). User should add other activation functions if needed.
dim_out (int): Output dimensionality of the bonus network.
scale_factor (float): Scaling factor between two bonus terms.
n_members (int): Size of the target ensemble.
learning_rate (float): Learning rate for training the predictor network.
"""
depth: int = 4
width: int = 256
norm: bool = True
activation: Literal["relu", "crelu"] = "relu"
dim_out: int = 32
scale_factor: float = 0.9
n_members: int = 10
learning_rate: float = 1e-4
# [end-bonus-config]
@dataclass
class DRNDActorConfig:
"""
Configuration for the actor network used in DRND.
Attributes:
arch (type): The neural network architecture to use.
actor_type (type): The actor class (typically DRNDActor).
lambda_actor (float): Scaling coefficient for exploration bonus in the actor loss.
"""
arch: type = ActorNetProbabilistic
actor_type: type = DRNDActor
lambda_actor: float = 1.0
@dataclass
class DRNDCriticConfig:
"""
Configuration for the critic network used in DRND.
Attributes:
arch (type): The neural network architecture to use.
critic_type (type): The critic class (typically DRNDCritics).
lambda_critic (float): Scaling coefficient for exploration bonus in the critic target.
"""
arch: type = CriticNet
critic_type: type = DRNDCritics
lambda_critic: float = 1.0
@dataclass
class DRNDConfig:
"""
Full configuration for the DRND algorithm.
Attributes:
name (str): Name of the algorithm.
bonus_conf (DRNDBonusConfig): Configuration for bonus (RND) component.
target_entropy (float | None): Target entropy for entropy regularization.
alpha (float): Entropy regularization coefficient.
loss (str): Type of loss function ('MSELoss').
policy_delay (int): Number of critic updates per actor update.
tau (float): Soft update coefficient for Polyak averaging
actor (DRNDActorConfig): Configuration for actor.
critic (DRNDCriticConfig): Configuration for critic.
"""
name: str = "drnd"
bonus_conf: DRNDBonusConfig = field(default_factory=DRNDBonusConfig)
target_entropy: float | None = None
alpha: float = 1.0
loss: str = "MSELoss"
policy_delay: int = 1
tau: float = 0.005
actor: DRNDActorConfig = field(default_factory=DRNDActorConfig)
critic: DRNDCriticConfig = field(default_factory=DRNDCriticConfig)
def __post_init__(self):
if isinstance(self.bonus_conf, dict):
self.bonus_conf = DRNDBonusConfig(**self.bonus_conf)
UML Diagram#
UML diagram for the DRND algorithm.#
We use the UML diagram to illustrate the relationships between the classes in our DRND implementation.
The diagram shows how the DRNDActor and DRNDCritic classes inherit from the base classes Actor and CriticEnsemble, respectively. DistributionalRandomNetworkDistillation class also inherits from ActorCritic class which inherits from Agent.
We illustrate each class's crucial attributes and methods for DRND. Specifically:
get_bellman_target() method in DRNDCritic class is implemented to compute the Bellman target with exploration bonuses derived from the DRND module.
DRNDBonus class is introduced to compute exploration signals by measuring the discrepancy between a fixed and a predictor network, similar to RND but using a distributional target.
The actor uses DRNDActor class to sample actions based on the regular policy and is trained using entropy-regularized policy gradients.
Classes#
- class objectrl.models.drnd.DRNDBonus(config: MainConfig, dim_state: int, dim_act: int)[source]#
Bases:
ModuleDistributional Random Network Distillation (DRND) bonus module. Provides an exploration bonus based on disagreement between an ensemble of target networks and a learned predictor network. Based on Yang et al. (2024).
- Parameters:
config (MainConfig) – Main experiment/configuration object.
dim_state (int) – Observation space dimension.
dim_act (int) – Action space dimension.
- predictor#
Predictor network for state-action pairs.
- Type:
nn.Module
- optim_pred#
Optimizer for the predictor.
- Type:
torch.optim.Optimizer
- n_members#
Number of ensemble members.
- Type:
int
- device#
Device for computations.
- Type:
torch.device
- bonus_conf#
Configuration for the bonus module.
- Type:
BonusConfig
- __init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- bonus(state: Tensor, action: Tensor) Tensor[source]#
Compute the DRND exploration bonus for a given (state, action) pair. Combines two terms: - Squared difference between predictor and mean of ensemble - Normalized difference in variances (distributional bonus)
- Parameters:
state (torch.Tensor) – Input state tensor.
action (torch.Tensor) – Input action tensor.
- Returns:
Exploration bonus.
- Return type:
torch.Tensor
- class objectrl.models.drnd.DRNDActor(config: MainConfig, dim_state: tuple[int, ...], dim_act: tuple[int, ...])[source]#
Bases:
SACActorActor network for DRND, based on SAC but augmented with an exploration bonus.
- Parameters:
config (MainConfig) – Experiment configuration.
dim_state (tuple) – Observation space dimension.
dim_act (tuple) – Action space dimension.
- lambda_actor#
Regularization coefficient for the actor loss.:
- Type:
float
- __init__(config: MainConfig, dim_state: tuple[int, ...], dim_act: tuple[int, ...]) None[source]#
Initializes the Actor.
- Parameters:
config (MainConfig) – Configuration dataclass instance.
dim_state (int) – Dimension of observation space.
dim_act (int) – Dimension of action space.
- Returns:
None
- loss(state: Tensor, critics: DRNDCritics, bonus_ensemble: DRNDBonus) tuple[Tensor, dict][source]#
Compute actor loss including the entropy term and the DRND exploration bonus.
- Parameters:
state (torch.Tensor) – Batch of input states.
critics (DRNDCritics) – Critic network(s).
bonus_ensemble (DRNDBonus) – Bonus ensemble for exploration.
- Returns:
Total actor loss. act_dict (dict): Output of actor network.
- Return type:
loss (Tensor)
- update(state: Tensor, critics: DRNDCritics, bonus_ensemble: DRNDBonus) None[source]#
Perform a gradient step for the actor.
- Parameters:
state (torch.Tensor) – Batch of input states.
critics (DRNDCritics) – Critic network(s).
bonus_ensemble (DRNDBonus) – Bonus ensemble for exploration.
- Returns:
None
- class objectrl.models.drnd.DRNDCritics(config: MainConfig, dim_state: int, dim_act: int)[source]#
Bases:
CriticEnsembleCritic module for DRND that incorporates exploration bonus into target computation.
- Parameters:
config (MainConfig) – Experiment configuration.
dim_state (tuple) – Observation space dimension.
dim_act (tuple) – Action space dimension.
- lambda_critic#
Regularization coefficient for the critic loss.
- Type:
float
- _gamma#
Discount factor for future rewards.
- Type:
float
- _agent_name#
Name of the agent.
- Type:
str
- __init__(config: MainConfig, dim_state: int, dim_act: int)[source]#
Initialize the critic ensemble.
- Parameters:
config (MainConfig) – Configuration object with model parameters.
dim_state (int) – Dimension of the state space.
dim_act (int) – Dimension of the action space.
- Returns:
None
- class objectrl.models.drnd.DRND(config: MainConfig, critic_type: type = <class 'objectrl.models.drnd.DRNDCritics'>, actor_type: type = <class 'objectrl.models.drnd.DRNDActor'>, bonus_type: type = <class 'objectrl.models.drnd.DRNDBonus'>)[source]#
Bases:
ActorCriticDRND agent integrating exploration through Distributional Random Network Distillation.
Implements actor-critic logic where: - Actor loss is regularized by an exploration bonus - Critic targets include bonus penalties - Bonus predictor is trained online
Yang et al. (2024): Exploration and Anti-Exploration with Distributional Random Network Distillation
- _agent_name = 'DRND'#
- __init__(config: MainConfig, critic_type: type = <class 'objectrl.models.drnd.DRNDCritics'>, actor_type: type = <class 'objectrl.models.drnd.DRNDActor'>, bonus_type: type = <class 'objectrl.models.drnd.DRNDBonus'>) None[source]#
Initializes DRND agent.
- Parameters:
config (MainConfig) – Configuration dataclass instance.
critic_type (type) – Critic class type.
actor_type (type) – Actor class type.
- Returns:
None