#!/usr/bin/env python3
"""
Targeted Optimized Neural Fractional Ordinary Differential Equations (Neural fODE)
This module provides targeted optimizations for neural networks that can learn
to represent fractional differential equations, focusing on high-impact improvements
without adding unnecessary complexity.
Key Improvements:
- Optimized fractional ODE implementation (proper fractional calculus)
- Advanced solver options with better performance
- Memory optimization for large inputs
- Improved training efficiency
- Performance monitoring without overhead
Author: Davian R. Chin, Department of Biomedical Engineering, University of Reading
Targeted Optimization: September 2025
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import Union, Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass
import warnings
import math
# Import from relative paths
from hpfracc.core.definitions import FractionalOrder
from hpfracc.core.utilities import validate_fractional_order
# ============================================================================
# TARGETED CONFIGURATION
# ============================================================================
[docs]
@dataclass
class NeuralODEConfig:
"""Targeted configuration for neural ODE models"""
input_dim: int = 2
hidden_dim: int = 64
output_dim: int = 2
num_layers: int = 3
activation: str = "tanh"
use_adjoint: bool = True
solver: str = "dopri5"
rtol: float = 1e-5
atol: float = 1e-5
fractional_order: Optional[Union[float, FractionalOrder]] = None
device: Optional[torch.device] = None
dtype: torch.dtype = torch.float32
enable_performance_monitoring: bool = False
memory_optimization: bool = True
use_advanced_solvers: bool = True
def __post_init__(self):
if self.fractional_order is None:
self.fractional_order = FractionalOrder(0.5)
elif isinstance(self.fractional_order, float):
self.fractional_order = FractionalOrder(self.fractional_order)
# ============================================================================
# TARGETED BASE CLASS
# ============================================================================
[docs]
class BaseNeuralODE(nn.Module, ABC):
"""Targeted optimized base class for Neural ODE implementations"""
[docs]
def __init__(self, config: NeuralODEConfig):
super().__init__()
self.config = config
self._setup_layer()
self.performance_stats = {} if config.enable_performance_monitoring else None
[docs]
def _setup_layer(self):
"""Setup layer-specific components"""
self.input_dim = self.config.input_dim
self.hidden_dim = self.config.hidden_dim
self.output_dim = self.config.output_dim
self.num_layers = self.config.num_layers
self.activation = self.config.activation
self.use_adjoint = self.config.use_adjoint
# Build network
self._build_network()
[docs]
def _build_network(self):
"""Build neural network architecture with optimizations"""
layers = []
# Input layer: time + input_dim -> hidden_dim
layers.append(nn.Linear(self.input_dim + 1, self.hidden_dim))
# Hidden layers with optimized initialization
for _ in range(self.num_layers - 1):
layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
# Output layer
layers.append(nn.Linear(self.hidden_dim, self.output_dim))
self.network = nn.Sequential(*layers)
self._initialize_weights()
[docs]
def _initialize_weights(self):
"""Optimized weight initialization"""
for module in self.modules():
if isinstance(module, nn.Linear):
# Use Xavier initialization for better gradient flow
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
[docs]
def _get_activation(self, x: torch.Tensor) -> torch.Tensor:
"""Apply activation function"""
if self.activation == "tanh":
return torch.tanh(x)
elif self.activation == "relu":
return F.relu(x)
elif self.activation == "sigmoid":
return torch.sigmoid(x)
else:
return torch.tanh(x) # Default to tanh
[docs]
def ode_func(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Optimized ODE function with improved tensor handling"""
# Optimized shape handling - minimize operations
was_single_input = x.dim() == 1
if was_single_input:
x = x.unsqueeze(0)
if t.dim() == 0:
t = t.unsqueeze(0)
# Efficient tensor operations
batch_size = x.shape[0]
if t.numel() == 1:
t = t.expand(batch_size)
# Vectorized concatenation
t_expanded = t.unsqueeze(-1)
input_tensor = torch.cat([t_expanded, x], dim=-1)
# Forward pass
output = self.network(input_tensor)
output = self._get_activation(output)
# Handle output shape
if was_single_input and output.shape[0] == 1:
output = output.squeeze(0)
return output
[docs]
@abstractmethod
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Forward pass - must be implemented by subclasses"""
pass
# ============================================================================
# TARGETED IMPLEMENTATIONS
# ============================================================================
[docs]
class NeuralODE(BaseNeuralODE):
"""Targeted optimized Neural ODE implementation"""
[docs]
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
num_layers: int = 3, activation: str = "tanh",
use_adjoint: bool = True, solver: str = "dopri5",
rtol: float = 1e-5, atol: float = 1e-5):
# Create config from parameters
config = NeuralODEConfig(
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
num_layers=num_layers, activation=activation, use_adjoint=use_adjoint,
solver=solver, rtol=rtol, atol=atol
)
super().__init__(config)
self.solver_name = config.solver
self.has_torchdiffeq = self._check_torchdiffeq()
[docs]
def _check_torchdiffeq(self) -> bool:
"""Check if torchdiffeq is available"""
try:
import torchdiffeq
return True
except ImportError:
return False
[docs]
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Optimized forward pass"""
batch_size = x.shape[0]
if t.dim() == 1:
t = t.unsqueeze(0).expand(batch_size, -1)
# Solve ODE using optimized solver
if self.has_torchdiffeq and self.solver_name == "dopri5" and self.config.use_advanced_solvers:
solution = self._solve_torchdiffeq(x, t)
else:
solution = self._solve_optimized_euler(x, t)
return solution
[docs]
def _solve_torchdiffeq(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Solve using torchdiffeq with optimizations"""
try:
import torchdiffeq as tde
# Ensure t is 1D time vector
if t.dim() > 1:
t_vec = t[0]
else:
t_vec = t
# Initial state for integration
if x.dim() == 1:
y0 = x[:self.output_dim].unsqueeze(0)
else:
y0 = x[:, :self.output_dim]
# Wrap the ODE function
class _ODEFunc(nn.Module):
def __init__(self, parent):
super().__init__()
self.parent = parent
def forward(self, time, state):
if state.dim() == 1:
state = state.unsqueeze(0)
batch_size, out_dim = state.shape
if self.parent.input_dim <= out_dim:
ode_input = state[:, :self.parent.input_dim]
else:
ode_input = torch.zeros(batch_size, self.parent.input_dim,
device=state.device, dtype=state.dtype)
ode_input[:, :out_dim] = state
deriv = self.parent.ode_func(time, ode_input)
if deriv.dim() == 1:
deriv = deriv.unsqueeze(0)
if deriv.shape[1] > self.parent.output_dim:
deriv = deriv[:, :self.parent.output_dim]
elif deriv.shape[1] < self.parent.output_dim:
padded = torch.zeros(batch_size, self.parent.output_dim,
device=deriv.device, dtype=deriv.dtype)
padded[:, :deriv.shape[1]] = deriv
deriv = padded
return deriv
func_module = _ODEFunc(self)
solution = tde.odeint_adjoint(func_module, y0, t_vec,
rtol=self.config.rtol, atol=self.config.atol)
# Convert from (time, batch, dim) to (batch, time, dim)
if solution.dim() == 3:
solution = solution.permute(1, 0, 2).contiguous()
return solution
except Exception as e:
warnings.warn(f"torchdiffeq failed, falling back to Euler: {e}")
return self._solve_optimized_euler(x, t)
[docs]
def _solve_optimized_euler(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Optimized Euler solver with memory efficiency"""
batch_size, time_steps = t.shape
solution = torch.zeros(batch_size, time_steps, self.output_dim,
device=x.device, dtype=x.dtype)
# Initialize
if x.shape[1] >= self.output_dim:
solution[:, 0, :] = x[:, :self.output_dim]
else:
solution[:, 0, :x.shape[1]] = x
solution[:, 0, x.shape[1]:] = 0.0
# Optimized Euler method with vectorized operations
for i in range(1, time_steps):
dt = t[:, i] - t[:, i-1]
current_state = solution[:, i-1, :]
# Map to input dimension efficiently
if current_state.shape[1] > self.input_dim:
ode_input = current_state[:, :self.input_dim]
else:
ode_input = torch.zeros(
batch_size, self.input_dim, device=x.device)
ode_input[:, :current_state.shape[1]] = current_state
# Get derivative
derivative = self.ode_func(t[:, i-1], ode_input)
# Ensure derivative has correct shape
if derivative.dim() == 1:
derivative = derivative.unsqueeze(0)
# Update solution efficiently
if derivative.shape[1] == self.output_dim:
solution[:, i, :] = current_state + \
dt.unsqueeze(-1) * derivative
else:
if derivative.shape[1] > self.output_dim:
solution[:, i, :] = current_state + \
dt.unsqueeze(-1) * derivative[:, :self.output_dim]
else:
solution[:, i, :derivative.shape[1]] = current_state[:,
:derivative.shape[1]] + dt.unsqueeze(-1) * derivative
solution[:, i, derivative.shape[1] :] = current_state[:, derivative.shape[1]:]
return solution
[docs]
class NeuralFODE(BaseNeuralODE):
"""Targeted optimized Neural Fractional ODE implementation"""
[docs]
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
fractional_order: Union[float, FractionalOrder] = 0.5,
num_layers: int = 3, activation: str = "tanh",
use_adjoint: bool = True, solver: str = "fractional_euler",
rtol: float = 1e-5, atol: float = 1e-5):
# Create config from parameters
config = NeuralODEConfig(
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
fractional_order=fractional_order, num_layers=num_layers,
activation=activation, use_adjoint=use_adjoint,
solver=solver, rtol=rtol, atol=atol
)
super().__init__(config)
# Validate but assign from config
validate_fractional_order(config.fractional_order)
self.alpha = config.fractional_order
self.solver = config.solver # Expose solver attribute
self.solver_name = config.solver
self.has_torchdiffeq = self._check_torchdiffeq()
[docs]
def get_fractional_order(self) -> float:
"""Get the fractional order"""
return float(self.alpha.alpha)
[docs]
def _check_torchdiffeq(self) -> bool:
"""Check if torchdiffeq is available"""
try:
import torchdiffeq
return True
except ImportError:
return False
[docs]
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Optimized fractional forward pass"""
batch_size = x.shape[0]
if t.dim() == 1:
# If t is 1D, verify it's not empty
if t.numel() == 0:
raise IndexError("Time tensor cannot be empty")
t = t.unsqueeze(0).expand(batch_size, -1)
elif t.dim() == 0:
# Scalar time not supported for trajectory
raise ValueError("Time tensor must be at least 1D")
# Solve fractional ODE using optimized solver
solution = self._solve_fractional_ode_optimized(x, t)
return solution
[docs]
def _solve_fractional_ode_optimized(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Optimized fractional ODE solver with proper fractional calculus.
Uses the L1 scheme (Grünwald-Letnikov weights) for correct memory handling.
"""
batch_size, time_steps = x.shape[0], t.shape[1]
solution = torch.zeros(batch_size, time_steps, self.output_dim,
device=x.device, dtype=x.dtype)
# Initialize
if x.shape[1] >= self.output_dim:
solution[:, 0, :] = x[:, :self.output_dim]
else:
solution[:, 0, :x.shape[1]] = x
solution[:, 0, x.shape[1]:] = 0.0
# Precompute weights for GL method / L1 scheme
# weights[k] = (k+1)^alpha - k^alpha
alpha_val = self.alpha.alpha
# Ensure alpha is a float or cpu tensor for weight computation to save memory?
# Better to compute on device
device = x.device
dtype = x.dtype
k_vals = torch.arange(time_steps, device=device, dtype=dtype)
weights = (k_vals + 1).pow(alpha_val) - k_vals.pow(alpha_val)
# Gamma factor: 1 / Gamma(alpha + 1)
gamma_val = math.gamma(alpha_val + 1)
gamma_factor = 1.0 / gamma_val
# History storage for drifts
# We need to store f(x_j, t_j) for j=0..i-1
drift_history = []
# Initial drift
t0 = t[:, 0]
x0_in = solution[:, 0, :] # (batch, out_dim)
# Map to input dim if needed
if x0_in.shape[1] > self.input_dim:
ode_input0 = x0_in[:, :self.input_dim]
else:
ode_input0 = torch.zeros(batch_size, self.input_dim, device=device, dtype=dtype)
ode_input0[:, :x0_in.shape[1]] = x0_in
drift0 = self.ode_func(t0, ode_input0)
# Ensure drift shape
if drift0.dim() == 1:
drift0 = drift0.unsqueeze(0)
if drift0.shape[1] > self.output_dim:
drift0 = drift0[:, :self.output_dim]
elif drift0.shape[1] < self.output_dim:
# Pad
padded = torch.zeros(batch_size, self.output_dim, device=device, dtype=dtype)
padded[:, :drift0.shape[1]] = drift0
drift0 = padded
drift_history.append(drift0)
# Optimized fractional Euler (L1 scheme implementation)
# x_{n+1} = x_0 + (h^alpha / Gamma(alpha+1)) * Sum_{j=0}^{n} w_{n-j} * f(x_j)
# Note: indices here: n goes from 0 to time_steps-2 to produce x_1 to x_{time_steps-1}
# loop i corresponds to calculating solution at index i (time t[i])
for i in range(1, time_steps):
dt = t[:, i] - t[:, i-1]
# Assume uniform dt for standard GL/L1 scheme accuracy,
# but usually code handles non-uniform roughly or assumes uniform.
# L1 scheme specifically assumes uniform grid h.
# We take mean dt if not uniform (simplified) or just use current dt (local scaling).
# For strict correctness we need uniform grid. We use t[1]-t[0] as h.
h = t[0, 1] - t[0, 0]
# Convolution
# stack history: (i, batch, output_dim)
hist_stack = torch.stack(drift_history)
# weights needed: w_{i-1}, w_{i-2}, ..., w_0
# i.e. weights[0] ... weights[i-1] in reverse order
# weights has length time_steps.
# We need i weights.
current_weights = weights[:i].flip(0) # (i,)
w_reshaped = current_weights.view(-1, 1, 1) # (i, 1, 1)
# Weighted sum
integral_term = (w_reshaped * hist_stack).sum(dim=0)
# Update
# x_i = x_0 + factor * integral
update = gamma_factor * torch.pow(h, alpha_val) * integral_term
solution[:, i, :] = solution[:, 0, :] + update
if i < time_steps - 1:
# Compute drift for next step
ti = t[:, i]
xi = solution[:, i, :]
# Input mapping
if xi.shape[1] > self.input_dim:
input_i = xi[:, :self.input_dim]
else:
input_i = torch.zeros(batch_size, self.input_dim, device=device, dtype=dtype)
input_i[:, :xi.shape[1]] = xi
drift_i = self.ode_func(ti, input_i)
# Output mapping
if drift_i.dim() == 1: drift_i = drift_i.unsqueeze(0)
if drift_i.shape[1] > self.output_dim:
drift_i = drift_i[:, :self.output_dim]
elif drift_i.shape[1] < self.output_dim:
padded = torch.zeros(batch_size, self.output_dim, device=device, dtype=dtype)
padded[:, :drift_i.shape[1]] = drift_i
drift_i = padded
drift_history.append(drift_i)
return solution
# ============================================================================
# TARGETED TRAINER
# ============================================================================
[docs]
class NeuralODETrainer:
"""Targeted optimized trainer for Neural ODE models"""
[docs]
def __init__(self, model: Union[NeuralODE, NeuralFODE],
optimizer: str = "adam", learning_rate: float = 1e-3,
loss_function: str = "mse"):
self.model = model
self.learning_rate = learning_rate
self.loss_function = loss_function
# Set up optimizer
self.optimizer = self._setup_optimizer(optimizer)
self.criterion = self._setup_loss_function(loss_function)
# Performance tracking
self.performance_stats = {
"training_time": [],
"loss_history": [],
"memory_usage": []
}
[docs]
def _setup_optimizer(self, optimizer_type: str) -> torch.optim.Optimizer:
"""Set up optimizer"""
if optimizer_type == "adam":
return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
elif optimizer_type == "sgd":
return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
elif optimizer_type == "rmsprop":
return torch.optim.RMSprop(self.model.parameters(), lr=self.learning_rate)
else:
return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
[docs]
def _setup_loss_function(self, loss_type: str) -> nn.Module:
"""Set up loss function"""
if loss_type == "mse":
return nn.MSELoss()
elif loss_type == "mae":
return nn.L1Loss()
elif loss_type == "huber":
return nn.SmoothL1Loss()
else:
return nn.MSELoss()
[docs]
def train_step(self, x: torch.Tensor, y_target: torch.Tensor,
t: torch.Tensor) -> float:
"""Optimized training step"""
start_time = time.time()
self.optimizer.zero_grad()
# Forward pass
y_pred = self.model(x, t)
# Compute loss
loss = self.criterion(y_pred, y_target)
# Backward pass
loss.backward()
# Update parameters
self.optimizer.step()
# Track performance
step_time = time.time() - start_time
self.performance_stats["training_time"].append(step_time)
self.performance_stats["loss_history"].append(loss.item())
return loss.item()
# Minimal validate method expected by tests
[docs]
def _validate(self, data_loader) -> float:
"""Compute average validation loss over a data loader."""
device = next(self.model.parameters()).device if any(
True for _ in self.model.parameters()) else torch.device('cpu')
total_loss = 0.0
count = 0
with torch.no_grad():
for batch in data_loader:
if isinstance(batch, (list, tuple)) and len(batch) == 3:
xb, yb, tb = batch
else:
# Fallback: assume (x, y)
xb, yb = batch
tb = torch.linspace(0, 1, yb.shape[1], device=yb.device)
tb = tb.unsqueeze(0).expand(xb.shape[0], -1)
xb = xb.to(device)
yb = yb.to(device)
tb = tb.to(device)
yp = self.model(xb, tb)
loss = self.criterion(yp, yb)
total_loss += float(loss.detach().cpu())
count += 1
return total_loss / max(count, 1)
# Minimal training loop expected by tests
[docs]
def train(self, data_loader, num_epochs: int = 1, verbose: bool = False):
history = {"loss": [], "epochs": []}
for epoch in range(num_epochs):
epoch_loss = 0.0
batches = 0
for batch in data_loader:
if isinstance(batch, (list, tuple)) and len(batch) == 3:
xb, yb, tb = batch
else:
# Fallback: assume (x, y)
xb, yb = batch
tb = torch.linspace(0, 1, yb.shape[1], device=yb.device)
tb = tb.unsqueeze(0).expand(xb.shape[0], -1)
loss = self.train_step(xb, yb, tb)
epoch_loss += loss
batches += 1
avg_loss = epoch_loss / max(batches, 1)
history["loss"].append(avg_loss)
history["epochs"].append(epoch + 1)
if verbose:
print(f"Epoch {epoch+1}/{num_epochs} - loss: {avg_loss:.6f}")
return history
# ============================================================================
# FACTORY FUNCTIONS
# ============================================================================
[docs]
def create_neural_ode(model_type: str = "standard", **kwargs) -> Union[NeuralODE, NeuralFODE]:
"""Factory function to create neural ODE models"""
if model_type == "standard":
return NeuralODE(**kwargs)
elif model_type == "fractional":
return NeuralFODE(**kwargs)
else:
raise ValueError(
f"Unknown model type: {model_type}. Must be one of: standard, fractional")
if model_type == "standard":
return NeuralODE(config)
elif model_type == "fractional":
return NeuralFODE(config)
else:
raise ValueError(
f"Unknown model type: {model_type}. Must be one of: standard, fractional")
[docs]
def create_neural_ode_trainer(model: Union[NeuralODE, NeuralFODE],
**kwargs) -> NeuralODETrainer:
"""Factory function to create targeted neural ODE trainer"""
return NeuralODETrainer(model, **kwargs)
if __name__ == "__main__":
print("TARGETED OPTIMIZED NEURAL ODE IMPLEMENTATION")
print("Focused on high-impact improvements")
print("=" * 60)
# Test basic functionality
config = NeuralODEConfig(input_dim=2, hidden_dim=64, output_dim=2)
model = create_neural_ode("standard", **config.__dict__)
x = torch.randn(32, 2)
t = torch.linspace(0, 1, 10)
result = model(x, t)
print(f"✅ Targeted Neural ODE: Input: {x.shape}, Output: {result.shape}")