Agents#

Overview#

The agents/ module defines the Agent base class, which provides a common interface and core utilities for reinforcement learning agents.

This abstract base class:

  • Manages experience replay

  • Provides model saving and loading functionality

  • Interfaces with environment and logger

  • Requires implementation of core methods: reset, learn, and select_action

Important

All reinforcement learning agents in this repository should inherit from Agent to ensure consistency and reusability of code.

Note

The Agent class is designed to be environment-agnostic. It does not assume a specific observation or action space format. You are expected to handle environment-specific preprocessing in your agent subclasses.

Base Class: Agent#

class objectrl.agents.base_agent.Agent(config: MainConfig)[source]#

Bases: Module, ABC

Abstract base class for reinforcement learning agents. Child classes must implement: - reset: For resetting agent-specific state - learn: The core learning loop - select_action: Policy used to choose actions from states

config#

Main configuration object with all submodules.

Type:

MainConfig

device#

Device on which computations are performed.

Type:

torch.device

config_env#

Environment-specific config.

config_train#

Training-specific config.

dim_state#

Shape of the observation space.

Type:

int

dim_act#

Shape of the action space.

Type:

int

_gamma#

Discount factor for future rewards.

Type:

float

_tau#

Polyak averaging coefficient for target updates.

Type:

float

experience_memory#

Experience replay buffer.

logger#

Logger used to track training and evaluation.

Type:

Logger

__init__(config: MainConfig) None[source]#

Initializes the base agent class.

Parameters:

config (MainConfig) – The experiment’s configuration object.

Returns:

None

generate_transition(**kwargs)[source]#

Constructs a transition dictionary for storing in the experience memory.

The transition includes the current state, action taken, reward received, next state, and episode termination information. All tensors are moved to the appropriate storage device, and scalar values are converted to floats.

Expected kwargs:

state (torch.Tensor): The current state. action (torch.Tensor): The action taken. reward (float or scalar Tensor): The reward received after taking the action. next_state (torch.Tensor): The next state after the transition. terminated (bool or int): Indicator whether the episode terminated. truncated (bool or int): Indicator whether the episode was truncated. step (int): Environment step count or index.

Returns:

A dictionary representing a single transition, ready to be stored.

Return type:

TensorDict

store_transition(transition: tuple[Any, ...]) None[source]#

Stores a transition tuple into the experience replay buffer.

Parameters:

transition (tuple) – A transition (s, a, r, s’, done) to be stored.

Returns:

None

save() None[source]#

Saves the model weights to disk at the logger’s checkpoint path. This method saves the current state of the agent, including model parameters.

Parameters:

None

Returns:

None

load(path: str | Path) None[source]#

Loads model weights from a given checkpoint path.

Parameters:

path (str or Path) – Path to the saved model checkpoint.

Returns:

None

requires_discrete_actions() bool[source]#
abstractmethod reset(*args, **kwargs) None[source]#

Reset any internal agent state. Must be implemented in subclass.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

None

abstractmethod learn(*args, **kwargs) None[source]#

Perform learning updates from replay buffer. Must be implemented in subclass.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

None

_abc_impl = <_abc._abc_data object>#
abstractmethod select_action(*args, **kwargs) Tensor[source]#

Select an action given the current state. Must be implemented in subclass.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The selected action tensor.

Return type:

torch.Tensor

Extending Agent#

To create a new agent, inherit Agent and implement the required methods. See our example agents in the objectrl/models/ directory for reference. Furthermore, see our example Build Your Own Model for a practical example of extending the Agent class.

Attention

When implementing your own agent, you must override the reset, learn, and select_action methods. These are the fundamental hooks that ObjectRL uses to interact with your agent.

Design Philosophy#

This interface was designed with flexibility in mind. It separates concerns cleanly:

  • Memory management is delegated to experience_memory

  • Logging is handled via a pluggable Logger

  • Training steps are decoupled via learn() and select_action() to support both offline and online training

Tip

If you’re building a new algorithm, start by extending Agent and reusing existing components like actors, critics, and loss functions from the models/ module. This will drastically reduce boilerplate.