| """ |
| Neural network layers for TopoHyper. |
| |
| Implements: |
| - SimplicialConv: Hodge Laplacian-based convolution on simplicial complexes |
| - HypergraphConv: Spectral convolution on hypergraphs |
| - CrossStructureAttention: Attention-gated fusion between simplicial and hypergraph views |
| - TopoHyperConv: Three-phase integrated convolution layer |
| - TopoHyperPool: Graph-level readout with both structures |
| - GCNLayer, GATLayer: Baselines |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class SimplicialConv(nn.Module): |
| """ |
| Simplicial convolution using the Hodge Laplacian. |
| |
| Implements: X' = sigma(L_0 X W + X W_self) |
| where L_0 = B1 B1^T is the 0-th Hodge Laplacian (graph Laplacian). |
| |
| Uses |B1| (unsigned) for compatibility with HGNN's non-negative space. |
| """ |
| |
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.W = nn.Linear(in_dim, out_dim, bias=False) |
| self.W_self = nn.Linear(in_dim, out_dim, bias=True) |
| |
| def forward(self, x, sc): |
| """ |
| Args: |
| x: (N, in_dim) node features |
| sc: SimplicialComplex object |
| Returns: |
| (N, out_dim) updated features |
| """ |
| |
| B1_abs = sc.B1.abs().to(x.device) |
| |
| L0 = B1_abs @ B1_abs.t() |
| |
| |
| D = L0.sum(dim=1, keepdim=True).clamp(min=1e-7) |
| L0_norm = L0 / D |
| |
| |
| h_neigh = L0_norm @ x |
| out = self.W(h_neigh) + self.W_self(x) |
| return out |
|
|
|
|
| class HypergraphConv(nn.Module): |
| """ |
| Hypergraph convolution using spectral propagation. |
| |
| Implements: X' = sigma(D_v^{-1/2} H W D_e^{-1} H^T D_v^{-1/2} X Theta) |
| """ |
| |
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.theta = nn.Linear(in_dim, out_dim, bias=True) |
| |
| def forward(self, x, hg): |
| """ |
| Args: |
| x: (N, in_dim) node features |
| hg: Hypergraph object |
| Returns: |
| (N, out_dim) updated features |
| """ |
| P = hg.propagation_matrix().to(x.device) |
| h = P @ x |
| return self.theta(h) |
|
|
|
|
| class CrossStructureAttention(nn.Module): |
| """ |
| Attention-gated fusion between simplicial and hypergraph representations. |
| |
| Learns alpha_i in [0, 1] for each node indicating how much to weight |
| each view, enabling adaptive combination. |
| """ |
| |
| def __init__(self, dim): |
| super().__init__() |
| self.gate = nn.Sequential( |
| nn.Linear(dim * 2, dim), |
| nn.ReLU(), |
| nn.Linear(dim, 1), |
| nn.Sigmoid() |
| ) |
| |
| def forward(self, h_sc, h_hg): |
| """ |
| Args: |
| h_sc: (N, dim) simplicial features |
| h_hg: (N, dim) hypergraph features |
| Returns: |
| (N, dim) fused features |
| """ |
| combined = torch.cat([h_sc, h_hg], dim=-1) |
| alpha = self.gate(combined) |
| return alpha * h_sc + (1 - alpha) * h_hg |
|
|
|
|
| class TopoHyperConv(nn.Module): |
| """ |
| Three-phase TopoHyper convolution layer. |
| |
| Phase 1: Simplicial convolution via unsigned Hodge Laplacian |
| Phase 2: Hypergraph convolution via spectral propagation |
| Phase 3: Cross-structure fusion with attention gate + bridge matrix |
| |
| The bridge matrix B = A_sc . A_hg captures nodes connected in BOTH views, |
| enabling topological constraints to inform hypergraph propagation and vice versa. |
| """ |
| |
| def __init__(self, in_dim, out_dim, use_bridge=True, use_attention=True): |
| super().__init__() |
| self.use_bridge = use_bridge |
| self.use_attention = use_attention |
| |
| self.sc_conv = SimplicialConv(in_dim, out_dim) |
| self.hg_conv = HypergraphConv(in_dim, out_dim) |
| |
| if use_attention: |
| self.attention = CrossStructureAttention(out_dim) |
| |
| if use_bridge: |
| self.bridge_proj = nn.Linear(out_dim, out_dim, bias=False) |
| |
| self.norm = nn.LayerNorm(out_dim) |
| |
| def forward(self, x, sc, hg): |
| """ |
| Args: |
| x: (N, in_dim) node features |
| sc: SimplicialComplex |
| hg: Hypergraph |
| Returns: |
| (N, out_dim) updated features |
| """ |
| |
| h_sc = F.relu(self.sc_conv(x, sc)) |
| |
| |
| h_hg = F.relu(self.hg_conv(x, hg)) |
| |
| |
| if self.use_attention: |
| h_fused = self.attention(h_sc, h_hg) |
| else: |
| h_fused = (h_sc + h_hg) / 2.0 |
| |
| |
| if self.use_bridge: |
| A_sc = sc.adjacency_matrix().to(x.device) |
| A_hg = hg.adjacency_matrix().to(x.device) |
| bridge = A_sc * A_hg |
| |
| |
| D_bridge = bridge.sum(dim=1, keepdim=True).clamp(min=1e-7) |
| bridge_norm = bridge / D_bridge |
| |
| h_bridge = bridge_norm @ h_fused |
| h_fused = h_fused + self.bridge_proj(h_bridge) |
| |
| return self.norm(h_fused) |
|
|
|
|
| class TopoHyperPool(nn.Module): |
| """ |
| Graph-level readout combining mean, max, and attention pooling. |
| """ |
| |
| def __init__(self, dim): |
| super().__init__() |
| self.att = nn.Linear(dim, 1) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x: (N, dim) node features |
| Returns: |
| (3*dim,) graph-level feature vector |
| """ |
| h_mean = x.mean(dim=0) |
| h_max = x.max(dim=0)[0] |
| |
| att_weights = F.softmax(self.att(x), dim=0) |
| h_att = (att_weights * x).sum(dim=0) |
| |
| return torch.cat([h_mean, h_max, h_att]) |
|
|
|
|
| |
|
|
| class GCNLayer(nn.Module): |
| """Standard GCN layer: X' = D^{-1/2} A D^{-1/2} X W""" |
| |
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.W = nn.Linear(in_dim, out_dim, bias=True) |
| |
| def forward(self, x, edge_index): |
| N = x.shape[0] |
| device = x.device |
| |
| |
| A = torch.zeros(N, N, device=device) |
| A[edge_index[0], edge_index[1]] = 1.0 |
| A = A + torch.eye(N, device=device) |
| |
| |
| D_inv_sqrt = torch.diag(1.0 / (A.sum(dim=1).sqrt() + 1e-7)) |
| A_norm = D_inv_sqrt @ A @ D_inv_sqrt |
| |
| return self.W(A_norm @ x) |
|
|
|
|
| class GATLayer(nn.Module): |
| """Graph Attention Network layer.""" |
| |
| def __init__(self, in_dim, out_dim, heads=4): |
| super().__init__() |
| self.heads = heads |
| self.head_dim = out_dim // heads |
| assert out_dim % heads == 0 |
| |
| self.W = nn.Linear(in_dim, out_dim, bias=False) |
| self.a_src = nn.Parameter(torch.randn(heads, self.head_dim)) |
| self.a_dst = nn.Parameter(torch.randn(heads, self.head_dim)) |
| self.bias = nn.Parameter(torch.zeros(out_dim)) |
| |
| nn.init.xavier_uniform_(self.W.weight) |
| nn.init.xavier_normal_(self.a_src.unsqueeze(0)) |
| nn.init.xavier_normal_(self.a_dst.unsqueeze(0)) |
| |
| def forward(self, x, edge_index): |
| N = x.shape[0] |
| device = x.device |
| |
| h = self.W(x).view(N, self.heads, self.head_dim) |
| |
| |
| A = torch.zeros(N, N, device=device) |
| A[edge_index[0], edge_index[1]] = 1.0 |
| A = A + torch.eye(N, device=device) |
| |
| |
| e_src = (h * self.a_src.unsqueeze(0)).sum(-1) |
| e_dst = (h * self.a_dst.unsqueeze(0)).sum(-1) |
| |
| |
| attn = F.leaky_relu(e_src.unsqueeze(2) + e_dst.unsqueeze(1), 0.2) |
| |
| |
| mask = (A == 0).unsqueeze(-1).expand_as(attn) |
| attn = attn.masked_fill(mask, float('-inf')) |
| attn = F.softmax(attn, dim=2) |
| attn = torch.nan_to_num(attn, nan=0.0) |
| |
| |
| out = torch.einsum('ijh,jhd->ihd', attn, h) |
| out = out.reshape(N, -1) + self.bias |
| |
| return out |
|
|