# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------
import typing
from abc import ABC, abstractmethod
import torch
from torch import nn as nn
from objectrl.models.basic.critic import CriticEnsemble
from objectrl.utils.net_utils import create_optimizer
if typing.TYPE_CHECKING:
from objectrl.config.config import MainConfig
[docs]
class Actor(nn.Module, ABC):
"""
Abstract base class for Actor network in Actor-Critic algorithms.
Handles policy network, optional target network, and optimization.
Attributes:
config (MainConfig): Configuration object.
device (torch.device): Device for tensor computations.
verbose (bool): Verbosity flag.
has_target (bool): Flag for using a target network.
iter (int): Training iteration counter.
dim_state (int): Observation space shape.
dim_act (int): Action space shape.
_tau (float): Polyak averaging coefficient for target updates.
_gamma (float): Discount factor for returns.
_reset (bool): Flag whether to reset model at initialization.
model (nn.Module): Main actor network.
target (nn.Module, optional): Target actor network.
optim (torch.optim.Optimizer): Optimizer for the actor parameters.
"""
[docs]
def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None:
"""
Initializes the Actor.
Args:
config (MainConfig): Configuration dataclass instance.
dim_state (int): Dimension of observation space.
dim_act (int): Dimension of action space.
Returns:
None
"""
super().__init__()
self.config = config
self.device = config.system.device
self.verbose = config.verbose
self.has_target = config.model.actor.has_target
self.iter = 0
self.dim_state, self.dim_act = dim_state, dim_act
self._tau = config.model.tau
self._gamma = config.training.gamma
self._reset = config.model.actor.reset
self.reset()
[docs]
def reset(self) -> None:
"""
Initializes or resets the main and target policy networks and optimizer.
Also sets the model architecture based on the configuration.
Args:
None
Returns:
None
"""
self.model = self.config.model.actor.arch(
self.dim_state,
self.dim_act,
depth=self.config.model.actor.depth,
width=self.config.model.actor.width,
act=self.config.model.actor.activation,
has_norm=self.config.model.actor.norm,
n_heads=self.config.model.actor.n_heads,
).to(self.device)
self.optim = create_optimizer(self.config.training)(self.model.parameters())
# Initialize target network if required
if self.has_target:
self.target = self.config.model.actor.arch(
self.dim_state,
self.dim_act,
depth=self.config.model.actor.depth,
width=self.config.model.actor.width,
act=self.config.model.actor.activation,
has_norm=self.config.model.actor.norm,
n_heads=self.config.model.actor.n_heads,
).to(self.device)
self.init_target()
[docs]
def init_target(self) -> None:
"""
Copies the main model parameters to the target network.
Args:
None
Returns:
None
"""
assert self.has_target, "There is no target network to initialize"
for target_param, local_param in zip(
self.target.parameters(), self.model.parameters(), strict=True
):
target_param.data.copy_(local_param.data)
[docs]
def act(self, state: torch.Tensor, is_training: bool = True) -> dict:
"""
Computes actions given input states.
Args:
state (torch.Tensor): Input state tensor.
is_training (bool): Whether in training mode.
Returns:
dict: Dictionary containing action tensor and optionally log probabilities.
"""
return_dict = self.model(state, is_training=is_training)
return return_dict
[docs]
def act_target(self, state: torch.Tensor) -> dict:
"""
Computes actions using the target policy network.
Args:
state (torch.Tensor): Input state tensor.
Returns:
dict: Dictionary containing action tensor and log probabilities.
"""
assert self.has_target, "There is no target network to evaluate"
return self.target(state)
[docs]
@torch.no_grad()
def update_target(self) -> None:
"""
Performs a soft update of the target network using Polyak averaging.
Args:
None
Returns:
None
"""
if not self.has_target:
return None
for target_param, local_param in zip(
self.target.parameters(), self.model.parameters(), strict=True
):
# Combine x = (1 - tau) * x + tau * y into a single inplace operation
target_param.data.lerp_(local_param.data, self._tau)
[docs]
@abstractmethod
def loss(self, *args, **kwargs) -> torch.Tensor:
"""
Abstract method to compute the loss for the actor.
Should be overridden in subclasses.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: Computed loss tensor.
"""
pass
[docs]
def update(self, state: torch.Tensor, critics: CriticEnsemble) -> None:
"""
Performs a gradient update on the actor network.
Args:
state (Tensor): Input state batch.
critics (object): Critic networks for computing Q-values.
Returns:
None
"""
self.optim.zero_grad()
loss = self.loss(state, critics)
loss.backward()
if self.config.model.actor.max_grad_norm > 0:
nn.utils.clip_grad_norm_(
self.parameters(),
self.config.model.actor.max_grad_norm,
)
self.optim.step()
self.iter += 1 # Increment iteration counter