FraudSentinel β€” AML GNN (GINE Edge Classifier)

Standalone edge-classification Graph Neural Network for anti-money laundering detection on the IBM AML HI-Small transaction multigraph. Part of the FraudSentinel Tier-1 real-time triage layer.


Model Summary

Property Value
Architecture GINEConv (Graph Isomorphism Network with Edge features)
Task Edge classification β€” per-transaction laundering probability
Dataset IBM AML HI-Small (5M transactions, 0.10% laundering)
Split Temporal 70 / 10 / 20 (train / val / test)
ROC-AUC (test) 0.584
PR-AUC (test) 0.0036
Best F1 (test) 0.0159 @ threshold 1.0
Layers 3 Γ— GINEConv
Hidden dim 96
Message passing Bidirectional (forward + reverse edges)
Training epochs 80
Checkpoint selection Best val ROC-AUC + PR-AUC composite score

Architecture

Nodes = bank accounts
Edges = individual transactions (directed, with edge features)

Node features (2): log in-degree, log out-degree β€” computed on training edges only to prevent temporal leakage.

Edge features (8+):

  • log(amount_paid), log(amount_received)
  • hour / 23, day_of_week / 6
  • currency_mismatch (binary)
  • self_loop (binary)
  • payment_format (one-hot)

Message passing: Reverse edges are added so receiver accounts aggregate signal from senders (bidirectional), matching the IBM Multi-GNN recipe. Supervised classification runs on the original forward edges only, keeping labels and masks aligned.

Edge classifier head: [h_src β€– h_dst β€– e_cls] β†’ Linear(3h, h) β†’ ReLU β†’ Dropout(0.2) β†’ Linear(h, 2)


Training Configuration

Hyperparameter Value
Optimizer Adam
Learning rate 5e-3
Weight decay 1e-5
Class weight (pos) ~1,245 (neg/pos ratio)
Loss Cross-entropy with class weights
Epochs 80
Batch Full-graph (single forward pass per step)

Training used a temporal split (edges sorted by timestamp, 70/10/20) to prevent graph-leakage β€” a common failure mode in AML GNN evaluation.


Evaluation Metrics

Metric Value
ROC-AUC 0.584
PR-AUC 0.00357
Best F1 0.01586
Best F1 threshold 1.0

Routing Thresholds (Test Split)

Recall Target Threshold Precision Recall Flagged %
70% 0.527 0.0019 73.3% 69.4%
80% 0.377 0.0019 80.0% 76.2%
90% 0.003 0.0018 90.0% 89.3%

Threshold selection depends on your false-positive budget. At 80% recall, 76% of edges are flagged for Tier-2 review β€” appropriate for a high-recall graph triage layer where the LLM will further filter.


Why These Metrics Are Expected

This is a graph triage model on a severely imbalanced dataset (laundering rate 0.10%). The goal is high recall β€” catching laundering patterns that get passed to the Tier-2 LLM for in-depth analysis β€” not precision optimization at Tier-1. The ROC-AUC of 0.584 outperforms random (0.5), and the model does capture graph structure (multi-hop fan-out, gather-scatter) that tabular models cannot.

For context, the IBM paper (arXiv:2306.16424) reports the full Multi-GNN reaching minority-class F1 in the 60–70% range with a much larger model using heterogeneous graph features, temporal GNNs, and ensemble stacking. This is a single-GNN baseline using a straightforward GINE architecture.


Inference

import torch
import torch.nn.functional as F
from torch_geometric.nn import GINEConv

# Rebuild the model class from train_gnn_aml.py
class EdgeGNN(torch.nn.Module):
    def __init__(self, nin, ein, h):
        super().__init__()
        self.lin_in = torch.nn.Linear(nin, h)
        self.lin_e  = torch.nn.Linear(ein, h)
        mlp = lambda: torch.nn.Sequential(
            torch.nn.Linear(h, h), torch.nn.ReLU(), torch.nn.Linear(h, h))
        self.c1 = GINEConv(mlp(), edge_dim=h)
        self.c2 = GINEConv(mlp(), edge_dim=h)
        self.c3 = GINEConv(mlp(), edge_dim=h)
        self.head = torch.nn.Sequential(
            torch.nn.Linear(3*h, h), torch.nn.ReLU(),
            torch.nn.Dropout(0.2), torch.nn.Linear(h, 2))

    def forward(self, x, mp_ei, mp_ea, cls_ei, cls_ea):
        e_mp = self.lin_e(mp_ea)
        h = F.relu(self.lin_in(x))
        h = F.relu(self.c1(h, mp_ei, e_mp))
        h = F.relu(self.c2(h, mp_ei, e_mp))
        h = F.relu(self.c3(h, mp_ei, e_mp))
        e_cls = self.lin_e(cls_ea)
        z = torch.cat([h[cls_ei[0]], h[cls_ei[1]], e_cls], dim=1)
        return self.head(z)

# Load weights
ea_dim = 8  # adjust to match your edge feature count
model = EdgeGNN(nin=2, ein=ea_dim, h=96)
model.load_state_dict(torch.load("aml_gnn.pt", map_location="cpu"))
model.eval()

# Score edges
with torch.no_grad():
    logits = model(x, mp_edge_index, mp_edge_attr, edge_index, edge_attr)
    probs  = F.softmax(logits, dim=1)[:, 1].numpy()

# Apply routing threshold (e.g., recall-80% target β†’ threshold 0.377)
route_to_llm = probs >= 0.377

See train_gnn_aml.py for the complete graph construction pipeline (build_graph, temporal_masks, and the full featurization logic).


Repository Contents

File Description
aml_gnn.pt Trained model weights (state dict)
aml_gnn_metrics.json Full evaluation metrics + routing thresholds
train_gnn_aml.py Training script (graph construction, model definition, evaluation)

Limitations

  • Prototype/research use. Source data is synthetic (IBM AML HI-Small). Validate on your own transaction graph before deployment.
  • The model is a high-recall graph triage tool. Low precision at the routing threshold is by design β€” downstream human review and the Tier-2 LLM are expected to filter further.
  • Full-graph inference requires loading the entire transaction multigraph into GPU memory. For production streaming, a subgraph sampling or incremental inference approach is needed.
  • No model here should be used for real customer adjudication without independent validation, bias review, and human-in-the-loop controls.

Citation

IBM AML dataset: Altman et al., NeurIPS 2023 (arXiv:2306.16424).
Reference GNN implementation: IBM/Multi-GNN.


License

Apache-2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train naazimsnh02/fraudsentinel-aml-gnn

Collection including naazimsnh02/fraudsentinel-aml-gnn

Paper for naazimsnh02/fraudsentinel-aml-gnn