"""
Native (differentiable) fractional derivatives for neural-network feature maps.
Operators are **discrete** convolutions along a user-chosen axis (default: last).
Grid spacing ``h`` is uniform; use ``fractional_t_grid`` to imply ``h`` from
coordinates, or pass ``fractional_step`` explicitly.
See ``FractionalNeuralNetwork`` docstring for tensor layout (batch vs. sampled axis).
"""
from __future__ import annotations
import math
from typing import Any, Optional, Sequence, Union
import numpy as np
from hpfracc.ml.backends import BackendType
from hpfracc.special.binomial_coeffs import binomial_sequence_fast
from hpfracc.special.gamma_beta import gamma as hpfracc_gamma
def _normalize_axis(axis: int, ndim: int) -> int:
if ndim < 1:
raise ValueError("fractional operator expects ndim >= 1")
if axis < 0:
axis += ndim
if not 0 <= axis < ndim:
raise ValueError(f"fractional axis {axis} out of bounds for ndim={ndim}")
return axis
def _step_size_uniform_01(length: int) -> float:
if length < 2:
return 1.0
return 1.0 / float(length - 1)
[docs]
def _t_grid_and_h(
length: int,
*,
h: Optional[float],
t_grid: Optional[Union[np.ndarray, Sequence[float]]],
) -> tuple[np.ndarray, float]:
"""
Build a 1D uniform time/sample grid of ``length`` and scalar step ``h``.
Exactly one of: ``h`` is set, ``t_grid`` is set, or both None (default [0,1]).
"""
if h is not None and t_grid is not None:
raise ValueError("Pass at most one of fractional_step (h) and fractional_t_grid.")
if length < 1:
raise ValueError("Sampled axis length must be >= 1")
if t_grid is not None:
t = np.asarray(t_grid, dtype=np.float64).ravel()
if t.size != length:
raise ValueError(
f"fractional_t_grid length {t.size} != sampled axis length {length}"
)
if t.size < 2:
return t, 1.0
diffs = np.diff(t)
if np.any(diffs <= 0):
raise ValueError("fractional_t_grid must be strictly increasing.")
d0 = float(diffs[0])
if not np.allclose(diffs, d0, rtol=1e-5, atol=1e-8):
raise ValueError(
"fractional_t_grid must be uniformly spaced (native Caputo L1 / GL path)."
)
return t, d0
if h is not None:
hf = float(h)
if hf <= 0:
raise ValueError("fractional_step (h) must be positive.")
t = np.arange(length, dtype=np.float64) * hf
return t, hf
t = np.linspace(0.0, 1.0, length, dtype=np.float64)
return t, _step_size_uniform_01(length)
def _next_fft_len(n: int) -> int:
return int(2 ** math.ceil(math.log2(2 * n - 1)))
[docs]
def _gl_weights_numpy(alpha: float, n: int) -> np.ndarray:
"""Length-``n`` Grünwald–Letnikov mask (same construction as numpy_backend)."""
coeffs = binomial_sequence_fast(alpha, n - 1, use_numba=False)
signs = (-1.0) ** np.arange(n, dtype=np.float64)
return (signs * coeffs).astype(np.float64)
[docs]
def grunwald_letnikov_last_dim_numpy(x: np.ndarray, alpha: float, h: float) -> np.ndarray:
"""Reference NumPy implementation along last axis (for tests / NUMBA backend)."""
if alpha == 0.0:
return np.array(x, copy=True)
n = x.shape[-1]
gl = _gl_weights_numpy(alpha, n)
size = _next_fft_len(n)
orig = x.shape
flat = x.reshape(-1, n)
out = np.empty_like(flat)
for i in range(flat.shape[0]):
fp = np.zeros(size, dtype=flat.dtype)
fp[:n] = flat[i]
cp = np.zeros(size, dtype=np.float64)
cp[:n] = gl.astype(flat.dtype, copy=False)
out[i] = np.fft.ifft(np.fft.fft(fp) * np.fft.fft(cp)).real[:n]
return (h ** (-alpha)) * out.reshape(orig)
[docs]
def caputo_l1_last_dim_numpy(x: np.ndarray, alpha: float, h: float) -> np.ndarray:
"""Caputo L1 along last axis (0 < alpha < 1), matching ``_caputo_numpy``."""
n = x.shape[-1]
k = np.arange(n, dtype=np.float64)
weights = (k + 1) ** (1 - alpha) - k ** (1 - alpha)
c = 1.0 / (hpfracc_gamma(2 - alpha) * h**alpha)
orig = x.shape
flat = x.reshape(-1, n)
df = np.diff(flat, axis=-1)
df = np.concatenate([flat[:, :1], df], axis=-1)
size = _next_fft_len(n)
out = np.empty_like(flat)
for i in range(flat.shape[0]):
wp = np.zeros(size, dtype=np.float64)
wp[:n] = weights
dp = np.zeros(size, dtype=flat.dtype)
dp[:n] = df[i]
out[i] = np.fft.ifft(np.fft.fft(dp) * np.fft.fft(wp)).real[:n]
return (c * out).reshape(orig)
[docs]
def grunwald_letnikov_last_dim_torch(x: Any, alpha: float, h: float) -> Any:
import torch
if alpha == 0.0:
return x
n = x.shape[-1]
device, dtype = x.device, x.dtype
gl_np = _gl_weights_numpy(alpha, n)
gl = torch.as_tensor(gl_np, device=device, dtype=dtype)
size = _next_fft_len(n)
flat = x.reshape(-1, n)
f_pad = flat.new_zeros(flat.shape[0], size)
f_pad[:, :n] = flat
c_pad = flat.new_zeros(size)
c_pad[:n] = gl
prod = torch.fft.fft(f_pad, dim=-1) * torch.fft.fft(c_pad, dim=-1)
conv = torch.fft.ifft(prod, dim=-1).real[:, :n]
scale = torch.tensor(
h ** (-alpha), device=device, dtype=dtype, requires_grad=False
)
return (conv * scale).reshape(x.shape)
[docs]
def caputo_l1_last_dim_torch(x: Any, alpha: float, h: float) -> Any:
import torch
n = x.shape[-1]
device, dtype = x.device, x.dtype
k = torch.arange(n, device=device, dtype=dtype)
weights = (k + 1) ** (1 - alpha) - k ** (1 - alpha)
g2m = torch.tensor(
float(hpfracc_gamma(2 - alpha)),
device=device,
dtype=dtype,
requires_grad=False,
)
h_t = torch.tensor(h, device=device, dtype=dtype, requires_grad=False)
c = 1.0 / (g2m * h_t**alpha)
flat = x.reshape(-1, n)
df = torch.cat([flat[:, :1], torch.diff(flat, dim=-1)], dim=-1)
size = _next_fft_len(n)
d_pad = flat.new_zeros(flat.shape[0], size)
d_pad[:, :n] = df
w_pad = flat.new_zeros(size)
w_pad[:n] = weights
conv = torch.fft.ifft(
torch.fft.fft(d_pad, dim=-1) * torch.fft.fft(w_pad, dim=-1),
dim=-1,
).real[:, :n]
return (c * conv).reshape(x.shape)
[docs]
def grunwald_letnikov_last_dim_jax(x: Any, alpha: float, h: float) -> Any:
import jax.numpy as jnp
from jax.scipy.special import gamma as jgamma
if alpha == 0.0:
return x
n = x.shape[-1]
k = jnp.arange(n, dtype=x.dtype)
alpha_arr = jnp.array(alpha, dtype=x.dtype)
coeffs = jgamma(alpha_arr + 1) / (jgamma(k + 1) * jgamma(alpha_arr - k + 1))
gl = ((-1.0) ** k) * coeffs
size = _next_fft_len(n)
flat = x.reshape(-1, n)
f_pad = jnp.zeros((flat.shape[0], size), dtype=x.dtype)
f_pad = f_pad.at[:, :n].set(flat)
c_pad = jnp.zeros((size,), dtype=x.dtype)
c_pad = c_pad.at[:n].set(gl.astype(x.dtype))
conv = jnp.fft.ifft(jnp.fft.fft(f_pad, axis=-1) * jnp.fft.fft(c_pad), axis=-1).real[
:, :n
]
return (h ** (-alpha)) * conv.reshape(x.shape)
[docs]
def caputo_l1_last_dim_jax(x: Any, alpha: float, h: float) -> Any:
import jax.numpy as jnp
from jax.scipy.special import gamma as jgamma
n = x.shape[-1]
k = jnp.arange(n, dtype=x.dtype)
weights = (k + 1) ** (1 - alpha) - k ** (1 - alpha)
g2 = jgamma(jnp.array(2 - alpha, dtype=x.dtype))
c = 1.0 / (g2 * jnp.array(h, dtype=x.dtype) ** alpha)
flat = x.reshape(-1, n)
df = jnp.concatenate([flat[:, :1], jnp.diff(flat, axis=-1)], axis=-1)
size = _next_fft_len(n)
d_pad = jnp.zeros((flat.shape[0], size), dtype=x.dtype)
d_pad = d_pad.at[:, :n].set(df)
w_pad = jnp.zeros((size,), dtype=x.dtype)
w_pad = w_pad.at[:n].set(weights)
conv = jnp.fft.ifft(
jnp.fft.fft(d_pad, axis=-1) * jnp.fft.fft(w_pad), axis=-1
).real[:, :n]
return (c * conv).reshape(x.shape)
def _moveaxis_to_last(x: Any, backend: BackendType, axis: int) -> Any:
if backend in (BackendType.TORCH, BackendType.AUTO):
import torch
return torch.movedim(x, axis, -1)
if backend == BackendType.JAX:
import jax.numpy as jnp
return jnp.moveaxis(x, axis, -1)
return np.moveaxis(x, axis, -1)
def _moveaxis_from_last(x: Any, backend: BackendType, axis: int) -> Any:
if backend in (BackendType.TORCH, BackendType.AUTO):
import torch
return torch.movedim(x, -1, axis)
if backend == BackendType.JAX:
import jax.numpy as jnp
return jnp.moveaxis(x, -1, axis)
return np.moveaxis(x, -1, axis)
[docs]
def fractional_feature_map_native(
x: Any,
*,
backend: BackendType,
alpha: float,
method: str,
axis: int = -1,
h: Optional[float] = None,
t_grid: Optional[Union[np.ndarray, Sequence[float]]] = None,
) -> Any:
"""
Apply a discrete fractional operator along ``axis`` (moved internally to last).
Grid
----
- Default: sample times ``0, 1/(n-1), …, 1`` and ``h = 1/(n-1)``.
- ``fractional_step``: use ``t_i = i * h`` (same ``h`` in formulas).
- ``fractional_t_grid``: strictly increasing, **uniform** spacing; ``h`` is inferred.
Non-uniform grids are not supported for the native L1/GL kernels.
"""
if x.ndim == 0:
raise ValueError("fractional_feature_map_native expects at least 1D input")
ax = _normalize_axis(axis, x.ndim)
n = int(x.shape[ax])
_, h_eff = _t_grid_and_h(n, h=h, t_grid=t_grid)
m = method.upper()
x_work = x if ax == x.ndim - 1 else _moveaxis_to_last(x, backend, ax)
def _apply_gl_nonint() -> Any:
if backend == BackendType.TORCH or backend == BackendType.AUTO:
return grunwald_letnikov_last_dim_torch(x_work, float(alpha), h_eff)
if backend == BackendType.JAX:
return grunwald_letnikov_last_dim_jax(x_work, float(alpha), h_eff)
if backend == BackendType.NUMBA:
x_np = np.asarray(x_work, dtype=np.float32)
out = grunwald_letnikov_last_dim_numpy(x_np, float(alpha), h_eff)
return out.astype(np.float32, copy=False)
raise RuntimeError(f"Unknown backend: {backend}")
def _apply_caputo_l1() -> Any:
if backend == BackendType.TORCH or backend == BackendType.AUTO:
return caputo_l1_last_dim_torch(x_work, float(alpha), h_eff)
if backend == BackendType.JAX:
return caputo_l1_last_dim_jax(x_work, float(alpha), h_eff)
if backend == BackendType.NUMBA:
x_np = np.asarray(x_work, dtype=np.float32)
out = caputo_l1_last_dim_numpy(x_np, float(alpha), h_eff)
return out.astype(np.float32, copy=False)
raise RuntimeError(f"Unknown backend: {backend}")
alpha_f = float(alpha)
if backend == BackendType.TORCH or backend == BackendType.AUTO:
if m in ("RL", "GL"):
if float(alpha_f).is_integer():
raise NotImplementedError(
"Native GL/RL defers integer alpha to legacy path."
)
out_w = _apply_gl_nonint()
elif m == "CAPUTO":
if not (0.0 < alpha_f < 1.0):
raise NotImplementedError(
"Native differentiable Caputo is only implemented for 0 < alpha < 1 "
f"(got alpha={alpha_f}). Use differentiable_fractional=False for the "
"full NumPy Caputo path."
)
out_w = _apply_caputo_l1()
else:
raise ValueError(f"Unknown method: {method}")
elif backend == BackendType.JAX:
if m in ("RL", "GL"):
if float(alpha_f).is_integer():
raise NotImplementedError(
"Native GL/RL defers integer alpha to legacy path."
)
out_w = _apply_gl_nonint()
elif m == "CAPUTO":
if not (0.0 < alpha_f < 1.0):
raise NotImplementedError(
"Native differentiable Caputo is only implemented for 0 < alpha < 1 "
f"(got alpha={alpha_f}). Use differentiable_fractional=False for the "
"full NumPy Caputo path."
)
out_w = _apply_caputo_l1()
else:
raise ValueError(f"Unknown method: {method}")
elif backend == BackendType.NUMBA:
if m in ("RL", "GL"):
if float(alpha_f).is_integer():
raise NotImplementedError(
"Native GL/RL defers integer alpha to legacy path."
)
out_w = _apply_gl_nonint()
elif m == "CAPUTO":
if not (0.0 < alpha_f < 1.0):
raise NotImplementedError(
"Native NumPy Caputo path here is L1 only for 0 < alpha < 1."
)
out_w = _apply_caputo_l1()
else:
raise ValueError(f"Unknown method: {method}")
else:
raise RuntimeError(f"Unknown backend: {backend}")
if ax == x.ndim - 1:
return out_w
return _moveaxis_from_last(out_w, backend, ax)