"""
Neural Fractional Stochastic Differential Equations
This module provides neural network-based fractional SDEs with adjoint
training methods for efficient gradient-based learning.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Optional, Tuple, Callable
from dataclasses import dataclass
import numpy as np
from hpfracc.core.definitions import FractionalOrder, validate_fractional_order
from ..ml.neural_ode import BaseNeuralODE, NeuralODEConfig
from ..ml.adjoint_optimization import AdjointConfig
from ..solvers.sde_solvers import solve_fractional_sde, FractionalSDESolver
[docs]
@dataclass
class NeuralFSDEConfig(NeuralODEConfig):
"""Configuration for neural fractional SDE models."""
diffusion_dim: int = 1 # Dimension of noise
noise_type: str = "additive" # "additive" or "multiplicative"
drift_net: Optional[nn.Module] = None
diffusion_net: Optional[nn.Module] = None
use_sde_adjoint: bool = True # Use SDE-specific adjoint method
learn_alpha: bool = False # Whether to learn fractional order
[docs]
class NeuralFractionalSDE(BaseNeuralODE):
"""
Neural network-based fractional SDE with adjoint training.
Extends neural ODE framework to fractional stochastic differential equations
for modeling stochastic dynamics with memory effects.
The model learns:
- Drift function f(t, x): deterministic dynamics
- Diffusion function g(t, x): stochastic noise magnitude
- Fractional order: memory effects in dynamics
"""
[docs]
def __init__(self, config: NeuralFSDEConfig):
"""
Initialize neural fractional SDE.
Args:
config: Configuration for the neural fSDE
"""
super().__init__(config)
self.config = config
self.diffusion_dim = config.diffusion_dim
self.noise_type = config.noise_type
# Build drift and diffusion networks
self._build_drift_network()
self._build_diffusion_network()
# Fractional order parameter
if isinstance(config.fractional_order, float):
self.fractional_order_value = config.fractional_order
else:
self.fractional_order_value = config.fractional_order.alpha
# Initialize learned fractional order if needed
self.learn_alpha = getattr(config, 'learn_alpha', False)
if self.learn_alpha:
self.alpha_param = nn.Parameter(torch.tensor(self.fractional_order_value))
[docs]
def _build_drift_network(self):
"""Build neural network for drift function."""
if self.config.drift_net is not None:
self.drift_net = self.config.drift_net
# Validate custom network input dimension
expected_input_dim = self.input_dim + 1 # time + state
actual_input_dim = None
# Find first Linear layer in the network
for module in self.drift_net.modules():
if isinstance(module, nn.Linear):
actual_input_dim = module.in_features
break
if actual_input_dim is not None and actual_input_dim != expected_input_dim:
raise ValueError(
f"Custom drift network input dimension mismatch: "
f"expected {expected_input_dim} (input_dim + 1 = {self.input_dim} + 1), "
f"but network has {actual_input_dim}. "
f"Ensure the first Linear layer has in_features={expected_input_dim}."
)
else:
# Default drift network
self.drift_net = nn.Sequential(
nn.Linear(self.input_dim + 1, self.hidden_dim),
nn.Tanh(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.Tanh(),
nn.Linear(self.hidden_dim, self.output_dim)
)
[docs]
def _build_diffusion_network(self):
"""Build neural network for diffusion function."""
if self.config.diffusion_net is not None:
self.diffusion_net = self.config.diffusion_net
# Validate custom network input dimension
expected_input_dim = self.input_dim + 1 # time + state
actual_input_dim = None
# Find first Linear layer in the network
for module in self.diffusion_net.modules():
if isinstance(module, nn.Linear):
actual_input_dim = module.in_features
break
if actual_input_dim is not None and actual_input_dim != expected_input_dim:
raise ValueError(
f"Custom diffusion network input dimension mismatch: "
f"expected {expected_input_dim} (input_dim + 1 = {self.input_dim} + 1), "
f"but network has {actual_input_dim}. "
f"Ensure the first Linear layer has in_features={expected_input_dim}."
)
else:
# Default diffusion network
# Output shape: (batch, output_dim, diffusion_dim) for multiplicative
# or (batch, diffusion_dim) for additive
if self.noise_type == "multiplicative":
output_dim = self.output_dim * self.diffusion_dim
else:
output_dim = self.diffusion_dim
self.diffusion_net = nn.Sequential(
nn.Linear(self.input_dim + 1, self.hidden_dim),
nn.Tanh(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.Tanh(),
nn.Linear(self.hidden_dim, output_dim)
)
# Use softplus to ensure positive diffusion
self.diffusion_activation = nn.Softplus()
[docs]
def drift_function(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Alias for drift() for compatibility with tests."""
return self.drift(t, x)
[docs]
def diffusion_function(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Alias for diffusion() for compatibility with tests."""
return self.diffusion(t, x)
[docs]
def drift(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Compute drift function f(t, x).
Args:
t: Time tensor
x: State tensor
Returns:
Drift vector
"""
# Ensure x has batch dimension
if x.dim() == 1:
x = x.unsqueeze(0)
batch_size = x.shape[0]
# Ensure t has proper shape (batch, 1)
if t.dim() == 0: # Scalar
t_expanded = t.unsqueeze(0).unsqueeze(0).expand(batch_size, 1)
elif t.dim() == 1: # 1D tensor
if t.shape[0] == 1: # Single time value
t_expanded = t.unsqueeze(0).expand(batch_size, 1)
else: # Multiple time values (should match batch)
t_expanded = t.unsqueeze(-1)
else: # Already 2D
t_expanded = t
# Concatenate [t, x]
input_tensor = torch.cat([t_expanded, x], dim=-1)
# Forward pass through drift network
drift = self.drift_net(input_tensor)
return drift
[docs]
def diffusion(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Compute diffusion function g(t, x).
Args:
t: Time tensor
x: State tensor
Returns:
Diffusion matrix
"""
# Ensure x has batch dimension
if x.dim() == 1:
x = x.unsqueeze(0)
batch_size = x.shape[0]
# Ensure t has proper shape (batch, 1)
if t.dim() == 0: # Scalar
t_expanded = t.unsqueeze(0).unsqueeze(0).expand(batch_size, 1)
elif t.dim() == 1: # 1D tensor
if t.shape[0] == 1: # Single time value
t_expanded = t.unsqueeze(0).expand(batch_size, 1)
else: # Multiple time values (should match batch)
t_expanded = t.unsqueeze(-1)
else: # Already 2D
t_expanded = t
# Concatenate [t, x]
input_tensor = torch.cat([t_expanded, x], dim=-1)
# Forward pass through diffusion network
diffusion = self.diffusion_net(input_tensor)
diffusion = self.diffusion_activation(diffusion)
# Reshape for multiplicative noise
if self.noise_type == "multiplicative":
diffusion = diffusion.view(batch_size, self.output_dim, self.diffusion_dim)
return diffusion
[docs]
def fractional_order(self) -> float:
"""Get current fractional order."""
if self.learn_alpha:
# Clamp to valid range (0, 2)
return torch.clamp(self.alpha_param, 0.1, 1.9).item()
return self.fractional_order_value
[docs]
def forward(
self,
x0: torch.Tensor,
t: torch.Tensor,
method: str = "euler_maruyama",
num_steps: int = 100,
seed: Optional[int] = None
) -> torch.Tensor:
"""
Forward pass through neural fractional SDE.
Args:
x0: Initial condition
t: Time points (1D tensor or 2D batch)
method: Solver method
num_steps: Number of integration steps
seed: Random seed for reproducibility
Returns:
Trajectory solution
"""
# Ensure inputs are tensors
if not isinstance(x0, torch.Tensor):
x0 = torch.tensor(x0, dtype=torch.float32)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, dtype=torch.float32)
# Handle batch dimensions
if x0.dim() == 1:
x0 = x0.unsqueeze(0) # (1, dim)
batch_size, dim = x0.shape
# Get time span
if t.dim() > 1:
t_flat = t.flatten()
else:
t_flat = t
t_start = t_flat[0]
t_end = t_flat[-1]
# Use the PyTorch solver directly
return self._solve_fractional_sde_torch(
x0, t_start, t_end, num_steps, seed
)
[docs]
def _solve_fractional_sde_torch(
self,
x0: torch.Tensor,
t_start: torch.Tensor,
t_end: torch.Tensor,
num_steps: int,
seed: Optional[int] = None
) -> torch.Tensor:
"""
PyTorch-native fractional SDE solver (Euler-Maruyama).
"""
device = x0.device
dtype = x0.dtype
batch_size, dim = x0.shape
dt = (t_end - t_start) / num_steps
alpha = self.fractional_order()
# Gamma factor: 1 / Gamma(alpha + 1)
# Use torch.lgamma for stability: exp(lgamma(x)) = Gamma(x)
if isinstance(alpha, torch.Tensor):
gamma_val = torch.exp(torch.lgamma(alpha + 1))
else:
gamma_val = torch.exp(torch.lgamma(torch.tensor(alpha + 1.0, device=device, dtype=dtype)))
gamma_factor = 1.0 / gamma_val
# Precompute weights
# weights[k] = (k+1)^alpha - k^alpha
# We need to handle alpha as tensor or float
k_vals = torch.arange(num_steps + 1, device=device, dtype=dtype)
if isinstance(alpha, torch.Tensor):
weights = (k_vals + 1).pow(alpha) - k_vals.pow(alpha)
else:
weights = (k_vals + 1).pow(alpha) - k_vals.pow(alpha)
# Initialize history
# We need to store history for convolution
# Shape: (num_steps, batch_size, dim)
drift_history = []
diffusion_history = []
# Initialize trajectory
# Shape: (num_steps + 1, batch_size, dim)
trajectory = [x0]
curr_x = x0
curr_t = t_start
if seed is not None:
torch.manual_seed(seed)
for i in range(num_steps):
# Compute drift and diffusion
# drift shape: (batch_size, dim)
drift_val = self.drift(curr_t, curr_x)
diffusion_val = self.diffusion(curr_t, curr_x)
# Generate noise
# Shape: (batch_size, diffusion_dim)
dW = torch.randn(batch_size, self.diffusion_dim, device=device, dtype=dtype) * torch.sqrt(dt)
# Store history
drift_history.append(drift_val)
# Handle noise term
if self.noise_type == "multiplicative":
# diffusion_val is (batch, dim, diffusion_dim)
# dW is (batch, diffusion_dim)
# Result should be (batch, dim)
noise_term = torch.bmm(diffusion_val, dW.unsqueeze(-1)).squeeze(-1)
else:
# Additive or diagonal
# diffusion_val is (batch, diffusion_dim)
# dW is (batch, diffusion_dim)
if dim == self.diffusion_dim:
# Diagonal noise
noise_term = diffusion_val * dW
elif self.diffusion_dim == 1:
# Scalar noise broadcasted
noise_term = diffusion_val * dW
# If result is (batch, 1), it will broadcast during addition
else:
# Mismatched dimensions for additive noise without matrix
# This might be an issue if not handled, but for now assume broadcasting or error
# If diffusion_val is (batch, m) and dW is (batch, m), we get (batch, m).
# If m != d, we can't add to drift (batch, d) unless m=1.
if diffusion_val.shape[-1] != dim:
# Try to treat as diagonal if possible or raise error?
# For now, let's assume if it's not multiplicative, it's element-wise compatible
noise_term = diffusion_val * dW
else:
noise_term = diffusion_val * dW
diffusion_history.append(noise_term)
# Compute memory terms via convolution
# We need to sum weights[i-j] * history[j] for j=0..i
# Efficient way using torch operations
# Stack history so far
# drift_hist_stack: (i+1, batch, dim)
drift_hist_stack = torch.stack(drift_history)
diff_hist_stack = torch.stack(diffusion_history)
# Get weights for this step: w_i, w_{i-1}, ..., w_0
# weights is 1D: (num_steps+1,)
# We need weights[0]...weights[i]
# And we need to reverse them to match history:
# sum_{j=0}^i w_{i-j} h_j
# w_{i} * h_0 + w_{i-1} * h_1 + ... + w_0 * h_i
current_weights = weights[:i+1].flip(0) # (i+1,)
# Reshape weights for broadcasting: (i+1, 1, 1)
w_reshaped = current_weights.view(-1, 1, 1)
# Weighted sum
drift_integral = (w_reshaped * drift_hist_stack).sum(dim=0)
diffusion_integral = (w_reshaped * diff_hist_stack).sum(dim=0)
# Update step
# X_{i+1} = X_0 + h^alpha / Gamma(alpha+1) * (drift_int + diff_int)
next_x = x0 + gamma_factor * dt.pow(alpha) * (drift_integral + diffusion_integral)
trajectory.append(next_x)
curr_x = next_x
curr_t = curr_t + dt
# Stack trajectory
# (num_steps + 1, batch_size, dim)
trajectory_tensor = torch.stack(trajectory)
# Return as (time_steps, batch_size, dim)
return trajectory_tensor
[docs]
def get_fractional_order(self) -> Union[float, torch.Tensor]:
"""Get the fractional order parameter."""
if self.learn_alpha:
return torch.clamp(self.alpha_param, 0.1, 1.9)
return self.fractional_order_value
[docs]
def adjoint_forward(self, x0: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
"""Adjoint-compatible forward pass."""
return self.forward(x0, t, **kwargs)
[docs]
def create_neural_fsde(
input_dim: int = None,
output_dim: int = None,
hidden_dim: int = 64,
num_layers: int = 3,
fractional_order: float = 0.5,
diffusion_dim: int = 1,
noise_type: str = "additive",
learn_alpha: bool = False,
use_adjoint: bool = True,
drift_net: Optional[nn.Module] = None,
diffusion_net: Optional[nn.Module] = None,
config: Optional[NeuralFSDEConfig] = None
) -> NeuralFractionalSDE:
"""
Factory function to create a neural fractional SDE.
Args:
input_dim: Input dimension
output_dim: Output dimension
hidden_dim: Hidden layer dimension
num_layers: Number of hidden layers
fractional_order: Initial fractional order
diffusion_dim: Dimension of noise
noise_type: Type of noise ("additive" or "multiplicative")
learn_alpha: Whether to learn fractional order
use_adjoint: Use adjoint method for backpropagation
drift_net: Custom drift network
diffusion_net: Custom diffusion network
Returns:
NeuralFractionalSDE instance
"""
if config is not None:
# Use provided config
pass
else:
# Create config from parameters
if input_dim is None or output_dim is None:
raise ValueError("input_dim and output_dim must be provided when config is None")
config = NeuralFSDEConfig(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
num_layers=num_layers,
fractional_order=fractional_order,
use_adjoint=use_adjoint,
diffusion_dim=diffusion_dim,
noise_type=noise_type,
learn_alpha=learn_alpha,
drift_net=drift_net,
diffusion_net=diffusion_net
)
model = NeuralFractionalSDE(config)
model.learn_alpha = learn_alpha
return model