"""
Advanced SDE Adjoint Optimization Utilities
Provides memory-efficient checkpointing, mixed precision training,
and sparse gradient accumulation for neural fractional SDEs.
"""
import torch
import torch.nn as nn
from typing import Callable, Optional, Tuple, List, Dict, Any
from dataclasses import dataclass
import numpy as np
[docs]
@dataclass
class CheckpointConfig:
"""Configuration for gradient checkpointing."""
checkpoint_frequency: int = 10 # Save checkpoint every N steps
max_checkpoints: int = 100 # Maximum number of checkpoints to keep
checkpoint_strategy: str = "uniform" # "uniform", "adaptive", "save_all"
enable_checkpointing: bool = True
[docs]
@dataclass
class MixedPrecisionConfig:
"""Configuration for mixed precision training."""
enable_amp: bool = False # Automatic Mixed Precision
half_precision: bool = False # Use float16 throughout
dtype_fp16: torch.dtype = torch.float16
dtype_fp32: torch.dtype = torch.float32
loss_scaling: float = 1.0
[docs]
class SDEStateCheckpoint:
"""
Checkpoint manager for SDE trajectory states.
Enables memory-efficient training by saving intermediate states.
"""
[docs]
def __init__(self, config: CheckpointConfig):
self.config = config
self.checkpoints: List[Dict[str, Any]] = []
self.checkpoint_indices: List[int] = []
[docs]
def save_checkpoint(
self,
step: int,
state: torch.Tensor,
metadata: Optional[Dict[str, Any]] = None
):
"""Save a checkpoint of the current state."""
if not self.config.enable_checkpointing:
return
# Decide if we should save based on strategy
should_save = False
if self.config.checkpoint_strategy == "save_all":
should_save = True
elif self.config.checkpoint_strategy == "uniform":
should_save = (step % self.config.checkpoint_frequency == 0)
elif self.config.checkpoint_strategy == "adaptive":
# Adaptive strategy: save more frequently in early stages
freq = max(1, self.config.checkpoint_frequency // (step // 100 + 1))
should_save = (step % freq == 0)
if should_save:
checkpoint = {
'step': step,
'state': state.detach().clone(),
'metadata': metadata or {}
}
self.checkpoints.append(checkpoint)
self.checkpoint_indices.append(step)
# Limit number of checkpoints
if len(self.checkpoints) > self.config.max_checkpoints:
# Remove oldest checkpoint
self.checkpoints.pop(0)
self.checkpoint_indices.pop(0)
[docs]
def load_checkpoint(self, step: int) -> Optional[torch.Tensor]:
"""Load checkpoint closest to the given step."""
if not self.checkpoints:
return None
# Find closest checkpoint
idx = min(range(len(self.checkpoint_indices)),
key=lambda i: abs(self.checkpoint_indices[i] - step))
if idx is not None:
return self.checkpoints[idx]['state']
return None
[docs]
def clear(self):
"""Clear all checkpoints."""
self.checkpoints.clear()
self.checkpoint_indices.clear()
[docs]
class MixedPrecisionManager:
"""
Manager for mixed precision training in SDEs.
Handles automatic mixed precision (AMP) and float16 operations.
"""
[docs]
def __init__(self, config: MixedPrecisionConfig):
self.config = config
self.scaler: Optional[torch.cuda.amp.GradScaler] = None
if self.config.enable_amp:
self.scaler = torch.cuda.amp.GradScaler()
[docs]
def autocast(self):
"""Create autocast context for mixed precision."""
if self.config.enable_amp:
return torch.cuda.amp.autocast()
elif self.config.half_precision:
return torch.cuda.amp.autocast(dtype=self.config.dtype_fp16)
else:
# Context manager that does nothing
return torch.no_grad()
[docs]
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Scale loss for mixed precision training."""
if self.scaler is not None:
return self.scaler.scale(loss)
return loss * self.config.loss_scaling
[docs]
def scale_gradients(self, loss: torch.Tensor):
"""Scale gradients after backward pass."""
if self.scaler is not None:
self.scaler.scale(loss).backward()
else:
(loss * self.config.loss_scaling).backward()
[docs]
def step_optimizer(self, optimizer: torch.optim.Optimizer):
"""Update optimizer with scaled gradients."""
if self.scaler is not None:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
[docs]
def unscale_gradients(self, optimizer: torch.optim.Optimizer):
"""Unscale gradients for clipping."""
if self.scaler is not None:
self.scaler.unscale_(optimizer)
[docs]
class SparseGradientAccumulator:
"""
Accumulator for sparse gradients in high-dimensional SDE systems.
Reduces memory usage by only storing non-zero gradients.
"""
[docs]
def __init__(self, sparsity_threshold: float = 1e-6):
"""
Initialize sparse gradient accumulator.
Args:
sparsity_threshold: Threshold below which gradients are considered zero
"""
self.sparsity_threshold = sparsity_threshold
self.accumulated_grads: List[Tuple[torch.Tensor, torch.Tensor]] = []
[docs]
def accumulate(self, grad: torch.Tensor, param_name: str = None):
"""
Accumulate a gradient with sparsity.
Args:
grad: Gradient tensor
param_name: Optional parameter name for debugging
"""
# Identify non-zero elements
mask = torch.abs(grad) > self.sparsity_threshold
if mask.any():
# Store only non-zero elements
sparse_grad = grad[mask]
indices = torch.nonzero(mask, as_tuple=False).squeeze()
self.accumulated_grads.append((sparse_grad, indices))
[docs]
def get_dense_gradient(self, shape: Tuple[int, ...]) -> torch.Tensor:
"""
Reconstruct dense gradient from sparse representation.
Args:
shape: Original gradient shape
Returns:
Dense gradient tensor
"""
grad = torch.zeros(shape, dtype=torch.float32)
for sparse_grad, indices in self.accumulated_grads:
grad.view(-1)[indices.view(-1)] = sparse_grad
return grad
[docs]
def clear(self):
"""Clear accumulated gradients."""
self.accumulated_grads.clear()
[docs]
def get_sparsity_ratio(self) -> float:
"""Get ratio of non-zero to total elements."""
if not self.accumulated_grads:
return 0.0
total_elements = 0
nonzero_elements = 0
for sparse_grad, indices in self.accumulated_grads:
# Estimate total elements from indices
total_elements += (indices.max().item() + 1) if indices.numel() > 0 else 0
nonzero_elements += sparse_grad.numel()
if total_elements == 0:
return 0.0
return nonzero_elements / total_elements
[docs]
def checkpoint_trajectory(
func: Callable,
*args,
checkpoint_freq: int = 10,
**kwargs
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Execute function with checkpointing to save memory.
Uses gradient checkpointing to trade compute for memory.
Args:
func: Function to execute
*args: Arguments to pass to function
checkpoint_freq: Frequency of checkpoints
**kwargs: Additional keyword arguments
Returns:
Tuple of (final_output, checkpointed_states)
"""
# Use PyTorch's gradient checkpointing if available
try:
from torch.utils.checkpoint import checkpoint as torch_checkpoint
return torch_checkpoint(func, *args, **kwargs)
except ImportError:
# Fallback to manual checkpointing
pass
# Manual checkpointing implementation
# This is a simplified version
return func(*args, **kwargs)
[docs]
class SDEAdjointOptimizer:
"""
Optimizer wrapper for SDE adjoint training with advanced optimizations.
Combines checkpointing, mixed precision, and sparse gradients.
"""
[docs]
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
checkpoint_config: Optional[CheckpointConfig] = None,
mixed_precision_config: Optional[MixedPrecisionConfig] = None,
enable_sparse_gradients: bool = False
):
self.model = model
self.optimizer = optimizer
self.checkpoint_config = checkpoint_config or CheckpointConfig()
self.mixed_precision_config = mixed_precision_config or MixedPrecisionConfig()
self.checkpoint_manager = SDEStateCheckpoint(self.checkpoint_config)
self.mixed_precision_manager = MixedPrecisionManager(self.mixed_precision_config)
self.sparse_accumulator = SparseGradientAccumulator() if enable_sparse_gradients else None
[docs]
def step(self, loss: torch.Tensor):
"""
Optimization step with advanced features.
Args:
loss: Loss tensor to optimize
"""
self.optimizer.zero_grad()
# Scale loss for mixed precision
scaled_loss = self.mixed_precision_manager.scale_loss(loss)
# Backward pass
with self.mixed_precision_manager.autocast():
scaled_loss.backward()
# Sparse gradient accumulation
if self.sparse_accumulator is not None:
for name, param in self.model.named_parameters():
if param.grad is not None:
self.sparse_accumulator.accumulate(param.grad, name)
# Step optimizer
self.mixed_precision_manager.step_optimizer(self.optimizer)
[docs]
def save_state_checkpoint(self, step: int, state: torch.Tensor, metadata: Dict[str, Any] = None):
"""Save state checkpoint."""
self.checkpoint_manager.save_checkpoint(step, state, metadata)
[docs]
def load_state_checkpoint(self, step: int) -> Optional[torch.Tensor]:
"""Load state checkpoint."""
return self.checkpoint_manager.load_checkpoint(step)
[docs]
def clear_checkpoints(self):
"""Clear all checkpoints."""
self.checkpoint_manager.clear()
if self.sparse_accumulator is not None:
self.sparse_accumulator.clear()