Replay Buffers#
Overview#
The replay_buffers/ module contains the ReplayBuffer class, which provides an efficient and extensible experience replay mechanism for reinforcement learning agents.
This buffer:
Stores experience tuples using TorchRL’s TensorDictReplayBuffer
Uses LazyMemmapStorage for efficient CPU memory handling
Supports random, index-based, and full-batch sampling
Allows metadata saving and loading for checkpointing
Enables epoch-based training via internal iterators
Note
This replay buffer is compatible with both CPU and GPU workflows by using a device (for training) and a storing_device (for storage) during initialization.
ReplayBuffer Class#
- class objectrl.replay_buffers.experience_memory.ReplayBuffer(device: device, storing_device: device, buffer_size: int, print_gc_warning: bool = False)[source]#
Bases:
objectA fixed-size replay buffer to store and sample experience tuples for reinforcement learning.
This buffer uses either lazy memory mapping on the cpu or a lazy tensor on the gpu and TorchRL’s TensorDictReplayBuffer for efficient storage and retrieval.
- __init__(device: device, storing_device: device, buffer_size: int, print_gc_warning: bool = False) None[source]#
Initialize the ReplayBuffer.
- Parameters:
device (torch.device) – The device to move sampled batches to (usually the training device).
storing_device (torch.device) – The device to store the experience buffer (e.g., CPU or gpu). Unless memory is a concern, use cuda to improve runtime.
buffer_size (int) – Maximum number of experience tuples the buffer can hold.
print_gc_warning (bool) – Print a warning that the garbage collector is run
- Returns:
None
- reset(buffer_size: int | None = None) None[source]#
Reset the buffer, optionally resizing it.
- Parameters:
buffer_size (int, optional) – If provided, sets a new maximum buffer size.
- Returns:
None
- add(experience: TensorDict) None[source]#
Add a single experience to the buffer.
- Parameters:
experience (TensorDict) – A single experience entry, usually a dictionary of tensors.
- Returns:
None
- add_batch(batch: TensorDict) None[source]#
Add a batch of experiences to memory.
- Parameters:
batch – A TensorDict containing the batch of experiences to be added.
- sample_batch(batch_size: int) TensorDict[source]#
Randomly sample a batch of experiences from the buffer.
- Parameters:
batch_size (int) – The number of samples to draw.
- Returns:
A batch of randomly sampled experiences moved to the working device.
- Return type:
TensorDict
- sample_random(batch_size: int) TensorDict[source]#
Alias for sample_batch.
- Parameters:
batch_size (int) – The number of samples to draw.
- Returns:
A batch of randomly sampled experiences.
- Return type:
TensorDict
- sample_by_index(indices: list | Tensor | range) TensorDict[source]#
Sample specific experiences by index.
- Parameters:
indices (Union[list, torch.Tensor, range]) – Indices of experiences to sample.
- Returns:
A batch of selected experiences moved to the working device.
- Return type:
TensorDict
- sample_by_index_fields(indices: list | Tensor | range, fields: list) TensorDict[source]#
Sample specific fields of selected experiences by index.
- Parameters:
indices (Union[list, torch.Tensor, range]) – Indices of experiences to sample.
fields (List[str]) – Names of fields to retrieve (e.g., [‘obs’, ‘action’]).
- Returns:
- If one field, returns a tensor.
If multiple fields, returns a tuple of tensors.
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor]]
- sample_all() TensorDict[source]#
Sample all experiences currently stored in the buffer.
- Returns:
All stored experiences up to data_size.
- Return type:
TensorDict
- property size: int#
Property alias for the current number of stored experiences.
- Parameters:
None
- Returns:
Current buffer size (number of items stored).
- Return type:
int
- save(path: str) None[source]#
Save memory metadata to a file.
- Parameters:
path (str) – Path to save the metadata file.
- Returns:
None
- load(path: str) None[source]#
Load buffer metadata from disk (does not restore experience data).
- Parameters:
path (str) – The file path (excluding file extension) to load metadata from.
- Returns:
None
- create_epoch_iterator(batch_size: int, n_epochs: int = 1) Iterator[source]#
Create an iterator that yields batches from the buffer for multiple epochs.
- Parameters:
batch_size (int) – Number of samples per batch.
n_epochs (int) – Number of epochs to iterate through the data.
- Returns:
An iterator that yields batches of experience.
- Return type:
Iterator
- get_next_batch(batch_size: int) TensorDict[source]#
Retrieve the next batch from the current epoch iterator. Falls back to random sampling if the iterator is not initialized.
- Parameters:
batch_size (int) – Batch size to use for fallback random sampling.
- Returns:
A batch of experience data.
- Return type:
TensorDict
- calculate_num_batches(batch_size: int) int[source]#
Calculate how many full batches can be drawn from the current buffer content.
- Parameters:
batch_size (int) – Number of samples per batch.
- Returns:
Total number of batches possible with current data size.
- Return type:
int
- get_steps_and_iterator(n_epochs: int, max_iter: int, batch_size: int) int[source]#
Compute total training steps and initialize an internal batch iterator.
- Parameters:
n_epochs (int) – Number of training epochs. If > 0, iterator will be used.
max_iter (int) – Number of learning updates to perform (used if n_epochs = 0).
batch_size (int) – Number of samples per training step.
- Returns:
Total number of training steps.
- Return type:
int
ReplayBuffer Management#
Initialization#
You must now initialize the replay buffer by explicitly specifying the devices and buffer size:
from objectrl.replay_buffers.experience_memory import ReplayBuffer
buffer = ReplayBuffer(
device=torch.device("cuda"), # Training device
storing_device=torch.device("cpu"), # Storage device
buffer_size=100_000 # Maximum number of transitions
)
Adding Experience#
You can add either individual experiences or batches in the form of TorchRL TensorDict objects:
buffer.add(single_experience) # Single transition
buffer.add_batch(batch_experience) # Batch of transitions
Sampling#
You can sample in multiple ways:
Random batch:
batch = buffer.sample_batch(64)
By index:
batch = buffer.sample_by_index([0, 5, 9])
Specific fields:
obs_act = buffer.sample_by_index_fields([0, 5], fields=["obs", "action"])
Entire buffer:
all_data = buffer.sample_all()
Note
All sampled experiences are automatically moved to the training device.
Epoch-Based Training#
The buffer supports epoch-based batch iteration using an internal iterator:
iterator = buffer.create_epoch_iterator(batch_size=64, n_epochs=5)
for batch in iterator:
train_step(batch)
You can also:
Use
ReplayBuffer.get_next_batch()to fetch the next batchUse
ReplayBuffer.calculate_num_batches()to determine how many full batches you can drawUse
ReplayBuffer.get_steps_and_iterator()to prepare steps + iterator simultaneously:n_steps = buffer.get_steps_and_iterator(n_epochs=10, max_iter=0, batch_size=64)
Iteration-Based Training#
When you set n_epochs=0, the replay buffer will not initialize an internal epoch-based iterator. Instead, training proceeds using a fixed number of update iterations, and batches are sampled randomly at each step using sample_batch() internally.
This mode is useful when:
You do not want to cycle through the buffer in epochs
You prefer fully stochastic sampling per training step
Your training logic depends on a fixed number of updates (e.g., max_steps=1000)
n_steps = buffer.get_steps_and_iterator(
n_epochs=0, # Disables epoch-based iterator
max_iter=1, # Number of learning updates at each step
batch_size=64
)
Note
When n_epochs=0, get_next_batch() uses random sampling instead of drawing from a predefined iterator.
Saving and Loading Metadata#
You can persist buffer metadata only (not full contents) for checkpointing:
buffer.save("checkpoints/buffer") # Saves to checkpoints/buffer.metadata
buffer.load("checkpoints/buffer") # Loads metadata only
This includes:
buffer_size: Total buffer capacity
data_size: Current number of stored transitions
pointer: Insertion index for next sample
Attention
The load() method does not restore experience data. Only the metadata is recovered. To restore full contents, consider future serialization support or log replay externally.
Buffer Size and Status#
You can inspect buffer usage via:
len(buffer) — number of stored transitions
buffer.size — same as above, for convenience
print(len(buffer)) # e.g., 9500
print(buffer.size) # e.g., 9500