Source code for hpfracc.ml.tensor_ops

"""
Unified Tensor Operations for Multi-Backend Support

This module provides consistent tensor operations across PyTorch, JAX, and a
NumPy-backed "NUMBA lane" (arrays are NumPy; numba is a compiler elsewhere),
enabling seamless switching between frameworks while maintaining the same API.
"""

from typing import Optional, Union, Any, List, Tuple
from contextlib import nullcontext
import warnings
import importlib
import os

from .backends import get_backend_manager, BackendType
from .adapters import get_optimal_adapter, HighPerformanceAdapter

import numpy as _np  # used as a safe NumPy namespace at construction


[docs] class TensorOps: """ Unified tensor operations across different backends. Notes: - AUTO is resolved to a concrete, installed backend during __init__. - NUMBA lane uses NumPy arrays (numba itself is not a tensor library). - JAX random ops require a PRNG key; pass via kwargs (key=...). """
[docs] def __init__(self, backend: Optional[Union[BackendType, str]] = None): # Backend manager might import lazily; fall back to NumPy-only if unavailable try: backend_manager = get_backend_manager() except Exception: backend_manager = None # Convert string backend to enum if needed if isinstance(backend, str): try: backend = BackendType(backend) except ValueError: raise ValueError( f"Unknown backend: {backend}. Available backends: {[b.value for b in BackendType]}") # Resolve requested (or active) backend into an installed concrete choice, # with sensible fallbacks. self.backend, self.tensor_lib = self._resolve_backend( backend, backend_manager) # Construct optimized adapter for future delegation self._adapter: HighPerformanceAdapter try: self._adapter = HighPerformanceAdapter(self.backend) # Ensure underlying lib is loaded _ = self._adapter.get_lib() except Exception: # Fallback: create adapter with current backend self._adapter = HighPerformanceAdapter()
# ------------------------ Backend resolution ------------------------
[docs] def _resolve_backend(self, backend: Optional[BackendType], backend_manager): """ Pick a concrete, installed backend with sensible fallbacks. Priority: 1) explicit `backend` (if not AUTO) when installed 2) backend_manager.active_backend (if not AUTO) when installed 3) fallback order: TORCH -> JAX -> NUMBA (NumPy) """ candidates: List[BackendType] = [] # 1) explicit request (if provided and not AUTO) if backend is not None and backend != BackendType.AUTO: candidates.append(backend) # 2) manager's active (if not AUTO) ab = getattr(backend_manager, "active_backend", None) if backend_manager is not None else None if ab is not None and ab != BackendType.AUTO: # Only honor manager active backend if not disabled by env disable_map_ab = { BackendType.TORCH: os.getenv("HPFRACC_DISABLE_TORCH", "0") == "1", BackendType.JAX: os.getenv("HPFRACC_DISABLE_JAX", "0") == "1", BackendType.NUMBA: os.getenv("HPFRACC_DISABLE_NUMBA", "0") == "1", } if not disable_map_ab.get(ab, False): candidates.append(ab) # 3) standard fallbacks (honor env disables) disable_map = { BackendType.TORCH: os.getenv("HPFRACC_DISABLE_TORCH", "0") == "1", BackendType.JAX: os.getenv("HPFRACC_DISABLE_JAX", "0") == "1", BackendType.NUMBA: os.getenv("HPFRACC_DISABLE_NUMBA", "0") == "1", } for b in (BackendType.TORCH, BackendType.JAX, BackendType.NUMBA): if b not in candidates and not disable_map.get(b, False): candidates.append(b) last_err: Optional[Exception] = None for b in candidates: try: lib = self._get_tensor_lib_for_backend(b) return b, lib except Exception as e: # torch/jax may raise RuntimeError/AttributeError on import last_err = e continue # If nothing worked, default to NumPy lane unconditionally return BackendType.NUMBA, _np
[docs] def _get_tensor_lib_for_backend(self, backend: BackendType) -> Any: """Get tensor library for a specific backend (imports guarded).""" if backend == BackendType.TORCH: torch = importlib.import_module("torch") return torch elif backend == BackendType.JAX: jnp = importlib.import_module("jax.numpy") return jnp elif backend == BackendType.NUMBA: # Use NumPy namespace for arrays/ops; numba is a compiler elsewhere. return _np else: # For constructor edge-cases, fall back to TORCH torch = importlib.import_module("torch") return torch
# ------------------------ Creation / conversion ------------------------
[docs] def create_tensor(self, data: Any, **kwargs) -> Any: """Create a tensor in the current backend.""" # Filter backend-specific args where necessary if self.backend == BackendType.TORCH: # Remove requires_grad from kwargs if it's False (default behavior) torch_kwargs = kwargs.copy() if 'requires_grad' in torch_kwargs and not torch_kwargs['requires_grad']: del torch_kwargs['requires_grad'] # Normalize dtype/device from strings dtype = torch_kwargs.get('dtype', None) device = torch_kwargs.get('device', None) torch = self.tensor_lib if isinstance(dtype, str): dtype_map = { 'float32': torch.float32, 'float': torch.float32, 'fp32': torch.float32, 'float64': torch.float64, 'double': torch.float64, 'fp64': torch.float64, 'float16': torch.float16, 'half': torch.float16, 'fp16': torch.float16, 'bfloat16': getattr(torch, 'bfloat16', None), 'bf16': getattr(torch, 'bfloat16', None), 'int64': torch.int64, 'long': torch.int64, 'int32': torch.int32, 'int': torch.int32, 'int16': torch.int16, 'short': torch.int16, 'int8': torch.int8, 'uint8': torch.uint8, 'bool': torch.bool, } mapped = dtype_map.get(dtype.lower()) if mapped is not None: torch_kwargs['dtype'] = mapped if isinstance(device, str): torch_kwargs['device'] = torch.device(device) return self.tensor_lib.tensor(data, **torch_kwargs) elif self.backend == BackendType.JAX: # JAX doesn't support requires_grad jax_kwargs = {k: v for k, v in kwargs.items() if k != 'requires_grad'} return self.tensor_lib.array(data, **jax_kwargs) elif self.backend == BackendType.NUMBA: # NUMBA lane: remove requires_grad; arrays are NumPy nb_kwargs = {k: v for k, v in kwargs.items() if k != 'requires_grad'} return self.tensor_lib.array(data, **nb_kwargs) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def tensor(self, data: Any, **kwargs) -> Any: """Alias for create_tensor.""" return self.create_tensor(data, **kwargs)
[docs] def from_numpy(self, array: Any) -> Any: if self.backend == BackendType.TORCH: torch = self.tensor_lib return torch.from_numpy(array) elif self.backend == BackendType.JAX: jnp = self.tensor_lib return jnp.array(array) elif self.backend == BackendType.NUMBA: return array else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def to_numpy(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: return tensor.detach().cpu().numpy() elif self.backend == BackendType.JAX: import jax import numpy as np return np.asarray(jax.device_get(tensor)) elif self.backend == BackendType.NUMBA: return tensor else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def no_grad(self): """ Context manager for disabling gradient computation. - PyTorch: torch.no_grad() - JAX: there is no true 'no_grad' context; we return a nullcontext(). Use jax.lax.stop_gradient at call sites if you need it. - NUMBA lane: nullcontext() """ if self.backend == BackendType.TORCH: torch = self.tensor_lib return torch.no_grad() elif self.backend == BackendType.JAX: # JAX doesn't have a no_grad context manager - it uses functional programming # Return nullcontext as documented return nullcontext() elif self.backend == BackendType.NUMBA: return nullcontext() else: raise RuntimeError("Unknown backend")
# ------------------------ Array constructors ------------------------
[docs] def zeros(self, shape: Tuple[int, ...], **kwargs) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.zeros(shape, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np return np.zeros(shape, **kwargs) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def ones(self, shape: Tuple[int, ...], **kwargs) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.ones(shape, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np return np.ones(shape, **kwargs) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def eye(self, n: int, **kwargs) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.eye(n, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np return np.eye(n, **kwargs) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def arange(self, start: int, end: int, step: int = 1, **kwargs) -> Any: if self.backend == BackendType.TORCH: # Default dtype to float32 to satisfy tests unless provided import torch if 'dtype' not in kwargs: kwargs['dtype'] = torch.float32 return self.tensor_lib.arange(start, end, step, **kwargs) elif self.backend == BackendType.JAX: return self.tensor_lib.arange(start, end, step, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np return np.arange(start, end, step, **kwargs) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def linspace(self, start: float, end: float, num: int, **kwargs) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.linspace(start, end, num, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np return np.linspace(start, end, num, **kwargs) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def zeros_like(self, tensor: Any, **kwargs) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.zeros_like(tensor, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np if hasattr(tensor, 'shape'): return np.zeros_like(tensor, **kwargs) return np.zeros(1, **kwargs) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def ones_like(self, tensor: Any, **kwargs) -> Any: """Create a tensor of ones with the same shape as input tensor.""" if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.ones_like(tensor, **kwargs) elif self.backend == BackendType.NUMBA: import numpy as np if hasattr(tensor, 'shape'): return np.ones_like(tensor, **kwargs) return np.ones(1, **kwargs) else: raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Basic transforms ------------------------
[docs] def sqrt(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.sqrt(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.sqrt(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def stack(self, tensors: List[Any], dim: int = 0) -> Any: if self.backend == BackendType.TORCH: return self.tensor_lib.stack(tensors, dim=dim) elif self.backend == BackendType.JAX: return self.tensor_lib.stack(tensors, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.stack(tensors, axis=dim) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def cat(self, tensors: List[Any], dim: int = 0) -> Any: if self.backend == BackendType.TORCH: return self.tensor_lib.cat(tensors, dim=dim) elif self.backend == BackendType.JAX: return self.tensor_lib.concatenate(tensors, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.concatenate(tensors, axis=dim) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def reshape(self, tensor: Any, shape: Tuple[int, ...]) -> Any: return tensor.reshape(shape)
[docs] def repeat(self, tensor: Any, repeats: Union[int, Tuple[int, ...]], dim: int = 0) -> Any: """ Repeat elements along a specified axis (element-wise repeat). For tiling the whole array shape, use `tile(...)` helper below. """ if self.backend == BackendType.TORCH: import torch if isinstance(repeats, int): if hasattr(tensor, 'dim'): rank = tensor.dim() if rank == 0: return torch.repeat_interleave(tensor, repeats, dim=0) if dim >= rank: # Treat as tiling across axes when dim is out-of-range if rank == 1: # Build a 2D tile with shape (repeats*L, dim*L) L = tensor.shape[0] row = tensor.repeat(dim) return row.unsqueeze(0).repeat(repeats * L, 1) if rank == 2: return tensor.repeat(repeats, repeats) reps = [1] * (rank - 1) + [repeats] return tensor.repeat(*reps) valid_dim = max(min(dim, rank - 1), -rank) return torch.repeat_interleave(tensor, repeats, dim=valid_dim) # Fallback: interleave along dim 0 return torch.repeat_interleave(tensor, repeats, dim=0) # tuple/sequence: tile return tensor.repeat(*repeats) elif self.backend == BackendType.JAX: return self.tensor_lib.repeat(tensor, repeats, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.repeat(tensor, repeats, axis=dim) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def tile(self, tensor: Any, reps: Union[int, Tuple[int, ...]]) -> Any: """Tile (broadcast repeat) tensor like np.tile / torch.repeat(*reps).""" if self.backend == BackendType.TORCH: return tensor.repeat(*((reps,) if isinstance(reps, int) else reps)) elif self.backend == BackendType.JAX: # Try to use jnp.tile if available (modern JAX versions have it) try: return self.tensor_lib.tile(tensor, reps) except AttributeError: # Fallback for older JAX versions import numpy as np return self.tensor_lib.array(np.tile(self.to_numpy(tensor), reps)) elif self.backend == BackendType.NUMBA: import numpy as np return np.tile(tensor, reps) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def clip(self, tensor: Any, min_val: float, max_val: float) -> Any: if self.backend == BackendType.TORCH: return tensor.clamp(min_val, max_val) elif self.backend == BackendType.JAX: return self.tensor_lib.clip(tensor, min_val, max_val) elif self.backend == BackendType.NUMBA: import numpy as np return np.clip(tensor, min_val, max_val) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def unsqueeze(self, tensor: Any, dim: int) -> Any: if self.backend == BackendType.TORCH: return tensor.unsqueeze(dim) elif self.backend == BackendType.JAX: return self.tensor_lib.expand_dims(tensor, dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.expand_dims(tensor, dim) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def expand(self, tensor: Any, *sizes: int) -> Any: if self.backend == BackendType.TORCH: return tensor.expand(*sizes) elif self.backend == BackendType.JAX: return self.tensor_lib.broadcast_to(tensor, sizes) elif self.backend == BackendType.NUMBA: import numpy as np return np.broadcast_to(tensor, sizes) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def gather(self, tensor: Any, dim: int, index: Any) -> Any: if self.backend == BackendType.TORCH: return tensor.gather(dim, index) elif self.backend == BackendType.JAX: # Use take_along_axis equivalent via jnp.take_along_axis return self.tensor_lib.take_along_axis(tensor, index, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.take_along_axis(tensor, index, axis=dim) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def squeeze(self, tensor: Any, dim: Optional[int] = None) -> Any: if self.backend == BackendType.TORCH: return tensor.squeeze(dim) elif self.backend == BackendType.JAX: # jnp.squeeze uses 'axis'; None removes all size-1 dimensions return self.tensor_lib.squeeze(tensor, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.squeeze(tensor, axis=dim) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def transpose(self, tensor: Any, *args, **kwargs) -> Any: """ Transpose a tensor. Supports signatures: - transpose(tensor) for 2D: matrix transpose; otherwise reverse axes - transpose(tensor, dim0, dim1) : swap two axes (positional) - transpose(tensor, dim0=..., dim1=...) : swap two axes (keyword) - transpose(tensor, dims=(...)) : permute by dims """ dims = kwargs.get('dims', None) dim0 = kwargs.get('dim0', None) dim1 = kwargs.get('dim1', None) # Handle positional args (dim0, dim1) if len(args) == 2: dim0, dim1 = args[0], args[1] elif len(args) > 0: raise ValueError( f"transpose expects 0 or 2 positional args, got {len(args)}") if self.backend == BackendType.TORCH: if dims is not None: return tensor.permute(dims) if dim0 is not None and dim1 is not None: return tensor.transpose(dim0, dim1) if tensor.dim() == 2: return tensor.t() return tensor.permute(tuple(reversed(range(tensor.dim())))) elif self.backend == BackendType.JAX: if dims is not None: # jnp.transpose expects axes as positional args or as a tuple # Try method call first, fall back to jnp.transpose if needed try: if isinstance(dims, tuple): return tensor.transpose(*dims) else: return tensor.transpose(dims) except (TypeError, AttributeError): # Fallback to jnp.transpose function return self.tensor_lib.transpose(tensor, dims) if dim0 is not None and dim1 is not None: axes = list(range(tensor.ndim)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] try: return tensor.transpose(*axes) except (TypeError, AttributeError): return self.tensor_lib.transpose(tensor, axes) # default: reverse axes (matrix transpose if 2D) try: return tensor.transpose() except (TypeError, AttributeError): # Fallback for edge cases axes = tuple(reversed(range(tensor.ndim))) return self.tensor_lib.transpose(tensor, axes) elif self.backend == BackendType.NUMBA: import numpy as np if dims is not None: return np.transpose(tensor, axes=dims) if dim0 is not None and dim1 is not None: axes = list(range(tensor.ndim)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return np.transpose(tensor, axes=axes) return tensor.T else: raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Linear algebra & reductions ------------------------
[docs] def matmul(self, a: Any, b: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().matmul(a, b) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.matmul(a, b) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def einsum(self, equation: str, *operands) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().einsum(equation, *operands) elif self.backend == BackendType.NUMBA: warnings.warn( "NUMBA lane doesn't support einsum fully; using fallback") return self._numba_einsum_fallback(equation, *operands) else: raise ValueError(f"Unknown backend: {self.backend}")
def _numba_einsum_fallback(self, equation: str, *operands) -> Any: import numpy as np if equation == "ij,jk->ik": return np.matmul(operands[0], operands[1]) elif equation == "i,i->": return np.sum(operands[0] * operands[1]) else: raise NotImplementedError( f"NUMBA lane doesn't support einsum pattern: {equation}" )
[docs] def sum(self, tensor: Any, dim: Optional[int] = None, keepdim: Optional[bool] = None, keepdims: Optional[bool] = False) -> Any: if keepdim is None: keepdim = bool(keepdims) if keepdims is not None else False if self.backend == BackendType.TORCH: return tensor.sum(dim=dim, keepdim=keepdim) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return lib.sum(tensor, axis=dim, keepdims=keepdim) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.sum(tensor, axis=dim, keepdims=keepdim) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def mean(self, tensor: Any, dim: Optional[int] = None, keepdim: Optional[bool] = None, keepdims: Optional[bool] = False) -> Any: if keepdim is None: keepdim = bool(keepdims) if keepdims is not None else False if self.backend == BackendType.TORCH: return tensor.mean(dim=dim, keepdim=keepdim) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return lib.mean(tensor, axis=dim, keepdims=keepdim) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.mean(tensor, axis=dim, keepdims=keepdim) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def std(self, tensor: Any, dim: Optional[int] = None, keepdims: bool = False) -> Any: if self.backend == BackendType.TORCH: return tensor.std(dim=dim, keepdim=keepdims) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return lib.std(tensor, axis=dim, keepdims=keepdims) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.std(tensor, axis=dim, keepdims=keepdims) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def median(self, tensor: Any, dim: Optional[int] = None, keepdims: bool = False) -> Any: if self.backend == BackendType.TORCH: if dim is None: return tensor.median() return tensor.median(dim=dim, keepdim=keepdims).values elif self.backend == BackendType.JAX: return self.tensor_lib.median(tensor, axis=dim, keepdims=keepdims) elif self.backend == BackendType.NUMBA: import numpy as np return np.median(tensor, axis=dim, keepdims=keepdims) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def quantile(self, tensor: Any, q: Union[float, List[float]], dim: Optional[int] = None, keepdims: bool = False) -> Any: if self.backend == BackendType.TORCH: import torch return tensor.quantile(torch.tensor(q), dim=dim, keepdim=keepdims) elif self.backend == BackendType.JAX: return self.tensor_lib.quantile(tensor, q, axis=dim, keepdims=keepdims) elif self.backend == BackendType.NUMBA: import numpy as np return np.quantile(tensor, q, axis=dim, keepdims=keepdims) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def max(self, tensor: Any, dim: Optional[int] = None, keepdims: bool = False) -> Any: if self.backend == BackendType.TORCH: if dim is None: return tensor.max() return tensor.max(dim=dim, keepdim=keepdims).values elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() if dim is None: return lib.max(tensor) return lib.max(tensor, axis=dim, keepdims=keepdims) elif self.backend == BackendType.NUMBA: import numpy as np if dim is None: return np.max(tensor) return np.max(tensor, axis=dim, keepdims=keepdims) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def min(self, tensor: Any, dim: Optional[int] = None, keepdims: bool = False) -> Any: if self.backend == BackendType.TORCH: if dim is None: return tensor.min() return tensor.min(dim=dim, keepdim=keepdims).values elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() if dim is None: return lib.min(tensor) return lib.min(tensor, axis=dim, keepdims=keepdims) elif self.backend == BackendType.NUMBA: import numpy as np if dim is None: return np.min(tensor) return np.min(tensor, axis=dim, keepdims=keepdims) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def norm(self, tensor: Any, p: float = 2, dim: Optional[int] = None) -> Any: if self.backend == BackendType.TORCH: return tensor.norm(p=p, dim=dim) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return lib.linalg.norm(tensor, ord=p, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np return np.linalg.norm(tensor, ord=p, axis=dim) else: raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ Non-linearities ------------------------
[docs] def softmax(self, tensor: Any, dim: int = -1) -> Any: if self.backend == BackendType.TORCH: return self._adapter.get_lib().softmax(tensor, dim=dim) elif self.backend == BackendType.JAX: import jax.nn as jnn return jnn.softmax(tensor, axis=dim) elif self.backend == BackendType.NUMBA: import numpy as np ex = np.exp(tensor - np.max(tensor, axis=dim, keepdims=True)) return ex / np.sum(ex, axis=dim, keepdims=True) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def relu(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: return self._adapter.get_lib().relu(tensor) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return lib.maximum(tensor, 0) elif self.backend == BackendType.NUMBA: import numpy as np return np.maximum(tensor, 0) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def sigmoid(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: return self._adapter.get_lib().sigmoid(tensor) elif self.backend == BackendType.JAX: lib = self._adapter.get_lib() return 1 / (1 + lib.exp(-tensor)) elif self.backend == BackendType.NUMBA: import numpy as np return 1 / (1 + np.exp(-tensor)) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def tanh(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().tanh(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.tanh(tensor) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def log(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().log(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.log(tensor) else: raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ Elementwise arithmetic ------------------------
[docs] def add(self, tensor1: Any, tensor2: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.add(tensor1, tensor2) elif self.backend == BackendType.NUMBA: import numpy as np return np.add(tensor1, tensor2) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def subtract(self, tensor1: Any, tensor2: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.subtract(tensor1, tensor2) elif self.backend == BackendType.NUMBA: import numpy as np return np.subtract(tensor1, tensor2) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def multiply(self, tensor1: Any, tensor2: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.multiply(tensor1, tensor2) elif self.backend == BackendType.NUMBA: import numpy as np return np.multiply(tensor1, tensor2) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def divide(self, tensor1: Any, tensor2: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self.tensor_lib.divide(tensor1, tensor2) elif self.backend == BackendType.NUMBA: import numpy as np return np.divide(tensor1, tensor2) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def power(self, tensor: Any, exponent: Union[int, float]) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): if self.backend == BackendType.TORCH: return self.tensor_lib.pow(tensor, exponent) return self.tensor_lib.power(tensor, exponent) elif self.backend == BackendType.NUMBA: import numpy as np return np.power(tensor, exponent) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def sin(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().sin(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.sin(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def cos(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().cos(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.cos(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def exp(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().exp(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.exp(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def abs(self, tensor: Any) -> Any: if self.backend in (BackendType.TORCH, BackendType.JAX): return self._adapter.get_lib().abs(tensor) elif self.backend == BackendType.NUMBA: import numpy as np return np.abs(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Randomness ------------------------
[docs] def randn(self, shape: Tuple[int, ...], **kwargs) -> Any: if self.backend == BackendType.TORCH: return self.tensor_lib.randn(*shape, **kwargs) elif self.backend == BackendType.JAX: import jax.random as random key = kwargs.pop("key", None) if key is None: raise ValueError( "JAX randn requires a PRNG key passed as key=...") return random.normal(key, shape, **kwargs) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.random.randn(*shape) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def randn_like(self, tensor: Any, **kwargs) -> Any: if self.backend == BackendType.TORCH: return self.tensor_lib.randn_like(tensor, **kwargs) elif self.backend == BackendType.JAX: import jax.random as random key = kwargs.pop("key", None) if key is None: raise ValueError( "JAX randn_like requires a PRNG key passed as key=...") return random.normal(key, tensor.shape, **kwargs) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.random.randn(*tensor.shape).astype(getattr(tensor, "dtype", _np.float64)) else: raise RuntimeError(f"Unknown backend: {self.backend}")
[docs] def dropout(self, tensor: Any, p: float = 0.5, training: bool = True, **kwargs) -> Any: if not training or p == 0: return tensor if self.backend == BackendType.TORCH: return self.tensor_lib.dropout(tensor, p=p, train=training) elif self.backend == BackendType.JAX: import jax.random as random key = kwargs.pop("key", None) if key is None: raise ValueError( "JAX dropout requires a PRNG key passed as key=...") keep_prob = 1.0 - p mask = random.bernoulli(key, keep_prob, tensor.shape) return tensor * mask / keep_prob elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() keep_prob = 1.0 - p mask = (lib.random.random(tensor.shape) < keep_prob).astype( tensor.dtype if hasattr(tensor, "dtype") else _np.float64) return tensor * mask / keep_prob else: raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ FFT ------------------------
[docs] def fft(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: import torch return torch.fft.fft(tensor) elif self.backend == BackendType.JAX: from jax.numpy import fft as jfft return jfft.fft(tensor) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.fft.fft(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
[docs] def ifft(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: import torch return torch.fft.ifft(tensor) elif self.backend == BackendType.JAX: from jax.numpy import fft as jfft return jfft.ifft(tensor) elif self.backend == BackendType.NUMBA: lib = self._adapter.get_lib() return lib.fft.ifft(tensor) else: raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Misc ------------------------
[docs] def clone(self, tensor: Any) -> Any: if self.backend == BackendType.TORCH: return tensor.clone() else: # JAX/NumPy arrays are immutable / copy-on-write; .copy() suffices return tensor.copy()
[docs] def concatenate(self, tensors: List[Any], dim: int = 0) -> Any: return self.cat(tensors, dim=dim)
# Global tensor operations instance _tensor_ops: Optional[TensorOps] = None
[docs] def get_tensor_ops(backend: Optional[BackendType] = None) -> TensorOps: """Get the global tensor operations instance (resolves AUTO safely).""" global _tensor_ops if _tensor_ops is None or (backend is not None and _tensor_ops.backend != backend): _tensor_ops = TensorOps(backend) return _tensor_ops
[docs] def create_tensor(data: Any, **kwargs) -> Any: """Create a tensor using the current backend.""" return get_tensor_ops().create_tensor(data, **kwargs)
[docs] def switch_backend(backend: BackendType) -> None: """Switch to a different backend and update tensor operations.""" from .backends import switch_backend as switch_backend_manager if switch_backend_manager(backend): global _tensor_ops _tensor_ops = None # Reset tensor ops for new backend