JAX GPU Setup for HPFRACC Library๏
This document explains how JAX GPU support is configured in the HPFRACC library.
Current Status๏
PyTorch GPU: โ Fully supported - RTX 5070 detected and working with CUDA 12.8
JAX GPU: โ Fully supported - RTX 5070 detected and working with CUDA 12
Automatic detection: โ Configured - Will use GPU when available
CuDNN: โ Compatible - CuDNN 9.12.0+ recommended for JAX 0.8.0
Installation๏
Recommended Installation for GPU Support๏
For optimal GPU performance with JAX and PyTorch compatibility:
# Install PyTorch with CUDA 12.8 first
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu128
# Then install JAX with CUDA 12 support
pip install --upgrade "jax[cuda12]"
# Install HPFRACC with GPU extras
pip install hpfracc[gpu]
Important Notes:
JAXโs CUDA 12 wheels are built with CUDA 12.3 but are compatible with CUDA โฅ12.1 (including CUDA 12.8)
CUDA libraries are backward compatible, so JAX will work with PyTorchโs CUDA 12.8 installation
Ensure CuDNN 9.12.0+ is installed for JAX 0.8.0 compatibility
If you have conda-installed CuDNN that conflicts, use
scripts/setup_jax_gpu_env.shto configure library paths
CuDNN Compatibility๏
If you encounter CuDNN version mismatch errors:
Upgrade CuDNN to 9.12.0+:
pip install --upgrade "nvidia-cudnn-cu12>=9.12.0"
Configure library paths (if conda CuDNN conflicts):
source scripts/setup_jax_gpu_env.sh
Verify installation:
python -c "import jax; print(jax.devices()); print(jax.default_backend())"
How It Works๏
The HPFRACC library automatically configures JAX to use GPU when available:
Auto-detection: On import,
hpfracc.jax_gpu_setupautomatically detects GPU availabilityLibrary path setup: Automatically prioritizes pip-installed CuDNN over condaโs older versions
Environment setup: Configures
LD_LIBRARY_PATHto find correct CuDNN librariesGraceful fallback: Falls back to CPU when GPU is not supported
No user intervention: Works automatically without any configuration needed
Usage๏
Simply import HPFRACC modules that use JAX - no additional setup required:
from hpfracc.jax_gpu_setup import JAX_GPU_AVAILABLE
import hpfracc.ml.probabilistic_fractional_orders # Uses JAX
print(f"JAX GPU available: {JAX_GPU_AVAILABLE}")
GPU Support Status๏
RTX 5070 (Current GPU)๏
PyTorch: โ Fully supported with CUDA 12.8
JAX: โ Fully supported with CUDA 12
CUDA Compatibility: โ JAX CUDA 12 wheels compatible with CUDA 12.8
CuDNN: โ 9.12.0+ recommended for JAX 0.8.0
Status: Complete GPU acceleration available
CUDA Version Compatibility๏
Component |
CUDA Version |
Status |
|---|---|---|
PyTorch |
12.8 |
โ Fully supported |
JAX |
12.3 (wheels) โ 12.8 (runtime) |
โ Compatible |
CuDNN |
9.12.0+ |
โ Recommended |
Key Point: JAXโs CUDA 12 wheels are built with CUDA 12.3 but work with CUDA โฅ12.1, including 12.8. This ensures compatibility between JAX and PyTorch installations.
Future GPU Support๏
When JAX adds support for newer GPUs, the library will automatically detect and use them without any code changes.
Performance Impact๏
PyTorch operations: Full GPU acceleration (8GB VRAM)
JAX operations: Full GPU acceleration (8GB VRAM)
Mixed workloads: Optimal performance through PyTorch GPU + JAX GPU
Troubleshooting๏
If you encounter issues:
Check GPU detection: Run
python jax_gpu_config.pyVerify PyTorch GPU: Run
python -c "import torch; print(torch.cuda.is_available())"Check JAX status: Run
python -c "from hpfracc.jax_gpu_setup import get_jax_info; print(get_jax_info())"
Technical Details๏
JAX version: 0.8.0 (compatible with NumPy 2.3+)
JAXlib version: 0.8.0
CUDA support: CUDA 12 (wheels built with 12.3, compatible with โฅ12.1 including 12.8)
CuDNN: 9.12.0+ required for JAX 0.8.0
PyTorch CUDA: 12.8 (compatible with JAX CUDA 12)
Environment variables: Automatically configured for optimal library resolution
GPU acceleration: Full RTX 5070 support with 8GB VRAM
Library path management: Automatic prioritization of pip-installed CuDNN
Troubleshooting CuDNN Issues๏
If you see CuDNN version mismatch errors:
Upgrade CuDNN:
pip install --upgrade "nvidia-cudnn-cu12>=9.12.0"
Use setup script (if conda CuDNN conflicts):
source scripts/setup_jax_gpu_env.sh
Manual library path (if needed):
export LD_LIBRARY_PATH=$(python3 -c "import site; print(site.getsitepackages()[0])")/nvidia/cudnn/lib:$LD_LIBRARY_PATH
Verify installation:
python -c "from hpfracc.jax_gpu_setup import get_jax_info; import json; print(json.dumps(get_jax_info(), indent=2))"