Distributional Soft Actor Critic (DSAC)#
distributional rl quantile regressionPaper: DSAC: Distributional Soft Actor-Critic for Risk-Sensitive Reinforcement Learning
Pseudocode#
Configuration#
@dataclass
class CriticLossConfig:
"""
Configuration for the DSAC critic loss.
Attributes:
kappa (float): Huber loss threshold.
"""
kappa: float = 1.0
@dataclass
class DSACActorConfig:
"""
Configuration for the DSAC Actor network.
Attributes:
arch (type): Architecture class for the actor network.
actor_type (type): Actor class type.
has_target (bool): Whether to use a target network.
"""
arch: type = ActorNetProbabilistic
actor_type: type = DSACActor
has_target: bool = True
@dataclass
class DSACCriticConfig:
"""
Configuration for the DSAC Critic network.
Attributes:
arch (type): Architecture class for the critic network.
critic_type (type): Critic class type.
n_quantiles (int): Number of atoms for quantile regression.
has_target (bool): Whether to use a target network.
n_members (int): Number of critic members.
tau_type (Literal["fix", "iqn"]): Type of quantile regression.
"""
arch: type = QuantileCriticNet
critic_type: type = DSACCritic
norm: bool = True
n_quantiles: int = 8
has_target: bool = True
n_members: int = 2
tau_type: Literal["fix", "iqn"] = "iqn"
def __post_init__(self):
self.dim_out = self.n_quantiles
@dataclass
class DSACConfig:
"""
Main DSAC algorithm configuration.
Attributes:
name (str): Name of the algorithm.
loss (str): Loss function used.
policy_delay (int): Delay for policy updates.
tau (float): Soft update coefficient.
target_entropy (float | None): Target entropy for the policy.
learnable_alpha (bool): Whether the temperature parameter alpha is learnable.
alpha (float): Initial value of the temperature parameter alpha.
actor (DSACActorConfig): Configuration for the actor network.
critic (DSACCriticConfig): Configuration for the critic network.
"""
name: str = "dsac"
lossparams: CriticLossConfig = field(default_factory=CriticLossConfig)
loss: str = "DSACLoss"
policy_delay: int = 1
tau: float = 0.005
target_entropy: float | None = None
learnable_alpha: bool = True
alpha: float = 1.0
actor: DSACActorConfig = field(default_factory=DSACActorConfig)
critic: DSACCriticConfig = field(default_factory=DSACCriticConfig)
def __post_init__(self):
if isinstance(self.lossparams, dict):
self.lossparams = CriticLossConfig(**self.lossparams)
UML Diagram#
UML diagram for the DSAC algorithm.#
We use the UML diagram to illustrate the relationships between the classes in our DSAC implementation.
The diagram shows how the DSACActor and DSACCritic classes inherit from SACActor and SACCritic, respectively. DistributionalSoftActorCritic class also inherits from ActorCritic class which inherits from Agent.
We illustrate each class's crucial attributes and methods for DSAC. Specifically:
DSACActor adapts the SAC actor to support both fixed and learnable entropy temperature alpha.
When learnable_alpha=False, the temperature is frozen, and no optimizer is maintained. The actor loss is modified to use quantile-weighted Q-values produced by the distributional critic.
DSACCritic implements quantile-based value estimation. The get_tau() method generates quantile fractions either uniformly (fixed) or using an IQN-style sampling strategy. The Q() and Q_t() methods evaluate the ensemble over these quantile midpoints using torch.vmap, returning full value distributions.
The get_bellman_target() method computes entropy-regularized distributional Bellman targets by applying SAC's clipped minimum across ensemble quantile outputs. The update() method performs quantile regression to align predicted and target quantile distributions.
Classes#
- class objectrl.models.dsac.DSACActor(config, dim_state, dim_act)[source]#
Bases:
SACActorDistributional Soft Actor-Critic (DSAC) Actor.
- Parameters:
config (MainConfig) – Configuration object containing model specifications.
dim_state (int) – Dimension of the state space.
dim_act (int) – Dimension of the action space.
- learnable_alpha#
Indicates if the temperature parameter alpha is learnable.
- Type:
bool
- __init__(config, dim_state, dim_act)[source]#
Initializes the Actor.
- Parameters:
config (MainConfig) – Configuration dataclass instance.
dim_state (int) – Dimension of observation space.
dim_act (int) – Dimension of action space.
- Returns:
None
- update_alpha(act_dict: dict) None[source]#
Updates alpha only if learnable_alpha=True.
- Parameters:
act_dict (dict) – Dictionary containing action information.
- Returns:
None
- loss(state: Tensor, critics: CriticEnsemble) tuple[Tensor, dict][source]#
Computes the actor loss for DSAC.
- Parameters:
state (Tensor) – Batch of states.
critics (CriticEnsemble) – Critic networks for Q-value estimation.
- Returns:
Actor loss and action dictionary containing action and log probability.
- Return type:
tuple
- class objectrl.models.dsac.DSACCritic(config: MainConfig, dim_state: int, dim_act: int)[source]#
Bases:
SACCriticDistributional Soft Actor-Critic (DSAC) Critic.
- Parameters:
config (MainConfig) – Configuration object containing model specifications.
dim_state (int) – Dimension of the state space.
dim_act (int) – Dimension of the action space.
- num_quantiles#
Number of quantile atoms.
- Type:
int
- tau_type#
Type of tau generation (‘fix’ or ‘iqn’).
- Type:
str
- __init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#
Initialize the critic ensemble.
- Parameters:
config (MainConfig) – Configuration object with model parameters.
dim_state (int) – Dimension of the state space.
dim_act (int) – Dimension of the action space.
- Returns:
None
- get_tau(batch_size)[source]#
Generates tau values based on the specified tau_type.
- Parameters:
batch_size (int) – The batch size for which tau values are generated.
- Returns:
A tuple containing tau, tau_hat, and presum_tau tensors.
- Return type:
tuple
- Q(state: Tensor, action: Tensor, tau: Tensor) Tensor[source]#
Computes the Q-values for given state, action, and tau values.
- Parameters:
state (torch.Tensor) – The state tensor.
action (torch.Tensor) – The action tensor.
tau (torch.Tensor) – The tau tensor representing quantile fractions.
- Returns:
The computed Q-values.
- Return type:
torch.Tensor
- Q_t(state: Tensor, action: Tensor, tau: Tensor) Tensor[source]#
Computes the target Q-values for given state, action, and tau values.
- Parameters:
state (torch.Tensor) – The state tensor.
action (torch.Tensor) – The action tensor.
tau (torch.Tensor) – The tau tensor representing quantile fractions.
- Returns:
The computed target Q-values.
- Return type:
torch.Tensor
- get_bellman_target(reward: Tensor, next_state: Tensor, done: Tensor, actor: DSACActor) Tensor[source]#
Computes the Bellman target for the given reward, next state, done flag, and actor.
- Parameters:
reward (torch.Tensor) – The reward tensor.
next_state (torch.Tensor) – The next state tensor.
done (torch.Tensor) – The done flag tensor.
actor (DSACActor) – The actor instance.
- Returns:
Bellman target values.
- Return type:
Tensor
- update(state: Tensor, action: Tensor, y: tuple[Tensor, Tensor]) None[source]#
Updates the critic network using the given state, action, and target values.
- Parameters:
state (torch.Tensor) – The state tensor.
action (torch.Tensor) – The action tensor.
y (tuple[torch.Tensor, torch.Tensor]) – The target values and target tau.
- Returns:
None
- class objectrl.models.dsac.DistributionalSoftActorCritic(config: MainConfig, critic_type: type = <class 'objectrl.models.dsac.DSACCritic'>, actor_type: type = <class 'objectrl.models.dsac.DSACActor'>)[source]#
Bases:
ActorCriticDistributional Soft Actor-Critic agent combining DSACActor and DSACCritic. Ma et al. (2025): DSAC: Distributional Soft Actor-Critic for Risk-Sensitive Reinforcement Learning
- _agent_name = 'DSAC'#
- __init__(config: MainConfig, critic_type: type = <class 'objectrl.models.dsac.DSACCritic'>, actor_type: type = <class 'objectrl.models.dsac.DSACActor'>) None[source]#
Initializes DSAC agent.
- Parameters:
config (MainConfig) – Configuration dataclass instance.
critic_type (type) – Critic class type.
actor_type (type) – Actor class type.
- Returns:
None