""" Model architectures for TopoHyper and baselines. Implements: - TopoHyperNet: Full hybrid architecture - GCNNet, GATNet: Standard GNN baselines - HGNNNet: Hypergraph-only baseline - SimplicialNet: Simplicial-only baseline - SimpleHybrid: Naive concatenation hybrid - get_model(): Factory function """ import torch import torch.nn as nn import torch.nn.functional as F from .layers import (TopoHyperConv, TopoHyperPool, SimplicialConv, HypergraphConv, GCNLayer, GATLayer) class TopoHyperNet(nn.Module): """ Full TopoHyper architecture. Architecture: 1. Input projection 2. N x TopoHyperConv layers (3-phase: simplicial + hypergraph + fusion) 3. TopoHyperPool (mean + max + attention readout) 4. Classification head Args: in_dim: Input feature dimension hidden_dim: Hidden layer dimension num_classes: Number of output classes num_layers: Number of TopoHyperConv layers dropout: Dropout rate use_bridge: Whether to use bridge matrix in fusion use_attention: Whether to use attention-gated fusion """ def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3, use_bridge=True, use_attention=True): super().__init__() self.input_proj = nn.Linear(in_dim, hidden_dim) self.convs = nn.ModuleList() for i in range(num_layers): self.convs.append(TopoHyperConv( hidden_dim, hidden_dim, use_bridge=use_bridge, use_attention=use_attention )) self.pool = TopoHyperPool(hidden_dim) self.classifier = nn.Sequential( nn.Linear(hidden_dim * 3, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc, hg, edge_index=None): x = F.relu(self.input_proj(x)) x = F.dropout(x, p=self.dropout, training=self.training) for conv in self.convs: x = conv(x, sc, hg) x = F.dropout(x, p=self.dropout, training=self.training) g = self.pool(x) return self.classifier(g) class GCNNet(nn.Module): """GCN baseline.""" def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): super().__init__() self.layers = nn.ModuleList() self.layers.append(GCNLayer(in_dim, hidden_dim)) for _ in range(num_layers - 1): self.layers.append(GCNLayer(hidden_dim, hidden_dim)) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc=None, hg=None, edge_index=None): for layer in self.layers: x = F.relu(layer(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) g = x.mean(dim=0) return self.classifier(g) class GATNet(nn.Module): """GAT baseline.""" def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): super().__init__() self.layers = nn.ModuleList() self.layers.append(GATLayer(in_dim, hidden_dim)) for _ in range(num_layers - 1): self.layers.append(GATLayer(hidden_dim, hidden_dim)) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc=None, hg=None, edge_index=None): for layer in self.layers: x = F.relu(layer(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) g = x.mean(dim=0) return self.classifier(g) class HGNNNet(nn.Module): """Hypergraph-only baseline.""" def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): super().__init__() self.layers = nn.ModuleList() self.layers.append(HypergraphConv(in_dim, hidden_dim)) for _ in range(num_layers - 1): self.layers.append(HypergraphConv(hidden_dim, hidden_dim)) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc=None, hg=None, edge_index=None): for layer in self.layers: x = F.relu(layer(x, hg)) x = F.dropout(x, p=self.dropout, training=self.training) g = x.mean(dim=0) return self.classifier(g) class SimplicialNet(nn.Module): """Simplicial-only baseline.""" def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): super().__init__() self.layers = nn.ModuleList() self.layers.append(SimplicialConv(in_dim, hidden_dim)) for _ in range(num_layers - 1): self.layers.append(SimplicialConv(hidden_dim, hidden_dim)) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc=None, hg=None, edge_index=None): for layer in self.layers: x = F.relu(layer(x, sc)) x = F.dropout(x, p=self.dropout, training=self.training) g = x.mean(dim=0) return self.classifier(g) class SimpleHybrid(nn.Module): """ Naive hybrid baseline: separate simplicial + hypergraph branches concatenated (no bridge, no attention). """ def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): super().__init__() half_dim = hidden_dim // 2 self.sc_layers = nn.ModuleList() self.sc_layers.append(SimplicialConv(in_dim, half_dim)) for _ in range(num_layers - 1): self.sc_layers.append(SimplicialConv(half_dim, half_dim)) self.hg_layers = nn.ModuleList() self.hg_layers.append(HypergraphConv(in_dim, half_dim)) for _ in range(num_layers - 1): self.hg_layers.append(HypergraphConv(half_dim, half_dim)) self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) self.dropout = dropout def forward(self, x, sc=None, hg=None, edge_index=None): h_sc = x for layer in self.sc_layers: h_sc = F.relu(layer(h_sc, sc)) h_sc = F.dropout(h_sc, p=self.dropout, training=self.training) h_hg = x for layer in self.hg_layers: h_hg = F.relu(layer(h_hg, hg)) h_hg = F.dropout(h_hg, p=self.dropout, training=self.training) h = torch.cat([h_sc, h_hg], dim=-1) g = h.mean(dim=0) return self.classifier(g) def get_model(name, in_dim, hidden_dim, num_classes, **kwargs): """ Factory function to create models by name. Args: name: One of 'topohyper', 'gcn', 'gat', 'hgnn', 'simplicial', 'simple_hybrid' in_dim: Input feature dimension hidden_dim: Hidden dimension num_classes: Number of classes **kwargs: Additional model arguments (use_bridge, use_attention, etc.) """ models = { 'topohyper': TopoHyperNet, 'gcn': GCNNet, 'gat': GATNet, 'hgnn': HGNNNet, 'simplicial': SimplicialNet, 'simple_hybrid': SimpleHybrid, } if name not in models: raise ValueError(f"Unknown model: {name}. Choose from {list(models.keys())}") return models[name](in_dim, hidden_dim, num_classes, **kwargs)