Source code for objectrl.models.td3

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------

import typing

import torch

from objectrl.models.basic.ac import ActorCritic
from objectrl.models.basic.actor import Actor
from objectrl.models.basic.critic import CriticEnsemble
from objectrl.utils.utils import totorch

if typing.TYPE_CHECKING:
    from objectrl.config.config import MainConfig


[docs] class TD3Actor(Actor): """ TD3 actor network with action noise for exploration and target policy smoothing. Args: config (MainConfig): Configuration object. dim_state (int): Observation space dimensions. dim_act (int): Action space dimensions. Attributes: policy_noise (float): Noise std for exploration. target_policy_noise (float): Noise std for target policy smoothing. target_policy_noise_clip (float): Clipping range for target noise. action_limit_low (Tensor): Lower bound for actions. action_limit_high (Tensor): Upper bound for actions. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None: super().__init__(config, dim_state, dim_act) noise = config.model.noise self.policy_noise = noise.policy_noise self.target_policy_noise = noise.target_policy_noise self.target_policy_noise_clip = noise.target_policy_noise_clip self.action_limit_low = totorch( config.env.env.action_space.low, device=self.device # type:ignore ) self.action_limit_high = totorch( config.env.env.action_space.high, device=self.device # type:ignore )
[docs] def act(self, state: torch.Tensor, is_training: bool = True) -> dict: """ Computes actions with optional noise added for exploration. Args: state (Tensor): Batch of states. is_training (bool): Whether in training mode (adds noise if True). Returns: dict: Contains 'action' tensor and 'action_wo_noise' tensor. """ action_dict = super().act(state) action = action_dict["action"] action_dict["action_wo_noise"] = action if is_training: noise = torch.normal(0, self.policy_noise, action.shape).to(self.device) action += noise action = torch.clip(action, self.action_limit_low, self.action_limit_high) action_dict["action"] = action return action_dict
[docs] def act_target(self, state: torch.Tensor) -> dict: """ Computes target policy action with smoothing noise added. Args: state (Tensor): Batch of next states. Returns: dict: Contains 'action' tensor with added clipped noise. """ action_dict = super().act(state) action = action_dict["action"] noise = torch.normal(0, self.target_policy_noise, action.shape).to(self.device) noise = torch.clip( noise, -self.target_policy_noise_clip, self.target_policy_noise_clip ) action += noise action = torch.clip(action, self.action_limit_low, self.action_limit_high) action_dict["action"] = action return action_dict
[docs] def loss(self, state: torch.Tensor, critics: CriticEnsemble) -> torch.Tensor: """ Computes actor loss as negative Q-value estimate. Args: state (Tensor): Batch of states. critics (CriticEnsemble): Critic networks. Returns: Tensor: Actor loss to maximize Q-values. """ act_dict = self.act(state, is_training=False) action = act_dict["action"] q = critics[0].Q(state, action) return (-q).mean()
[docs] class TD3Critic(CriticEnsemble): """ TD3 critic ensemble handling Bellman target computation and training loss. Args: config (MainConfig): Configuration object. dim_state (int): Observation space dimensions. dim_act (int): Action space dimensions. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None: super().__init__(config, dim_state, dim_act)
[docs] @torch.no_grad() def get_bellman_target( self, reward: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, actor: TD3Actor, ) -> torch.Tensor: """ Computes target Q-values using Bellman backup. Args: reward (Tensor): Rewards batch. next_state (Tensor): Next state batch. done (Tensor): Done flags batch. actor (TD3Actor): Target policy actor network. Returns: Tensor: Bellman target Q-values. """ next_action_dict = actor.act_target(next_state) next_action = next_action_dict["action"] target_values = self.Q_t(next_state, next_action) target_value = self.reduce( target_values, reduce_type=self.config.model.critic.target_reduce ) y = reward.unsqueeze(-1) + self._gamma * target_value * (1 - done.unsqueeze(-1)) return y
[docs] class TwinDelayedDeepDeterministicPolicyGradient(ActorCritic): """ TD3 agent combining delayed policy updates and clipped noise target smoothing. Fujimoto et al. (2018): Addressing Function Approximation Error in Actor-Critic Methods """ _agent_name = "TD3"
[docs] def __init__( self, config: "MainConfig", critic_type: type = TD3Critic, actor_type: type = TD3Actor, ) -> None: """ Initializes the TD3 agent. Args: config (MainConfig): Configuration dataclass instance. critic_type (type): Critic class type. actor_type (type): Actor class type. Returns: None """ super().__init__(config, critic_type, actor_type)