Source code for hpfracc.special.mittag_leffler

"""
Optimized Mittag-Leffler function for fractional calculus.

This module provides high-performance implementations of the Mittag-Leffler function,
specifically optimized for fractional calculus applications including Atangana-Baleanu
derivatives and other high-performance use cases.
"""

import numpy as np
from typing import Union, Optional

# Simplified JAX import
try:
    import jax
    import jax.numpy as jnp
    from jax.config import config
    config.update("jax_enable_x64", True)
    JAX_AVAILABLE = True
except ImportError:
    jax = None
    jnp = None
    JAX_AVAILABLE = False

# Optional numba import
try:
    from numba import jit, prange
    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

    def jit(*args, **kwargs):
        def decorator(func):
            return func
        return decorator

    def prange(*args, **kwargs):
        return range(*args, **kwargs)

from .gamma_beta import gamma

from scipy.special import gammaln


class MittagLefflerFunction:
    """
    High-performance Mittag-Leffler function implementation.

    Features:
    - Fast evaluation for negative arguments (common in fractional calculus)
    - Vectorized operations for array inputs
    - Adaptive convergence criteria
    - Caching for repeated evaluations
    - Specialized optimizations for Atangana-Baleanu derivatives
    """

    def __init__(
        self,
        use_jax: bool = False,
        use_numba: bool = False,  # Disabled by default due to compilation issues
        cache_size: int = 1000,
        adaptive_convergence: bool = True
    ):
        """
        Initialize optimized Mittag-Leffler function.

        Args:
            use_jax: Use JAX acceleration if available
            use_numba: Use Numba JIT compilation (disabled by default due to issues)
            cache_size: Size of LRU cache for repeated evaluations
            adaptive_convergence: Use adaptive convergence criteria
        """
        self.use_jax = use_jax and JAX_AVAILABLE
        self.use_numba = False  # Force disable Numba due to compilation issues
        self.adaptive_convergence = adaptive_convergence

        # Initialize cache
        self._cache = {}
        self._cache_size = cache_size

        # Precompute common gamma values for caching
        self._gamma_cache = {}

    def compute(
        self,
        z: Union[float, np.ndarray],
        alpha: float,
        beta: float = 1.0,
        max_terms: Optional[int] = None,
        tolerance: float = 1e-12
    ) -> Union[float, np.ndarray]:
        """
        Compute the Mittag-Leffler function E_α,β(z).

        Args:
            z: Input value(s)
            alpha: First parameter
            beta: Second parameter  
            max_terms: Maximum number of terms (auto if None)
            tolerance: Convergence tolerance

        Returns:
            Mittag-Leffler function value(s)
        """
        # Handle special cases first
        if alpha == 1.0 and beta == 1.0:
            return np.exp(z)
        elif alpha == 2.0 and beta == 1.0:
            if np.isscalar(z):
                return np.cos(np.sqrt(-z)) if z <= 0 else np.cosh(np.sqrt(z))
            else:
                return np.where(z <= 0, np.cos(np.sqrt(-z)), np.cosh(np.sqrt(z)))
        elif alpha == 2.0 and beta == 2.0:
            if np.isscalar(z):
                return 1.0 if z == 0 else np.sin(np.sqrt(z)) / np.sqrt(z)
            else:
                return np.where(z == 0, 1.0, np.sin(np.sqrt(z)) / np.sqrt(z))

        # Determine optimal method based on input type and size
        if np.isscalar(z):
            return self._compute_scalar(z, alpha, beta, max_terms, tolerance)
        else:
            return self._compute_array(z, alpha, beta, max_terms, tolerance)

    def _compute_scalar(
        self,
        z: float,
        alpha: float,
        beta: float,
        max_terms: Optional[int],
        tolerance: float
    ) -> float:
        """Compute Mittag-Leffler function for scalar input."""
        # Check cache first
        cache_key = (z, alpha, beta, tolerance)
        if cache_key in self._cache:
            return self._cache[cache_key]

        # Determine optimal max_terms if not provided
        if max_terms is None:
            max_terms = self._get_optimal_max_terms(z, alpha, beta)

        # Ensure max_terms is not None
        if max_terms is None:
            max_terms = 100

        # Choose computation method
        if self.use_jax and JAX_AVAILABLE:
            try:
                result = self._compute_jax_scalar(
                    z, alpha, beta, max_terms, tolerance)
            except Exception:
                result = self._compute_numba_scalar(
                    z, alpha, beta, max_terms, tolerance)
        elif self.use_numba:
            result = self._compute_numba_scalar(
                z, alpha, beta, max_terms, tolerance)
        else:
            result = self._compute_python_scalar(
                z, alpha, beta, max_terms, tolerance)

        # Cache result
        if len(self._cache) < self._cache_size:
            self._cache[cache_key] = result

        return result

    def _compute_array(
        self,
        z: np.ndarray,
        alpha: float,
        beta: float,
        max_terms: Optional[int],
        tolerance: float
    ) -> np.ndarray:
        """Compute Mittag-Leffler function for array input."""
        # Determine optimal max_terms if not provided
        if max_terms is None:
            max_terms = self._get_optimal_max_terms(z[0], alpha, beta)

        # Ensure max_terms is not None
        if max_terms is None:
            max_terms = 100

        # Vectorized computation for better performance
        if self.use_jax and JAX_AVAILABLE:
            try:
                return self._compute_jax_array(z, alpha, beta, max_terms, tolerance)
            except Exception:
                pass

        # Fallback to optimized NumPy implementation
        return self._compute_numpy_array(z, alpha, beta, max_terms, tolerance)

    def _get_optimal_max_terms(self, z: float, alpha: float, beta: float) -> int:
        """Determine optimal number of terms for convergence."""
        if not self.adaptive_convergence:
            return 100

        # Adaptive convergence based on argument magnitude and parameters
        abs_z = abs(z)

        if abs_z < 0.1:
            return 20
        elif abs_z < 1.0:
            return 50
        elif abs_z < 10.0:
            return 100
        else:
            return 200

    def _compute_python_scalar(
        self,
        z: float,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> float:
        """Python implementation with optimizations for fractional calculus."""
        if alpha <= 0 or beta <= 0:
            return np.nan

        # Handle large z where exp(z) would overflow
        if not np.iscomplexobj(z) and z > 700:
            return np.inf

        if abs(z) < 1e-15:
            return 1.0 / gamma(beta)

        term = 1.0 / gamma(beta)
        result = term
        k = 1

        while k < max_terms:
            log_gamma_ratio = gammaln(
                alpha * (k - 1) + beta) - gammaln(alpha * k + beta)
            term = term * z * np.exp(log_gamma_ratio)

            if abs(term) < tolerance:
                break

            result += term
            k += 1

        return result

    def _compute_numba_scalar(
        self,
        z: float,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> float:
        """Numba-optimized scalar computation."""
        return self._ml_numba_scalar(z, alpha, beta, max_terms, tolerance)

    def _compute_numpy_array(
        self,
        z: np.ndarray,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> np.ndarray:
        """Optimized NumPy array computation."""
        result = np.zeros_like(z)

        # Vectorized computation for better performance
        for i in prange(len(z.flat)):
            result.flat[i] = self._compute_python_scalar(
                z.flat[i], alpha, beta, max_terms, tolerance
            )

        return result

    def _compute_jax_scalar(
        self,
        z: float,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> float:
        """JAX-optimized scalar computation."""
        if not JAX_AVAILABLE:
            raise RuntimeError("JAX not available")

        # JAX implementation would go here
        # For now, fallback to Python implementation
        return self._compute_python_scalar(z, alpha, beta, max_terms, tolerance)

    def _compute_jax_array(
        self,
        z: np.ndarray,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> np.ndarray:
        """JAX-optimized array computation."""
        if not JAX_AVAILABLE:
            raise RuntimeError("JAX not available")

        # JAX implementation would go here
        # For now, fallback to NumPy implementation
        return self._compute_numpy_array(z, alpha, beta, max_terms, tolerance)

    @staticmethod
    @jit(nopython=True)
    def _ml_numba_scalar(
        z: float,
        alpha: float,
        beta: float,
        max_terms: int,
        tolerance: float
    ) -> float:
        """Numba-optimized Mittag-Leffler function."""
        if abs(z) < 1e-15:
            return 1.0

        result = 0.0
        term = 1.0
        k = 0

        while k < max_terms and abs(term) > tolerance:
            result += term
            k += 1

            if k > 0:
                denominator = alpha * k + beta - alpha
                if abs(denominator) < 1e-15:
                    break
                term = term * z / (denominator + 1e-99)

                # Early termination for negative arguments
                if z < 0 and k > 10 and abs(term) < tolerance * 10:
                    break

        return result if np.isfinite(result) else 0.0

    def compute_fast(
        self,
        z: Union[float, np.ndarray],
        alpha: float,
        beta: float = 1.0
    ) -> Union[float, np.ndarray]:
        """
        Fast computation optimized for Atangana-Baleanu derivatives.

        This method is specifically optimized for the common use case
        E_α(-α(t-τ)^α/(1-α)) in Atangana-Baleanu derivatives.
        """
        # Special optimizations for negative arguments
        if np.isscalar(z) and z < 0:
            return self._compute_negative_fast(z, alpha, beta)
        elif not np.isscalar(z) and np.all(z < 0):
            return self._compute_negative_array_fast(z, alpha, beta)
        else:
            return self.compute(z, alpha, beta)

    def _compute_negative_fast(self, z: float, alpha: float, beta: float) -> float:
        """Fast computation for negative arguments."""
        if abs(z) < 1e-15:
            return 1.0

        # Optimized series for negative arguments
        result = 0.0
        term = 1.0
        k = 0

        while k < 50 and abs(term) > 1e-12:
            result += term
            k += 1

            if k > 0:
                denominator = alpha * k + beta - alpha
                if abs(denominator) < 1e-15:
                    break
                term = term * z / denominator

        return result if np.isfinite(result) else 0.0

    def _compute_negative_array_fast(self, z: np.ndarray, alpha: float, beta: float) -> np.ndarray:
        """Fast computation for negative argument arrays."""
        result = np.zeros_like(z)

        for i in prange(len(z.flat)):
            result.flat[i] = self._compute_negative_fast(
                z.flat[i], alpha, beta)

        return result


# Convenience functions for backward compatibility
[docs] def mittag_leffler_function( alpha: float, beta: float, z: Union[float, np.ndarray], use_jax: bool = False, use_numba: bool = False # Disabled by default due to compilation issues ) -> Union[float, np.ndarray]: """ Optimized Mittag-Leffler function. Args: alpha: First parameter beta: Second parameter z: Input value(s) use_jax: Use JAX acceleration use_numba: Use Numba JIT compilation Returns: Mittag-Leffler function value(s) """ ml_func = MittagLefflerFunction( use_jax=use_jax, use_numba=use_numba ) return ml_func.compute(z, alpha, beta)
[docs] def mittag_leffler_derivative( alpha: float, beta: float, z: Union[float, np.ndarray], order: int = 1 ) -> Union[float, np.ndarray]: """ Compute the derivative of the Mittag-Leffler function. The derivative is given by: d/dz E_α,β(z) = E_α,α+β(z) / α Args: alpha: First parameter beta: Second parameter z: Input value(s) order: Order of derivative (default: 1) Returns: Derivative value(s) """ if order == 0: return mittag_leffler_function(alpha, beta, z) elif order == 1: return mittag_leffler_function(alpha, alpha + beta, z) / alpha else: # Higher order derivatives can be computed recursively ml_func = MittagLefflerFunction() return ml_func.compute(z, alpha, alpha + beta) / alpha
[docs] def mittag_leffler_fast( z: Union[float, np.ndarray], alpha: float, beta: float = 1.0 ) -> Union[float, np.ndarray]: """ Fast Mittag-Leffler function optimized for fractional calculus. This function is specifically optimized for common use cases in fractional calculus, particularly Atangana-Baleanu derivatives. """ ml_func = MittagLefflerFunction( use_jax=False, use_numba=False, # Disabled due to compilation issues adaptive_convergence=True ) return ml_func.compute_fast(z, alpha, beta)
[docs] def mittag_leffler( z: Union[float, np.ndarray], alpha: float, beta: float = 1.0, use_jax: bool = False, use_numba: bool = False # Disabled by default due to compilation issues ) -> Union[float, np.ndarray]: """ Convenience function for Mittag-Leffler function. This is an alias for mittag_leffler_function to maintain compatibility with existing code that expects this function name. Args: z: Input value(s) alpha: First parameter beta: Second parameter use_jax: Use JAX acceleration use_numba: Use Numba JIT compilation Returns: Mittag-Leffler function value(s) """ if alpha <= 0 or beta <= 0: return np.nan return mittag_leffler_function(alpha, beta, z, use_jax=use_jax, use_numba=use_numba)