"""
Graph-SDE Coupling for Spatio-Temporal Dynamics
This module provides coupling layers that integrate spatial dynamics (via graph
neural networks) with temporal dynamics (via fractional SDEs) for modeling
spatio-temporal phenomena.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Any, Union
import numpy as np
from ..core.definitions import FractionalOrder
from .gnn_layers import FractionalGraphConv
from .neural_fsde import NeuralFractionalSDE, NeuralFSDEConfig
from .backends import BackendType
[docs]
class CouplingType:
"""Types of spatial-temporal coupling."""
BIDIRECTIONAL = "bidirectional" # Space โ Time
SPATIAL_TO_TEMPORAL = "spatial_to_temporal" # Space โ Time only
TEMPORAL_TO_SPATIAL = "temporal_to_spatial" # Time โ Space only
GATED = "gated" # Learned gating mechanism
[docs]
class SpatialTemporalCoupling(nn.Module):
"""
Coupling layer between spatial (graph) and temporal (SDE) dynamics.
Computes learned coupling between spatial embeddings and temporal states.
"""
[docs]
def __init__(
self,
spatial_dim: int,
temporal_dim: int,
coupling_dim: int,
coupling_type: str = CouplingType.BIDIRECTIONAL,
use_attention: bool = True
):
"""
Initialize coupling layer.
Args:
spatial_dim: Dimension of spatial (graph) features
temporal_dim: Dimension of temporal (SDE) features
coupling_dim: Dimension of coupling embedding
coupling_type: Type of coupling mechanism
use_attention: Use attention mechanism for coupling
"""
super().__init__()
self.spatial_dim = spatial_dim
self.temporal_dim = temporal_dim
self.coupling_dim = coupling_dim
self.coupling_type = coupling_type
self.use_attention = use_attention
# Spatial to temporal projection
self.spatial_to_temporal = nn.Linear(spatial_dim, coupling_dim)
# Temporal to spatial projection
self.temporal_to_spatial = nn.Linear(temporal_dim, coupling_dim)
# Coupling function
if coupling_type == CouplingType.GATED:
self.gate = nn.Sequential(
nn.Linear(coupling_dim * 2, coupling_dim),
nn.Sigmoid()
)
# Final projection back to original dimensions
self.temporal_projection = nn.Linear(coupling_dim, temporal_dim)
self.spatial_projection = nn.Linear(coupling_dim, spatial_dim)
[docs]
def forward(
self,
spatial_features: torch.Tensor,
temporal_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply spatial-temporal coupling.
Args:
spatial_features: Graph node features (batch, num_nodes, spatial_dim)
temporal_features: SDE state features (batch, num_nodes, temporal_dim)
Returns:
Tuple of (coupled_spatial, coupled_temporal)
"""
# Project to coupling space
spatial_embed = self.spatial_to_temporal(spatial_features)
temporal_embed = self.temporal_to_spatial(temporal_features)
# Compute coupling based on type
if self.coupling_type == CouplingType.BIDIRECTIONAL:
# Average coupling
coupled_embed = (spatial_embed + temporal_embed) / 2
elif self.coupling_type == CouplingType.SPATIAL_TO_TEMPORAL:
# Only spatial affects temporal
coupled_embed = spatial_embed
elif self.coupling_type == CouplingType.TEMPORAL_TO_SPATIAL:
# Only temporal affects spatial
coupled_embed = temporal_embed
elif self.coupling_type == CouplingType.GATED:
# Learned gating
gate = self.gate(torch.cat([spatial_embed, temporal_embed], dim=-1))
coupled_embed = gate * spatial_embed + (1 - gate) * temporal_embed
else:
raise ValueError(f"Unknown coupling type: {self.coupling_type}")
# Project back to original spaces
coupled_temporal = self.temporal_projection(coupled_embed)
coupled_spatial = self.spatial_projection(coupled_embed)
# Combine with residuals
coupled_temporal = coupled_temporal + temporal_features
coupled_spatial = coupled_spatial + spatial_features
return coupled_spatial, coupled_temporal
[docs]
class GraphFractionalSDELayer(nn.Module):
"""
Layer that couples graph-based spatial dynamics with fractional SDE temporal evolution.
Architecture:
- Graph convolution for spatial features
- Fractional SDE for temporal dynamics at each node
- Learned coupling between spatial and temporal embeddings
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
fractional_order: Union[float, FractionalOrder] = 0.5,
coupling_type: str = CouplingType.BIDIRECTIONAL,
num_sde_steps: int = 10,
backend: BackendType = BackendType.AUTO
):
"""
Initialize Graph-Fractional SDE layer.
Args:
input_dim: Input feature dimension
hidden_dim: Hidden dimension for both spatial and temporal
output_dim: Output feature dimension
fractional_order: Fractional order for SDE dynamics
coupling_type: Type of spatial-temporal coupling
num_sde_steps: Number of SDE integration steps
backend: Computation backend
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.fractional_order = FractionalOrder(fractional_order) if isinstance(
fractional_order, float) else fractional_order
self.num_sde_steps = num_sde_steps
self.backend = backend
# Spatial (graph) dynamics
self.spatial_layer = FractionalGraphConv(
in_channels=input_dim,
out_channels=hidden_dim,
fractional_order=fractional_order,
backend=backend
)
# Temporal (SDE) dynamics - using a simplified neural SDE
self.temporal_layer = nn.GRUCell(
input_dim,
hidden_dim
)
# SDE drift network
self.drift_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim)
)
# SDE diffusion network
self.diffusion_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Softplus()
)
# Coupling layer
self.coupling = SpatialTemporalCoupling(
spatial_dim=hidden_dim,
temporal_dim=hidden_dim,
coupling_dim=hidden_dim,
coupling_type=coupling_type
)
# Output projection
self.output_proj = nn.Linear(hidden_dim, output_dim)
[docs]
def drift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute drift term for SDE."""
return self.drift_net(x)
[docs]
def diffusion(self, x: torch.Tensor) -> torch.Tensor:
"""Compute diffusion term for SDE."""
return self.diffusion_net(x)
[docs]
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_weight: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass through graph-SDE layer.
Args:
x: Node features (batch, num_nodes, input_dim)
edge_index: Edge connectivity (2, num_edges)
edge_weight: Optional edge weights (num_edges,)
Returns:
Updated node features (batch, num_nodes, output_dim)
"""
batch_size, num_nodes, _ = x.shape
# Spatial dynamics: graph convolution
x_graph = x.view(-1, self.input_dim)
spatial_features = self.spatial_layer(x_graph, edge_index, edge_weight)
spatial_features = spatial_features.view(batch_size, num_nodes, self.hidden_dim)
# Temporal dynamics: SDE evolution
# Initialize temporal state from spatial features
temporal_state = spatial_features.clone()
# Simulate SDE for a few steps
dt = 0.1 # Time step
for _ in range(self.num_sde_steps):
# Compute drift
drift = self.drift(temporal_state)
# Generate noise
noise = torch.randn_like(temporal_state)
# SDE update (Euler-Maruyama)
alpha = self.fractional_order.alpha
temporal_state = (temporal_state +
dt**alpha * drift +
np.sqrt(dt) * self.diffusion(temporal_state) * noise)
# Apply coupling
coupled_spatial, coupled_temporal = self.coupling(spatial_features, temporal_state)
# Combine spatial and temporal
combined = (coupled_spatial + coupled_temporal) / 2
# Output projection
output = self.output_proj(combined)
return output
[docs]
class MultiScaleGraphSDE(nn.Module):
"""
Multi-scale graph-SDE network with adaptive time stepping.
Handles different time scales for graph updates vs SDE evolution,
optimal for stiff coupled systems.
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dims: list,
output_dim: int,
fractional_order: Union[float, FractionalOrder] = 0.5,
spatial_time_scale: float = 1.0,
temporal_time_scale: float = 0.1
):
"""
Initialize multi-scale graph-SDE.
Args:
input_dim: Input dimension
hidden_dims: List of hidden dimensions for each layer
output_dim: Output dimension
fractional_order: Fractional order for SDE
spatial_time_scale: Time scale for spatial (graph) dynamics
temporal_time_scale: Time scale for temporal (SDE) dynamics
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.output_dim = output_dim
self.spatial_time_scale = spatial_time_scale
self.temporal_time_scale = temporal_time_scale
# Build layers
self.layers = nn.ModuleList()
dims = [input_dim] + hidden_dims
for i in range(len(dims) - 1):
layer = GraphFractionalSDELayer(
input_dim=dims[i],
hidden_dim=dims[i+1],
output_dim=dims[i+1],
fractional_order=fractional_order
)
self.layers.append(layer)
# Output layer
self.output_layer = nn.Linear(hidden_dims[-1], output_dim)
[docs]
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_weight: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass with multi-scale dynamics.
Args:
x: Node features
edge_index: Edge connectivity
edge_weight: Optional edge weights
Returns:
Output features
"""
# Propagate through layers
for layer in self.layers:
x = layer(x, edge_index, edge_weight)
# Output projection
output = self.output_layer(x)
return output