Replay Buffers#

Overview#

The replay_buffers/ module contains the ReplayBuffer class, which provides an efficient and extensible experience replay mechanism for reinforcement learning agents.

This buffer:

  • Stores experience tuples using TorchRL’s TensorDictReplayBuffer

  • Uses LazyMemmapStorage for efficient CPU memory handling

  • Supports random, index-based, and full-batch sampling

  • Allows metadata saving and loading for checkpointing

  • Enables epoch-based training via internal iterators

Note

This replay buffer is compatible with both CPU and GPU workflows by using a device (for training) and a storing_device (for storage) during initialization.

ReplayBuffer Class#

class objectrl.replay_buffers.experience_memory.ReplayBuffer(device: device, storing_device: device, buffer_size: int, print_gc_warning: bool = False)[source]#

Bases: object

A fixed-size replay buffer to store and sample experience tuples for reinforcement learning.

This buffer uses either lazy memory mapping on the cpu or a lazy tensor on the gpu and TorchRL’s TensorDictReplayBuffer for efficient storage and retrieval.

__init__(device: device, storing_device: device, buffer_size: int, print_gc_warning: bool = False) None[source]#

Initialize the ReplayBuffer.

Parameters:
  • device (torch.device) – The device to move sampled batches to (usually the training device).

  • storing_device (torch.device) – The device to store the experience buffer (e.g., CPU or gpu). Unless memory is a concern, use cuda to improve runtime.

  • buffer_size (int) – Maximum number of experience tuples the buffer can hold.

  • print_gc_warning (bool) – Print a warning that the garbage collector is run

Returns:

None

_get_storage(buffer_size: int, device: device) LazyMemmapStorage | LazyTensorStorage[source]#
reset(buffer_size: int | None = None) None[source]#

Reset the buffer, optionally resizing it.

Parameters:

buffer_size (int, optional) – If provided, sets a new maximum buffer size.

Returns:

None

add(experience: TensorDict) None[source]#

Add a single experience to the buffer.

Parameters:

experience (TensorDict) – A single experience entry, usually a dictionary of tensors.

Returns:

None

add_batch(batch: TensorDict) None[source]#

Add a batch of experiences to memory.

Parameters:

batch – A TensorDict containing the batch of experiences to be added.

sample_batch(batch_size: int) TensorDict[source]#

Randomly sample a batch of experiences from the buffer.

Parameters:

batch_size (int) – The number of samples to draw.

Returns:

A batch of randomly sampled experiences moved to the working device.

Return type:

TensorDict

sample_random(batch_size: int) TensorDict[source]#

Alias for sample_batch.

Parameters:

batch_size (int) – The number of samples to draw.

Returns:

A batch of randomly sampled experiences.

Return type:

TensorDict

sample_by_index(indices: list | Tensor | range) TensorDict[source]#

Sample specific experiences by index.

Parameters:

indices (Union[list, torch.Tensor, range]) – Indices of experiences to sample.

Returns:

A batch of selected experiences moved to the working device.

Return type:

TensorDict

sample_by_index_fields(indices: list | Tensor | range, fields: list) TensorDict[source]#

Sample specific fields of selected experiences by index.

Parameters:
  • indices (Union[list, torch.Tensor, range]) – Indices of experiences to sample.

  • fields (List[str]) – Names of fields to retrieve (e.g., [‘obs’, ‘action’]).

Returns:

If one field, returns a tensor.

If multiple fields, returns a tuple of tensors.

Return type:

Union[torch.Tensor, Tuple[torch.Tensor]]

sample_all() TensorDict[source]#

Sample all experiences currently stored in the buffer.

Returns:

All stored experiences up to data_size.

Return type:

TensorDict

property size: int#

Property alias for the current number of stored experiences.

Parameters:

None

Returns:

Current buffer size (number of items stored).

Return type:

int

save(path: str) None[source]#

Save memory metadata to a file.

Parameters:

path (str) – Path to save the metadata file.

Returns:

None

load(path: str) None[source]#

Load buffer metadata from disk (does not restore experience data).

Parameters:

path (str) – The file path (excluding file extension) to load metadata from.

Returns:

None

create_epoch_iterator(batch_size: int, n_epochs: int = 1) Iterator[source]#

Create an iterator that yields batches from the buffer for multiple epochs.

