Distributional Soft Actor Critic (DSAC)#

distributional rl quantile regression

Paper: DSAC: Distributional Soft Actor-Critic for Risk-Sensitive Reinforcement Learning

Pseudocode#

Configuration#

Specific configuration for the DSAC algorithm (in config/model_configs/).#
@dataclass
class CriticLossConfig:
    """
    Configuration for the DSAC critic loss.

    Attributes:
        kappa (float): Huber loss threshold.
    """

    kappa: float = 1.0


@dataclass
class DSACActorConfig:
    """
    Configuration for the DSAC Actor network.

    Attributes:
        arch (type): Architecture class for the actor network.
        actor_type (type): Actor class type.
        has_target (bool): Whether to use a target network.
    """

    arch: type = ActorNetProbabilistic
    actor_type: type = DSACActor
    has_target: bool = True


@dataclass
class DSACCriticConfig:
    """
    Configuration for the DSAC Critic network.

    Attributes:
        arch (type): Architecture class for the critic network.
        critic_type (type): Critic class type.
        n_quantiles (int): Number of atoms for quantile regression.
        has_target (bool): Whether to use a target network.
        n_members (int): Number of critic members.
        tau_type (Literal["fix", "iqn"]): Type of quantile regression.
    """

    arch: type = QuantileCriticNet
    critic_type: type = DSACCritic
    norm: bool = True
    n_quantiles: int = 8
    has_target: bool = True
    n_members: int = 2
    tau_type: Literal["fix", "iqn"] = "iqn"

    def __post_init__(self):
        self.dim_out = self.n_quantiles


@dataclass
class DSACConfig:
    """
    Main DSAC algorithm configuration.

    Attributes:
        name (str): Name of the algorithm.
        loss (str): Loss function used.
        policy_delay (int): Delay for policy updates.
        tau (float): Soft update coefficient.
        target_entropy (float | None): Target entropy for the policy.
        learnable_alpha (bool): Whether the temperature parameter alpha is learnable.
        alpha (float): Initial value of the temperature parameter alpha.
        actor (DSACActorConfig): Configuration for the actor network.
        critic (DSACCriticConfig): Configuration for the critic network.
    """

    name: str = "dsac"
    lossparams: CriticLossConfig = field(default_factory=CriticLossConfig)
    loss: str = "DSACLoss"
    policy_delay: int = 1
    tau: float = 0.005
    target_entropy: float | None = None
    learnable_alpha: bool = True
    alpha: float = 1.0

    actor: DSACActorConfig = field(default_factory=DSACActorConfig)
    critic: DSACCriticConfig = field(default_factory=DSACCriticConfig)

    def __post_init__(self):
        if isinstance(self.lossparams, dict):
            self.lossparams = CriticLossConfig(**self.lossparams)


UML Diagram#

UML diagram for the DSAC algorithm.

UML diagram for the DSAC algorithm.#

We use the UML diagram to illustrate the relationships between the classes in our DSAC implementation.

The diagram shows how the DSACActor and DSACCritic classes inherit from SACActor and SACCritic, respectively. DistributionalSoftActorCritic class also inherits from ActorCritic class which inherits from Agent.

We illustrate each class's crucial attributes and methods for DSAC. Specifically:

DSACActor adapts the SAC actor to support both fixed and learnable entropy temperature alpha.

When learnable_alpha=False, the temperature is frozen, and no optimizer is maintained. The actor loss is modified to use quantile-weighted Q-values produced by the distributional critic.

DSACCritic implements quantile-based value estimation. The get_tau() method generates quantile fractions either uniformly (fixed) or using an IQN-style sampling strategy. The Q() and Q_t() methods evaluate the ensemble over these quantile midpoints using torch.vmap, returning full value distributions.

The get_bellman_target() method computes entropy-regularized distributional Bellman targets by applying SAC's clipped minimum across ensemble quantile outputs. The update() method performs quantile regression to align predicted and target quantile distributions.

Classes#

class objectrl.models.dsac.DSACActor(config, dim_state, dim_act)[source]#

Bases: SACActor

Distributional Soft Actor-Critic (DSAC) Actor.

Parameters:
  • config (MainConfig) – Configuration object containing model specifications.

  • dim_state (int) – Dimension of the state space.

  • dim_act (int) – Dimension of the action space.

learnable_alpha#

Indicates if the temperature parameter alpha is learnable.

Type:

bool

__init__(config, dim_state, dim_act)[source]#

