| """ |
| Model architectures for TopoHyper and baselines. |
| |
| Implements: |
| - TopoHyperNet: Full hybrid architecture |
| - GCNNet, GATNet: Standard GNN baselines |
| - HGNNNet: Hypergraph-only baseline |
| - SimplicialNet: Simplicial-only baseline |
| - SimpleHybrid: Naive concatenation hybrid |
| - get_model(): Factory function |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .layers import (TopoHyperConv, TopoHyperPool, SimplicialConv, |
| HypergraphConv, GCNLayer, GATLayer) |
|
|
|
|
| class TopoHyperNet(nn.Module): |
| """ |
| Full TopoHyper architecture. |
| |
| Architecture: |
| 1. Input projection |
| 2. N x TopoHyperConv layers (3-phase: simplicial + hypergraph + fusion) |
| 3. TopoHyperPool (mean + max + attention readout) |
| 4. Classification head |
| |
| Args: |
| in_dim: Input feature dimension |
| hidden_dim: Hidden layer dimension |
| num_classes: Number of output classes |
| num_layers: Number of TopoHyperConv layers |
| dropout: Dropout rate |
| use_bridge: Whether to use bridge matrix in fusion |
| use_attention: Whether to use attention-gated fusion |
| """ |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, |
| dropout=0.3, use_bridge=True, use_attention=True): |
| super().__init__() |
| self.input_proj = nn.Linear(in_dim, hidden_dim) |
| |
| self.convs = nn.ModuleList() |
| for i in range(num_layers): |
| self.convs.append(TopoHyperConv( |
| hidden_dim, hidden_dim, |
| use_bridge=use_bridge, |
| use_attention=use_attention |
| )) |
| |
| self.pool = TopoHyperPool(hidden_dim) |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim * 3, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc, hg, edge_index=None): |
| x = F.relu(self.input_proj(x)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| |
| for conv in self.convs: |
| x = conv(x, sc, hg) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| |
| g = self.pool(x) |
| return self.classifier(g) |
|
|
|
|
| class GCNNet(nn.Module): |
| """GCN baseline.""" |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| self.layers.append(GCNLayer(in_dim, hidden_dim)) |
| for _ in range(num_layers - 1): |
| self.layers.append(GCNLayer(hidden_dim, hidden_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc=None, hg=None, edge_index=None): |
| for layer in self.layers: |
| x = F.relu(layer(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| g = x.mean(dim=0) |
| return self.classifier(g) |
|
|
|
|
| class GATNet(nn.Module): |
| """GAT baseline.""" |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| self.layers.append(GATLayer(in_dim, hidden_dim)) |
| for _ in range(num_layers - 1): |
| self.layers.append(GATLayer(hidden_dim, hidden_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc=None, hg=None, edge_index=None): |
| for layer in self.layers: |
| x = F.relu(layer(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| g = x.mean(dim=0) |
| return self.classifier(g) |
|
|
|
|
| class HGNNNet(nn.Module): |
| """Hypergraph-only baseline.""" |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| self.layers.append(HypergraphConv(in_dim, hidden_dim)) |
| for _ in range(num_layers - 1): |
| self.layers.append(HypergraphConv(hidden_dim, hidden_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc=None, hg=None, edge_index=None): |
| for layer in self.layers: |
| x = F.relu(layer(x, hg)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| g = x.mean(dim=0) |
| return self.classifier(g) |
|
|
|
|
| class SimplicialNet(nn.Module): |
| """Simplicial-only baseline.""" |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| self.layers.append(SimplicialConv(in_dim, hidden_dim)) |
| for _ in range(num_layers - 1): |
| self.layers.append(SimplicialConv(hidden_dim, hidden_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc=None, hg=None, edge_index=None): |
| for layer in self.layers: |
| x = F.relu(layer(x, sc)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| g = x.mean(dim=0) |
| return self.classifier(g) |
|
|
|
|
| class SimpleHybrid(nn.Module): |
| """ |
| Naive hybrid baseline: separate simplicial + hypergraph branches |
| concatenated (no bridge, no attention). |
| """ |
| |
| def __init__(self, in_dim, hidden_dim, num_classes, num_layers=2, dropout=0.3): |
| super().__init__() |
| half_dim = hidden_dim // 2 |
| |
| self.sc_layers = nn.ModuleList() |
| self.sc_layers.append(SimplicialConv(in_dim, half_dim)) |
| for _ in range(num_layers - 1): |
| self.sc_layers.append(SimplicialConv(half_dim, half_dim)) |
| |
| self.hg_layers = nn.ModuleList() |
| self.hg_layers.append(HypergraphConv(in_dim, half_dim)) |
| for _ in range(num_layers - 1): |
| self.hg_layers.append(HypergraphConv(half_dim, half_dim)) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, num_classes) |
| ) |
| self.dropout = dropout |
| |
| def forward(self, x, sc=None, hg=None, edge_index=None): |
| h_sc = x |
| for layer in self.sc_layers: |
| h_sc = F.relu(layer(h_sc, sc)) |
| h_sc = F.dropout(h_sc, p=self.dropout, training=self.training) |
| |
| h_hg = x |
| for layer in self.hg_layers: |
| h_hg = F.relu(layer(h_hg, hg)) |
| h_hg = F.dropout(h_hg, p=self.dropout, training=self.training) |
| |
| h = torch.cat([h_sc, h_hg], dim=-1) |
| g = h.mean(dim=0) |
| return self.classifier(g) |
|
|
|
|
| def get_model(name, in_dim, hidden_dim, num_classes, **kwargs): |
| """ |
| Factory function to create models by name. |
| |
| Args: |
| name: One of 'topohyper', 'gcn', 'gat', 'hgnn', 'simplicial', 'simple_hybrid' |
| in_dim: Input feature dimension |
| hidden_dim: Hidden dimension |
| num_classes: Number of classes |
| **kwargs: Additional model arguments (use_bridge, use_attention, etc.) |
| """ |
| models = { |
| 'topohyper': TopoHyperNet, |
| 'gcn': GCNNet, |
| 'gat': GATNet, |
| 'hgnn': HGNNNet, |
| 'simplicial': SimplicialNet, |
| 'simple_hybrid': SimpleHybrid, |
| } |
| |
| if name not in models: |
| raise ValueError(f"Unknown model: {name}. Choose from {list(models.keys())}") |
| |
| return models[name](in_dim, hidden_dim, num_classes, **kwargs) |
|
|