Source code for objectrl.models.oac

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------

import math
import typing

import torch
from torch.distributions import TransformedDistribution

from objectrl.models.basic.ac import ActorCritic
from objectrl.models.basic.actor import Actor
from objectrl.models.basic.critic import CriticEnsemble

if typing.TYPE_CHECKING:
    from objectrl.config.config import MainConfig


[docs] class OptimisticNoise: """ Computes optimistic exploration noise as described in the OAC algorithm. Attributes: beta_ub (float): Coefficient on standard deviation of Q-values. delta (float): Exploration confidence parameter. """
[docs] def __init__(self, beta_ub: float, delta: float) -> None: self.beta_ub = beta_ub self.delta = delta
[docs] def compute( self, state: torch.Tensor, critics: CriticEnsemble, transformed_dist: TransformedDistribution, ) -> dict: """ Computes the optimistic adjustment to the mean of the action distribution. Args: state (Tensor): Input state tensor. critics (CriticEnsemble): Critic ensemble to evaluate Q-values. transformed_dist (TransformedDistribution): Tanh-transformed distribution from actor. Returns: dict: Contains adjusted mean ('mu_e') and scale ('scale') of the new action distribution. """ pre_tanh_mu = transformed_dist.base_dist.loc # type: ignore pre_tanh_mu.requires_grad_() tanh_mu = torch.tanh(pre_tanh_mu) q_values = critics.Q(state, tanh_mu) q_mean = q_values.mean(dim=0) q_std = q_values.std(dim=0, unbiased=False) q_ub = q_mean + self.beta_ub * q_std grad = torch.autograd.grad(q_ub.sum(), pre_tanh_mu)[0] sigma_t = transformed_dist.base_dist.scale.square() # type: ignore denom = torch.sqrt(torch.sum(grad.square() * sigma_t)) + 1e-6 mu_c = math.sqrt(2.0 * self.delta) * (sigma_t * grad) / denom mu_e = pre_tanh_mu + mu_c return {"mu_e": mu_e, "scale": transformed_dist.base_dist.scale} # type: ignore
[docs] class GaussianNoise: """ Adds Gaussian noise to actions, used for target value perturbation in critic updates. Attributes: sigma_target (float): Standard deviation of the noise. noise_clamp (float): Value to clamp the noise between [-noise_clamp, noise_clamp]. """
[docs] def __init__(self, sigma_target=0, noise_clamp=0.15): self.sigma_target = sigma_target self.noise_clamp = noise_clamp
[docs] def add_noise(self, next_action_shape: torch.Size) -> torch.Tensor: """ Generates Gaussian noise for a given action shape. Args: next_action_shape (torch.Size): Shape of the action tensor. Returns: Tensor: Clamped noise tensor. """ noise = torch.distributions.Normal(0, self.sigma_target).sample( sample_shape=next_action_shape ) noise = noise.clamp(-self.noise_clamp, self.noise_clamp) return noise
[docs] class OACActor(Actor): """ OAC-specific actor class with optimistic noise-based exploration. Inherits from a base probabilistic actor, and modifies the loss function to incorporate upper-confidence bounds via Q-value ensembles. Args: config (MainConfig): Global configuration. dim_state (int): Dimensionality of observation space. dim_act (int): Dimensionality of action space. Attributes: optimist_noise (OptimisticNoise): Instance to compute optimistic exploration noise. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int): super().__init__(config, dim_state, dim_act) exploration = config.model.exploration self.optimistic_noise = OptimisticNoise( beta_ub=exploration.beta_ub, delta=exploration.delta )
[docs] def loss(self, state: torch.Tensor, critics: CriticEnsemble) -> torch.Tensor: """ Computes the actor loss using the mean Q-value. Args: state (Tensor): Input states. critics (CriticEnsemble): Critic networks. Returns: Tensor: Scalar loss value. """ act_dict = self.act(state, is_training=False) action = act_dict["action"] q_values = critics.Q(state, action) q = critics.reduce(q_values, reduce_type=self.config.model.critic.reduce) return (-q).mean()
[docs] class OACCritic(CriticEnsemble): """ OAC-specific critic ensemble class, adds Gaussian noise to actions for more robust target computation. Args: config (MainConfig): Global configuration. dim_state (int): Dimensionality of state space. dim_act (int): Dimensionality of action space. """
[docs] def __init__(self, config: "MainConfig", dim_state: int, dim_act: int): super().__init__(config, dim_state, dim_act) noise = config.model.noise self.noise = GaussianNoise( sigma_target=noise.sigma_target, noise_clamp=noise.noise_clamp, )
[docs] @torch.no_grad() def get_bellman_target( self, reward: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, actor: OACActor, ) -> torch.Tensor: """ Computes the Bellman target for TD learning using noisy next actions. Args: reward (Tensor): Reward signal. next_state (Tensor): Next state input. done (Tensor): Episode termination flags. actor (OACActor): Actor used to compute next action. Returns: Tensor: Bellman target values. """ act_dict = actor.act(next_state) next_action = act_dict["action"] noise = self.noise.add_noise(next_action.shape).to(self.device) next_action += noise target_values = self.Q_t(next_state, next_action) # Use perturbed action target_value = self.reduce( target_values, reduce_type=self.config.model.critic.target_reduce ) y = reward.unsqueeze(-1) + ( self._gamma * target_value * (1 - done.unsqueeze(-1)) ) return y
[docs] class OptimisticActorCritic(ActorCritic): """ OAC agent class that integrates the OAC actor and critic. Ciosek et al. (2019): Better Exploration with Optimistic Actor-Critic """ _agent_name = "OAC"
[docs] def __init__( self, config: "MainConfig", critic_type: type = OACCritic, actor_type: type = OACActor, ) -> None: """ Initializes the OAC agent. Args: config (MainConfig): Configuration dataclass instance. critic_type (type): Critic class type. actor_type (type): Actor class type. Returns: None """ super().__init__(config, critic_type, actor_type)
[docs] def select_action( self, state: torch.Tensor, is_training: bool = True ) -> torch.Tensor: """ Selects an action given a state, optionally applying optimistic exploration. Args: state (Tensor): Input state tensor. is_training (bool): If True, applies optimistic exploration noise. Returns: Tensor: Action to execute in the environment. """ state = state.to(self.device) act_dict = self.actor.act(state) action, transformed_dist = act_dict["action"], act_dict["dist"] if is_training: result = self.actor.optimistic_noise.compute( state, self.critic, transformed_dist ) mu_e, scale = result["mu_e"], result["scale"] dist_bt = torch.distributions.Normal(mu_e, scale, validate_args=False) dist = torch.distributions.TransformedDistribution( dist_bt, torch.distributions.transforms.TanhTransform(cache_size=1) ) action = dist.sample() act_dict["dist"] = dist act_dict["action_logprob"] = dist.log_prob(action).sum(dim=-1, keepdim=True) else: action = torch.tanh(transformed_dist.base_dist.loc) # type: ignore act_dict["action"] = action return act_dict