"""
Unified Tensor Operations for Multi-Backend Support
This module provides consistent tensor operations across PyTorch, JAX, and a
NumPy-backed "NUMBA lane" (arrays are NumPy; numba is a compiler elsewhere),
enabling seamless switching between frameworks while maintaining the same API.
"""
from typing import Optional, Union, Any, List, Tuple
from contextlib import nullcontext
import warnings
import importlib
import os
from .backends import get_backend_manager, BackendType
from .adapters import get_optimal_adapter, HighPerformanceAdapter
import numpy as _np # used as a safe NumPy namespace at construction
[docs]
class TensorOps:
"""
Unified tensor operations across different backends.
Notes:
- AUTO is resolved to a concrete, installed backend during __init__.
- NUMBA lane uses NumPy arrays (numba itself is not a tensor library).
- JAX random ops require a PRNG key; pass via kwargs (key=...).
"""
[docs]
def __init__(self, backend: Optional[Union[BackendType, str]] = None):
# Backend manager might import lazily; fall back to NumPy-only if unavailable
try:
backend_manager = get_backend_manager()
except Exception:
backend_manager = None
# Convert string backend to enum if needed
if isinstance(backend, str):
try:
backend = BackendType(backend)
except ValueError:
raise ValueError(
f"Unknown backend: {backend}. Available backends: {[b.value for b in BackendType]}")
# Resolve requested (or active) backend into an installed concrete choice,
# with sensible fallbacks.
self.backend, self.tensor_lib = self._resolve_backend(
backend, backend_manager)
# Construct optimized adapter for future delegation
self._adapter: HighPerformanceAdapter
try:
self._adapter = HighPerformanceAdapter(self.backend)
# Ensure underlying lib is loaded
_ = self._adapter.get_lib()
except Exception:
# Fallback: create adapter with current backend
self._adapter = HighPerformanceAdapter()
# ------------------------ Backend resolution ------------------------
[docs]
def _resolve_backend(self, backend: Optional[BackendType], backend_manager):
"""
Pick a concrete, installed backend with sensible fallbacks.
Priority:
1) explicit `backend` (if not AUTO) when installed
2) backend_manager.active_backend (if not AUTO) when installed
3) fallback order: TORCH -> JAX -> NUMBA (NumPy)
"""
candidates: List[BackendType] = []
# 1) explicit request (if provided and not AUTO)
if backend is not None and backend != BackendType.AUTO:
candidates.append(backend)
# 2) manager's active (if not AUTO)
ab = getattr(backend_manager, "active_backend",
None) if backend_manager is not None else None
if ab is not None and ab != BackendType.AUTO:
# Only honor manager active backend if not disabled by env
disable_map_ab = {
BackendType.TORCH: os.getenv("HPFRACC_DISABLE_TORCH", "0") == "1",
BackendType.JAX: os.getenv("HPFRACC_DISABLE_JAX", "0") == "1",
BackendType.NUMBA: os.getenv("HPFRACC_DISABLE_NUMBA", "0") == "1",
}
if not disable_map_ab.get(ab, False):
candidates.append(ab)
# 3) standard fallbacks (honor env disables)
disable_map = {
BackendType.TORCH: os.getenv("HPFRACC_DISABLE_TORCH", "0") == "1",
BackendType.JAX: os.getenv("HPFRACC_DISABLE_JAX", "0") == "1",
BackendType.NUMBA: os.getenv("HPFRACC_DISABLE_NUMBA", "0") == "1",
}
for b in (BackendType.TORCH, BackendType.JAX, BackendType.NUMBA):
if b not in candidates and not disable_map.get(b, False):
candidates.append(b)
last_err: Optional[Exception] = None
for b in candidates:
try:
lib = self._get_tensor_lib_for_backend(b)
return b, lib
except Exception as e: # torch/jax may raise RuntimeError/AttributeError on import
last_err = e
continue
# If nothing worked, default to NumPy lane unconditionally
return BackendType.NUMBA, _np
[docs]
def _get_tensor_lib_for_backend(self, backend: BackendType) -> Any:
"""Get tensor library for a specific backend (imports guarded)."""
if backend == BackendType.TORCH:
torch = importlib.import_module("torch")
return torch
elif backend == BackendType.JAX:
jnp = importlib.import_module("jax.numpy")
return jnp
elif backend == BackendType.NUMBA:
# Use NumPy namespace for arrays/ops; numba is a compiler elsewhere.
return _np
else:
# For constructor edge-cases, fall back to TORCH
torch = importlib.import_module("torch")
return torch
# ------------------------ Creation / conversion ------------------------
[docs]
def create_tensor(self, data: Any, **kwargs) -> Any:
"""Create a tensor in the current backend."""
# Filter backend-specific args where necessary
if self.backend == BackendType.TORCH:
# Remove requires_grad from kwargs if it's False (default behavior)
torch_kwargs = kwargs.copy()
if 'requires_grad' in torch_kwargs and not torch_kwargs['requires_grad']:
del torch_kwargs['requires_grad']
# Normalize dtype/device from strings
dtype = torch_kwargs.get('dtype', None)
device = torch_kwargs.get('device', None)
torch = self.tensor_lib
if isinstance(dtype, str):
dtype_map = {
'float32': torch.float32,
'float': torch.float32,
'fp32': torch.float32,
'float64': torch.float64,
'double': torch.float64,
'fp64': torch.float64,
'float16': torch.float16,
'half': torch.float16,
'fp16': torch.float16,
'bfloat16': getattr(torch, 'bfloat16', None),
'bf16': getattr(torch, 'bfloat16', None),
'int64': torch.int64,
'long': torch.int64,
'int32': torch.int32,
'int': torch.int32,
'int16': torch.int16,
'short': torch.int16,
'int8': torch.int8,
'uint8': torch.uint8,
'bool': torch.bool,
}
mapped = dtype_map.get(dtype.lower())
if mapped is not None:
torch_kwargs['dtype'] = mapped
if isinstance(device, str):
torch_kwargs['device'] = torch.device(device)
return self.tensor_lib.tensor(data, **torch_kwargs)
elif self.backend == BackendType.JAX:
# JAX doesn't support requires_grad
jax_kwargs = {k: v for k, v in kwargs.items() if k !=
'requires_grad'}
return self.tensor_lib.array(data, **jax_kwargs)
elif self.backend == BackendType.NUMBA:
# NUMBA lane: remove requires_grad; arrays are NumPy
nb_kwargs = {k: v for k, v in kwargs.items() if k !=
'requires_grad'}
return self.tensor_lib.array(data, **nb_kwargs)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def tensor(self, data: Any, **kwargs) -> Any:
"""Alias for create_tensor."""
return self.create_tensor(data, **kwargs)
[docs]
def from_numpy(self, array: Any) -> Any:
if self.backend == BackendType.TORCH:
torch = self.tensor_lib
return torch.from_numpy(array)
elif self.backend == BackendType.JAX:
jnp = self.tensor_lib
return jnp.array(array)
elif self.backend == BackendType.NUMBA:
return array
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def to_numpy(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
return tensor.detach().cpu().numpy()
elif self.backend == BackendType.JAX:
import jax
import numpy as np
return np.asarray(jax.device_get(tensor))
elif self.backend == BackendType.NUMBA:
return tensor
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def no_grad(self):
"""
Context manager for disabling gradient computation.
- PyTorch: torch.no_grad()
- JAX: there is no true 'no_grad' context; we return a nullcontext().
Use jax.lax.stop_gradient at call sites if you need it.
- NUMBA lane: nullcontext()
"""
if self.backend == BackendType.TORCH:
torch = self.tensor_lib
return torch.no_grad()
elif self.backend == BackendType.JAX:
# JAX doesn't have a no_grad context manager - it uses functional programming
# Return nullcontext as documented
return nullcontext()
elif self.backend == BackendType.NUMBA:
return nullcontext()
else:
raise RuntimeError("Unknown backend")
# ------------------------ Array constructors ------------------------
[docs]
def zeros(self, shape: Tuple[int, ...], **kwargs) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.zeros(shape, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.zeros(shape, **kwargs)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def ones(self, shape: Tuple[int, ...], **kwargs) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.ones(shape, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.ones(shape, **kwargs)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def eye(self, n: int, **kwargs) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.eye(n, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.eye(n, **kwargs)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def arange(self, start: int, end: int, step: int = 1, **kwargs) -> Any:
if self.backend == BackendType.TORCH:
# Default dtype to float32 to satisfy tests unless provided
import torch
if 'dtype' not in kwargs:
kwargs['dtype'] = torch.float32
return self.tensor_lib.arange(start, end, step, **kwargs)
elif self.backend == BackendType.JAX:
return self.tensor_lib.arange(start, end, step, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.arange(start, end, step, **kwargs)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def linspace(self, start: float, end: float, num: int, **kwargs) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.linspace(start, end, num, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.linspace(start, end, num, **kwargs)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def zeros_like(self, tensor: Any, **kwargs) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.zeros_like(tensor, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
if hasattr(tensor, 'shape'):
return np.zeros_like(tensor, **kwargs)
return np.zeros(1, **kwargs)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def ones_like(self, tensor: Any, **kwargs) -> Any:
"""Create a tensor of ones with the same shape as input tensor."""
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.ones_like(tensor, **kwargs)
elif self.backend == BackendType.NUMBA:
import numpy as np
if hasattr(tensor, 'shape'):
return np.ones_like(tensor, **kwargs)
return np.ones(1, **kwargs)
else:
raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Basic transforms ------------------------
[docs]
def sqrt(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.sqrt(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.sqrt(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def stack(self, tensors: List[Any], dim: int = 0) -> Any:
if self.backend == BackendType.TORCH:
return self.tensor_lib.stack(tensors, dim=dim)
elif self.backend == BackendType.JAX:
return self.tensor_lib.stack(tensors, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.stack(tensors, axis=dim)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def cat(self, tensors: List[Any], dim: int = 0) -> Any:
if self.backend == BackendType.TORCH:
return self.tensor_lib.cat(tensors, dim=dim)
elif self.backend == BackendType.JAX:
return self.tensor_lib.concatenate(tensors, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.concatenate(tensors, axis=dim)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def reshape(self, tensor: Any, shape: Tuple[int, ...]) -> Any:
return tensor.reshape(shape)
[docs]
def repeat(self, tensor: Any,
repeats: Union[int, Tuple[int, ...]], dim: int = 0) -> Any:
"""
Repeat elements along a specified axis (element-wise repeat).
For tiling the whole array shape, use `tile(...)` helper below.
"""
if self.backend == BackendType.TORCH:
import torch
if isinstance(repeats, int):
if hasattr(tensor, 'dim'):
rank = tensor.dim()
if rank == 0:
return torch.repeat_interleave(tensor, repeats, dim=0)
if dim >= rank:
# Treat as tiling across axes when dim is out-of-range
if rank == 1:
# Build a 2D tile with shape (repeats*L, dim*L)
L = tensor.shape[0]
row = tensor.repeat(dim)
return row.unsqueeze(0).repeat(repeats * L, 1)
if rank == 2:
return tensor.repeat(repeats, repeats)
reps = [1] * (rank - 1) + [repeats]
return tensor.repeat(*reps)
valid_dim = max(min(dim, rank - 1), -rank)
return torch.repeat_interleave(tensor, repeats, dim=valid_dim)
# Fallback: interleave along dim 0
return torch.repeat_interleave(tensor, repeats, dim=0)
# tuple/sequence: tile
return tensor.repeat(*repeats)
elif self.backend == BackendType.JAX:
return self.tensor_lib.repeat(tensor, repeats, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.repeat(tensor, repeats, axis=dim)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def tile(self, tensor: Any, reps: Union[int, Tuple[int, ...]]) -> Any:
"""Tile (broadcast repeat) tensor like np.tile / torch.repeat(*reps)."""
if self.backend == BackendType.TORCH:
return tensor.repeat(*((reps,) if isinstance(reps, int) else reps))
elif self.backend == BackendType.JAX:
# Try to use jnp.tile if available (modern JAX versions have it)
try:
return self.tensor_lib.tile(tensor, reps)
except AttributeError:
# Fallback for older JAX versions
import numpy as np
return self.tensor_lib.array(np.tile(self.to_numpy(tensor), reps))
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.tile(tensor, reps)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def clip(self, tensor: Any, min_val: float, max_val: float) -> Any:
if self.backend == BackendType.TORCH:
return tensor.clamp(min_val, max_val)
elif self.backend == BackendType.JAX:
return self.tensor_lib.clip(tensor, min_val, max_val)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.clip(tensor, min_val, max_val)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def unsqueeze(self, tensor: Any, dim: int) -> Any:
if self.backend == BackendType.TORCH:
return tensor.unsqueeze(dim)
elif self.backend == BackendType.JAX:
return self.tensor_lib.expand_dims(tensor, dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.expand_dims(tensor, dim)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def expand(self, tensor: Any, *sizes: int) -> Any:
if self.backend == BackendType.TORCH:
return tensor.expand(*sizes)
elif self.backend == BackendType.JAX:
return self.tensor_lib.broadcast_to(tensor, sizes)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.broadcast_to(tensor, sizes)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def gather(self, tensor: Any, dim: int, index: Any) -> Any:
if self.backend == BackendType.TORCH:
return tensor.gather(dim, index)
elif self.backend == BackendType.JAX:
# Use take_along_axis equivalent via jnp.take_along_axis
return self.tensor_lib.take_along_axis(tensor, index, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.take_along_axis(tensor, index, axis=dim)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def squeeze(self, tensor: Any, dim: Optional[int] = None) -> Any:
if self.backend == BackendType.TORCH:
return tensor.squeeze(dim)
elif self.backend == BackendType.JAX:
# jnp.squeeze uses 'axis'; None removes all size-1 dimensions
return self.tensor_lib.squeeze(tensor, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.squeeze(tensor, axis=dim)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def transpose(self, tensor: Any, *args, **kwargs) -> Any:
"""
Transpose a tensor. Supports signatures:
- transpose(tensor) for 2D: matrix transpose; otherwise reverse axes
- transpose(tensor, dim0, dim1) : swap two axes (positional)
- transpose(tensor, dim0=..., dim1=...) : swap two axes (keyword)
- transpose(tensor, dims=(...)) : permute by dims
"""
dims = kwargs.get('dims', None)
dim0 = kwargs.get('dim0', None)
dim1 = kwargs.get('dim1', None)
# Handle positional args (dim0, dim1)
if len(args) == 2:
dim0, dim1 = args[0], args[1]
elif len(args) > 0:
raise ValueError(
f"transpose expects 0 or 2 positional args, got {len(args)}")
if self.backend == BackendType.TORCH:
if dims is not None:
return tensor.permute(dims)
if dim0 is not None and dim1 is not None:
return tensor.transpose(dim0, dim1)
if tensor.dim() == 2:
return tensor.t()
return tensor.permute(tuple(reversed(range(tensor.dim()))))
elif self.backend == BackendType.JAX:
if dims is not None:
# jnp.transpose expects axes as positional args or as a tuple
# Try method call first, fall back to jnp.transpose if needed
try:
if isinstance(dims, tuple):
return tensor.transpose(*dims)
else:
return tensor.transpose(dims)
except (TypeError, AttributeError):
# Fallback to jnp.transpose function
return self.tensor_lib.transpose(tensor, dims)
if dim0 is not None and dim1 is not None:
axes = list(range(tensor.ndim))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
try:
return tensor.transpose(*axes)
except (TypeError, AttributeError):
return self.tensor_lib.transpose(tensor, axes)
# default: reverse axes (matrix transpose if 2D)
try:
return tensor.transpose()
except (TypeError, AttributeError):
# Fallback for edge cases
axes = tuple(reversed(range(tensor.ndim)))
return self.tensor_lib.transpose(tensor, axes)
elif self.backend == BackendType.NUMBA:
import numpy as np
if dims is not None:
return np.transpose(tensor, axes=dims)
if dim0 is not None and dim1 is not None:
axes = list(range(tensor.ndim))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return np.transpose(tensor, axes=axes)
return tensor.T
else:
raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Linear algebra & reductions ------------------------
[docs]
def matmul(self, a: Any, b: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().matmul(a, b)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.matmul(a, b)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def einsum(self, equation: str, *operands) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().einsum(equation, *operands)
elif self.backend == BackendType.NUMBA:
warnings.warn(
"NUMBA lane doesn't support einsum fully; using fallback")
return self._numba_einsum_fallback(equation, *operands)
else:
raise ValueError(f"Unknown backend: {self.backend}")
def _numba_einsum_fallback(self, equation: str, *operands) -> Any:
import numpy as np
if equation == "ij,jk->ik":
return np.matmul(operands[0], operands[1])
elif equation == "i,i->":
return np.sum(operands[0] * operands[1])
else:
raise NotImplementedError(
f"NUMBA lane doesn't support einsum pattern: {equation}"
)
[docs]
def sum(self, tensor: Any, dim: Optional[int] = None,
keepdim: Optional[bool] = None, keepdims: Optional[bool] = False) -> Any:
if keepdim is None:
keepdim = bool(keepdims) if keepdims is not None else False
if self.backend == BackendType.TORCH:
return tensor.sum(dim=dim, keepdim=keepdim)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return lib.sum(tensor, axis=dim, keepdims=keepdim)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.sum(tensor, axis=dim, keepdims=keepdim)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def mean(self, tensor: Any, dim: Optional[int] = None,
keepdim: Optional[bool] = None, keepdims: Optional[bool] = False) -> Any:
if keepdim is None:
keepdim = bool(keepdims) if keepdims is not None else False
if self.backend == BackendType.TORCH:
return tensor.mean(dim=dim, keepdim=keepdim)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return lib.mean(tensor, axis=dim, keepdims=keepdim)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.mean(tensor, axis=dim, keepdims=keepdim)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def std(self, tensor: Any, dim: Optional[int] = None,
keepdims: bool = False) -> Any:
if self.backend == BackendType.TORCH:
return tensor.std(dim=dim, keepdim=keepdims)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return lib.std(tensor, axis=dim, keepdims=keepdims)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.std(tensor, axis=dim, keepdims=keepdims)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def quantile(self, tensor: Any, q: Union[float, List[float]],
dim: Optional[int] = None, keepdims: bool = False) -> Any:
if self.backend == BackendType.TORCH:
import torch
return tensor.quantile(torch.tensor(q), dim=dim, keepdim=keepdims)
elif self.backend == BackendType.JAX:
return self.tensor_lib.quantile(tensor, q, axis=dim, keepdims=keepdims)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.quantile(tensor, q, axis=dim, keepdims=keepdims)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def max(self, tensor: Any, dim: Optional[int] = None,
keepdims: bool = False) -> Any:
if self.backend == BackendType.TORCH:
if dim is None:
return tensor.max()
return tensor.max(dim=dim, keepdim=keepdims).values
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
if dim is None:
return lib.max(tensor)
return lib.max(tensor, axis=dim, keepdims=keepdims)
elif self.backend == BackendType.NUMBA:
import numpy as np
if dim is None:
return np.max(tensor)
return np.max(tensor, axis=dim, keepdims=keepdims)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def min(self, tensor: Any, dim: Optional[int] = None,
keepdims: bool = False) -> Any:
if self.backend == BackendType.TORCH:
if dim is None:
return tensor.min()
return tensor.min(dim=dim, keepdim=keepdims).values
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
if dim is None:
return lib.min(tensor)
return lib.min(tensor, axis=dim, keepdims=keepdims)
elif self.backend == BackendType.NUMBA:
import numpy as np
if dim is None:
return np.min(tensor)
return np.min(tensor, axis=dim, keepdims=keepdims)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def norm(self, tensor: Any, p: float = 2,
dim: Optional[int] = None) -> Any:
if self.backend == BackendType.TORCH:
return tensor.norm(p=p, dim=dim)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return lib.linalg.norm(tensor, ord=p, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.linalg.norm(tensor, ord=p, axis=dim)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ Non-linearities ------------------------
[docs]
def softmax(self, tensor: Any, dim: int = -1) -> Any:
if self.backend == BackendType.TORCH:
return self._adapter.get_lib().softmax(tensor, dim=dim)
elif self.backend == BackendType.JAX:
import jax.nn as jnn
return jnn.softmax(tensor, axis=dim)
elif self.backend == BackendType.NUMBA:
import numpy as np
ex = np.exp(tensor - np.max(tensor, axis=dim, keepdims=True))
return ex / np.sum(ex, axis=dim, keepdims=True)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def relu(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
return self._adapter.get_lib().relu(tensor)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return lib.maximum(tensor, 0)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.maximum(tensor, 0)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def sigmoid(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
return self._adapter.get_lib().sigmoid(tensor)
elif self.backend == BackendType.JAX:
lib = self._adapter.get_lib()
return 1 / (1 + lib.exp(-tensor))
elif self.backend == BackendType.NUMBA:
import numpy as np
return 1 / (1 + np.exp(-tensor))
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def tanh(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().tanh(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.tanh(tensor)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def log(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().log(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.log(tensor)
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ Elementwise arithmetic ------------------------
[docs]
def add(self, tensor1: Any, tensor2: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.add(tensor1, tensor2)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.add(tensor1, tensor2)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def subtract(self, tensor1: Any, tensor2: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.subtract(tensor1, tensor2)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.subtract(tensor1, tensor2)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def multiply(self, tensor1: Any, tensor2: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.multiply(tensor1, tensor2)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.multiply(tensor1, tensor2)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def divide(self, tensor1: Any, tensor2: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self.tensor_lib.divide(tensor1, tensor2)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.divide(tensor1, tensor2)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def power(self, tensor: Any, exponent: Union[int, float]) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
if self.backend == BackendType.TORCH:
return self.tensor_lib.pow(tensor, exponent)
return self.tensor_lib.power(tensor, exponent)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.power(tensor, exponent)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def sin(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().sin(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.sin(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def cos(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().cos(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.cos(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def exp(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().exp(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.exp(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def abs(self, tensor: Any) -> Any:
if self.backend in (BackendType.TORCH, BackendType.JAX):
return self._adapter.get_lib().abs(tensor)
elif self.backend == BackendType.NUMBA:
import numpy as np
return np.abs(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Randomness ------------------------
[docs]
def randn(self, shape: Tuple[int, ...], **kwargs) -> Any:
if self.backend == BackendType.TORCH:
return self.tensor_lib.randn(*shape, **kwargs)
elif self.backend == BackendType.JAX:
import jax.random as random
key = kwargs.pop("key", None)
if key is None:
raise ValueError(
"JAX randn requires a PRNG key passed as key=...")
return random.normal(key, shape, **kwargs)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.random.randn(*shape)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def randn_like(self, tensor: Any, **kwargs) -> Any:
if self.backend == BackendType.TORCH:
return self.tensor_lib.randn_like(tensor, **kwargs)
elif self.backend == BackendType.JAX:
import jax.random as random
key = kwargs.pop("key", None)
if key is None:
raise ValueError(
"JAX randn_like requires a PRNG key passed as key=...")
return random.normal(key, tensor.shape, **kwargs)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.random.randn(*tensor.shape).astype(getattr(tensor, "dtype", _np.float64))
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
[docs]
def dropout(self, tensor: Any, p: float = 0.5, training: bool = True, **kwargs) -> Any:
if not training or p == 0:
return tensor
if self.backend == BackendType.TORCH:
return self.tensor_lib.dropout(tensor, p=p, train=training)
elif self.backend == BackendType.JAX:
import jax.random as random
key = kwargs.pop("key", None)
if key is None:
raise ValueError(
"JAX dropout requires a PRNG key passed as key=...")
keep_prob = 1.0 - p
mask = random.bernoulli(key, keep_prob, tensor.shape)
return tensor * mask / keep_prob
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
keep_prob = 1.0 - p
mask = (lib.random.random(tensor.shape) < keep_prob).astype(
tensor.dtype if hasattr(tensor, "dtype") else _np.float64)
return tensor * mask / keep_prob
else:
raise RuntimeError(f"Unknown backend: {self.backend}")
# ------------------------ FFT ------------------------
[docs]
def fft(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
import torch
return torch.fft.fft(tensor)
elif self.backend == BackendType.JAX:
from jax.numpy import fft as jfft
return jfft.fft(tensor)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.fft.fft(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
[docs]
def ifft(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
import torch
return torch.fft.ifft(tensor)
elif self.backend == BackendType.JAX:
from jax.numpy import fft as jfft
return jfft.ifft(tensor)
elif self.backend == BackendType.NUMBA:
lib = self._adapter.get_lib()
return lib.fft.ifft(tensor)
else:
raise ValueError(f"Unknown backend: {self.backend}")
# ------------------------ Misc ------------------------
[docs]
def clone(self, tensor: Any) -> Any:
if self.backend == BackendType.TORCH:
return tensor.clone()
else:
# JAX/NumPy arrays are immutable / copy-on-write; .copy() suffices
return tensor.copy()
[docs]
def concatenate(self, tensors: List[Any], dim: int = 0) -> Any:
return self.cat(tensors, dim=dim)
# Global tensor operations instance
_tensor_ops: Optional[TensorOps] = None
[docs]
def get_tensor_ops(backend: Optional[BackendType] = None) -> TensorOps:
"""Get the global tensor operations instance (resolves AUTO safely)."""
global _tensor_ops
if _tensor_ops is None or (backend is not None and _tensor_ops.backend != backend):
_tensor_ops = TensorOps(backend)
return _tensor_ops
[docs]
def create_tensor(data: Any, **kwargs) -> Any:
"""Create a tensor using the current backend."""
return get_tensor_ops().create_tensor(data, **kwargs)
[docs]
def switch_backend(backend: BackendType) -> None:
"""Switch to a different backend and update tensor operations."""
from .backends import switch_backend as switch_backend_manager
if switch_backend_manager(backend):
global _tensor_ops
_tensor_ops = None # Reset tensor ops for new backend