chaoshengt's picture
Add layers module: SimplicialConv, HypergraphConv, TopoHyperConv, baselines
ad0507d verified
Raw
History Blame Contribute Delete
8.39 kB
"""
Neural network layers for TopoHyper.
Implements:
- SimplicialConv: Hodge Laplacian-based convolution on simplicial complexes
- HypergraphConv: Spectral convolution on hypergraphs
- CrossStructureAttention: Attention-gated fusion between simplicial and hypergraph views
- TopoHyperConv: Three-phase integrated convolution layer
- TopoHyperPool: Graph-level readout with both structures
- GCNLayer, GATLayer: Baselines
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SimplicialConv(nn.Module):
"""
Simplicial convolution using the Hodge Laplacian.
Implements: X' = sigma(L_0 X W + X W_self)
where L_0 = B1 B1^T is the 0-th Hodge Laplacian (graph Laplacian).
Uses |B1| (unsigned) for compatibility with HGNN's non-negative space.
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.W = nn.Linear(in_dim, out_dim, bias=False)
self.W_self = nn.Linear(in_dim, out_dim, bias=True)
def forward(self, x, sc):
"""
Args:
x: (N, in_dim) node features
sc: SimplicialComplex object
Returns:
(N, out_dim) updated features
"""
# Use unsigned boundary for compatibility with HGNN
B1_abs = sc.B1.abs().to(x.device)
# L0_unsigned = |B1| |B1|^T
L0 = B1_abs @ B1_abs.t()
# Normalize
D = L0.sum(dim=1, keepdim=True).clamp(min=1e-7)
L0_norm = L0 / D
# Convolution
h_neigh = L0_norm @ x
out = self.W(h_neigh) + self.W_self(x)
return out
class HypergraphConv(nn.Module):
"""
Hypergraph convolution using spectral propagation.
Implements: X' = sigma(D_v^{-1/2} H W D_e^{-1} H^T D_v^{-1/2} X Theta)
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.theta = nn.Linear(in_dim, out_dim, bias=True)
def forward(self, x, hg):
"""
Args:
x: (N, in_dim) node features
hg: Hypergraph object
Returns:
(N, out_dim) updated features
"""
P = hg.propagation_matrix().to(x.device)
h = P @ x
return self.theta(h)
class CrossStructureAttention(nn.Module):
"""
Attention-gated fusion between simplicial and hypergraph representations.
Learns alpha_i in [0, 1] for each node indicating how much to weight
each view, enabling adaptive combination.
"""
def __init__(self, dim):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, 1),
nn.Sigmoid()
)
def forward(self, h_sc, h_hg):
"""
Args:
h_sc: (N, dim) simplicial features
h_hg: (N, dim) hypergraph features
Returns:
(N, dim) fused features
"""
combined = torch.cat([h_sc, h_hg], dim=-1)
alpha = self.gate(combined) # (N, 1)
return alpha * h_sc + (1 - alpha) * h_hg
class TopoHyperConv(nn.Module):
"""
Three-phase TopoHyper convolution layer.
Phase 1: Simplicial convolution via unsigned Hodge Laplacian
Phase 2: Hypergraph convolution via spectral propagation
Phase 3: Cross-structure fusion with attention gate + bridge matrix
The bridge matrix B = A_sc . A_hg captures nodes connected in BOTH views,
enabling topological constraints to inform hypergraph propagation and vice versa.
"""
def __init__(self, in_dim, out_dim, use_bridge=True, use_attention=True):
super().__init__()
self.use_bridge = use_bridge
self.use_attention = use_attention
self.sc_conv = SimplicialConv(in_dim, out_dim)
self.hg_conv = HypergraphConv(in_dim, out_dim)
if use_attention:
self.attention = CrossStructureAttention(out_dim)
if use_bridge:
self.bridge_proj = nn.Linear(out_dim, out_dim, bias=False)
self.norm = nn.LayerNorm(out_dim)
def forward(self, x, sc, hg):
"""
Args:
x: (N, in_dim) node features
sc: SimplicialComplex
hg: Hypergraph
Returns:
(N, out_dim) updated features
"""
# Phase 1: Simplicial convolution
h_sc = F.relu(self.sc_conv(x, sc))
# Phase 2: Hypergraph convolution
h_hg = F.relu(self.hg_conv(x, hg))
# Phase 3: Fusion
if self.use_attention:
h_fused = self.attention(h_sc, h_hg)
else:
h_fused = (h_sc + h_hg) / 2.0
# Bridge: propagate through shared structure
if self.use_bridge:
A_sc = sc.adjacency_matrix().to(x.device)
A_hg = hg.adjacency_matrix().to(x.device)
bridge = A_sc * A_hg # element-wise: nodes connected in BOTH views
# Normalize bridge
D_bridge = bridge.sum(dim=1, keepdim=True).clamp(min=1e-7)
bridge_norm = bridge / D_bridge
h_bridge = bridge_norm @ h_fused
h_fused = h_fused + self.bridge_proj(h_bridge)
return self.norm(h_fused)
class TopoHyperPool(nn.Module):
"""
Graph-level readout combining mean, max, and attention pooling.
"""
def __init__(self, dim):
super().__init__()
self.att = nn.Linear(dim, 1)
def forward(self, x):
"""
Args:
x: (N, dim) node features
Returns:
(3*dim,) graph-level feature vector
"""
h_mean = x.mean(dim=0)
h_max = x.max(dim=0)[0]
att_weights = F.softmax(self.att(x), dim=0)
h_att = (att_weights * x).sum(dim=0)
return torch.cat([h_mean, h_max, h_att])
# ==================== Baseline Layers ====================
class GCNLayer(nn.Module):
"""Standard GCN layer: X' = D^{-1/2} A D^{-1/2} X W"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.W = nn.Linear(in_dim, out_dim, bias=True)
def forward(self, x, edge_index):
N = x.shape[0]
device = x.device
# Build adjacency with self-loops
A = torch.zeros(N, N, device=device)
A[edge_index[0], edge_index[1]] = 1.0
A = A + torch.eye(N, device=device)
# Symmetric normalization
D_inv_sqrt = torch.diag(1.0 / (A.sum(dim=1).sqrt() + 1e-7))
A_norm = D_inv_sqrt @ A @ D_inv_sqrt
return self.W(A_norm @ x)
class GATLayer(nn.Module):
"""Graph Attention Network layer."""
def __init__(self, in_dim, out_dim, heads=4):
super().__init__()
self.heads = heads
self.head_dim = out_dim // heads
assert out_dim % heads == 0
self.W = nn.Linear(in_dim, out_dim, bias=False)
self.a_src = nn.Parameter(torch.randn(heads, self.head_dim))
self.a_dst = nn.Parameter(torch.randn(heads, self.head_dim))
self.bias = nn.Parameter(torch.zeros(out_dim))
nn.init.xavier_uniform_(self.W.weight)
nn.init.xavier_normal_(self.a_src.unsqueeze(0))
nn.init.xavier_normal_(self.a_dst.unsqueeze(0))
def forward(self, x, edge_index):
N = x.shape[0]
device = x.device
h = self.W(x).view(N, self.heads, self.head_dim)
# Build adjacency
A = torch.zeros(N, N, device=device)
A[edge_index[0], edge_index[1]] = 1.0
A = A + torch.eye(N, device=device)
# Attention scores
e_src = (h * self.a_src.unsqueeze(0)).sum(-1) # (N, heads)
e_dst = (h * self.a_dst.unsqueeze(0)).sum(-1) # (N, heads)
# e_ij = LeakyReLU(e_src_i + e_dst_j)
attn = F.leaky_relu(e_src.unsqueeze(2) + e_dst.unsqueeze(1), 0.2) # (N, N, heads)
# Mask non-edges
mask = (A == 0).unsqueeze(-1).expand_as(attn)
attn = attn.masked_fill(mask, float('-inf'))
attn = F.softmax(attn, dim=2)
attn = torch.nan_to_num(attn, nan=0.0)
# Aggregate
out = torch.einsum('ijh,jhd->ihd', attn, h)
out = out.reshape(N, -1) + self.bias
return out