""" Neural network layers for TopoHyper. Implements: - SimplicialConv: Hodge Laplacian-based convolution on simplicial complexes - HypergraphConv: Spectral convolution on hypergraphs - CrossStructureAttention: Attention-gated fusion between simplicial and hypergraph views - TopoHyperConv: Three-phase integrated convolution layer - TopoHyperPool: Graph-level readout with both structures - GCNLayer, GATLayer: Baselines """ import torch import torch.nn as nn import torch.nn.functional as F import math class SimplicialConv(nn.Module): """ Simplicial convolution using the Hodge Laplacian. Implements: X' = sigma(L_0 X W + X W_self) where L_0 = B1 B1^T is the 0-th Hodge Laplacian (graph Laplacian). Uses |B1| (unsigned) for compatibility with HGNN's non-negative space. """ def __init__(self, in_dim, out_dim): super().__init__() self.W = nn.Linear(in_dim, out_dim, bias=False) self.W_self = nn.Linear(in_dim, out_dim, bias=True) def forward(self, x, sc): """ Args: x: (N, in_dim) node features sc: SimplicialComplex object Returns: (N, out_dim) updated features """ # Use unsigned boundary for compatibility with HGNN B1_abs = sc.B1.abs().to(x.device) # L0_unsigned = |B1| |B1|^T L0 = B1_abs @ B1_abs.t() # Normalize D = L0.sum(dim=1, keepdim=True).clamp(min=1e-7) L0_norm = L0 / D # Convolution h_neigh = L0_norm @ x out = self.W(h_neigh) + self.W_self(x) return out class HypergraphConv(nn.Module): """ Hypergraph convolution using spectral propagation. Implements: X' = sigma(D_v^{-1/2} H W D_e^{-1} H^T D_v^{-1/2} X Theta) """ def __init__(self, in_dim, out_dim): super().__init__() self.theta = nn.Linear(in_dim, out_dim, bias=True) def forward(self, x, hg): """ Args: x: (N, in_dim) node features hg: Hypergraph object Returns: (N, out_dim) updated features """ P = hg.propagation_matrix().to(x.device) h = P @ x return self.theta(h) class CrossStructureAttention(nn.Module): """ Attention-gated fusion between simplicial and hypergraph representations. Learns alpha_i in [0, 1] for each node indicating how much to weight each view, enabling adaptive combination. """ def __init__(self, dim): super().__init__() self.gate = nn.Sequential( nn.Linear(dim * 2, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() ) def forward(self, h_sc, h_hg): """ Args: h_sc: (N, dim) simplicial features h_hg: (N, dim) hypergraph features Returns: (N, dim) fused features """ combined = torch.cat([h_sc, h_hg], dim=-1) alpha = self.gate(combined) # (N, 1) return alpha * h_sc + (1 - alpha) * h_hg class TopoHyperConv(nn.Module): """ Three-phase TopoHyper convolution layer. Phase 1: Simplicial convolution via unsigned Hodge Laplacian Phase 2: Hypergraph convolution via spectral propagation Phase 3: Cross-structure fusion with attention gate + bridge matrix The bridge matrix B = A_sc . A_hg captures nodes connected in BOTH views, enabling topological constraints to inform hypergraph propagation and vice versa. """ def __init__(self, in_dim, out_dim, use_bridge=True, use_attention=True): super().__init__() self.use_bridge = use_bridge self.use_attention = use_attention self.sc_conv = SimplicialConv(in_dim, out_dim) self.hg_conv = HypergraphConv(in_dim, out_dim) if use_attention: self.attention = CrossStructureAttention(out_dim) if use_bridge: self.bridge_proj = nn.Linear(out_dim, out_dim, bias=False) self.norm = nn.LayerNorm(out_dim) def forward(self, x, sc, hg): """ Args: x: (N, in_dim) node features sc: SimplicialComplex hg: Hypergraph Returns: (N, out_dim) updated features """ # Phase 1: Simplicial convolution h_sc = F.relu(self.sc_conv(x, sc)) # Phase 2: Hypergraph convolution h_hg = F.relu(self.hg_conv(x, hg)) # Phase 3: Fusion if self.use_attention: h_fused = self.attention(h_sc, h_hg) else: h_fused = (h_sc + h_hg) / 2.0 # Bridge: propagate through shared structure if self.use_bridge: A_sc = sc.adjacency_matrix().to(x.device) A_hg = hg.adjacency_matrix().to(x.device) bridge = A_sc * A_hg # element-wise: nodes connected in BOTH views # Normalize bridge D_bridge = bridge.sum(dim=1, keepdim=True).clamp(min=1e-7) bridge_norm = bridge / D_bridge h_bridge = bridge_norm @ h_fused h_fused = h_fused + self.bridge_proj(h_bridge) return self.norm(h_fused) class TopoHyperPool(nn.Module): """ Graph-level readout combining mean, max, and attention pooling. """ def __init__(self, dim): super().__init__() self.att = nn.Linear(dim, 1) def forward(self, x): """ Args: x: (N, dim) node features Returns: (3*dim,) graph-level feature vector """ h_mean = x.mean(dim=0) h_max = x.max(dim=0)[0] att_weights = F.softmax(self.att(x), dim=0) h_att = (att_weights * x).sum(dim=0) return torch.cat([h_mean, h_max, h_att]) # ==================== Baseline Layers ==================== class GCNLayer(nn.Module): """Standard GCN layer: X' = D^{-1/2} A D^{-1/2} X W""" def __init__(self, in_dim, out_dim): super().__init__() self.W = nn.Linear(in_dim, out_dim, bias=True) def forward(self, x, edge_index): N = x.shape[0] device = x.device # Build adjacency with self-loops A = torch.zeros(N, N, device=device) A[edge_index[0], edge_index[1]] = 1.0 A = A + torch.eye(N, device=device) # Symmetric normalization D_inv_sqrt = torch.diag(1.0 / (A.sum(dim=1).sqrt() + 1e-7)) A_norm = D_inv_sqrt @ A @ D_inv_sqrt return self.W(A_norm @ x) class GATLayer(nn.Module): """Graph Attention Network layer.""" def __init__(self, in_dim, out_dim, heads=4): super().__init__() self.heads = heads self.head_dim = out_dim // heads assert out_dim % heads == 0 self.W = nn.Linear(in_dim, out_dim, bias=False) self.a_src = nn.Parameter(torch.randn(heads, self.head_dim)) self.a_dst = nn.Parameter(torch.randn(heads, self.head_dim)) self.bias = nn.Parameter(torch.zeros(out_dim)) nn.init.xavier_uniform_(self.W.weight) nn.init.xavier_normal_(self.a_src.unsqueeze(0)) nn.init.xavier_normal_(self.a_dst.unsqueeze(0)) def forward(self, x, edge_index): N = x.shape[0] device = x.device h = self.W(x).view(N, self.heads, self.head_dim) # Build adjacency A = torch.zeros(N, N, device=device) A[edge_index[0], edge_index[1]] = 1.0 A = A + torch.eye(N, device=device) # Attention scores e_src = (h * self.a_src.unsqueeze(0)).sum(-1) # (N, heads) e_dst = (h * self.a_dst.unsqueeze(0)).sum(-1) # (N, heads) # e_ij = LeakyReLU(e_src_i + e_dst_j) attn = F.leaky_relu(e_src.unsqueeze(2) + e_dst.unsqueeze(1), 0.2) # (N, N, heads) # Mask non-edges mask = (A == 0).unsqueeze(-1).expand_as(attn) attn = attn.masked_fill(mask, float('-inf')) attn = F.softmax(attn, dim=2) attn = torch.nan_to_num(attn, nan=0.0) # Aggregate out = torch.einsum('ijh,jhd->ihd', attn, h) out = out.reshape(N, -1) + self.bias return out