# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------
import copy
import typing
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch import nn as nn
from objectrl.models.basic.ensemble import Ensemble
from objectrl.utils.net_utils import create_loss, create_optimizer
if typing.TYPE_CHECKING:
from objectrl.config.config import MainConfig
[docs]
class Critic(nn.Module):
"""
A critic network that estimates Q-values for state-action pairs.
Attributes:
device (torch.device): Device for computations.
has_target (bool): Flag for presence of target network.
_tau (float): Polyak averaging factor for target update.
_gamma (float): Discount factor for rewards.
model (nn.Module): Main critic network.
target (nn.Module, optional): Target critic network.
"""
model: nn.Module
target: nn.Module | None
[docs]
def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None:
"""
Initialize the critic network.
Args:
config: Configuration object containing model parameters
dim_state: Dimension of the state space
dim_act: Dimension of the action space
Returns:
None
"""
super().__init__()
critic = config.model.critic
self.device = config.system.device
self.has_target = critic.has_target
self._tau = config.model.tau
self._gamma = config.training.gamma
self.model = critic.arch(
dim_state,
dim_act,
depth=critic.depth,
width=critic.width,
act=critic.activation,
has_norm=critic.norm,
).to(self.device)
if self.has_target:
self.target = critic.arch(
dim_state,
dim_act,
depth=critic.depth,
width=critic.width,
act=critic.activation,
has_norm=critic.norm,
).to(self.device)
self.init_target()
[docs]
def reduce(self, q_val: torch.Tensor) -> torch.Tensor:
"""
Reduce Q-values if needed.
Args:
q_val (torch.Tensor): Q-values tensor
Returns:
torch.Tensor: Reduced Q-values.
"""
return q_val
[docs]
def Q(
self, state: torch.Tensor, action: torch.Tensor | None = None
) -> torch.Tensor:
"""
Compute Q-values for given state-action pairs.
Args:
state (torch.Tensor): State tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Q-values for the state-action pairs.
"""
return self.model(self._prepare_input(state, action))
[docs]
def init_target(self) -> None:
"""
Initialize target network with weights from the main network.
Args:
None
Returns:
None
"""
assert self.has_target, "There is no target network to initialize"
for target_param, local_param in zip(
self.target.parameters(), self.model.parameters(), strict=True
):
target_param.data.copy_(local_param.data)
[docs]
def update_target(self) -> None:
"""
Update target network parameters using soft update.
Args:
None
Returns:
None
"""
assert self.has_target, "There is no target network to update"
for target_param, local_param in zip(
self.target.parameters(), self.model.parameters(), strict=True
):
# Combine x = (1 - tau) * x + tau * y into a single inplace operation
target_param.data.lerp_(local_param.data, self._tau)
[docs]
def Q_t(
self, state: torch.Tensor, action: torch.Tensor | None = None
) -> torch.Tensor:
"""
Compute Q-values using the target network.
Args:
state (torch.Tensor): State tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Target Q-values.
"""
assert self.has_target, "There is no target network to evaluate"
return self.target(self._prepare_input(state, action))
def __getitem__(self) -> "Critic":
"""
Return self when indexed.
Args:
None
Returns:
Critic: Self instance.
"""
return self
[docs]
class CriticEnsemble(nn.Module, ABC):
"""
Ensemble of critic networks for robust Q-value estimation.
Attributes:
n_members (int): Number of critics in the ensemble.
config (MainConfig): Configuration object.
dim_state (int): Dimension of the state space.
dim_act (int): Dimension of the action space.
has_target (bool): Flag for target networks.
_reset (bool): Reset flag from config.
device (torch.device): Device for computations.
loss (callable): Loss function.
_tau (float): Polyak averaging factor.
_gamma (float): Discount factor.
model_ensemble (Ensemble[Critic]): Ensemble of critic models.
optim (torch.optim.Optimizer): Optimizer for ensemble parameters.
target_ensemble (Ensemble[Critic], optional): Target ensemble.
iter (int): Training iteration counter.
"""
[docs]
def __init__(self, config: "MainConfig", dim_state: int, dim_act: int) -> None:
"""
Initialize the critic ensemble.
Args:
config (MainConfig): Configuration object with model parameters.
dim_state (int): Dimension of the state space.
dim_act (int): Dimension of the action space.
Returns:
None
"""
super().__init__()
self.n_members = config.model.critic.n_members
self.config = config
self.dim_state = dim_state
self.dim_act = dim_act
self.has_target = config.model.critic.has_target
self._reset = config.model.critic.reset
self.device = config.system.device
self.loss = create_loss(config.model, reduction="none")
self.dim_state = dim_state
self.dim_act = dim_act
self._tau = config.model.tau
self._gamma = config.training.gamma
self.model_ensemble = Ensemble[nn.Module](
n_members=int(self.n_members),
models=[
Critic(config, dim_state, dim_act).model for _ in range(self.n_members)
],
device=self.device,
)
self.optim = create_optimizer(self.config.training)(
self.model_ensemble.parameters()
)
if self.has_target:
self.target_ensemble = Ensemble[nn.Module](
n_members=int(self.n_members),
models=[
Critic(config, dim_state, dim_act).target
for _ in range(self.n_members)
],
device=self.device,
)
self.target_ensemble.load_state_dict(self.model_ensemble.state_dict())
self.iter = 0
[docs]
def reset(self) -> None:
"""
Reset the ensemble models and optimizer.
Args:
None
Returns:
None
"""
self.model_ensemble = Ensemble[nn.Module](
n_members=int(self.n_members),
models=[
Critic(self.config, self.dim_state, self.dim_act).model
for _ in range(self.n_members)
],
device=self.device,
)
if self.has_target:
self.target_ensemble = copy.deepcopy(self.model_ensemble)
self.optim = create_optimizer(self.config.training)(
self.model_ensemble.parameters()
)
[docs]
def reduce(self, q_val: torch.Tensor, reduce_type: str) -> torch.Tensor:
"""
Reduce Q-values from multiple critics according to the configured method.
Currently supports 'min' or 'mean'. User should add more methods if needed.
Args:
q_val (torch.Tensor): Q-values tensor from all critics.
reduce_type (str): How to reduce the Q-values.
Returns:
torch.Tensor: Reduced Q-values.
"""
if reduce_type == "min":
return q_val.min(0).values
elif reduce_type == "mean":
return q_val.mean(0)
else:
raise ValueError(
f"Unknown reduction method {self.config.model.critic.reduce}. Implement it if needed."
)
[docs]
def _get_single_critic(self, index: int = 0) -> Critic:
"""
Get a single Critic instance from the ensemble.
Args:
index (int): Index of the critic to retrieve.
Returns:
Critic: Critic instance.
"""
single_critic = Critic(self.config, self.dim_state, self.dim_act)
single_critic.model.load_state_dict(self.model_ensemble[index].state_dict())
if self.has_target:
single_critic.target.load_state_dict( # type: ignore // we know that there is a state_dict
self.target_ensemble[index].state_dict()
)
return single_critic
def __getitem__(self, index: int) -> Critic:
"""
Indexing to access a single critic from the ensemble.
Args:
index (int): Index of the critic.
Returns:
Critic: Critic instance.
"""
return self._get_single_critic(index)
[docs]
def Q(
self, state: torch.Tensor, action: torch.Tensor | None = None
) -> torch.Tensor:
"""
Compute Q-values for given state-action pairs using all critics.
Args:
state (torch.Tensor): State tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Q-values from all critics.
"""
if action is None:
sa = state
else:
sa = torch.cat((state, action), -1)
return self.model_ensemble(sa)
[docs]
@torch.no_grad()
def Q_t(
self, state: torch.Tensor, action: torch.Tensor | None = None
) -> torch.Tensor:
"""
Compute Q-values using the target networks.
Args:
state (torch.Tensor): State tensor.
action (torch.Tensor): Action tensor.
Returns:
torch.Tensor: Target Q-values from all critics.
"""
assert self.has_target, "There is no target network to evaluate"
if action is None:
sa = state
else:
sa = torch.cat((state, action), -1)
return self.target_ensemble(sa)
[docs]
def update(
self, state: torch.Tensor, action: torch.Tensor, y: torch.Tensor
) -> None:
"""
Update critic networks using the provided Bellman targets.
Args:
state (torch.Tensor): State tensor.
action (torch.Tensor): Action tensor.
y (torch.Tensor): Bellman target values.
Returns:
None
"""
self.optim.zero_grad()
loss = self.loss(self.Q(state, action), self.model_ensemble.expand(y))
# Sum over the ensemble members and average over the batches
loss = loss.sum(0).mean() if self.n_members > 1 else loss.mean()
loss.backward()
self.optim.step()
self.iter += 1
[docs]
@torch.no_grad()
def update_target(self) -> None:
"""
Update target network parameters using soft update.
Args:
None
Returns:
None
"""
assert self.has_target, "There is no target network to update"
for key in self.model_ensemble.params.keys():
# Combine x = (1 - tau) * x + tau * y into a single inplace operation
self.target_ensemble.params[key].data.lerp_(
self.model_ensemble.params[key].data, self._tau
)
[docs]
@abstractmethod
@torch.no_grad()
def get_bellman_target(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""
Calculate the Bellman target for training the critic.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: Bellman target tensor.
"""
pass