Initializes the Actor.

Parameters:
  • config (MainConfig) – Configuration dataclass instance.

  • dim_state (int) – Dimension of observation space.

  • dim_act (int) – Dimension of action space.

Returns:

None

update_alpha(act_dict: dict) None[source]#

Updates alpha only if learnable_alpha=True.

Parameters:

act_dict (dict) – Dictionary containing action information.

Returns:

None

loss(state: Tensor, critics: CriticEnsemble) tuple[Tensor, dict][source]#

Computes the actor loss for DSAC.

Parameters:
  • state (Tensor) – Batch of states.

  • critics (CriticEnsemble) – Critic networks for Q-value estimation.

Returns:

Actor loss and action dictionary containing action and log probability.

Return type:

tuple

class objectrl.models.dsac.DSACCritic(config: MainConfig, dim_state: int, dim_act: int)[source]#

Bases: SACCritic

Distributional Soft Actor-Critic (DSAC) Critic.

Parameters:
  • config (MainConfig) – Configuration object containing model specifications.

  • dim_state (int) – Dimension of the state space.

  • dim_act (int) – Dimension of the action space.

num_quantiles#

Number of quantile atoms.

Type:

int

tau_type#

Type of tau generation (‘fix’ or ‘iqn’).

Type:

str

__init__(config: MainConfig, dim_state: int, dim_act: int) None[source]#

Initialize the critic ensemble.

Parameters:
  • config (MainConfig) – Configuration object with model parameters.

  • dim_state (int) – Dimension of the state space.

  • dim_act (int) – Dimension of the action space.

Returns:

None

get_tau(batch_size)[source]#

Generates tau values based on the specified tau_type.

Parameters:

batch_size (int) – The batch size for which tau values are generated.

Returns:

A tuple containing tau, tau_hat, and presum_tau tensors.

Return type:

tuple

Q(state: Tensor, action: Tensor, tau: Tensor) Tensor[source]#

Computes the Q-values for given state, action, and tau values.

Parameters:
  • state (torch.Tensor) – The state tensor.

  • action (torch.Tensor) – The action tensor.

  • tau (torch.Tensor) – The tau tensor representing quantile fractions.

Returns:

The computed Q-values.

Return type:

torch.Tensor

Q_t(state: Tensor, action: Tensor, tau: Tensor) Tensor[source]#

Computes the target Q-values for given state, action, and tau values.

Parameters:
  • state (torch.Tensor) – The state tensor.

  • action (torch.Tensor) – The action tensor.

  • tau (torch.Tensor) – The tau tensor representing quantile fractions.

Returns:

The computed target Q-values.

Return type:

torch.Tensor

get_bellman_target(reward: Tensor, next_state: Tensor, done: Tensor, actor: DSACActor) Tensor[source]#

Computes the Bellman target for the given reward, next state, done flag, and actor.

Parameters:
  • reward (torch.Tensor) – The reward tensor.

  • next_state (torch.Tensor) – The next state tensor.

  • done (torch.Tensor) – The done flag tensor.

  • actor (DSACActor) – The actor instance.

Returns:

Bellman target values.

Return type:

Tensor

update(state: Tensor, action: Tensor, y: tuple[Tensor, Tensor]) None[source]#

Updates the critic network using the given state, action, and target values.

Parameters:
  • state (torch.Tensor) – The state tensor.

  • action (torch.Tensor) – The action tensor.

  • y (tuple[torch.Tensor, torch.Tensor]) – The target values and target tau.

Returns:

None

class objectrl.models.dsac.DistributionalSoftActorCritic(config: MainConfig, critic_type: type = <class 'objectrl.models.dsac.DSACCritic'>, actor_type: type = <class 'objectrl.models.dsac.DSACActor'>)[source]#

Bases: ActorCritic

Distributional Soft Actor-Critic agent combining DSACActor and DSACCritic. Ma et al. (2025): DSAC: Distributional Soft Actor-Critic for Risk-Sensitive Reinforcement Learning

_agent_name = 'DSAC'#
__init__(config: MainConfig, critic_type: type = <class 'objectrl.models.dsac.DSACCritic'>, actor_type: type = <class 'objectrl.models.dsac.DSACActor'>) None[source]#

Initializes DSAC agent.

Parameters:
  • config (MainConfig) – Configuration dataclass instance.

  • critic_type (type) – Critic class type.

  • actor_type (type) – Actor class type.

Returns:

None