Source code for hpfracc.ml.gnn_layers.base

from abc import ABC, abstractmethod
from typing import Optional, Union, Any
from hpfracc.ml.backends import get_backend_manager, BackendType
from hpfracc.core.definitions import FractionalOrder
from hpfracc.core.fractional_implementations import _AlphaCompatibilityWrapper
from hpfracc.ml.tensor_ops import get_tensor_ops

[docs] class BaseFractionalGNNLayer(ABC): """ Base class for fractional GNN layers """
[docs] def __init__( self, in_channels: int, out_channels: int, fractional_order: Union[float, FractionalOrder] = 0.5, method: str = "RL", use_fractional: bool = True, activation: str = "relu", dropout: float = 0.1, bias: bool = True, backend: Optional[BackendType] = None ): self.in_channels = in_channels self.out_channels = out_channels # Alpha handling if isinstance(fractional_order, float): self.fractional_order = _AlphaCompatibilityWrapper( FractionalOrder(fractional_order)) elif isinstance(fractional_order, FractionalOrder): self.fractional_order = fractional_order else: self.fractional_order = _AlphaCompatibilityWrapper(fractional_order) self.method = method self.use_fractional = use_fractional self.activation = activation self.dropout = dropout self.bias = bias if bias else None # Backend handling self.backend = backend or get_backend_manager().active_backend self.tensor_ops = get_tensor_ops(self.backend) # Initialize layer-specific parameters (weights, biases, etc.) self._initialize_layer()
[docs] @abstractmethod def forward( self, x: Any, edge_index: Any, edge_weight: Optional[Any] = None, **kwargs) -> Any: """Forward pass through the layer""" pass
[docs] def apply_fractional_derivative(self, x: Any) -> Any: """Apply fractional derivative to input features.""" if not self.use_fractional: return x alpha = self.fractional_order.alpha # Delegate to backend-specific implementations that must be provided # by subclasses or mixins, or handled via tensor_ops if possible. # For now, we'll keep the logic that switches based on backend # but implement it cleanly. if self.backend == BackendType.TORCH or self.backend == BackendType.AUTO: return self._torch_fractional_derivative(x, alpha) elif self.backend == BackendType.JAX: return self._jax_fractional_derivative(x, alpha) elif self.backend == BackendType.NUMBA: return self._numba_fractional_derivative(x, alpha) else: raise RuntimeError(f"Unknown backend: {self.backend}")
# These internal implementations will be provided by backend-specific logic # injected or inherited via mixins. # They are NOT abstract so that mixin composition works correctly. def _torch_fractional_derivative(self, x: Any, alpha: float) -> Any: raise NotImplementedError("Subclass must provide _torch_fractional_derivative via mixin") def _jax_fractional_derivative(self, x: Any, alpha: float) -> Any: raise NotImplementedError("Subclass must provide _jax_fractional_derivative via mixin") def _numba_fractional_derivative(self, x: Any, alpha: float) -> Any: raise NotImplementedError("Subclass must provide _numba_fractional_derivative via mixin") def __call__(self, *args, **kwargs): """Callable layer wrapper""" return self.forward(*args, **kwargs)