Source code for hpfracc.ml.losses

"""
Loss Functions with Fractional Calculus Integration

This module provides loss functions that incorporate fractional derivatives,
enabling enhanced training dynamics and potentially better convergence.
Supports multiple backends: PyTorch, JAX, and NUMBA.
"""

import warnings
import numpy as np
from abc import ABC, abstractmethod
from typing import Optional, Any

from hpfracc.core.definitions import FractionalOrder
from .fractional_autograd import fractional_derivative
from .backends import get_backend_manager, BackendType
from .fractional_derivative_native import fractional_feature_map_native
from .tensor_ops import get_tensor_ops


[docs] class FractionalLossFunction(ABC): """ Base class for loss functions with fractional calculus integration This class provides a framework for loss functions that can apply fractional derivatives to predictions before computing the loss. Supports multiple backends: PyTorch, JAX, and NUMBA. """
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None): self.fractional_order = FractionalOrder(fractional_order) self.method = method # Set backend self.backend = backend or get_backend_manager().active_backend self.tensor_ops = get_tensor_ops(self.backend)
[docs] def fractional_forward(self, x: Any) -> Any: """ Apply fractional derivative to input tensor Args: x: Input tensor Returns: Tensor with fractional derivative applied """ # Apply fractional derivative based on backend resolved = self.backend if resolved == BackendType.AUTO: resolved = get_backend_manager().active_backend if resolved == BackendType.TORCH: return fractional_derivative( x, self.fractional_order.alpha, self.method) if resolved == BackendType.JAX: import jax.numpy as jnp m = (self.method or "RL").upper() alpha = float(self.fractional_order.alpha) try: return fractional_feature_map_native( x, backend=BackendType.JAX, alpha=alpha, method=m, axis=-1, h=None, t_grid=None, ) except NotImplementedError as e: warnings.warn( f"JAX fractional_forward: native discrete operator unavailable ({e!r}); " "using elementwise |x|^alpha * sign(x) (not a Grรผnwaldโ€“Letnikov / L1 Caputo " "derivative along the last axis).", RuntimeWarning, stacklevel=2, ) return jnp.power(jnp.abs(x) + 1e-8, alpha) * jnp.sign(x) # NUMBA lane: NumPy arrays; try native discrete GL / L1 Caputo on last axis first. x_arr = np.asarray(x, dtype=np.float32) m = (self.method or "RL").upper() alpha = float(self.fractional_order.alpha) try: return fractional_feature_map_native( x_arr, backend=BackendType.NUMBA, alpha=alpha, method=m, axis=-1, h=None, t_grid=None, ) except NotImplementedError as e: warnings.warn( f"NUMBA fractional_forward: native discrete operator unavailable ({e!r}); " "using elementwise |x|^alpha * sign(x) proxy.", RuntimeWarning, stacklevel=2, ) return np.power(np.abs(x_arr) + 1e-8, alpha) * np.sign(x_arr)
[docs] @abstractmethod def compute_loss(self, predictions: Any, targets: Any) -> Any: """ Compute the base loss function Args: predictions: Model predictions targets: Ground truth targets Returns: Loss value """
[docs] def forward( self, predictions: Any, targets: Any, use_fractional: bool = True ) -> Any: """ Forward pass for loss computation Args: predictions: Model predictions targets: Ground truth targets use_fractional: Whether to apply fractional derivatives Returns: Loss value """ if use_fractional: # Apply fractional derivative to predictions predictions = self.fractional_forward(predictions) # Compute the base loss return self.compute_loss(predictions, targets)
def __call__(self, predictions: Any, targets: Any, use_fractional: bool = True) -> Any: """Make the loss function callable""" return self.forward(predictions, targets, use_fractional)
[docs] class FractionalMSELoss(FractionalLossFunction): """Mean Squared Error loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch import torch.nn.functional as F # Ensure inputs are PyTorch tensors if not isinstance(predictions, torch.Tensor): predictions = torch.tensor(predictions, dtype=torch.float32) if not isinstance(targets, torch.Tensor): targets = torch.tensor(targets, dtype=torch.float32) return F.mse_loss(predictions, targets, reduction=self.reduction) else: # JAX/NUMBA implementation squared_diff = (predictions - targets) ** 2 if self.reduction == "mean": return self.tensor_ops.mean(squared_diff) elif self.reduction == "sum": return self.tensor_ops.sum(squared_diff) else: # none return squared_diff
[docs] class FractionalCrossEntropyLoss(FractionalLossFunction): """Cross Entropy loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.cross_entropy( predictions, targets, reduction=self.reduction) else: # JAX/NUMBA implementation # Apply softmax to predictions softmax_pred = self.tensor_ops.softmax(predictions, dim=-1) # Compute cross-entropy # Add small epsilon for numerical stability log_softmax = self.tensor_ops.log(softmax_pred + 1e-8) # For one-hot targets if len(targets.shape) == 2: loss = -self.tensor_ops.sum(targets * log_softmax, dim=-1) else: # For class indices batch_size = predictions.shape[0] loss = -log_softmax[np.arange(batch_size), targets] if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalHuberLoss(FractionalLossFunction): """Huber loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", delta: float = 1.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.delta = delta self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.huber_loss( predictions, targets, delta=self.delta, reduction=self.reduction) else: # JAX/NUMBA implementation diff = predictions - targets abs_diff = self.tensor_ops.abs(diff) # Huber loss computation quadratic = self.tensor_ops.minimum(abs_diff, self.delta) linear = abs_diff - quadratic loss = 0.5 * quadratic ** 2 + self.delta * linear if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalSmoothL1Loss(FractionalLossFunction): """Smooth L1 loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", beta: float = 1.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.beta = beta self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.smooth_l1_loss( predictions, targets, beta=self.beta, reduction=self.reduction) else: # JAX/NUMBA implementation diff = predictions - targets abs_diff = self.tensor_ops.abs(diff) # Smooth L1 loss computation quadratic = 0.5 * (abs_diff ** 2) / self.beta linear = abs_diff - 0.5 * self.beta loss = self.tensor_ops.where( abs_diff < self.beta, quadratic, linear) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalKLDivLoss(FractionalLossFunction): """KL Divergence loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.kl_div(predictions, targets, reduction=self.reduction) else: # JAX/NUMBA implementation # KL divergence: KL(targets || predictions) # targets should be log-probabilities, predictions should be # probabilities loss = targets * (targets - predictions) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalBCELoss(FractionalLossFunction): """Binary Cross Entropy loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch import torch.nn.functional as F if not isinstance(predictions, torch.Tensor): predictions = torch.tensor(predictions, dtype=torch.float32) if not isinstance(targets, torch.Tensor): targets = torch.tensor(targets, dtype=torch.float32) # Ensure predictions are within (0,1) if predictions.min() < 0 or predictions.max() > 1: predictions = torch.sigmoid(predictions) return F.binary_cross_entropy( predictions, targets, reduction=self.reduction) else: # JAX/NUMBA implementation # Binary cross-entropy: -[targets * log(predictions) + (1 - # targets) * log(1 - predictions)] epsilon = 1e-8 # Small epsilon for numerical stability predictions = self.tensor_ops.clip( predictions, epsilon, 1 - epsilon) loss = -targets * self.tensor_ops.log(predictions) - ( 1 - targets) * self.tensor_ops.log(1 - predictions) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalNLLLoss(FractionalLossFunction): """Negative Log Likelihood loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.nll_loss(predictions, targets, reduction=self.reduction) else: # JAX/NUMBA implementation # NLL loss: -log(predictions[targets]) batch_size = predictions.shape[0] loss = -predictions[np.arange(batch_size), targets] if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalPoissonNLLLoss(FractionalLossFunction): """Poisson Negative Log Likelihood loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", log_input: bool = True, full: bool = False, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.log_input = log_input self.full = full self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.poisson_nll_loss( predictions, targets, log_input=self.log_input, full=self.full, reduction=self.reduction) else: # JAX/NUMBA implementation if self.log_input: # predictions are log(input) loss = self.tensor_ops.exp(predictions) - targets * predictions else: # predictions are input loss = predictions - targets * \ self.tensor_ops.log(predictions + 1e-8) if self.full: # Add Stirling approximation term loss += targets * self.tensor_ops.log( targets + 1e-8) - targets + 0.5 * self.tensor_ops.log(2 * np.pi * targets + 1e-8) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalCosineEmbeddingLoss(FractionalLossFunction): """Cosine Embedding loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", margin: float = 0.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.margin = margin self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F # Handle multi-input loss: predictions is a tuple (input1, input2) if isinstance(predictions, tuple) and len(predictions) == 2: input1, input2 = predictions return F.cosine_embedding_loss( input1, input2, targets, margin=self.margin, reduction=self.reduction) else: # Fallback for single input return F.cosine_embedding_loss( predictions, targets, margin=self.margin, reduction=self.reduction) else: # JAX/NUMBA implementation # Cosine embedding loss cos_sim = self.tensor_ops.sum(predictions * targets, dim=-1) / ( self.tensor_ops.norm(predictions, dim=-1) * self.tensor_ops.norm(targets, dim=-1) + 1e-8 ) loss = self.tensor_ops.where( targets == 1, 1 - cos_sim, self.tensor_ops.maximum(0, cos_sim - self.margin) ) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalMarginRankingLoss(FractionalLossFunction): """Margin Ranking loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", margin: float = 0.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.margin = margin self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.margin_ranking_loss( predictions[0], predictions[1], targets, margin=self.margin, reduction=self.reduction) else: # JAX/NUMBA implementation # Margin ranking loss: max(0, -targets * (predictions[0] - # predictions[1]) + margin) x1, x2 = predictions[0], predictions[1] loss = self.tensor_ops.maximum( 0, -targets * (x1 - x2) + self.margin) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalMultiMarginLoss(FractionalLossFunction): """Multi Margin loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", p: int = 1, margin: float = 1.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.p = p self.margin = margin self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.multi_margin_loss( predictions, targets, p=self.p, margin=self.margin, reduction=self.reduction) else: # JAX/NUMBA implementation # Multi margin loss batch_size = predictions.shape[0] num_classes = predictions.shape[1] loss = self.tensor_ops.zeros(batch_size) for i in range(batch_size): target = targets[i] pred = predictions[i] # Compute margin loss for each sample target_pred = pred[target] margin_loss = self.tensor_ops.maximum( 0, self.margin - target_pred + pred) margin_loss = self.tensor_ops.where( np.arange(num_classes) == target, 0, margin_loss ) loss = loss.at[i].set(self.tensor_ops.sum( margin_loss ** self.p) ** (1 / self.p)) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalTripletMarginLoss(FractionalLossFunction): """Triplet Margin loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", margin: float = 1.0, p: float = 2.0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.margin = margin self.p = p self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F return F.triplet_margin_loss( predictions[0], predictions[1], predictions[2], margin=self.margin, p=self.p, reduction=self.reduction) else: # JAX/NUMBA implementation # Triplet margin loss: max(0, d(anchor, positive) - d(anchor, # negative) + margin) anchor, positive, negative = predictions[0], predictions[1], predictions[2] # Compute distances pos_dist = self.tensor_ops.norm( anchor - positive, p=self.p, dim=-1) neg_dist = self.tensor_ops.norm( anchor - negative, p=self.p, dim=-1) loss = self.tensor_ops.maximum( 0, pos_dist - neg_dist + self.margin) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalCTCLoss(FractionalLossFunction): """Connectionist Temporal Classification loss with fractional calculus integration"""
[docs] def __init__( self, fractional_order: float = 0.5, method: str = "RL", blank: int = 0, reduction: str = "mean", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.blank = blank self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: if self.backend == BackendType.TORCH: import torch.nn.functional as F # Handle CTC loss with additional parameters if isinstance(predictions, tuple) and len(predictions) == 3: log_probs, input_lengths, target_lengths = predictions return F.ctc_loss( log_probs, targets, input_lengths, target_lengths, blank=self.blank, reduction=self.reduction) else: # Fallback for simple case return F.ctc_loss( predictions, targets, blank=self.blank, reduction=self.reduction) else: # JAX/NUMBA implementation # Simplified CTC loss (in practice, you'd want a more sophisticated implementation) # This is a placeholder implementation batch_size = predictions.shape[0] loss = self.tensor_ops.zeros(batch_size) # Simplified loss computation for i in range(batch_size): pred = predictions[i] target = targets[i] # Basic alignment-based loss (simplified) pred_probs = self.tensor_ops.softmax(pred, dim=-1) target_probs = self.tensor_ops.zeros_like(pred_probs) target_probs = target_probs.at[np.arange( len(target)), target].set(1.0) loss = loss.at[i].set(-self.tensor_ops.sum(target_probs * self.tensor_ops.log(pred_probs + 1e-8))) if self.reduction == "mean": return self.tensor_ops.mean(loss) elif self.reduction == "sum": return self.tensor_ops.sum(loss) else: # none return loss
[docs] class FractionalCustomLoss(FractionalLossFunction): """Custom loss function with fractional calculus integration"""
[docs] def __init__( self, loss_fn, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.loss_fn = loss_fn self.custom_loss_fn = loss_fn # Alias for backward compatibility
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: return self.loss_fn(predictions, targets)
[docs] class FractionalCombinedLoss(FractionalLossFunction): """Combined loss function with fractional calculus integration"""
[docs] def __init__( self, loss_functions: list, weights: list = None, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None): super().__init__(fractional_order, method, backend) self.loss_functions = loss_functions self.weights = weights or [1.0] * len(loss_functions) if len(self.weights) != len(self.loss_functions): raise ValueError( "Number of weights must match number of loss functions")
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: total_loss = 0.0 for loss_fn, weight in zip(self.loss_functions, self.weights): if isinstance(loss_fn, FractionalLossFunction): # Use fractional loss function # Don't apply fractional twice loss = loss_fn(predictions, targets, use_fractional=False) else: # Use regular loss function loss = loss_fn(predictions, targets) total_loss += weight * loss return total_loss
# ============================================================================ # SDE-SPECIFIC LOSS FUNCTIONS # ============================================================================
[docs] class FractionalSDEMSELoss(FractionalLossFunction): """ MSE loss for fractional SDE trajectory matching. Computes L2 distance between predicted and target trajectories, accounting for stochastic variability through multiple samples. """
[docs] def __init__( self, num_samples: int = 10, reduction: str = "mean", fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None ): """ Initialize SDE MSE loss. Args: num_samples: Number of stochastic samples to average over reduction: Reduction type ("mean", "sum", "none") fractional_order: Fractional order for loss computation method: Fractional derivative method backend: Computation backend """ super().__init__(fractional_order, method, backend) self.num_samples = num_samples self.reduction = reduction
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: """ Compute MSE loss for SDE trajectories. Args: predictions: Predicted trajectories (may be stochastic samples) targets: Target trajectories Returns: MSE loss """ if self.backend == BackendType.TORCH: import torch import torch.nn.functional as F # Handle multiple samples - average across samples if predictions.dim() == 3: # (num_samples, batch, features) predictions = predictions.mean(dim=0) loss = F.mse_loss(predictions, targets, reduction=self.reduction) return loss else: # JAX/NUMBA implementation # Average over samples if needed if len(predictions.shape) == 3: predictions = self.tensor_ops.mean(predictions, axis=0) squared_diff = (predictions - targets) ** 2 if self.reduction == "mean": return self.tensor_ops.mean(squared_diff) elif self.reduction == "sum": return self.tensor_ops.sum(squared_diff) else: return squared_diff
[docs] class FractionalKLDivergenceLoss(FractionalLossFunction): """ KL divergence loss for matching SDE distributions. Measures divergence between predicted and target distributions in stochastic dynamics. """
[docs] def __init__( self, eps: float = 1e-8, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None ): """ Initialize KL divergence loss. Args: eps: Small constant for numerical stability fractional_order: Fractional order for loss computation method: Fractional derivative method backend: Computation backend """ super().__init__(fractional_order, method, backend) self.eps = eps
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: """ Compute KL divergence loss. Args: predictions: Predicted probabilities or samples targets: Target probabilities or samples Returns: KL divergence loss """ if self.backend == BackendType.TORCH: import torch import torch.nn.functional as F # Ensure probabilities are in [0, 1] pred = torch.clamp(predictions, self.eps, 1.0 - self.eps) tgt = torch.clamp(targets, self.eps, 1.0) kl = tgt * torch.log(tgt / pred) return kl.sum() else: # JAX/NUMBA implementation pred = self.tensor_ops.clip(predictions, self.eps, 1.0 - self.eps) tgt = self.tensor_ops.clip(targets, self.eps, 1.0) kl = tgt * self.tensor_ops.log(tgt / pred) return self.tensor_ops.sum(kl)
[docs] class FractionalPathwiseLoss(FractionalLossFunction): """ Pathwise loss with uncertainty weighting. Weighted loss that accounts for prediction uncertainty in stochastic trajectories. """
[docs] def __init__( self, uncertainty_weight: float = 1.0, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None ): """ Initialize pathwise loss. Args: uncertainty_weight: Weight for uncertainty term fractional_order: Fractional order for loss computation method: Fractional derivative method backend: Computation backend """ super().__init__(fractional_order, method, backend) self.uncertainty_weight = uncertainty_weight
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: """ Compute pathwise loss with uncertainty weighting. Args: predictions: Predicted trajectories with uncertainty targets: Target trajectories Returns: Weighted pathwise loss """ if self.backend == BackendType.TORCH: import torch # Compute trajectory loss trajectory_loss = (predictions - targets) ** 2 # Estimate uncertainty as variance if predictions.dim() == 3: # (num_samples, batch, features) uncertainty = predictions.var(dim=0) else: uncertainty = torch.zeros_like(trajectory_loss) # Weighted combination loss = trajectory_loss + self.uncertainty_weight * uncertainty return loss.mean() else: # JAX/NUMBA implementation trajectory_loss = (predictions - targets) ** 2 if len(predictions.shape) == 3: uncertainty = self.tensor_ops.var(predictions, axis=0) else: uncertainty = self.tensor_ops.zeros_like(trajectory_loss) loss = trajectory_loss + self.uncertainty_weight * uncertainty return self.tensor_ops.mean(loss)
[docs] class FractionalMomentMatchingLoss(FractionalLossFunction): """ Moment matching loss for SDE distributions. Matches statistical moments (mean, variance, etc.) between predicted and target distributions. """
[docs] def __init__( self, moments: list = None, weights: list = None, fractional_order: float = 0.5, method: str = "RL", backend: Optional[BackendType] = None ): """ Initialize moment matching loss. Args: moments: List of moments to match (default: [1, 2] for mean, variance) weights: Weights for each moment fractional_order: Fractional order for loss computation method: Fractional derivative method backend: Computation backend """ super().__init__(fractional_order, method, backend) self.moments = moments or [1, 2] self.weights = weights or [1.0] * len(self.moments) if len(self.weights) != len(self.moments): raise ValueError("Number of weights must match number of moments")
[docs] def compute_loss(self, predictions: Any, targets: Any) -> Any: """ Compute moment matching loss. Args: predictions: Predicted samples or distribution parameters targets: Target samples or distribution parameters Returns: Moment matching loss """ if self.backend == BackendType.TORCH: import torch total_loss = 0.0 for moment, weight in zip(self.moments, self.weights): # Compute moments pred_moment = self._compute_moment(predictions, moment) tgt_moment = self._compute_moment(targets, moment) # Add weighted squared difference total_loss += weight * (pred_moment - tgt_moment) ** 2 return total_loss else: # JAX/NUMBA implementation total_loss = 0.0 for moment, weight in zip(self.moments, self.weights): pred_moment = self._compute_moment(predictions, moment) tgt_moment = self._compute_moment(targets, moment) total_loss += weight * (pred_moment - tgt_moment) ** 2 return total_loss
[docs] def _compute_moment(self, x: Any, order: int) -> Any: """Compute statistical moment of given order.""" if self.backend == BackendType.TORCH: import torch if order == 1: # Mean return x.mean(dim=0) if x.dim() > 1 else x.mean() elif order == 2: # Variance return x.var(dim=0) if x.dim() > 1 else x.var() else: # Higher moments centered = x - x.mean() return (centered ** order).mean(dim=0) if x.dim() > 1 else (centered ** order).mean() else: # JAX/NUMBA implementation if order == 1: return self.tensor_ops.mean(x, axis=0) if len(x.shape) > 1 else self.tensor_ops.mean(x) elif order == 2: return self.tensor_ops.var(x, axis=0) if len(x.shape) > 1 else self.tensor_ops.var(x) else: centered = x - self.tensor_ops.mean(x) return self.tensor_ops.mean(centered ** order, axis=0) if len(x.shape) > 1 else self.tensor_ops.mean(centered ** order)