chaoshengt's picture
Add models module: TopoHyperNet + 5 baselines
8949695 verified
raw
history blame contribute delete
8.2 kB
"""
Model architectures for TopoHyper and baselines.
Implements:
- TopoHyperNet: Full hybrid architecture
- GCNNet, GATNet: Standard GNN baselines
- HGNNNet: Hypergraph-only baseline
- SimplicialNet: Simplicial-only baseline
- SimpleHybrid: Naive concatenation hybrid
- get_model(): Factory function
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import (TopoHyperConv, TopoHyperPool, SimplicialConv,
HypergraphConv, GCNLayer, GATLayer)
class TopoHyperNet(nn.Module):
"""
Full TopoHyper architecture.
Architecture:
1. Input projection
2. N x TopoHyperConv layers (3-phase: simplicial + hypergraph + fusion)
3. TopoHyperPool (mean + max + attention readout)
4. Classification head
Args:
in_dim: Input feature dimension
hidden_dim: Hidden layer dimension
num_classes: Number of output classes
num_layers: Number of TopoHyperConv layers
dropout: Dropout rate
use_bridge: Whether to use bridge matrix in fusion
use_attention: Whether to use attention-gated fusion
"""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2,
dropout=0.3, use_bridge=True, use_attention=True):
super().__init__()
self.input_proj = nn.Linear(in_dim, hidden_dim)
self.convs = nn.ModuleList()
for i in range(num_layers):
self.convs.append(TopoHyperConv(
hidden_dim, hidden_dim,
use_bridge=use_bridge,
use_attention=use_attention
))
self.pool = TopoHyperPool(hidden_dim)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc, hg, edge_index=None):
x = F.relu(self.input_proj(x))
x = F.dropout(x, p=self.dropout, training=self.training)
for conv in self.convs:
x = conv(x, sc, hg)
x = F.dropout(x, p=self.dropout, training=self.training)
g = self.pool(x)
return self.classifier(g)
class GCNNet(nn.Module):
"""GCN baseline."""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GCNLayer(in_dim, hidden_dim))
for _ in range(num_layers - 1):
self.layers.append(GCNLayer(hidden_dim, hidden_dim))
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc=None, hg=None, edge_index=None):
for layer in self.layers:
x = F.relu(layer(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
g = x.mean(dim=0)
return self.classifier(g)
class GATNet(nn.Module):
"""GAT baseline."""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GATLayer(in_dim, hidden_dim))
for _ in range(num_layers - 1):
self.layers.append(GATLayer(hidden_dim, hidden_dim))
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc=None, hg=None, edge_index=None):
for layer in self.layers:
x = F.relu(layer(x, edge_index))
x = F.dropout(x, p=self.dropout, training=self.training)
g = x.mean(dim=0)
return self.classifier(g)
class HGNNNet(nn.Module):
"""Hypergraph-only baseline."""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(HypergraphConv(in_dim, hidden_dim))
for _ in range(num_layers - 1):
self.layers.append(HypergraphConv(hidden_dim, hidden_dim))
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc=None, hg=None, edge_index=None):
for layer in self.layers:
x = F.relu(layer(x, hg))
x = F.dropout(x, p=self.dropout, training=self.training)
g = x.mean(dim=0)
return self.classifier(g)
class SimplicialNet(nn.Module):
"""Simplicial-only baseline."""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(SimplicialConv(in_dim, hidden_dim))
for _ in range(num_layers - 1):
self.layers.append(SimplicialConv(hidden_dim, hidden_dim))
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc=None, hg=None, edge_index=None):
for layer in self.layers:
x = F.relu(layer(x, sc))
x = F.dropout(x, p=self.dropout, training=self.training)
g = x.mean(dim=0)
return self.classifier(g)
class SimpleHybrid(nn.Module):
"""
Naive hybrid baseline: separate simplicial + hypergraph branches
concatenated (no bridge, no attention).
"""
def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3):
super().__init__()
half_dim = hidden_dim // 2
self.sc_layers = nn.ModuleList()
self.sc_layers.append(SimplicialConv(in_dim, half_dim))
for _ in range(num_layers - 1):
self.sc_layers.append(SimplicialConv(half_dim, half_dim))
self.hg_layers = nn.ModuleList()
self.hg_layers.append(HypergraphConv(in_dim, half_dim))
for _ in range(num_layers - 1):
self.hg_layers.append(HypergraphConv(half_dim, half_dim))
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
self.dropout = dropout
def forward(self, x, sc=None, hg=None, edge_index=None):
h_sc = x
for layer in self.sc_layers:
h_sc = F.relu(layer(h_sc, sc))
h_sc = F.dropout(h_sc, p=self.dropout, training=self.training)
h_hg = x
for layer in self.hg_layers:
h_hg = F.relu(layer(h_hg, hg))
h_hg = F.dropout(h_hg, p=self.dropout, training=self.training)
h = torch.cat([h_sc, h_hg], dim=-1)
g = h.mean(dim=0)
return self.classifier(g)
def get_model(name, in_dim, hidden_dim, num_classes, **kwargs):
"""
Factory function to create models by name.
Args:
name: One of 'topohyper', 'gcn', 'gat', 'hgnn', 'simplicial', 'simple_hybrid'
in_dim: Input feature dimension
hidden_dim: Hidden dimension
num_classes: Number of classes
**kwargs: Additional model arguments (use_bridge, use_attention, etc.)
"""
models = {
'topohyper': TopoHyperNet,
'gcn': GCNNet,
'gat': GATNet,
'hgnn': HGNNNet,
'simplicial': SimplicialNet,
'simple_hybrid': SimpleHybrid,
}
if name not in models:
raise ValueError(f"Unknown model: {name}. Choose from {list(models.keys())}")
return models[name](in_dim, hidden_dim, num_classes, **kwargs)