Source code for objectrl.models.basic.ac

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------

import typing

import torch

from objectrl.agents.base_agent import Agent
from objectrl.models.basic.actor import Actor
from objectrl.models.basic.critic import CriticEnsemble

if typing.TYPE_CHECKING:
    from objectrl.config.config import MainConfig


[docs] class ActorCritic(Agent): """ Base Actor-Critic agent combining an actor policy and a critic ensemble. This class serves as a foundation for various Actor-Critic algorithms. Args: config (MainConfig): Configuration object containing model hyperparameters. critic_type (type[CriticEnsemble]): Type of critic to use. actor_type (type[Actor]): Type of actor to use. Attributes: critic (CriticEnsemble): Critic network ensemble instance. actor (Actor): Actor network instance. policy_delay (int): Number of critic updates per actor update. n_iter (int): Iteration counter for training steps. """ _agent_name = "AC"
[docs] def __init__( self, config: "MainConfig", critic_type: type[CriticEnsemble], actor_type: type[Actor], ) -> None: """ Initializes the ActorCritic agent with actor and critic networks. Args: config (MainConfig): Configuration dataclass instance. critic_type (type[CriticEnsemble]): Critic class type. actor_type (type[Actor]): Actor class type. Returns: None """ super().__init__(config) self.critic = critic_type(config, self.dim_state, self.dim_act) self.actor = actor_type(config, self.dim_state, self.dim_act) self.policy_delay: int = config.model.policy_delay self.n_iter: int = 0
[docs] def learn(self, max_iter: int = 1, n_epochs: int = 0) -> None: """ Perform the learning process for the agent. Args: max_iter (int): Maximum number of iterations for learning. n_epochs (int): Number of epochs for training. If 0, random sampling is used. Returns: None """ # Check if there is enough data in memory to sample a batch if self.config_train.batch_size > len(self.experience_memory): return None # Determine the number of steps and initialize the iterator n_steps = self.experience_memory.get_steps_and_iterator( n_epochs, max_iter, self.config_train.batch_size ) for _ in range(n_steps): # Get batch using the internal iterator batch = self.experience_memory.get_next_batch(self.config_train.batch_size) bellman_target = self.critic.get_bellman_target( batch["reward"], batch["next_state"], batch["terminated"], self.actor ) self.critic.update(batch["state"], batch["action"], bellman_target) # Update the actor network periodically if self.n_iter % self.policy_delay == 0: self.actor.update(batch["state"], self.critic) if self.actor.has_target: self.actor.update_target() # Update target networks if self.critic.has_target: self.critic.update_target() self.n_iter += 1 return None
[docs] @torch.no_grad() def select_action( self, state: torch.Tensor, is_training: bool = True ) -> torch.Tensor: """ Select an action based on the current state. Args: state (torch.Tensor): The current state. is_training (bool): Whether the agent is in training mode. Returns: torch.Tensor: The selected action. """ act_dict = self.actor.act(state, is_training=is_training) return act_dict
[docs] def reset(self) -> None: """ Reset the agent. Args: None Returns: None """ if self.actor._reset: self.actor.reset() if self.critic._reset: self.critic.reset()