"""
Backend Management System for Multi-Framework Support
This module provides unified interfaces for PyTorch, JAX, and NUMBA backends,
enabling seamless switching between frameworks and automatic backend selection
based on data type, hardware availability, and performance requirements.
"""
from typing import Optional, Any, Dict, List, Callable
from enum import Enum
import warnings
import importlib
# Backend availability checking (lazy; do not import heavy libs at module import)
def _spec_available(module_name: str) -> bool:
try:
return importlib.util.find_spec(module_name) is not None
except Exception:
return False
TORCH_AVAILABLE = _spec_available("torch")
JAX_AVAILABLE = _spec_available("jax") and _spec_available("jax.numpy")
NUMPY_AVAILABLE = _spec_available("numpy")
# NUMBA lane uses NumPy arrays; consider lane available if NumPy exists
NUMBA_AVAILABLE = _spec_available("numba") or NUMPY_AVAILABLE
[docs]
class BackendType(Enum):
"""Available computation backends"""
TORCH = "torch"
JAX = "jax"
NUMBA = "numba"
AUTO = "auto"
[docs]
class BackendManager:
"""
Manages backend selection and provides unified interfaces
This class handles automatic backend selection based on:
- Data type and size
- Hardware availability (CPU/GPU)
- Performance requirements
- User preferences
"""
[docs]
def __init__(
self,
preferred_backend: BackendType = BackendType.AUTO,
force_cpu: bool = False,
enable_jit: bool = True,
enable_gpu: bool = True
):
self.preferred_backend = preferred_backend
self.force_cpu = force_cpu
self.enable_jit = enable_jit
self.enable_gpu = enable_gpu
# Detect available backends
self.available_backends = self._detect_available_backends()
# Current active backend
self.active_backend = self._select_optimal_backend()
# Backend-specific configurations
self.backend_configs = self._initialize_backend_configs()
print(
f"🎯 Backend Manager initialized with {self.active_backend.value}")
print(
f"📊 Available backends: {[b.value for b in self.available_backends]}")
[docs]
def _detect_available_backends(self) -> List[BackendType]:
"""Detect which backends are available on the system"""
available = []
if TORCH_AVAILABLE:
available.append(BackendType.TORCH)
try:
torch = importlib.import_module("torch")
if hasattr(torch, "cuda") and torch.cuda.is_available() and not self.force_cpu:
print("🚀 PyTorch CUDA support detected")
except Exception:
pass
if JAX_AVAILABLE:
available.append(BackendType.JAX)
try:
jax = importlib.import_module("jax")
devices = jax.devices()
if any('gpu' in str(d).lower()
for d in devices) and not self.force_cpu:
print("🚀 JAX GPU support detected")
except BaseException:
pass
if NUMBA_AVAILABLE:
available.append(BackendType.NUMBA)
try:
numba = importlib.import_module("numba")
if hasattr(numba, 'cuda') and numba.cuda.is_available() and not self.force_cpu:
print("🚀 NUMBA CUDA support detected")
except BaseException:
pass
if not available:
raise RuntimeError("No computation backends available!")
return available
[docs]
def _select_optimal_backend(self) -> BackendType:
"""Select the optimal backend based on preferences and availability"""
if self.preferred_backend == BackendType.AUTO:
# Prefer PyTorch by default to match test expectations and widest
# API coverage
if BackendType.TORCH in self.available_backends:
return BackendType.TORCH
elif BackendType.JAX in self.available_backends and self.enable_gpu:
return BackendType.JAX
elif BackendType.NUMBA in self.available_backends:
return BackendType.NUMBA
else:
return self.available_backends[0]
else:
if self.preferred_backend in self.available_backends:
return self.preferred_backend
else:
warnings.warn(
f"Preferred backend {self.preferred_backend.value} not available, using {self.available_backends[0].value}")
return self.available_backends[0]
[docs]
def _initialize_backend_configs(self) -> Dict[BackendType, Dict[str, Any]]:
"""Initialize backend-specific configurations"""
configs = {}
# PyTorch configuration (lazy import)
if BackendType.TORCH in self.available_backends:
try:
torch = importlib.import_module("torch")
configs[BackendType.TORCH] = {
'device': 'cuda' if hasattr(torch, 'cuda') and torch.cuda.is_available() and not self.force_cpu else 'cpu',
'dtype': getattr(torch, 'float32', None),
'enable_amp': True,
'enable_compile': hasattr(torch, 'compile'),
}
except Exception:
configs[BackendType.TORCH] = {
'device': 'cpu',
'dtype': None,
'enable_amp': False,
'enable_compile': False,
}
# JAX configuration (lazy import)
if BackendType.JAX in self.available_backends:
try:
jnp = importlib.import_module("jax.numpy")
configs[BackendType.JAX] = {
'device': 'gpu' if self.enable_gpu and not self.force_cpu else 'cpu',
'dtype': getattr(jnp, 'float32', None),
'enable_jit': self.enable_jit,
'enable_x64': False,
'enable_amp': True,
}
except Exception:
configs[BackendType.JAX] = {
'device': 'cpu',
'dtype': None,
'enable_jit': False,
'enable_x64': False,
'enable_amp': False,
}
# NUMBA configuration (lazy import)
if BackendType.NUMBA in self.available_backends:
# Prefer numpy-backed lane if numba not available
np = importlib.import_module("numpy") if NUMPY_AVAILABLE else None
gpu_available = False
try:
numba = importlib.import_module("numba")
try:
gpu_available = hasattr(
numba, 'cuda') and numba.cuda.is_available()
except BaseException:
gpu_available = False
dtype_val = getattr(numba, 'float32', None) or (
getattr(np, 'float32', None) if np else None)
except Exception:
dtype_val = getattr(np, 'float32', None) if np else None
configs[BackendType.NUMBA] = {
'device': 'gpu' if gpu_available and not self.force_cpu else 'cpu',
'dtype': dtype_val,
'enable_jit': self.enable_jit and _spec_available("numba"),
'enable_parallel': _spec_available("numba"),
'enable_fastmath': _spec_available("numba"),
}
return configs
[docs]
def get_backend_config(
self, backend: Optional[BackendType] = None) -> Dict[str, Any]:
"""Get configuration for a specific backend"""
backend = backend or self.active_backend
return self.backend_configs.get(backend, {})
[docs]
def switch_backend(self, backend: BackendType) -> bool:
"""Switch to a different backend"""
if backend in self.available_backends:
self.active_backend = backend
print(f"🔄 Switched to {backend.value} backend")
return True
else:
warnings.warn(f"Backend {backend.value} not available")
return False
[docs]
def get_tensor_lib(self) -> Any:
"""Get the active tensor library"""
if self.active_backend == BackendType.TORCH:
return importlib.import_module("torch")
elif self.active_backend == BackendType.JAX:
return importlib.import_module("jax.numpy")
elif self.active_backend == BackendType.NUMBA:
# NUMBA lane arrays are NumPy
return importlib.import_module("numpy")
else:
raise RuntimeError(f"Unknown backend: {self.active_backend}")
[docs]
def create_tensor(self, data: Any, **kwargs) -> Any:
"""Create a tensor in the active backend"""
if self.active_backend == BackendType.TORCH:
# Ensure consistent dtype for PyTorch
if 'dtype' not in kwargs:
# Preserve integer types for classification targets
if hasattr(data, 'dtype') and 'int' in str(data.dtype):
kwargs['dtype'] = importlib.import_module("torch").long
else:
kwargs['dtype'] = importlib.import_module("torch").float32
return importlib.import_module("torch").tensor(data, **kwargs)
elif self.active_backend == BackendType.JAX:
# Ensure consistent dtype for JAX
if 'dtype' not in kwargs:
# Preserve integer types for classification targets
if hasattr(data, 'dtype') and 'int' in str(data.dtype):
kwargs['dtype'] = importlib.import_module(
"jax.numpy").int32
else:
kwargs['dtype'] = importlib.import_module(
"jax.numpy").float32
return importlib.import_module("jax.numpy").array(data, **kwargs)
elif self.active_backend == BackendType.NUMBA:
# NUMBA works with numpy arrays
import numpy as np
if 'dtype' not in kwargs:
# Preserve integer types for classification targets
if hasattr(data, 'dtype') and 'int' in str(data.dtype):
kwargs['dtype'] = np.int32
else:
kwargs['dtype'] = np.float32
return np.array(data, **kwargs)
else:
raise RuntimeError(f"Unknown backend: {self.active_backend}")
[docs]
def to_device(self, tensor: Any, device: Optional[str] = None) -> Any:
"""Move tensor to specified device"""
if self.active_backend == BackendType.TORCH:
return tensor.to(
device or self.backend_configs[BackendType.TORCH]['device'])
elif self.active_backend == BackendType.JAX:
# JAX handles device placement differently
return tensor
elif self.active_backend == BackendType.NUMBA:
# NUMBA handles device placement differently
return tensor
else:
raise RuntimeError(f"Unknown backend: {self.active_backend}")
[docs]
def compile_function(self, func: Callable) -> Callable:
"""Compile a function using the active backend's compilation system"""
if self.active_backend == BackendType.TORCH:
torch = importlib.import_module("torch")
if hasattr(torch, 'compile'):
return torch.compile(func)
else:
return func
elif self.active_backend == BackendType.JAX:
if self.enable_jit:
jax = importlib.import_module("jax")
return jax.jit(func)
else:
return func
elif self.active_backend == BackendType.NUMBA:
if self.enable_jit and _spec_available("numba"):
numba = importlib.import_module("numba")
return numba.jit(func)
else:
return func
else:
return func
# Global backend manager instance
_backend_manager: Optional[BackendManager] = None
[docs]
def get_backend_manager() -> BackendManager:
"""Get the global backend manager instance"""
global _backend_manager
if _backend_manager is None:
_backend_manager = BackendManager()
return _backend_manager
[docs]
def set_backend_manager(manager: BackendManager) -> None:
"""Set the global backend manager instance"""
global _backend_manager
_backend_manager = manager
[docs]
def get_active_backend() -> BackendType:
"""Get the currently active backend"""
return get_backend_manager().active_backend
[docs]
def switch_backend(backend: BackendType) -> bool:
"""Switch to a different backend"""
return get_backend_manager().switch_backend(backend)