Source code for hpfracc.ml.tensor_ops

from .base import TensorOps
from .torch_ops import TorchTensorOps
from .jax_ops import JaxTensorOps
from .numpy_ops import NumpyTensorOps
from hpfracc.ml.backends import get_backend_manager, BackendType
import warnings
from typing import Optional, Any

[docs] def get_tensor_ops(backend: Optional[BackendType] = None) -> TensorOps: """ Factory function to get the appropriate TensorOps implementation. """ if backend is None: manager = get_backend_manager() backend = manager.active_backend if manager is not None else BackendType.NUMBA if backend == BackendType.AUTO: manager = get_backend_manager() backend = manager.active_backend if manager is not None else BackendType.NUMBA if backend == BackendType.TORCH: return TorchTensorOps() elif backend == BackendType.JAX: return JaxTensorOps() elif backend == BackendType.NUMBA: # Numba typically uses NumPy arrays for storage return NumpyTensorOps() else: # Fallback or error raise ValueError(f"Unknown backend or backend not supported: {backend}")
[docs] def create_tensor(data: Any, *args, **kwargs) -> Any: """Wrapper to create tensor using active backend""" return get_tensor_ops().create_tensor(data, *args, **kwargs)
[docs] def switch_backend(backend: BackendType) -> bool: """Wrapper to switch backend""" return get_backend_manager().switch_backend(backend)
# For backward compatibility with wildcard imports __all__ = ["get_tensor_ops", "TensorOps", "create_tensor", "switch_backend"]