JAX GPU Setup for HPFRACC Library๏ƒ

This document explains how JAX GPU support is configured in the HPFRACC library.

Current Status๏ƒ

  • PyTorch GPU: โœ… Fully supported - RTX 5070 detected and working with CUDA 12.8

  • JAX GPU: โœ… Fully supported - RTX 5070 detected and working with CUDA 12

  • Automatic detection: โœ… Configured - Will use GPU when available

  • CuDNN: โœ… Compatible - CuDNN 9.12.0+ recommended for JAX 0.8.0

Installation๏ƒ

CuDNN Compatibility๏ƒ

If you encounter CuDNN version mismatch errors:

  1. Upgrade CuDNN to 9.12.0+:

    pip install --upgrade "nvidia-cudnn-cu12>=9.12.0"
    
  2. Configure library paths (if conda CuDNN conflicts):

    source scripts/setup_jax_gpu_env.sh
    
  3. Verify installation:

    python -c "import jax; print(jax.devices()); print(jax.default_backend())"
    

How It Works๏ƒ

The HPFRACC library automatically configures JAX to use GPU when available:

  1. Auto-detection: On import, hpfracc.jax_gpu_setup automatically detects GPU availability

  2. Library path setup: Automatically prioritizes pip-installed CuDNN over condaโ€™s older versions

  3. Environment setup: Configures LD_LIBRARY_PATH to find correct CuDNN libraries

  4. Graceful fallback: Falls back to CPU when GPU is not supported

  5. No user intervention: Works automatically without any configuration needed

Usage๏ƒ

Simply import HPFRACC modules that use JAX - no additional setup required:

from hpfracc.jax_gpu_setup import JAX_GPU_AVAILABLE
import hpfracc.ml.probabilistic_fractional_orders  # Uses JAX

print(f"JAX GPU available: {JAX_GPU_AVAILABLE}")

GPU Support Status๏ƒ

RTX 5070 (Current GPU)๏ƒ

  • PyTorch: โœ… Fully supported with CUDA 12.8

  • JAX: โœ… Fully supported with CUDA 12

  • CUDA Compatibility: โœ… JAX CUDA 12 wheels compatible with CUDA 12.8

  • CuDNN: โœ… 9.12.0+ recommended for JAX 0.8.0

  • Status: Complete GPU acceleration available

CUDA Version Compatibility๏ƒ

Component

CUDA Version

Status

PyTorch

12.8

โœ… Fully supported

JAX

12.3 (wheels) โ†’ 12.8 (runtime)

โœ… Compatible

CuDNN

9.12.0+

โœ… Recommended

Key Point: JAXโ€™s CUDA 12 wheels are built with CUDA 12.3 but work with CUDA โ‰ฅ12.1, including 12.8. This ensures compatibility between JAX and PyTorch installations.

Future GPU Support๏ƒ

When JAX adds support for newer GPUs, the library will automatically detect and use them without any code changes.

Performance Impact๏ƒ

  • PyTorch operations: Full GPU acceleration (8GB VRAM)

  • JAX operations: Full GPU acceleration (8GB VRAM)

  • Mixed workloads: Optimal performance through PyTorch GPU + JAX GPU

Troubleshooting๏ƒ

If you encounter issues:

  1. Check GPU detection: Run python jax_gpu_config.py

  2. Verify PyTorch GPU: Run python -c "import torch; print(torch.cuda.is_available())"

  3. Check JAX status: Run python -c "from hpfracc.jax_gpu_setup import get_jax_info; print(get_jax_info())"

Technical Details๏ƒ

  • JAX version: 0.8.0 (compatible with NumPy 2.3+)

  • JAXlib version: 0.8.0

  • CUDA support: CUDA 12 (wheels built with 12.3, compatible with โ‰ฅ12.1 including 12.8)

  • CuDNN: 9.12.0+ required for JAX 0.8.0

  • PyTorch CUDA: 12.8 (compatible with JAX CUDA 12)

  • Environment variables: Automatically configured for optimal library resolution

  • GPU acceleration: Full RTX 5070 support with 8GB VRAM

  • Library path management: Automatic prioritization of pip-installed CuDNN

Troubleshooting CuDNN Issues๏ƒ

If you see CuDNN version mismatch errors:

  1. Upgrade CuDNN:

    pip install --upgrade "nvidia-cudnn-cu12>=9.12.0"
    
  2. Use setup script (if conda CuDNN conflicts):

    source scripts/setup_jax_gpu_env.sh
    
  3. Manual library path (if needed):

    export LD_LIBRARY_PATH=$(python3 -c "import site; print(site.getsitepackages()[0])")/nvidia/cudnn/lib:$LD_LIBRARY_PATH
    
  4. Verify installation:

    python -c "from hpfracc.jax_gpu_setup import get_jax_info; import json; print(json.dumps(get_jax_info(), indent=2))"