Source code for objectrl.replay_buffers.experience_memory

# -----------------------------------------------------------------------------------
# ObjectRL: An Object-Oriented Reinforcement Learning Codebase
# Copyright (C) 2025 ADIN Lab

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------------
import gc
import warnings
from collections.abc import Iterator

import torch
from tensordict import TensorDict
from torchrl.data import (
    LazyMemmapStorage,
    LazyTensorStorage,
    TensorDictReplayBuffer,
)


[docs] class ReplayBuffer: """ A fixed-size replay buffer to store and sample experience tuples for reinforcement learning. This buffer uses either lazy memory mapping on the cpu or a lazy tensor on the gpu and TorchRL's TensorDictReplayBuffer for efficient storage and retrieval. """
[docs] def __init__( self, device: torch.device, storing_device: torch.device, buffer_size: int, print_gc_warning: bool = False, ) -> None: """ Initialize the ReplayBuffer. Args: device (torch.device): The device to move sampled batches to (usually the training device). storing_device (torch.device): The device to store the experience buffer (e.g., CPU or gpu). Unless memory is a concern, use `cuda` to improve runtime. buffer_size (int): Maximum number of experience tuples the buffer can hold. print_gc_warning (bool): Print a warning that the garbage collector is run Returns: None """ self.device = device self.storing_device = storing_device self.buffer_size = buffer_size self.print_gc_warning = print_gc_warning self.reset()
[docs] def _get_storage( self, buffer_size: int, device: torch.device ) -> LazyMemmapStorage | LazyTensorStorage: # Use LazyMemmap on CPU and LazyTensorStorage on GPU if device.type == "cpu": return LazyMemmapStorage(buffer_size, device=device) elif device.type == "cuda": return LazyTensorStorage(buffer_size, device=device) else: raise NotImplementedError( f"No storage support for device type: {device.type}" )
[docs] def reset(self, buffer_size: int | None = None) -> None: """ Reset the buffer, optionally resizing it. Args: buffer_size (int, optional): If provided, sets a new maximum buffer size. Returns: None """ if buffer_size is not None: self.buffer_size = buffer_size self.data_size = 0 self.pointer = 0 storage = self._get_storage(self.buffer_size, self.storing_device) # Initialize TensorDictReplayBuffer with the storage self.memory = TensorDictReplayBuffer(storage=storage)
[docs] def add(self, experience: TensorDict) -> None: """ Add a single experience to the buffer. Args: experience (TensorDict): A single experience entry, usually a dictionary of tensors. Returns: None """ try: self.memory.add(experience) except OSError: warnings.warn( "Failed to add experience to replay buffer, triggering a manual garbage collection", stacklevel=2, ) gc.collect() self.memory.add(experience) self.data_size = min(self.data_size + 1, self.buffer_size) self.pointer = (self.pointer + 1) % self.buffer_size
[docs] def add_batch(self, batch: TensorDict) -> None: """ Add a batch of experiences to memory. Args: batch: A TensorDict containing the batch of experiences to be added. """ try: self.memory.extend(batch) except OSError: warnings.warn( "Failed to add experience to replay buffer, triggering a manual garbage collection", stacklevel=2, ) gc.collect() self.memory.extend(batch) self.data_size = min(self.data_size + len(batch), self.buffer_size) self.pointer = (self.pointer + len(batch)) % self.buffer_size
[docs] def sample_batch(self, batch_size: int) -> TensorDict: """ Randomly sample a batch of experiences from the buffer. Args: batch_size (int): The number of samples to draw. Returns: TensorDict: A batch of randomly sampled experiences moved to the working device. """ batch = self.memory.sample(batch_size).to(self.device).clone() return batch
[docs] def sample_random(self, batch_size: int) -> TensorDict: """ Alias for `sample_batch`. Args: batch_size (int): The number of samples to draw. Returns: TensorDict: A batch of randomly sampled experiences. """ # Same as sample_batch for the new implementation return self.sample_batch(batch_size)
[docs] def sample_by_index(self, indices: list | torch.Tensor | range) -> TensorDict: """ Sample specific experiences by index. Args: indices (Union[list, torch.Tensor, range]): Indices of experiences to sample. Returns: TensorDict: A batch of selected experiences moved to the working device. """ if isinstance(indices, range): indices = torch.tensor(list(indices), device=self.storing_device) elif isinstance(indices, list): indices = torch.tensor(indices, device=self.storing_device) return self.memory.storage[indices].to(self.device).clone()
[docs] def sample_by_index_fields( self, indices: list | torch.Tensor | range, fields: list ) -> TensorDict: """ Sample specific fields of selected experiences by index. Args: indices (Union[list, torch.Tensor, range]): Indices of experiences to sample. fields (List[str]): Names of fields to retrieve (e.g., ['obs', 'action']). Returns: Union[torch.Tensor, Tuple[torch.Tensor]]: If one field, returns a tensor. If multiple fields, returns a tuple of tensors. """ if isinstance(indices, range): indices = torch.tensor(list(indices), device=self.storing_device) elif isinstance(indices, list): indices = torch.tensor(indices, device=self.storing_device) return self.memory.storage[indices].select(fields).to(self.device).clone()
[docs] def sample_all(self) -> TensorDict: """ Sample all experiences currently stored in the buffer. Returns: TensorDict: All stored experiences up to `data_size`. """ return self.sample_by_index(range(self.data_size))
def __len__(self) -> int: """ Get the number of experiences currently stored in the buffer. Args: None Returns: int: The current number of stored experiences. """ return self.data_size @property def size(self) -> int: """ Property alias for the current number of stored experiences. Args: None Returns: int: Current buffer size (number of items stored). """ return self.data_size
[docs] def save(self, path: str) -> None: """ Save memory metadata to a file. Args: path (str): Path to save the metadata file. Returns: None """ metadata = { "buffer_size": self.buffer_size, "data_size": self.data_size, "pointer": self.pointer, } torch.save(metadata, path + ".metadata")
[docs] def load(self, path: str) -> None: """ Load buffer metadata from disk (does not restore experience data). Args: path (str): The file path (excluding file extension) to load metadata from. Returns: None """ metadata = torch.load(path + ".metadata") self.buffer_size = metadata["buffer_size"] self.data_size = metadata["data_size"] self.pointer = metadata["pointer"]
[docs] def create_epoch_iterator(self, batch_size: int, n_epochs: int = 1) -> Iterator: """ Create an iterator that yields batches from the buffer for multiple epochs. Args: batch_size (int): Number of samples per batch. n_epochs (int): Number of epochs to iterate through the data. Returns: Iterator: An iterator that yields batches of experience. """ total_samples = self.data_size def batch_generator(): for _ in range(n_epochs): # Create indices for the entire dataset indices = torch.arange(total_samples, device=self.storing_device) # Yield batches for i in range(0, total_samples, batch_size): batch_indices = indices[i : min(i + batch_size, total_samples)] yield self.sample_by_index(batch_indices) self.epoch_iterator = batch_generator() return self.epoch_iterator
[docs] def get_next_batch(self, batch_size: int) -> TensorDict: """ Retrieve the next batch from the current epoch iterator. Falls back to random sampling if the iterator is not initialized. Args: batch_size (int): Batch size to use for fallback random sampling. Returns: TensorDict: A batch of experience data. """ if self.epoch_iterator is not None: return next(self.epoch_iterator) return self.sample_batch(batch_size)
[docs] def calculate_num_batches(self, batch_size: int) -> int: """ Calculate how many full batches can be drawn from the current buffer content. Args: batch_size (int): Number of samples per batch. Returns: int: Total number of batches possible with current data size. """ return (self.data_size + batch_size - 1) // batch_size
[docs] def get_steps_and_iterator( self, n_epochs: int, max_iter: int, batch_size: int ) -> int: """ Compute total training steps and initialize an internal batch iterator. Args: n_epochs (int): Number of training epochs. If > 0, iterator will be used. max_iter (int): Number of learning updates to perform (used if n_epochs = 0). batch_size (int): Number of samples per training step. Returns: int: Total number of training steps. """ if n_epochs > 0: n_batches = self.calculate_num_batches(batch_size) n_steps = n_epochs * n_batches # Initialize the internal iterator self.create_epoch_iterator(batch_size, n_epochs) else: n_steps = max_iter self.epoch_iterator = None return n_steps