Source code for objectrl.models.dsac

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------

import math
import typing

import torch

from objectrl.models.basic.ac import ActorCritic
from objectrl.models.basic.critic import CriticEnsemble
from objectrl.models.sac import SACActor, SACCritic

if typing.TYPE_CHECKING:
    from objectrl.config.config import MainConfig


[docs] class DSACActor(SACActor): """ Distributional Soft Actor-Critic (DSAC) Actor. Args: config (MainConfig): Configuration object containing model specifications. dim_state (int): Dimension of the state space. dim_act (int): Dimension of the action space. Attributes: learnable_alpha (bool): Indicates if the temperature parameter alpha is learnable. """
[docs] def __init__(self, config, dim_state, dim_act): super().__init__(config, dim_state, dim_act) self.learnable_alpha = config.model.learnable_alpha if not self.learnable_alpha: # Fix log_alpha as a non-learnable tensor (constant) self.log_alpha = torch.tensor( math.log(config.model.alpha), requires_grad=False, device=self.device, ) # No optimizer for alpha del self.optim_alpha self.optim_alpha = None
[docs] def update_alpha(self, act_dict: dict) -> None: """ Updates alpha only if learnable_alpha=True. Args: act_dict (dict): Dictionary containing action information. Returns: None """ if self.learnable_alpha: super().update_alpha(act_dict)
[docs] def loss( self, state: torch.Tensor, critics: CriticEnsemble ) -> tuple[torch.Tensor, dict]: """ Computes the actor loss for DSAC. Args: state (Tensor): Batch of states. critics (CriticEnsemble): Critic networks for Q-value estimation. Returns: tuple: Actor loss and action dictionary containing action and log probability. """ batch_size = state.shape[0] act_dict = self.act(state) action, log_prob = act_dict["action"], act_dict["action_logprob"] tau, tau_hat, presum_tau = critics.get_tau(batch_size=batch_size) z_values = critics.Q(state, action, tau_hat) q_values = torch.sum(z_values * presum_tau, dim=-1, keepdim=True) q = torch.min(q_values, dim=0).values loss = (-q + self.log_alpha.exp() * log_prob).mean() return loss, act_dict
[docs] class DSACCritic(SACCritic): """ Distributional Soft Actor-Critic (DSAC) Critic. Args: config (MainConfig): Configuration object containing model specifications. dim_state (int): Dimension of the state space. dim_act (int): Dimension of the action space. Attributes: num_quantiles (int): Number of quantile atoms. tau_type (str): Type of tau generation ('fix' or 'iqn'). """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None: assert ( config.model.critic.n_quantiles > 1 ), "Number of quantiles must be greater than one" super().__init__(config, dim_state, dim_act) self.num_quantiles = config.model.critic.n_quantiles self.tau_type = config.model.critic.tau_type
[docs] def get_tau(self, batch_size): """ Generates tau values based on the specified tau_type. Args: batch_size (int): The batch size for which tau values are generated. Returns: tuple: A tuple containing tau, tau_hat, and presum_tau tensors. """ if self.tau_type == "fix": presum_tau = ( torch.zeros(batch_size, self.num_quantiles, device=self.device) + 1.0 / self.num_quantiles ) elif self.tau_type == "iqn": # add 0.1 to prevent tau getting too close presum_tau = ( torch.rand(batch_size, self.num_quantiles, device=self.device) + 0.1 ) presum_tau /= presum_tau.sum(dim=-1, keepdim=True) else: raise NotImplementedError(f"tau_type {self.tau_type} not implemented") tau = torch.cumsum( presum_tau, dim=1 ) # (N, T), note that they are tau1...tauN in the paper with torch.no_grad(): tau_hat = torch.zeros_like(tau) tau_hat[:, 0:1] = tau[:, 0:1] / 2.0 tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.0 return tau, tau_hat, presum_tau
[docs] def Q( self, state: torch.Tensor, action: torch.Tensor, tau: torch.Tensor ) -> torch.Tensor: """ Computes the Q-values for given state, action, and tau values. Args: state (torch.Tensor): The state tensor. action (torch.Tensor): The action tensor. tau (torch.Tensor): The tau tensor representing quantile fractions. Returns: torch.Tensor: The computed Q-values. """ sa = torch.cat((state, action), -1) # [n_batch x dim_state + dim_act] tau = tau # [n_batch x n_quantiles] # Vectorize over the sample dimension output = torch.vmap( lambda a, b: self.model_ensemble((a, b)), in_dims=(None, 1), out_dims=2 )( sa, tau.unsqueeze(1) ) # [n_member x n_batch x 1 x n_quantiles ] return output.squeeze(2) # [n_member x n_batch x n_quantiles]
[docs] def Q_t( self, state: torch.Tensor, action: torch.Tensor, tau: torch.Tensor ) -> torch.Tensor: """ Computes the target Q-values for given state, action, and tau values. Args: state (torch.Tensor): The state tensor. action (torch.Tensor): The action tensor. tau (torch.Tensor): The tau tensor representing quantile fractions. Returns: torch.Tensor: The computed target Q-values. """ sa = torch.cat((state, action), -1) # [n_batch x dim_state + dim_act] # Vectorize over the sample dimension output = torch.vmap( lambda a, b: self.target_ensemble((a, b)), in_dims=(None, 1), out_dims=2 )( sa, tau.unsqueeze(1) ) # [n_member x n_batch x 1 x n_quantiles ] return output.squeeze(2) # [n_member x n_batch x n_quantiles]
[docs] @torch.no_grad() def get_bellman_target( self, reward: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, actor: DSACActor, ) -> torch.Tensor: """ Computes the Bellman target for the given reward, next state, done flag, and actor. Args: reward (torch.Tensor): The reward tensor. next_state (torch.Tensor): The next state tensor. done (torch.Tensor): The done flag tensor. actor (DSACActor): The actor instance. Returns: Tensor: Bellman target values. """ batch_size = reward.shape[0] alpha = actor.log_alpha.exp().detach() # Get actions from target actor act_dict = actor.act_target(next_state) next_action = act_dict["action"] action_logprob = act_dict["action_logprob"] next_tau, next_tau_hat, next_presum_tau = self.get_tau(batch_size=batch_size) target_z_values = self.Q_t(next_state, next_action, next_tau_hat) min_clip_z_values = torch.min(target_z_values, dim=0).values z_next_values = min_clip_z_values - alpha * action_logprob z_target = ( reward.unsqueeze(-1) + self._gamma * (1 - done.unsqueeze(-1)) * z_next_values ) return z_target, next_presum_tau
[docs] def update( self, state: torch.Tensor, action: torch.Tensor, y: tuple[torch.Tensor, torch.Tensor], ) -> None: """ Updates the critic network using the given state, action, and target values. Args: state (torch.Tensor): The state tensor. action (torch.Tensor): The action tensor. y (tuple[torch.Tensor, torch.Tensor]): The target values and target tau. Returns: None """ self.optim.zero_grad() batch_size = state.shape[0] tau, tau_hat, presum_tau = self.get_tau(batch_size=batch_size) pred_quantiles = self.Q( state, action, tau_hat ) # [n_ensemble x n_batch x n_quantiles] y, target_tau = y # Unpack target values and target tau loss = self.loss( pred_quantiles, self.model_ensemble.expand(y), tau_hat, target_tau ) # [n_ensemble x n_batch x n_quantiles x n_quantiles] loss = ( loss.sum(-1).mean(axis=(1, 2)).sum(0) if self.n_members > 1 else loss.sum(-1).mean() ) loss.backward() self.optim.step() self.iter += 1
[docs] class DistributionalSoftActorCritic(ActorCritic): """ Distributional Soft Actor-Critic agent combining DSACActor and DSACCritic. Ma et al. (2025): DSAC: Distributional Soft Actor-Critic for Risk-Sensitive Reinforcement Learning """ _agent_name = "DSAC"
[docs] def __init__( self, config: "MainConfig", critic_type: type = DSACCritic, actor_type: type = DSACActor, ) -> None: """ Initializes DSAC agent. Args: config (MainConfig): Configuration dataclass instance. critic_type (type): Critic class type. actor_type (type): Actor class type. Returns: None """ # Add postfix to name based on tau_type post_name_tag = "_Q" if config.model.critic.tau_type == "fix" else "_IQ" config.model.name += post_name_tag super().__init__(config, critic_type, actor_type)