Source code for hpfracc.solvers.coupled_solvers

"""
Coupled System Solvers for Graph-SDE Dynamics

This module provides numerical solvers for systems of coupled spatial-temporal
dynamics, integrating graph-based spatial evolution with fractional SDE temporal evolution.

**Modeling scope (read before interpreting results)**

- **Spatial track:** explicit Euler on ``graph_dynamics(state, adjacency)`` (one step per
  substep). No fractional derivative in space unless ``graph_dynamics`` implements it.
- **Temporal track:** history convolutions via :class:`FastHistoryConvolution` with
  ``gamma(alpha+1)`` scaling and a single ``dt**alpha`` factor per stepโ€”an engineering
  approximation to fractional SDE memory, not a full rigorous discretization proof.
- **Coupling:** both solvers add ``coupling_strength * (spatial - temporal)`` into the
  temporal drift (operator splitting: after the spatial half-step; monolithic: inside
  the combined drift). The spatial ODE does **not** receive an symmetric pull from
  temporal in :class:`OperatorSplittingSolver` (only :class:`MonolithicSolver` adds
  ``k*(V-U)`` to ``dspatial``).
- **API:** ``coupling_type`` in :func:`solve_coupled_graph_sde` is reserved and **not**
  used; only ``coupling_strength`` and the solver choice affect dynamics.
- **``multiscale``:** not implemented; raises ``ValueError``.

For production experiments, validate against a reference integrator or reduce step size
when coupling is strong; splitting error can dominate.
"""

import warnings

import numpy as np
from typing import Callable, Tuple, Optional, Dict, Any, Union, List
from dataclasses import dataclass
from abc import ABC, abstractmethod

from ..core.definitions import FractionalOrder
from .ode_solvers import gamma
from .sde_solvers import FastHistoryConvolution


