Network Utilities#

This module provides utility functions and neural network classes commonly used for actor-critic algorithms in reinforcement learning, including deterministic and Bayesian MLP architectures.

Functions#

objectrl.utils.net_utils.create_optimizer(config) Callable[source]#

Creates a PyTorch optimizer based on the configuration.

Parameters:

config – Configuration object containing: - config.optimizer (str): Name of the optimizer (e.g., ‘Adam’, ‘SGD’). - config.learning_rate (float): Learning rate for the optimizer.

Returns:

A function that accepts model parameters and returns an optimizer instance.

Return type:

Callable

Raises:

NotImplementedError – If the optimizer name is not available in torch.optim.

objectrl.utils.net_utils.create_loss(config, reduction: str = 'none') Module[source]#

Creates a loss function module from either torch.nn or a custom module.

Parameters:
  • config – Configuration object containing: - config.loss (str): Name of the loss function.

  • reduction (str, optional) – Reduction method (‘none’, ‘mean’, or ‘sum’). Defaults to “none”.

Returns:

A PyTorch loss function module.

Return type:

nn.Module

Raises:

NotImplementedError – If the loss is not found in torch.nn or the custom module.

Classes#

FeatureExtractor#

class objectrl.utils.net_utils.FeatureExtractor(dim_in: int, depth: int, width: int, act: Literal['relu', 'sigmoid'] = 'relu', has_norm: bool = True)[source]#

Bases: Module

Generic shallow MLP for feature extraction.

Creates a stack of layers with the pattern:

Linear → (LayerNorm) → Activation

Parameters:
  • dim_in (int) – Input feature dimension.

  • depth (int) – Number of hidden layers (>= 1).

  • width (int) – Width of hidden layers.

  • act (Literal["relu", "sigmoid"]) – Activation function for all hidden layers.

  • has_norm (bool) – Whether to include LayerNorm after each linear layer.

__init__(dim_in: int, depth: int, width: int, act: Literal['relu', 'sigmoid'] = 'relu', has_norm: bool = True)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

MLP#

class objectrl.utils.net_utils.MLP(dim_in: int, dim_out: int, depth: int, width: int, act: str = 'relu', has_norm: bool = False)[source]#

Bases: Module

__init__(dim_in: int, dim_out: int, depth: int, width: int, act: str = 'relu', has_norm: bool = False) None[source]#

Constructs a fully connected Multi-Layer Perceptron (MLP).

Parameters:
  • dim_in (int) – Input feature dimension.

  • dim_out (int) – Output feature dimension.

  • depth (int) – Total number of layers (>= 1).

  • width (int) – Width of the hidden layers.

  • act (str) – Activation function. Options are: - “relu”: Standard ReLU. - “crelu”: Concatenated ReLU (doubles width).

  • has_norm (bool) – If True, applies LayerNorm between layers.

Raises:
  • AssertionError – If depth <= 0.

  • NotImplementedError – If unknown activation function is specified.

forward(x: Tensor) Tensor[source]#

Forward pass of the standard MLP.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, dim_in).

Returns:

Output tensor of shape (batch_size, dim_out).

Return type:

torch.Tensor

BayesianMLP#

class objectrl.utils.net_utils.BayesianMLP(dim_in: int, dim_out: int, depth: int, width: int, layer_type: Literal['bbb', 'lr', 'clt', 'cltdet'] = 'lr', act: Literal['crelu', 'relu'] = 'relu', has_norm: bool = False)[source]#

Bases: Module

__init__(dim_in: int, dim_out: int, depth: int, width: int, layer_type: Literal['bbb', 'lr', 'clt', 'cltdet'] = 'lr', act: Literal['crelu', 'relu'] = 'relu', has_norm: bool = False) None[source]#

Constructs a Bayesian MLP using probabilistic linear layers. Supports various types of Bayesian layers for uncertainty estimation.

Parameters:
  • dim_in (int) – Input feature dimension.

  • dim_out (int) – Output feature dimension.

  • depth (int) – Number of layers (>= 1).

  • width (int) – Width of the hidden layers.

  • layer_type (str) – Type of Bayesian linear layer. One of: - “bbb”: Bayes by Backprop. - “lr”: Local Reparameterization trick. - “clt”: Central Limit Theorem (probabilistic forward). - “cltdet”: CLT with deterministic weights.

  • act (str) – Activation function. One of “relu” or “crelu”.

  • has_norm (bool) – Whether to apply LayerNorm. Not supported for CLT variants.

Raises:
  • AssertionError – If depth <= 0 or incompatible settings.

  • NotImplementedError – For unknown layer or activation types.

forward(x: Tensor | tuple[Tensor, Tensor | None]) Tensor | tuple[Tensor, Tensor | None][source]#

Forward pass of the Bayesian MLP.

Parameters:

x (Union[Tensor, Tuple[Tensor, Optional[Tensor]]]) –

  • For standard use: input tensor.

  • For CLT-based: tuple of (mean, variance).

Returns:

Output in the same format as input.

Return type:

Union[Tensor, Tuple[Tensor, Optional[Tensor]]]

get_kl() tuple[Tensor, int][source]#

Get the KL divergence of the Bayesian MLP.

Notes#

  • The create_optimizer function dynamically selects and configures an optimizer from torch.optim.

  • The create_loss function supports both PyTorch and custom loss modules (e.g., from objectrl.models.basic.loss).

  • The FeatureExtractor class creates a stack of layers with the pattern: Linear → (LayerNorm) → Activation

  • The MLP class supports ReLU and CReLU activations and optional LayerNorm.

  • The BayesianMLP supports multiple Bayesian linear layer types:

    • “bbb”: Bayes by Backprop

    • “lr”: Local Reparameterization

    • “clt”: Central Limit Theorem

    • “cltdet”: Deterministic CLT

Example#

net = MLP(128, 64, depth=3, width=256, act="crelu", has_norm=True)
out = net(torch.randn(32, 128))

bayesian_net = BayesianMLP(128, 64, depth=3, width=256, layer_type="lr")
out_bnn = bayesian_net(torch.randn(32, 128))