FraudSentinel β AML GNN (GINE Edge Classifier)
Standalone edge-classification Graph Neural Network for anti-money laundering detection on the IBM AML HI-Small transaction multigraph. Part of the FraudSentinel Tier-1 real-time triage layer.
Model Summary
| Property | Value |
|---|---|
| Architecture | GINEConv (Graph Isomorphism Network with Edge features) |
| Task | Edge classification β per-transaction laundering probability |
| Dataset | IBM AML HI-Small (5M transactions, 0.10% laundering) |
| Split | Temporal 70 / 10 / 20 (train / val / test) |
| ROC-AUC (test) | 0.584 |
| PR-AUC (test) | 0.0036 |
| Best F1 (test) | 0.0159 @ threshold 1.0 |
| Layers | 3 Γ GINEConv |
| Hidden dim | 96 |
| Message passing | Bidirectional (forward + reverse edges) |
| Training epochs | 80 |
| Checkpoint selection | Best val ROC-AUC + PR-AUC composite score |
Architecture
Nodes = bank accounts
Edges = individual transactions (directed, with edge features)
Node features (2): log in-degree, log out-degree β computed on training edges only to prevent temporal leakage.
Edge features (8+):
log(amount_paid),log(amount_received)hour / 23,day_of_week / 6currency_mismatch(binary)self_loop(binary)payment_format(one-hot)
Message passing: Reverse edges are added so receiver accounts aggregate signal from senders (bidirectional), matching the IBM Multi-GNN recipe. Supervised classification runs on the original forward edges only, keeping labels and masks aligned.
Edge classifier head: [h_src β h_dst β e_cls] β Linear(3h, h) β ReLU β Dropout(0.2) β Linear(h, 2)
Training Configuration
| Hyperparameter | Value |
|---|---|
| Optimizer | Adam |
| Learning rate | 5e-3 |
| Weight decay | 1e-5 |
| Class weight (pos) | ~1,245 (neg/pos ratio) |
| Loss | Cross-entropy with class weights |
| Epochs | 80 |
| Batch | Full-graph (single forward pass per step) |
Training used a temporal split (edges sorted by timestamp, 70/10/20) to prevent graph-leakage β a common failure mode in AML GNN evaluation.
Evaluation Metrics
| Metric | Value |
|---|---|
| ROC-AUC | 0.584 |
| PR-AUC | 0.00357 |
| Best F1 | 0.01586 |
| Best F1 threshold | 1.0 |
Routing Thresholds (Test Split)
| Recall Target | Threshold | Precision | Recall | Flagged % |
|---|---|---|---|---|
| 70% | 0.527 | 0.0019 | 73.3% | 69.4% |
| 80% | 0.377 | 0.0019 | 80.0% | 76.2% |
| 90% | 0.003 | 0.0018 | 90.0% | 89.3% |
Threshold selection depends on your false-positive budget. At 80% recall, 76% of edges are flagged for Tier-2 review β appropriate for a high-recall graph triage layer where the LLM will further filter.
Why These Metrics Are Expected
This is a graph triage model on a severely imbalanced dataset (laundering rate 0.10%). The goal is high recall β catching laundering patterns that get passed to the Tier-2 LLM for in-depth analysis β not precision optimization at Tier-1. The ROC-AUC of 0.584 outperforms random (0.5), and the model does capture graph structure (multi-hop fan-out, gather-scatter) that tabular models cannot.
For context, the IBM paper (arXiv:2306.16424) reports the full Multi-GNN reaching minority-class F1 in the 60β70% range with a much larger model using heterogeneous graph features, temporal GNNs, and ensemble stacking. This is a single-GNN baseline using a straightforward GINE architecture.
Inference
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINEConv
# Rebuild the model class from train_gnn_aml.py
class EdgeGNN(torch.nn.Module):
def __init__(self, nin, ein, h):
super().__init__()
self.lin_in = torch.nn.Linear(nin, h)
self.lin_e = torch.nn.Linear(ein, h)
mlp = lambda: torch.nn.Sequential(
torch.nn.Linear(h, h), torch.nn.ReLU(), torch.nn.Linear(h, h))
self.c1 = GINEConv(mlp(), edge_dim=h)
self.c2 = GINEConv(mlp(), edge_dim=h)
self.c3 = GINEConv(mlp(), edge_dim=h)
self.head = torch.nn.Sequential(
torch.nn.Linear(3*h, h), torch.nn.ReLU(),
torch.nn.Dropout(0.2), torch.nn.Linear(h, 2))
def forward(self, x, mp_ei, mp_ea, cls_ei, cls_ea):
e_mp = self.lin_e(mp_ea)
h = F.relu(self.lin_in(x))
h = F.relu(self.c1(h, mp_ei, e_mp))
h = F.relu(self.c2(h, mp_ei, e_mp))
h = F.relu(self.c3(h, mp_ei, e_mp))
e_cls = self.lin_e(cls_ea)
z = torch.cat([h[cls_ei[0]], h[cls_ei[1]], e_cls], dim=1)
return self.head(z)
# Load weights
ea_dim = 8 # adjust to match your edge feature count
model = EdgeGNN(nin=2, ein=ea_dim, h=96)
model.load_state_dict(torch.load("aml_gnn.pt", map_location="cpu"))
model.eval()
# Score edges
with torch.no_grad():
logits = model(x, mp_edge_index, mp_edge_attr, edge_index, edge_attr)
probs = F.softmax(logits, dim=1)[:, 1].numpy()
# Apply routing threshold (e.g., recall-80% target β threshold 0.377)
route_to_llm = probs >= 0.377
See train_gnn_aml.py for the complete graph construction pipeline (build_graph, temporal_masks, and the full featurization logic).
Repository Contents
| File | Description |
|---|---|
aml_gnn.pt |
Trained model weights (state dict) |
aml_gnn_metrics.json |
Full evaluation metrics + routing thresholds |
train_gnn_aml.py |
Training script (graph construction, model definition, evaluation) |
Limitations
- Prototype/research use. Source data is synthetic (IBM AML HI-Small). Validate on your own transaction graph before deployment.
- The model is a high-recall graph triage tool. Low precision at the routing threshold is by design β downstream human review and the Tier-2 LLM are expected to filter further.
- Full-graph inference requires loading the entire transaction multigraph into GPU memory. For production streaming, a subgraph sampling or incremental inference approach is needed.
- No model here should be used for real customer adjudication without independent validation, bias review, and human-in-the-loop controls.
Citation
IBM AML dataset: Altman et al., NeurIPS 2023 (arXiv:2306.16424).
Reference GNN implementation: IBM/Multi-GNN.
License
Apache-2.0.