Source code for objectrl.models.sac

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------

import math
import typing

import torch

from objectrl.models.basic.ac import ActorCritic
from objectrl.models.basic.actor import Actor
from objectrl.models.basic.critic import CriticEnsemble
from objectrl.utils.net_utils import create_optimizer

if typing.TYPE_CHECKING:
    from objectrl.config.config import MainConfig


[docs] class SACActor(Actor): """ Soft Actor network with automatic temperature tuning. Args: config (MainConfig): Configuration object with hyperparameters. dim_state (int): Observation space dimensions. dim_act (int): Action space dimensions. Attributes: target_entropy (float): Target entropy for temperature tuning. log_alpha (Tensor): Learnable log temperature parameter. optim_alpha (Optimizer): Optimizer for temperature parameter. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None: super().__init__(config, dim_state, dim_act) self.target_entropy = ( -dim_act if config.model.target_entropy is None else config.model.target_entropy ) self.log_alpha = torch.tensor( math.log(config.model.alpha), requires_grad=True, device=self.device, ) self.optim_alpha = create_optimizer(config.training)([self.log_alpha])
[docs] def update_alpha(self, act_dict: dict) -> None: """ Updates the temperature parameter alpha based on current policy entropy. Args: act_dict (dict): Dictionary with keys 'action_logprob' containing log probabilities. Returns: None """ log_prob = act_dict["action_logprob"] loss = -self.log_alpha.exp() * (log_prob + self.target_entropy).detach() self.optim_alpha.zero_grad() loss.mean().backward() self.optim_alpha.step()
[docs] def loss( self, state: torch.Tensor, critics: CriticEnsemble ) -> tuple[torch.Tensor, dict]: """ Computes the SAC actor loss. Args: state (Tensor): Batch of states. critics (CriticEnsemble): Critic networks for Q-value estimation. Returns: tuple: Actor loss and action dictionary containing action and log probability. """ act_dict = self.act(state) action, log_prob = act_dict["action"], act_dict["action_logprob"] q_values = critics.Q(state, action) q = critics.reduce(q_values, reduce_type=self.config.model.critic.reduce) loss = (-q + self.log_alpha.exp() * log_prob).mean() return loss, act_dict
[docs] def update(self, state: torch.Tensor, critics: CriticEnsemble) -> None: """ Performs a gradient step on the actor network and updates alpha. Args: state (Tensor): Batch of states. critics (CriticEnsemble): Critic ensemble for Q-value estimates. Returns: None """ self.optim.zero_grad() loss, act_dict = self.loss(state, critics) loss.backward() self.optim.step() self.update_alpha(act_dict) self.iter += 1 # Increment iteration counter
[docs] class SACCritic(CriticEnsemble): """ SAC critic ensemble handling Bellman target computation and updates. Args: config (MainConfig): Configuration object. dim_state (int): State space dimensions. dim_act (int): Action space dimensions. Attributes: _gamma (float): Discount factor for future rewards. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None: super().__init__(config, dim_state, dim_act)
[docs] @torch.no_grad() def get_bellman_target( self, reward: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, actor: SACActor, ) -> torch.Tensor: """ Computes target Q-values using entropy-regularized Bellman backup. Args: reward (Tensor): Reward batch. next_state (Tensor): Next state batch. done (Tensor): Done flags batch. actor (SACActor): Actor network for next action sampling. Returns: Tensor: Target Q-values for critic training. """ alpha = actor.log_alpha.exp().detach() act_dict = actor.act(next_state) next_action = act_dict["action"] action_logprob = act_dict["action_logprob"] target_values = self.Q_t(next_state, next_action) target_value = ( self.reduce( target_values, reduce_type=self.config.model.critic.target_reduce ) - alpha * action_logprob ) y = reward.unsqueeze(-1) + ( self._gamma * target_value * (1 - done.unsqueeze(-1)) ) return y
[docs] class SoftActorCritic(ActorCritic): """ Soft Actor-Critic agent combining SACActor and SACCritic. Haarnoja et al. (2018): Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor """ _agent_name = "SAC"
[docs] def __init__( self, config: "MainConfig", critic_type: type = SACCritic, actor_type: type = SACActor, ) -> None: """ Initializes SAC agent. Args: config (MainConfig): Configuration dataclass instance. critic_type (type): Critic class type. actor_type (type): Actor class type. Returns: None """ super().__init__(config, critic_type, actor_type)