"""
Complete Fractional Graph Neural Network Architectures
This module provides complete GNN architectures with fractional calculus integration,
including various model types and configurations for different graph learning tasks.
"""
from typing import Optional, Union, Any, List, Dict
from abc import ABC, abstractmethod
from .backends import get_backend_manager, BackendType
from .gnn_layers import FractionalGraphConv, FractionalGraphAttention, FractionalGraphPooling
from .tensor_ops import get_tensor_ops
from ..core.definitions import FractionalOrder
[docs]
class BaseFractionalGNN(ABC):
"""
Base class for fractional Graph Neural Networks
This abstract class defines the interface for all fractional GNN models,
ensuring consistency across different architectures and backends.
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int = 3,
fractional_order: Union[float, FractionalOrder] = 0.5,
method: str = "RL",
use_fractional: bool = True,
activation: str = "relu",
dropout: float = 0.1,
backend: Optional[BackendType] = None
):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_layers = num_layers
self.fractional_order = FractionalOrder(fractional_order) if isinstance(
fractional_order, float) else fractional_order
self.method = method
self.use_fractional = use_fractional
self.activation = activation
self.dropout = dropout
self.backend = backend or get_backend_manager().active_backend
# Initialize tensor operations
self.tensor_ops = get_tensor_ops(self.backend)
# Build the network
self._build_network()
[docs]
@abstractmethod
def _build_network(self):
"""Build the specific network architecture"""
[docs]
@abstractmethod
def forward(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""Forward pass through the network"""
[docs]
def get_backend_info(self) -> Dict[str, Any]:
"""Get information about the current backend"""
return {
'backend': self.backend.value,
'tensor_lib': str(self.tensor_ops.tensor_lib),
'fractional_order': self.fractional_order.alpha,
'method': self.method,
'use_fractional': self.use_fractional
}
def __call__(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""Make models callable like torch modules"""
return self.forward(x, edge_index, batch, **kwargs)
[docs]
def parameters(self) -> List[Any]:
"""Collect learnable parameters from sub-layers for testing/optimizers"""
params: List[Any] = []
layer_attrs = []
# Gather potential layer lists/attributes
for attr in [
'layers', 'encoder_layers', 'decoder_layers', 'output_layer'
]:
if hasattr(self, attr):
layer_attrs.append(getattr(self, attr))
for entry in layer_attrs:
if isinstance(entry, list):
iterable = entry
else:
iterable = [entry]
for layer in iterable:
for name in [
'weight',
'bias',
'query_weight',
'key_weight',
'value_weight',
'output_weight',
'score_network']:
if hasattr(layer, name):
params.append(getattr(layer, name))
return params
[docs]
class FractionalGCN(BaseFractionalGNN):
"""
Fractional Graph Convolutional Network
A GNN architecture that uses fractional graph convolution layers
for node classification, graph classification, and other tasks.
"""
[docs]
def _build_network(self):
"""Build the GCN architecture"""
self.layers = []
# Input layer
self.layers.append(
FractionalGraphConv(
in_channels=self.input_dim,
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Hidden layers
for _ in range(self.num_layers - 2):
self.layers.append(
FractionalGraphConv(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Output layer
self.layers.append(
FractionalGraphConv(
in_channels=self.hidden_dim,
out_channels=self.output_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation="identity", # No activation for output
dropout=0.0, # No dropout for output
backend=self.backend
)
)
[docs]
def forward(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""
Forward pass through the fractional GCN
Args:
x: Node feature matrix [num_nodes, input_dim]
edge_index: Graph connectivity [2, num_edges]
batch: Batch assignment vector [num_nodes]
Returns:
Node embeddings [num_nodes, output_dim]
"""
# Pass through all layers
for layer in self.layers:
x = layer.forward(x, edge_index, **kwargs)
return x
[docs]
class FractionalGAT(BaseFractionalGNN):
"""
Fractional Graph Attention Network
A GNN architecture that uses fractional graph attention layers
for enhanced graph representation learning.
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int = 3,
num_heads: int = 8,
fractional_order: Union[float, FractionalOrder] = 0.5,
method: str = "RL",
use_fractional: bool = True,
activation: str = "relu",
dropout: float = 0.1,
backend: Optional[BackendType] = None
):
self.num_heads = num_heads
super().__init__(
input_dim,
hidden_dim,
output_dim,
num_layers,
fractional_order,
method,
use_fractional,
activation,
dropout,
backend)
[docs]
def _build_network(self):
"""Build the GAT architecture"""
self.layers = []
# Input layer
self.layers.append(
FractionalGraphAttention(
in_channels=self.input_dim,
out_channels=self.hidden_dim,
heads=self.num_heads,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Hidden layers
for _ in range(self.num_layers - 2):
self.layers.append(
FractionalGraphAttention(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
heads=self.num_heads,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Output layer
self.layers.append(
FractionalGraphAttention(
in_channels=self.hidden_dim,
out_channels=self.output_dim,
heads=1, # Single head for output
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation="identity",
dropout=0.0,
backend=self.backend
)
)
[docs]
def forward(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""
Forward pass through the fractional GAT
Args:
x: Node feature matrix [num_nodes, input_dim]
edge_index: Graph connectivity [2, num_edges]
batch: Batch assignment vector [num_nodes]
Returns:
Node embeddings [num_nodes, output_dim]
"""
# Pass through all layers
for layer in self.layers:
x = layer.forward(x, edge_index, **kwargs)
return x
[docs]
class FractionalGraphSAGE(BaseFractionalGNN):
"""
Fractional GraphSAGE Network
A GNN architecture that uses fractional graph convolution layers
with neighbor sampling for scalable graph learning.
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int = 3,
num_samples: int = 25,
fractional_order: Union[float, FractionalOrder] = 0.5,
method: str = "RL",
use_fractional: bool = True,
activation: str = "relu",
dropout: float = 0.1,
backend: Optional[BackendType] = None
):
self.num_samples = num_samples
super().__init__(
input_dim,
hidden_dim,
output_dim,
num_layers,
fractional_order,
method,
use_fractional,
activation,
dropout,
backend)
[docs]
def _build_network(self):
"""Build the GraphSAGE architecture"""
self.layers = []
# Input layer
self.layers.append(
FractionalGraphConv(
in_channels=self.input_dim,
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Hidden layers
for _ in range(self.num_layers - 2):
self.layers.append(
FractionalGraphConv(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Output layer
self.layers.append(
FractionalGraphConv(
in_channels=self.hidden_dim,
out_channels=self.output_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation="identity",
dropout=0.0,
backend=self.backend
)
)
[docs]
def forward(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""
Forward pass through the fractional GraphSAGE
Args:
x: Node feature matrix [num_nodes, input_dim]
edge_index: Graph connectivity [2, num_edges]
batch: Batch assignment vector [num_nodes]
Returns:
Node embeddings [num_nodes, output_dim]
"""
# Pass through all layers
for layer in self.layers:
x = layer.forward(x, edge_index, **kwargs)
return x
[docs]
class FractionalGraphUNet(BaseFractionalGNN):
"""
Fractional Graph U-Net
A hierarchical GNN architecture that uses fractional calculus
for multi-scale graph representation learning.
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int = 4,
pooling_ratio: float = 0.5,
fractional_order: Union[float, FractionalOrder] = 0.5,
method: str = "RL",
use_fractional: bool = True,
activation: str = "relu",
dropout: float = 0.1,
backend: Optional[BackendType] = None
):
self.pooling_ratio = pooling_ratio
super().__init__(
input_dim,
hidden_dim,
output_dim,
num_layers,
fractional_order,
method,
use_fractional,
activation,
dropout,
backend)
[docs]
def _build_network(self):
"""Build the Graph U-Net architecture"""
# Encoder layers
self.encoder_layers = []
current_dim = self.input_dim
for i in range(self.num_layers):
self.encoder_layers.append(
FractionalGraphConv(
in_channels=current_dim,
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
current_dim = self.hidden_dim
# Pooling layers (skip for small networks to preserve node count)
self.pooling_layers = []
if self.num_layers > 2: # Only use pooling for deeper networks
for _ in range(self.num_layers - 1):
self.pooling_layers.append(
FractionalGraphPooling(
in_channels=self.hidden_dim,
pooling_ratio=self.pooling_ratio,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
backend=self.backend
)
)
# Decoder layers (skip for small networks)
self.decoder_layers = []
if self.num_layers > 2: # Only use decoder layers for deeper networks
for i in range(self.num_layers - 1):
self.decoder_layers.append(
FractionalGraphConv(
in_channels=self.hidden_dim * 2, # Skip connection
out_channels=self.hidden_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation=self.activation,
dropout=self.dropout,
backend=self.backend
)
)
# Output layer
self.output_layer = FractionalGraphConv(
in_channels=self.hidden_dim,
out_channels=self.output_dim,
fractional_order=self.fractional_order,
method=self.method,
use_fractional=self.use_fractional,
activation="identity",
dropout=0.0,
backend=self.backend
)
[docs]
def forward(
self,
x: Any,
edge_index: Any,
batch: Optional[Any] = None,
**kwargs) -> Any:
"""
Forward pass through the fractional Graph U-Net
Args:
x: Node feature matrix [num_nodes, input_dim]
edge_index: Graph connectivity [2, num_edges]
batch: Batch assignment vector [num_nodes]
Returns:
Node embeddings [num_nodes, output_dim]
"""
# Encoder path
encoder_outputs = [x]
current_x = x
current_edge_index = edge_index
current_batch = batch
for i, layer in enumerate(self.encoder_layers):
current_x = layer.forward(current_x, current_edge_index, **kwargs)
encoder_outputs.append(current_x)
# Apply pooling (except for the last layer and if pooling layers
# exist)
if i < len(self.pooling_layers) and len(self.pooling_layers) > 0:
current_x, current_edge_index, current_batch = self.pooling_layers[i].forward(
current_x, current_edge_index, current_batch, **kwargs)
# Decoder path with skip connections (only if decoder layers exist)
if len(self.decoder_layers) > 0:
for i, layer in enumerate(self.decoder_layers):
# Skip connection
skip_x = encoder_outputs[-(i + 2)]
# Ensure skip_x has compatible dimensions with current_x
if skip_x.shape[0] != current_x.shape[0]:
# Reshape skip_x to match current_x dimensions
if skip_x.shape[0] > current_x.shape[0]:
# Truncate skip_x to match current_x
skip_x = skip_x[:current_x.shape[0], :]
else:
# Pad skip_x to match current_x
padding = current_x.shape[0] - skip_x.shape[0]
if padding > 0:
# Create padding tensor using tensor_ops
padding_tensor = self.tensor_ops.zeros(
(padding, skip_x.shape[1]))
skip_x = self.tensor_ops.cat(
[skip_x, padding_tensor], dim=0)
# Ensure feature dimensions are compatible for concatenation
if skip_x.shape[-1] != current_x.shape[-1]:
# Reshape skip_x to match current_x feature dimensions
if skip_x.shape[-1] > current_x.shape[-1]:
# Truncate features
skip_x = skip_x[..., :current_x.shape[-1]]
else:
# Pad features with zeros
feature_padding = current_x.shape[-1] - \
skip_x.shape[-1]
if feature_padding > 0:
padding_tensor = self.tensor_ops.zeros(
(skip_x.shape[0], feature_padding))
skip_x = self.tensor_ops.cat(
[skip_x, padding_tensor], dim=-1)
# Concatenate with current features
current_x = self.tensor_ops.cat([current_x, skip_x], dim=-1)
# Pass through decoder layer
current_x = layer.forward(current_x, current_edge_index, **kwargs)
# Output layer
output = self.output_layer.forward(current_x, current_edge_index, **kwargs)
return output
[docs]
class FractionalGNNFactory:
"""
Factory class for creating fractional GNN models
This class provides a convenient interface for creating different
types of fractional GNN architectures with consistent configurations.
"""
[docs]
@staticmethod
def create_model(
model_type: str,
input_dim: int,
hidden_dim: int,
output_dim: int,
**kwargs
) -> BaseFractionalGNN:
"""
Create a fractional GNN model of the specified type
Args:
model_type: Type of GNN model ('gcn', 'gat', 'sage', 'unet')
input_dim: Input feature dimension
hidden_dim: Hidden layer dimension
output_dim: Output dimension
**kwargs: Additional arguments for the model
Returns:
Configured fractional GNN model
"""
model_type = model_type.lower()
if model_type == 'gcn':
return FractionalGCN(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
**kwargs
)
elif model_type == 'gat':
return FractionalGAT(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
**kwargs
)
elif model_type == 'sage':
return FractionalGraphSAGE(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
**kwargs
)
elif model_type == 'unet':
return FractionalGraphUNet(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
**kwargs
)
else:
raise ValueError(f"Unknown model type: {model_type}. "
f"Available types: gcn, gat, sage, unet")
[docs]
@staticmethod
def get_available_models() -> List[str]:
"""Get list of available model types"""
return ['gcn', 'gat', 'sage', 'unet']
[docs]
@staticmethod
def get_model_info(model_type: str) -> Dict[str, Any]:
"""Get information about a specific model type"""
model_type = model_type.lower()
info = {
'gcn': {
'name': 'Fractional Graph Convolutional Network',
'description': 'Standard GCN with fractional calculus integration',
'best_for': ['Node classification', 'Graph classification', 'Link prediction'],
'complexity': 'Low',
'memory_efficient': True
},
'gat': {
'name': 'Fractional Graph Attention Network',
'description': 'GNN with attention mechanisms and fractional calculus',
'best_for': ['Node classification', 'Graph classification', 'Attention analysis'],
'complexity': 'Medium',
'memory_efficient': False
},
'sage': {
'name': 'Fractional GraphSAGE',
'description': 'Scalable GNN with neighbor sampling and fractional calculus',
'best_for': ['Large graphs', 'Inductive learning', 'Node classification'],
'complexity': 'Low',
'memory_efficient': True
},
'unet': {
'name': 'Fractional Graph U-Net',
'description': 'Hierarchical GNN with skip connections and fractional calculus',
'best_for': ['Multi-scale learning', 'Graph segmentation', 'Hierarchical tasks'],
'complexity': 'High',
'memory_efficient': False
}
}
return info.get(
model_type, {
'error': f'Unknown model type: {model_type}'})