""" Training and evaluation script for TopoHyper and baselines. Usage: pip install torch torchvision scikit-learn scipy medmnist python train_eval.py Trains 6 models (TopoHyper, GCN, GAT, HGNN, Simplicial, SimpleHybrid) and runs ablation study on TopoHyper variants. """ import torch import torch.nn as nn import torch.optim as optim import numpy as np import json import os import time from sklearn.metrics import f1_score, classification_report # Import from topohyper package from topohyper.data import load_medmnist_data from topohyper.models import get_model def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def train_epoch(model, dataset, optimizer, criterion): model.train() total_loss = 0 correct = 0 total = 0 indices = list(range(len(dataset))) np.random.shuffle(indices) for idx in indices: sample = dataset[idx] features = sample['features'] sc = sample['sc'] hg = sample['hg'] edge_index = sample['edge_index'] label = torch.tensor(sample['label'], dtype=torch.long) optimizer.zero_grad() logits = model(features, sc, hg, edge_index) loss = criterion(logits.unsqueeze(0), label.unsqueeze(0)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() pred = logits.argmax().item() correct += (pred == sample['label']) total += 1 return total_loss / total, correct / total @torch.no_grad() def evaluate(model, dataset): model.eval() correct = 0 total = 0 all_preds = [] all_labels = [] for idx in range(len(dataset)): sample = dataset[idx] features = sample['features'] sc = sample['sc'] hg = sample['hg'] edge_index = sample['edge_index'] logits = model(features, sc, hg, edge_index) pred = logits.argmax().item() all_preds.append(pred) all_labels.append(sample['label']) correct += (pred == sample['label']) total += 1 acc = correct / total f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) return acc, f1, all_preds, all_labels def train_model(model, train_ds, val_ds, test_ds, model_name, epochs=25, lr=0.001): """Train a model and return results.""" print(f"\n{'='*60}") print(f"Training: {model_name} ({count_parameters(model)} params)") print(f"{'='*60}") optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) criterion = nn.CrossEntropyLoss() best_val_acc = -1 best_state = None for epoch in range(epochs): train_loss, train_acc = train_epoch(model, train_ds, optimizer, criterion) val_acc, val_f1, _, _ = evaluate(model, val_ds) scheduler.step(1 - val_acc) if val_acc > best_val_acc: best_val_acc = val_acc best_state = {k: v.clone() for k, v in model.state_dict().items()} if (epoch + 1) % 5 == 0 or epoch == 0: print(f" Epoch {epoch+1:3d}: loss={train_loss:.4f} train_acc={train_acc:.3f} " f"val_acc={val_acc:.3f} val_f1={val_f1:.4f}") # Load best model and evaluate on test if best_state is not None: model.load_state_dict(best_state) test_acc, test_f1, test_preds, test_labels = evaluate(model, test_ds) print(f"\n BEST val_acc={best_val_acc:.3f}") print(f" TEST acc={test_acc:.3f} f1_macro={test_f1:.4f}") return { 'model_name': model_name, 'params': count_parameters(model), 'best_val_acc': best_val_acc, 'test_acc': test_acc, 'test_f1_macro': test_f1, } def main(): # Configuration MAX_TRAIN = 800 MAX_VAL = 200 MAX_TEST = 200 EPOCHS = 25 HIDDEN_DIM = 64 NUM_CLASSES = 9 IN_DIM = 38 LR = 0.001 print("Loading PathMNIST data...") train_ds, val_ds, test_ds, num_classes = load_medmnist_data( dataset_name='pathmnist', size=64, max_train=MAX_TRAIN, max_val=MAX_VAL, max_test=MAX_TEST ) print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}") os.makedirs('results', exist_ok=True) all_results = {} # ==================== Main Comparison ==================== print("\n" + "=" * 70) print("MAIN MODEL COMPARISON") print("=" * 70) model_configs = [ ('topohyper', {'use_bridge': True, 'use_attention': True}), ('gcn', {}), ('gat', {}), ('hgnn', {}), ('simplicial', {}), ('simple_hybrid', {}), ] for model_name, kwargs in model_configs: torch.manual_seed(42) np.random.seed(42) model = get_model(model_name, IN_DIM, HIDDEN_DIM, NUM_CLASSES, **kwargs) result = train_model(model, train_ds, val_ds, test_ds, model_name, EPOCHS, LR) all_results[model_name] = result if model_name == 'topohyper': torch.save(model.state_dict(), 'results/topohyper_model.pt') # ==================== Ablation Study ==================== print("\n" + "=" * 70) print("ABLATION STUDY (TopoHyper variants)") print("=" * 70) ablation_configs = [ ('Full (Bridge+Attn)', True, True), ('No Attention (Bridge only)', True, False), ('No Bridge (Attn only)', False, True), ('Neither', False, False), ] ablation_results = {} for name, use_bridge, use_attention in ablation_configs: torch.manual_seed(42) np.random.seed(42) model = get_model('topohyper', IN_DIM, HIDDEN_DIM, NUM_CLASSES, use_bridge=use_bridge, use_attention=use_attention) result = train_model(model, train_ds, val_ds, test_ds, name, EPOCHS, LR) ablation_results[name] = result # ==================== Print Summary ==================== print("\n" + "=" * 70) print("FINAL RESULTS SUMMARY") print("=" * 70) print(f"\n{'Model':<20} {'Test Acc':>10} {'F1-macro':>10} {'Val Acc':>10} {'Params':>10}") print("-" * 62) for name, r in sorted(all_results.items(), key=lambda x: x[1]['test_acc'], reverse=True): print(f"{name:<20} {r['test_acc']:>10.3f} {r['test_f1_macro']:>10.4f} " f"{r['best_val_acc']:>10.3f} {r['params']:>10d}") print(f"\nAblation Study:") print(f"{'Config':<35} {'Test Acc':>10} {'Val Acc':>10}") print("-" * 57) for name, r in ablation_results.items(): print(f"{name:<35} {r['test_acc']:>10.3f} {r['best_val_acc']:>10.3f}") # Save results save_data = { 'main_results': all_results, 'ablation_results': ablation_results, 'config': { 'max_train': MAX_TRAIN, 'max_val': MAX_VAL, 'max_test': MAX_TEST, 'epochs': EPOCHS, 'hidden_dim': HIDDEN_DIM, 'lr': LR, } } with open('results/results.json', 'w') as f: json.dump(save_data, f, indent=2) print(f"\nResults saved to results/results.json") print("Done!") if __name__ == '__main__': main()