Custom Loss Functions#

This module implements probabilistic and PAC-Bayesian loss functions for uncertainty-aware learning.

Included Losses:

  • ProbabilisticLoss: Base class for probabilistic losses supporting different reduction modes.

  • PACBayesLoss: Combines empirical risk with complexity regularization for PAC-Bayesian learning.

  • DSACLoss: Distributional loss function for the DSAC algorithm.

Extending#

Important

If you need to implement a custom loss function, add it to this module for consistency and integration with existing training pipelines. Extend from ProbabilisticLoss for probabilistic outputs or from PyTorch’s standard losses otherwise.

Here are the detailed methods and attributes for the implemented losses.

ProbabilisticLoss#

class objectrl.models.basic.loss.ProbabilisticLoss(reduction: str = 'mean')[source]#

Bases: _Loss, ABC

Base class for probabilistic loss functions.

Parameters:
  • reduction (str) – Specifies the reduction to apply to the output:

  • Default ('none' | 'mean' | 'sum'.) – ‘mean’.

reduction#

Reduction method for the loss.

Type:

str

__init__(reduction: str = 'mean')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod forward(mu_lvar_dict: dict, y: Tensor) Tensor[source]#

Forward pass to compute loss (to be implemented in subclasses).

Parameters:
  • mu_lvar_dict (dict) – Predicted mean and log_variance tensors

  • y (Tensor) – Target tensor.

Returns:

Computed loss.

Return type:

Tensor

_apply_reduction(loss: Tensor) Tensor[source]#

Apply the specified reduction to the loss tensor.

Parameters:

loss – Tensor of loss values.

Returns:

Reduced loss tensor based on the specified reduction method.

Return type:

Tensor

Raises:

ValueError – If an unknown reduction method is specified.

PACBayesLoss#

class objectrl.models.basic.loss.PACBayesLoss(config: PBACConfig)[source]#

Bases: ProbabilisticLoss

Implements PAC-Bayesian loss for critic training using uncertainty-aware estimates. Computes a PAC-Bayes bound-based Q-learning loss that penalizes uncertainty and uses bootstrapping for improved generalization.

__init__(config: PBACConfig)[source]#
Parameters:

config (PBACConfig) – Configuration object containing model settings.

forward(q: Tensor, y: Tensor, weights: Tensor | None = None) Tensor[source]#

Computes the PAC-Bayes loss between predicted Q-values and targets.

Parameters:
  • q (Tensor) – Predicted Q-values (ensemble shape: [ensemble, batch]).

  • y (Tensor) – Target Q-values (shape: [ensemble, batch]).

  • weights (Tensor, optional) – Sample weights (unused here).

Returns:

Loss scalar.

Return type:

Tensor

DSACLoss#

class objectrl.models.basic.loss.DSACLoss(config: DSACConfig)[source]#

Bases: _Loss

Distributional Soft Actor-Critic (DSAC) Loss Function.

Parameters:

config – Configuration object with loss parameters

_kappa#

Huber loss threshold.

Type:

float

__init__(config: DSACConfig)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

vec_asymmetric_huber_loss_weighted(pred: Tensor, target: Tensor, tau: Tensor, weight: Tensor) Tensor[source]#

Vectorized Asymmetric Huber Loss with weights.

Parameters:
  • pred (Tensor) – Predicted quantiles: [n_member x n_batch x n_quantiles].

  • target (Tensor) – Target quantiles: [n_member x n_batch x n_quantiles].

  • tau (Tensor) – Quantile levels: [n_quantiles] or [n_batch x n_quantiles].

  • weight (Tensor) – Weights for each quantile: [n_quantiles] or [n_batch x n_quantiles].

Returns:

Computed loss tensor.

Return type:

Tensor

forward(pred: Tensor, target: Tensor, tau: Tensor, target_tau: Tensor) Tensor[source]#

Compute the DSAC loss.

Parameters:
  • pred (Tensor) – Predicted quantiles: [n_member x n_batch x n_quantiles].

  • target (Tensor) – Target quantiles: [n_member x n_batch x n_quantiles].

  • tau (Tensor) – Quantile levels: [n_quantiles] or [n_batch x n_quantiles].

  • target_tau (Tensor) – Target quantile levels: [n_quantiles] or [n_batch x n_quantiles].

Returns:

Computed loss tensor.

Return type:

Tensor