[docs] @dataclass class CoupledSolution: """Solution object for coupled graph-SDE systems.""" t: np.ndarray spatial: np.ndarray # Spatial (graph) state trajectory temporal: np.ndarray # Temporal (SDE) state trajectory coupling: np.ndarray # Coupling strength trajectory metadata: Dict[str, Any] = None def __post_init__(self): if self.metadata is None: self.metadata = {}
[docs] class CoupledSystemSolver(ABC): """Base class for coupled system solvers."""
[docs] def __init__( self, fractional_orders: Union[float, FractionalOrder, Dict[str, float]], coupling_strength: float = 1.0 ): """ Initialize coupled system solver. Args: fractional_orders: Fractional order(s) for system coupling_strength: Strength of spatial-temporal coupling """ # Handle different types of fractional orders if isinstance(fractional_orders, dict): self.fractional_orders = fractional_orders elif isinstance(fractional_orders, (float, FractionalOrder)): self.fractional_orders = { 'spatial': fractional_orders, 'temporal': fractional_orders } else: raise ValueError("Invalid fractional_orders type") self.coupling_strength = coupling_strength
[docs] @abstractmethod def solve( self, graph_dynamics: Callable, sde_drift: Callable, sde_diffusion: Callable, adjacency: np.ndarray, node_features: np.ndarray, t_span: Tuple[float, float], **kwargs ) -> CoupledSolution: """Solve coupled system.""" pass
[docs] class OperatorSplittingSolver(CoupledSystemSolver): """ Operator splitting solver for graph-SDE dynamics. Spatial evolution uses ``graph_dynamics`` on the spatial track. The temporal (fractional SDE) track adds an explicit coupling term to the drift: drift_total = sde_drift(t, temporal) + coupling_strength * (spatial - temporal) using the current spatial state after the preceding spatial substeps in the split. This matches the spirit of :class:`MonolithicSolver` (coupling in the temporal drift). """
[docs] def __init__( self, fractional_orders: Union[float, FractionalOrder, Dict[str, float]], coupling_strength: float = 1.0, split_order: int = 2 ): """ Initialize operator splitting solver. Args: fractional_orders: Fractional order(s) coupling_strength: Coupling strength split_order: Splitting order (1=Lie-Trotter, 2=Strang) """ super().__init__(fractional_orders, coupling_strength) self.split_order = split_order
[docs] def solve( self, graph_dynamics: Callable, sde_drift: Callable, sde_diffusion: Callable, adjacency: np.ndarray, node_features: np.ndarray, t_span: Tuple[float, float], num_steps: int = 100, seed: Optional[int] = None, **kwargs ) -> CoupledSolution: """ Solve using operator splitting. For Strang splitting (order 2): - Half step of spatial dynamics - Full step of temporal dynamics - Half step of spatial dynamics """ t0, tf = t_span dt = (tf - t0) / num_steps t = np.linspace(t0, tf, num_steps + 1) # Initialize state spatial_state = node_features.copy() temporal_state = node_features.copy() # Storage spatial_traj = np.zeros((num_steps + 1, *spatial_state.shape)) temporal_traj = np.zeros((num_steps + 1, *temporal_state.shape)) coupling_traj = np.zeros(num_steps + 1) spatial_traj[0] = spatial_state temporal_traj[0] = temporal_state rng = np.random.default_rng(seed) # Initialize history for temporal fractional dynamics alpha_t = self.fractional_orders.get('temporal', 0.5) if isinstance(alpha_t, FractionalOrder): alpha_t = alpha_t.alpha dim = int(np.asarray(temporal_state).size) drift_conv = FastHistoryConvolution(alpha_t, num_steps, dim) diffusion_conv = FastHistoryConvolution(alpha_t, num_steps, dim) gamma_factor = 1.0 / gamma(alpha_t + 1) # Time stepping with operator splitting for i in range(num_steps): if self.split_order == 2: # Strang: half spatial -> full temporal (with coupling) -> half spatial spatial_state = self._spatial_step( graph_dynamics, adjacency, spatial_state, dt / 2 ) temporal_state = self._temporal_step( sde_drift, sde_diffusion, temporal_state, t[i], dt, drift_conv, diffusion_conv, gamma_factor, alpha_t, temporal_traj[0], spatial_for_coupling=spatial_state, rng=rng, ) spatial_state = self._spatial_step( graph_dynamics, adjacency, spatial_state, dt / 2 ) else: spatial_state = self._spatial_step( graph_dynamics, adjacency, spatial_state, dt ) temporal_state = self._temporal_step( sde_drift, sde_diffusion, temporal_state, t[i], dt, drift_conv, diffusion_conv, gamma_factor, alpha_t, temporal_traj[0], spatial_for_coupling=spatial_state, rng=rng, ) spatial_traj[i + 1] = spatial_state temporal_traj[i + 1] = temporal_state coupling_traj[i + 1] = float( self.coupling_strength * np.mean(np.abs(spatial_state - temporal_state)) ) # Create solution solution = CoupledSolution( t=t, spatial=spatial_traj, temporal=temporal_traj, coupling=coupling_traj, metadata={ 'solver': 'operator_splitting', 'split_order': self.split_order, 'num_steps': num_steps, 'coupling_strength': self.coupling_strength, 'coupling_in_temporal_drift': True, } ) return solution
[docs] def _spatial_step( self, graph_dynamics: Callable, adjacency: np.ndarray, state: np.ndarray, dt: float ) -> np.ndarray: """Single spatial (graph) evolution step.""" # Apply graph dynamics dspatial = graph_dynamics(state, adjacency) return state + dt * dspatial
[docs] def _temporal_step( self, drift: Callable, diffusion: Callable, state: np.ndarray, t: float, dt: float, drift_conv: FastHistoryConvolution, diffusion_conv: FastHistoryConvolution, gamma_factor: float, alpha: float, initial_state: np.ndarray, spatial_for_coupling: np.ndarray, rng: np.random.Generator, ) -> np.ndarray: """Temporal fractional SDE substep with explicit drift coupling to the spatial track.""" coupling = self.coupling_strength * (spatial_for_coupling - state) drift_val = drift(t, state) + coupling diffusion_val = diffusion(t, state) dim = int(np.asarray(state).size) if np.isscalar(diffusion_val): noise_dim = dim g = np.full(dim, float(diffusion_val), dtype=float) dW = rng.normal(0, np.sqrt(dt), size=(noise_dim,)) noise_term = (g * dW).reshape(state.shape) elif np.asarray(diffusion_val).ndim == 0: noise_dim = dim g = np.full(dim, float(diffusion_val), dtype=float) dW = rng.normal(0, np.sqrt(dt), size=(noise_dim,)) noise_term = (g * dW).reshape(state.shape) elif np.asarray(diffusion_val).ndim == 1: g = np.asarray(diffusion_val, dtype=float).reshape(-1) if g.size != dim: raise ValueError( f"Diffusion vector length {g.size} does not match state size {dim}" ) dW = rng.normal(0, np.sqrt(dt), size=(dim,)) noise_term = (g * dW).reshape(state.shape) elif np.asarray(diffusion_val).ndim == 2: gmat = np.asarray(diffusion_val, dtype=float) _, m_in = gmat.shape dW = rng.normal(0, np.sqrt(dt), size=(m_in,)) noise_flat = gmat @ dW noise_term = noise_flat.reshape(state.shape) else: dW = rng.normal(0, np.sqrt(dt), size=state.shape) noise_term = np.asarray(diffusion_val, dtype=float) * dW drift_flat = np.asarray(drift_val, dtype=float).reshape(-1) noise_flat = np.asarray(noise_term, dtype=float).reshape(-1) drift_conv.update(drift_flat) diffusion_conv.update(noise_flat) drift_integral = drift_conv.convolve() diffusion_integral = diffusion_conv.convolve() initial_flat = np.asarray(initial_state, dtype=float).reshape(-1) update = gamma_factor * (dt**alpha) * (drift_integral + diffusion_integral) out_flat = initial_flat + update return out_flat.reshape(state.shape)
[docs] class MonolithicSolver(CoupledSystemSolver): """ Monolithic solver for strongly coupled graph-SDE systems. Solves the full coupled system simultaneously for better accuracy in strongly coupled regimes, at the cost of higher memory usage. """
[docs] def solve( self, graph_dynamics: Callable, sde_drift: Callable, sde_diffusion: Callable, adjacency: np.ndarray, node_features: np.ndarray, t_span: Tuple[float, float], num_steps: int = 100, seed: Optional[int] = None, **kwargs ) -> CoupledSolution: """Solve monolithic coupled system.""" t0, tf = t_span dt = (tf - t0) / num_steps t = np.linspace(t0, tf, num_steps + 1) # Combined state: [spatial; temporal] combined_state = np.concatenate([node_features, node_features], axis=-1) # Storage combined_traj = np.zeros((num_steps + 1, *combined_state.shape)) combined_traj[0] = combined_state rng = np.random.default_rng(seed) # Time stepping # Precompute gamma info for temporal memory alpha = self.fractional_orders.get('temporal', 0.5) if isinstance(alpha, FractionalOrder): alpha = alpha.alpha dim = int(np.asarray(node_features).size) temporal_drift_conv = FastHistoryConvolution(alpha, num_steps, dim) temporal_diffusion_conv = FastHistoryConvolution(alpha, num_steps, dim) gamma_factor = 1.0 / gamma(alpha + 1) # Time stepping for i in range(num_steps): # Split state spatial_state = combined_state[..., :combined_state.shape[-1]//2] temporal_state = combined_state[..., combined_state.shape[-1]//2:] # --- SPATIAL PART (Integer order assumed usually) --- # d/dt U = GraphDyn(U) + k(V-U) dspatial = graph_dynamics(spatial_state, adjacency) dspatial += self.coupling_strength * (temporal_state - spatial_state) # --- TEMPORAL PART (Fractional order) --- # D^alpha V = Drift(V) + Diff(V) dW + k(U-V) # 1. Calculate the instantaneous force (drift/diff) drift_val = sde_drift(t[i], temporal_state) diffusion_val = sde_diffusion(t[i], temporal_state) coupling_val = self.coupling_strength * (spatial_state - temporal_state) # Total "Drift" input to the memory integral includes the coupling term! # Because D^alpha V - k(U-V) = f(V) => D^alpha V = f(V) + k(U-V) total_drift = drift_val + coupling_val dW = rng.normal(0, np.sqrt(dt), size=temporal_state.shape) noise_term = np.asarray(diffusion_val, dtype=float) * dW total_drift_flat = np.asarray(total_drift, dtype=float).reshape(-1) noise_flat = np.asarray(noise_term, dtype=float).reshape(-1) temporal_drift_conv.update(total_drift_flat) temporal_diffusion_conv.update(noise_flat) # Compute memory integrals drift_integral = temporal_drift_conv.convolve() diffusion_integral = temporal_diffusion_conv.convolve() # Calculate new states # Spatial: standard Euler spatial_new = spatial_state + dspatial * dt # Temporal: Fractional integration # V(t) = V(0) + 1/Gamma * int(...) # Note: We must use the INITIAL temporal state for V(0) # stored in combined_traj[0] initial_temporal = combined_traj[0, ..., combined_traj.shape[-1]//2:] inc = gamma_factor * (dt**alpha) * (drift_integral + diffusion_integral) temporal_new = ( np.asarray(initial_temporal, dtype=float).reshape(-1) + inc ).reshape(temporal_state.shape) # Update combined state combined_state = np.concatenate([spatial_new, temporal_new], axis=-1) # Save combined_traj[i+1] = combined_state # Split trajectories spatial_traj = combined_traj[..., :combined_traj.shape[-1]//2] temporal_traj = combined_traj[..., combined_traj.shape[-1]//2:] coupling_traj = np.mean(np.abs(spatial_traj - temporal_traj), axis=(-2, -1)) solution = CoupledSolution( t=t, spatial=spatial_traj, temporal=temporal_traj, coupling=coupling_traj, metadata={'solver': 'monolithic', 'num_steps': num_steps} ) return solution
[docs] def solve_coupled_graph_sde( graph_dynamics: Callable, sde_drift: Callable, sde_diffusion: Callable, adjacency: np.ndarray, node_features: np.ndarray, t_span: Tuple[float, float], fractional_orders: Union[float, FractionalOrder, Dict[str, float]] = 0.5, coupling_type: str = "bidirectional", coupling_strength: float = 1.0, solver: str = "operator_splitting", **kwargs ) -> CoupledSolution: """ Solve coupled graph-SDE system. See module docstring for limitations (unused ``coupling_type``, approximate temporal fractional path, asymmetric coupling in operator splitting). Args: graph_dynamics: Spatial dynamics ``f(spatial, adjacency)``. sde_drift: Temporal drift ``f(t, temporal)``. sde_diffusion: Temporal diffusion ``g(t, temporal)``. adjacency: Graph adjacency matrix. node_features: Initial node features (shared as spatial/temporal initial shape). t_span: Time interval. fractional_orders: Scalar or dict with ``'spatial'`` / ``'temporal'`` keys; temporal order drives ``FastHistoryConvolution``. coupling_type: Reserved for future use; **currently ignored**. coupling_strength: Prefactor on ``(spatial - temporal)`` in the temporal drift. solver: ``"operator_splitting"`` or ``"monolithic"``. **kwargs: Passed to the solver's ``solve`` (e.g. ``num_steps``, ``seed``). Returns: CoupledSolution """ if coupling_type != "bidirectional": warnings.warn( f"solve_coupled_graph_sde: coupling_type={coupling_type!r} is not implemented; " "only coupling_strength is applied. Ignoring coupling_type.", UserWarning, stacklevel=2, ) if solver == "operator_splitting": solver_obj = OperatorSplittingSolver(fractional_orders, coupling_strength) elif solver == "monolithic": solver_obj = MonolithicSolver(fractional_orders, coupling_strength) else: raise ValueError(f"Unknown solver type: {solver}") return solver_obj.solve( graph_dynamics, sde_drift, sde_diffusion, adjacency, node_features, t_span, **kwargs )