"""
Coupled System Solvers for Graph-SDE Dynamics
This module provides numerical solvers for systems of coupled spatial-temporal
dynamics, integrating graph-based spatial evolution with fractional SDE temporal evolution.
**Modeling scope (read before interpreting results)**
- **Spatial track:** explicit Euler on ``graph_dynamics(state, adjacency)`` (one step per
substep). No fractional derivative in space unless ``graph_dynamics`` implements it.
- **Temporal track:** history convolutions via :class:`FastHistoryConvolution` with
``gamma(alpha+1)`` scaling and a single ``dt**alpha`` factor per stepโan engineering
approximation to fractional SDE memory, not a full rigorous discretization proof.
- **Coupling:** both solvers add ``coupling_strength * (spatial - temporal)`` into the
temporal drift (operator splitting: after the spatial half-step; monolithic: inside
the combined drift). The spatial ODE does **not** receive an symmetric pull from
temporal in :class:`OperatorSplittingSolver` (only :class:`MonolithicSolver` adds
``k*(V-U)`` to ``dspatial``).
- **API:** ``coupling_type`` in :func:`solve_coupled_graph_sde` is reserved and **not**
used; only ``coupling_strength`` and the solver choice affect dynamics.
- **``multiscale``:** not implemented; raises ``ValueError``.
For production experiments, validate against a reference integrator or reduce step size
when coupling is strong; splitting error can dominate.
"""
import warnings
import numpy as np
from typing import Callable, Tuple, Optional, Dict, Any, Union, List
from dataclasses import dataclass
from abc import ABC, abstractmethod
from ..core.definitions import FractionalOrder
from .ode_solvers import gamma
from .sde_solvers import FastHistoryConvolution
[docs]
@dataclass
class CoupledSolution:
"""Solution object for coupled graph-SDE systems."""
t: np.ndarray
spatial: np.ndarray # Spatial (graph) state trajectory
temporal: np.ndarray # Temporal (SDE) state trajectory
coupling: np.ndarray # Coupling strength trajectory
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
[docs]
class CoupledSystemSolver(ABC):
"""Base class for coupled system solvers."""
[docs]
def __init__(
self,
fractional_orders: Union[float, FractionalOrder, Dict[str, float]],
coupling_strength: float = 1.0
):
"""
Initialize coupled system solver.
Args:
fractional_orders: Fractional order(s) for system
coupling_strength: Strength of spatial-temporal coupling
"""
# Handle different types of fractional orders
if isinstance(fractional_orders, dict):
self.fractional_orders = fractional_orders
elif isinstance(fractional_orders, (float, FractionalOrder)):
self.fractional_orders = {
'spatial': fractional_orders,
'temporal': fractional_orders
}
else:
raise ValueError("Invalid fractional_orders type")
self.coupling_strength = coupling_strength
[docs]
@abstractmethod
def solve(
self,
graph_dynamics: Callable,
sde_drift: Callable,
sde_diffusion: Callable,
adjacency: np.ndarray,
node_features: np.ndarray,
t_span: Tuple[float, float],
**kwargs
) -> CoupledSolution:
"""Solve coupled system."""
pass
[docs]
class OperatorSplittingSolver(CoupledSystemSolver):
"""
Operator splitting solver for graph-SDE dynamics.
Spatial evolution uses ``graph_dynamics`` on the spatial track. The temporal
(fractional SDE) track adds an explicit coupling term to the drift:
drift_total = sde_drift(t, temporal) + coupling_strength * (spatial - temporal)
using the current spatial state after the preceding spatial substeps in the split.
This matches the spirit of :class:`MonolithicSolver` (coupling in the temporal drift).
"""
[docs]
def __init__(
self,
fractional_orders: Union[float, FractionalOrder, Dict[str, float]],
coupling_strength: float = 1.0,
split_order: int = 2
):
"""
Initialize operator splitting solver.
Args:
fractional_orders: Fractional order(s)
coupling_strength: Coupling strength
split_order: Splitting order (1=Lie-Trotter, 2=Strang)
"""
super().__init__(fractional_orders, coupling_strength)
self.split_order = split_order
[docs]
def solve(
self,
graph_dynamics: Callable,
sde_drift: Callable,
sde_diffusion: Callable,
adjacency: np.ndarray,
node_features: np.ndarray,
t_span: Tuple[float, float],
num_steps: int = 100,
seed: Optional[int] = None,
**kwargs
) -> CoupledSolution:
"""
Solve using operator splitting.
For Strang splitting (order 2):
- Half step of spatial dynamics
- Full step of temporal dynamics
- Half step of spatial dynamics
"""
t0, tf = t_span
dt = (tf - t0) / num_steps
t = np.linspace(t0, tf, num_steps + 1)
# Initialize state
spatial_state = node_features.copy()
temporal_state = node_features.copy()
# Storage
spatial_traj = np.zeros((num_steps + 1, *spatial_state.shape))
temporal_traj = np.zeros((num_steps + 1, *temporal_state.shape))
coupling_traj = np.zeros(num_steps + 1)
spatial_traj[0] = spatial_state
temporal_traj[0] = temporal_state
rng = np.random.default_rng(seed)
# Initialize history for temporal fractional dynamics
alpha_t = self.fractional_orders.get('temporal', 0.5)
if isinstance(alpha_t, FractionalOrder):
alpha_t = alpha_t.alpha
dim = int(np.asarray(temporal_state).size)
drift_conv = FastHistoryConvolution(alpha_t, num_steps, dim)
diffusion_conv = FastHistoryConvolution(alpha_t, num_steps, dim)
gamma_factor = 1.0 / gamma(alpha_t + 1)
# Time stepping with operator splitting
for i in range(num_steps):
if self.split_order == 2:
# Strang: half spatial -> full temporal (with coupling) -> half spatial
spatial_state = self._spatial_step(
graph_dynamics, adjacency, spatial_state, dt / 2
)
temporal_state = self._temporal_step(
sde_drift,
sde_diffusion,
temporal_state,
t[i],
dt,
drift_conv,
diffusion_conv,
gamma_factor,
alpha_t,
temporal_traj[0],
spatial_for_coupling=spatial_state,
rng=rng,
)
spatial_state = self._spatial_step(
graph_dynamics, adjacency, spatial_state, dt / 2
)
else:
spatial_state = self._spatial_step(
graph_dynamics, adjacency, spatial_state, dt
)
temporal_state = self._temporal_step(
sde_drift,
sde_diffusion,
temporal_state,
t[i],
dt,
drift_conv,
diffusion_conv,
gamma_factor,
alpha_t,
temporal_traj[0],
spatial_for_coupling=spatial_state,
rng=rng,
)
spatial_traj[i + 1] = spatial_state
temporal_traj[i + 1] = temporal_state
coupling_traj[i + 1] = float(
self.coupling_strength * np.mean(np.abs(spatial_state - temporal_state))
)
# Create solution
solution = CoupledSolution(
t=t,
spatial=spatial_traj,
temporal=temporal_traj,
coupling=coupling_traj,
metadata={
'solver': 'operator_splitting',
'split_order': self.split_order,
'num_steps': num_steps,
'coupling_strength': self.coupling_strength,
'coupling_in_temporal_drift': True,
}
)
return solution
[docs]
def _spatial_step(
self,
graph_dynamics: Callable,
adjacency: np.ndarray,
state: np.ndarray,
dt: float
) -> np.ndarray:
"""Single spatial (graph) evolution step."""
# Apply graph dynamics
dspatial = graph_dynamics(state, adjacency)
return state + dt * dspatial
[docs]
def _temporal_step(
self,
drift: Callable,
diffusion: Callable,
state: np.ndarray,
t: float,
dt: float,
drift_conv: FastHistoryConvolution,
diffusion_conv: FastHistoryConvolution,
gamma_factor: float,
alpha: float,
initial_state: np.ndarray,
spatial_for_coupling: np.ndarray,
rng: np.random.Generator,
) -> np.ndarray:
"""Temporal fractional SDE substep with explicit drift coupling to the spatial track."""
coupling = self.coupling_strength * (spatial_for_coupling - state)
drift_val = drift(t, state) + coupling
diffusion_val = diffusion(t, state)
dim = int(np.asarray(state).size)
if np.isscalar(diffusion_val):
noise_dim = dim
g = np.full(dim, float(diffusion_val), dtype=float)
dW = rng.normal(0, np.sqrt(dt), size=(noise_dim,))
noise_term = (g * dW).reshape(state.shape)
elif np.asarray(diffusion_val).ndim == 0:
noise_dim = dim
g = np.full(dim, float(diffusion_val), dtype=float)
dW = rng.normal(0, np.sqrt(dt), size=(noise_dim,))
noise_term = (g * dW).reshape(state.shape)
elif np.asarray(diffusion_val).ndim == 1:
g = np.asarray(diffusion_val, dtype=float).reshape(-1)
if g.size != dim:
raise ValueError(
f"Diffusion vector length {g.size} does not match state size {dim}"
)
dW = rng.normal(0, np.sqrt(dt), size=(dim,))
noise_term = (g * dW).reshape(state.shape)
elif np.asarray(diffusion_val).ndim == 2:
gmat = np.asarray(diffusion_val, dtype=float)
_, m_in = gmat.shape
dW = rng.normal(0, np.sqrt(dt), size=(m_in,))
noise_flat = gmat @ dW
noise_term = noise_flat.reshape(state.shape)
else:
dW = rng.normal(0, np.sqrt(dt), size=state.shape)
noise_term = np.asarray(diffusion_val, dtype=float) * dW
drift_flat = np.asarray(drift_val, dtype=float).reshape(-1)
noise_flat = np.asarray(noise_term, dtype=float).reshape(-1)
drift_conv.update(drift_flat)
diffusion_conv.update(noise_flat)
drift_integral = drift_conv.convolve()
diffusion_integral = diffusion_conv.convolve()
initial_flat = np.asarray(initial_state, dtype=float).reshape(-1)
update = gamma_factor * (dt**alpha) * (drift_integral + diffusion_integral)
out_flat = initial_flat + update
return out_flat.reshape(state.shape)
[docs]
class MonolithicSolver(CoupledSystemSolver):
"""
Monolithic solver for strongly coupled graph-SDE systems.
Solves the full coupled system simultaneously for better accuracy
in strongly coupled regimes, at the cost of higher memory usage.
"""
[docs]
def solve(
self,
graph_dynamics: Callable,
sde_drift: Callable,
sde_diffusion: Callable,
adjacency: np.ndarray,
node_features: np.ndarray,
t_span: Tuple[float, float],
num_steps: int = 100,
seed: Optional[int] = None,
**kwargs
) -> CoupledSolution:
"""Solve monolithic coupled system."""
t0, tf = t_span
dt = (tf - t0) / num_steps
t = np.linspace(t0, tf, num_steps + 1)
# Combined state: [spatial; temporal]
combined_state = np.concatenate([node_features, node_features], axis=-1)
# Storage
combined_traj = np.zeros((num_steps + 1, *combined_state.shape))
combined_traj[0] = combined_state
rng = np.random.default_rng(seed)
# Time stepping
# Precompute gamma info for temporal memory
alpha = self.fractional_orders.get('temporal', 0.5)
if isinstance(alpha, FractionalOrder):
alpha = alpha.alpha
dim = int(np.asarray(node_features).size)
temporal_drift_conv = FastHistoryConvolution(alpha, num_steps, dim)
temporal_diffusion_conv = FastHistoryConvolution(alpha, num_steps, dim)
gamma_factor = 1.0 / gamma(alpha + 1)
# Time stepping
for i in range(num_steps):
# Split state
spatial_state = combined_state[..., :combined_state.shape[-1]//2]
temporal_state = combined_state[..., combined_state.shape[-1]//2:]
# --- SPATIAL PART (Integer order assumed usually) ---
# d/dt U = GraphDyn(U) + k(V-U)
dspatial = graph_dynamics(spatial_state, adjacency)
dspatial += self.coupling_strength * (temporal_state - spatial_state)
# --- TEMPORAL PART (Fractional order) ---
# D^alpha V = Drift(V) + Diff(V) dW + k(U-V)
# 1. Calculate the instantaneous force (drift/diff)
drift_val = sde_drift(t[i], temporal_state)
diffusion_val = sde_diffusion(t[i], temporal_state)
coupling_val = self.coupling_strength * (spatial_state - temporal_state)
# Total "Drift" input to the memory integral includes the coupling term!
# Because D^alpha V - k(U-V) = f(V) => D^alpha V = f(V) + k(U-V)
total_drift = drift_val + coupling_val
dW = rng.normal(0, np.sqrt(dt), size=temporal_state.shape)
noise_term = np.asarray(diffusion_val, dtype=float) * dW
total_drift_flat = np.asarray(total_drift, dtype=float).reshape(-1)
noise_flat = np.asarray(noise_term, dtype=float).reshape(-1)
temporal_drift_conv.update(total_drift_flat)
temporal_diffusion_conv.update(noise_flat)
# Compute memory integrals
drift_integral = temporal_drift_conv.convolve()
diffusion_integral = temporal_diffusion_conv.convolve()
# Calculate new states
# Spatial: standard Euler
spatial_new = spatial_state + dspatial * dt
# Temporal: Fractional integration
# V(t) = V(0) + 1/Gamma * int(...)
# Note: We must use the INITIAL temporal state for V(0)
# stored in combined_traj[0]
initial_temporal = combined_traj[0, ..., combined_traj.shape[-1]//2:]
inc = gamma_factor * (dt**alpha) * (drift_integral + diffusion_integral)
temporal_new = (
np.asarray(initial_temporal, dtype=float).reshape(-1) + inc
).reshape(temporal_state.shape)
# Update combined state
combined_state = np.concatenate([spatial_new, temporal_new], axis=-1)
# Save
combined_traj[i+1] = combined_state
# Split trajectories
spatial_traj = combined_traj[..., :combined_traj.shape[-1]//2]
temporal_traj = combined_traj[..., combined_traj.shape[-1]//2:]
coupling_traj = np.mean(np.abs(spatial_traj - temporal_traj), axis=(-2, -1))
solution = CoupledSolution(
t=t,
spatial=spatial_traj,
temporal=temporal_traj,
coupling=coupling_traj,
metadata={'solver': 'monolithic', 'num_steps': num_steps}
)
return solution
[docs]
def solve_coupled_graph_sde(
graph_dynamics: Callable,
sde_drift: Callable,
sde_diffusion: Callable,
adjacency: np.ndarray,
node_features: np.ndarray,
t_span: Tuple[float, float],
fractional_orders: Union[float, FractionalOrder, Dict[str, float]] = 0.5,
coupling_type: str = "bidirectional",
coupling_strength: float = 1.0,
solver: str = "operator_splitting",
**kwargs
) -> CoupledSolution:
"""
Solve coupled graph-SDE system.
See module docstring for limitations (unused ``coupling_type``, approximate temporal
fractional path, asymmetric coupling in operator splitting).
Args:
graph_dynamics: Spatial dynamics ``f(spatial, adjacency)``.
sde_drift: Temporal drift ``f(t, temporal)``.
sde_diffusion: Temporal diffusion ``g(t, temporal)``.
adjacency: Graph adjacency matrix.
node_features: Initial node features (shared as spatial/temporal initial shape).
t_span: Time interval.
fractional_orders: Scalar or dict with ``'spatial'`` / ``'temporal'`` keys; temporal
order drives ``FastHistoryConvolution``.
coupling_type: Reserved for future use; **currently ignored**.
coupling_strength: Prefactor on ``(spatial - temporal)`` in the temporal drift.
solver: ``"operator_splitting"`` or ``"monolithic"``.
**kwargs: Passed to the solver's ``solve`` (e.g. ``num_steps``, ``seed``).
Returns:
CoupledSolution
"""
if coupling_type != "bidirectional":
warnings.warn(
f"solve_coupled_graph_sde: coupling_type={coupling_type!r} is not implemented; "
"only coupling_strength is applied. Ignoring coupling_type.",
UserWarning,
stacklevel=2,
)
if solver == "operator_splitting":
solver_obj = OperatorSplittingSolver(fractional_orders, coupling_strength)
elif solver == "monolithic":
solver_obj = MonolithicSolver(fractional_orders, coupling_strength)
else:
raise ValueError(f"Unknown solver type: {solver}")
return solver_obj.solve(
graph_dynamics,
sde_drift,
sde_diffusion,
adjacency,
node_features,
t_span,
**kwargs
)