Parameters:
  • batch_size (int) – Number of samples per batch.

  • n_epochs (int) – Number of epochs to iterate through the data.

Returns:

An iterator that yields batches of experience.

Return type:

Iterator

get_next_batch(batch_size: int) TensorDict[source]#

Retrieve the next batch from the current epoch iterator. Falls back to random sampling if the iterator is not initialized.

Parameters:

batch_size (int) – Batch size to use for fallback random sampling.

Returns:

A batch of experience data.

Return type:

TensorDict

calculate_num_batches(batch_size: int) int[source]#

Calculate how many full batches can be drawn from the current buffer content.

Parameters:

batch_size (int) – Number of samples per batch.

Returns:

Total number of batches possible with current data size.

Return type:

int

get_steps_and_iterator(n_epochs: int, max_iter: int, batch_size: int) int[source]#

Compute total training steps and initialize an internal batch iterator.

Parameters:
  • n_epochs (int) – Number of training epochs. If > 0, iterator will be used.

  • max_iter (int) – Number of learning updates to perform (used if n_epochs = 0).

  • batch_size (int) – Number of samples per training step.

Returns:

Total number of training steps.

Return type:

int

ReplayBuffer Management#

Initialization#

You must now initialize the replay buffer by explicitly specifying the devices and buffer size:

from objectrl.replay_buffers.experience_memory import ReplayBuffer
buffer = ReplayBuffer(
    device=torch.device("cuda"),             # Training device
    storing_device=torch.device("cpu"),      # Storage device
    buffer_size=100_000                      # Maximum number of transitions
)

Adding Experience#

You can add either individual experiences or batches in the form of TorchRL TensorDict objects:

buffer.add(single_experience)       # Single transition
buffer.add_batch(batch_experience)  # Batch of transitions

Sampling#

You can sample in multiple ways:

  • Random batch:

    batch = buffer.sample_batch(64)
    
  • By index:

    batch = buffer.sample_by_index([0, 5, 9])
    
  • Specific fields:

    obs_act = buffer.sample_by_index_fields([0, 5], fields=["obs", "action"])
    
  • Entire buffer:

    all_data = buffer.sample_all()
    

Note

All sampled experiences are automatically moved to the training device.

Epoch-Based Training#

The buffer supports epoch-based batch iteration using an internal iterator:

iterator = buffer.create_epoch_iterator(batch_size=64, n_epochs=5)

for batch in iterator:
    train_step(batch)

You can also:

  • Use ReplayBuffer.get_next_batch() to fetch the next batch

  • Use ReplayBuffer.calculate_num_batches() to determine how many full batches you can draw

  • Use ReplayBuffer.get_steps_and_iterator() to prepare steps + iterator simultaneously:

    n_steps = buffer.get_steps_and_iterator(n_epochs=10, max_iter=0, batch_size=64)
    

Iteration-Based Training#

When you set n_epochs=0, the replay buffer will not initialize an internal epoch-based iterator. Instead, training proceeds using a fixed number of update iterations, and batches are sampled randomly at each step using sample_batch() internally.

This mode is useful when:

  • You do not want to cycle through the buffer in epochs

  • You prefer fully stochastic sampling per training step

  • Your training logic depends on a fixed number of updates (e.g., max_steps=1000)

n_steps = buffer.get_steps_and_iterator(
    n_epochs=0,         # Disables epoch-based iterator
    max_iter=1,         # Number of learning updates at each step
    batch_size=64
)

Note

When n_epochs=0, get_next_batch() uses random sampling instead of drawing from a predefined iterator.

Saving and Loading Metadata#

You can persist buffer metadata only (not full contents) for checkpointing:

buffer.save("checkpoints/buffer")   # Saves to checkpoints/buffer.metadata
buffer.load("checkpoints/buffer")   # Loads metadata only

This includes:

  • buffer_size: Total buffer capacity

  • data_size: Current number of stored transitions

  • pointer: Insertion index for next sample

Attention

The load() method does not restore experience data. Only the metadata is recovered. To restore full contents, consider future serialization support or log replay externally.

Buffer Size and Status#

You can inspect buffer usage via:

  • len(buffer) — number of stored transitions

  • buffer.size — same as above, for convenience

print(len(buffer))        # e.g., 9500
print(buffer.size)        # e.g., 9500