Text Classification
Transformers
TensorBoard
English
Swahili
multi-task-learning
fraud-detection
sentiment-analysis
call-quality
question-answering
jenga-ai
nlp-for-africa
security
attention-fusion
Eval Results (legacy)
Instructions to use Rogendo/JengaAI-multi-task-distilbert-base-uncased with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Rogendo/JengaAI-multi-task-distilbert-base-uncased with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="Rogendo/JengaAI-multi-task-distilbert-base-uncased")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Rogendo/JengaAI-multi-task-distilbert-base-uncased", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| language: | |
| - en | |
| - sw | |
| tags: | |
| - multi-task-learning | |
| - text-classification | |
| - fraud-detection | |
| - sentiment-analysis | |
| - call-quality | |
| - question-answering | |
| - jenga-ai | |
| - nlp-for-africa | |
| - security | |
| - attention-fusion | |
| base_model: distilbert-base-uncased | |
| license: apache-2.0 | |
| pipeline_tag: text-classification | |
| datasets: | |
| - custom | |
| model-index: | |
| - name: JengaAI-multi-task-nlp | |
| results: | |
| - task: | |
| type: text-classification | |
| name: Fraud Detection | |
| metrics: | |
| - type: f1 | |
| value: 1 | |
| name: F1 | |
| - type: accuracy | |
| value: 1 | |
| name: Accuracy | |
| - task: | |
| type: text-classification | |
| name: Sentiment Analysis | |
| metrics: | |
| - type: f1 | |
| value: 0.167 | |
| name: F1 | |
| - type: accuracy | |
| value: 0.333 | |
| name: Accuracy | |
| - task: | |
| type: text-classification | |
| name: Call Quality - Listening | |
| metrics: | |
| - type: f1 | |
| value: 0.922 | |
| name: F1 | |
| - task: | |
| type: text-classification | |
| name: Call Quality - Resolution | |
| metrics: | |
| - type: f1 | |
| value: 0.908 | |
| name: F1 | |
| widget: | |
| - text: >- | |
| Suspicious M-Pesa transaction detected from unknown account requesting | |
| urgent transfer | |
| example_title: Fraud Detection | |
| - text: >- | |
| The customer service was excellent, my billing issue was resolved on the | |
| first call | |
| example_title: Positive Sentiment | |
| - text: Hello, welcome to Safaricom customer care. How can I assist you today? | |
| example_title: Call Quality Scoring | |
| library_name: transformers | |
| # JengaAI Multi-Task NLP (3-Task Attention Fusion) | |
| A **multi-task NLP model** built with the [JengaAI framework](https://github.com/Rogendo/JengaAI) that performs **fraud detection**, **sentiment analysis**, and **call quality scoring** simultaneously through a shared encoder with attention-based task fusion. Designed for Kenyan national security and telecommunications applications. | |
| ## Model Capabilities | |
| This model handles **3 tasks** with **8 prediction heads** producing **22 total output dimensions** in a single forward pass: | |
| | Task | Type | Heads | Outputs | Best F1 | | |
| |:-----|:-----|:------|:--------|:--------| | |
| | **Fraud Detection** | Binary classification | 1 (fraud) | 2 classes: normal / fraud | **1.000** | | |
| | **Sentiment Analysis** | 3-class classification | 1 (sentiment) | 3 classes: negative / neutral / positive | 0.167 | | |
| | **Call Quality Scoring** | Multi-label QA | 6 heads, 17 sub-metrics | Binary per sub-metric | **0.646 - 0.967** | | |
| ### Call Quality Sub-Metrics (17 Binary Outputs) | |
| The call quality task evaluates customer service transcripts across 6 quality dimensions: | |
| | Head | Sub-Metrics | F1 | | |
| |:-----|:-----------|:---| | |
| | **Opening** | greeting | 0.967 | | |
| | **Listening** | acknowledgment, empathy, clarification, active_listening, patience | 0.922 | | |
| | **Proactiveness** | initiative, follow_up, suggestions | 0.802 | | |
| | **Resolution** | identified_issue, provided_solution, confirmed_resolution, set_expectations, offered_alternatives | 0.908 | | |
| | **Hold** | asked_permission, explained_reason | 0.647 | | |
| | **Closing** | proper_farewell | 0.881 | | |
| ## Architecture | |
| ``` | |
| Input Text | |
| | | |
| v | |
| [DistilBERT Encoder] ---- 6 layers, 768 hidden, 12 attention heads | |
| | | |
| v | |
| [Attention Fusion] ------- task-conditioned attention with residual connections | |
| | | |
| +-- [Task 0: Fraud Head] ----------- Linear(768, 2) --> softmax | |
| +-- [Task 1: Sentiment Head] ------- Linear(768, 3) --> softmax | |
| +-- [Task 2: QA Scoring 6 Heads] --- 6x Linear(768, 1..5) --> sigmoid | |
| ``` | |
| **Key design choices:** | |
| - **Shared encoder**: All 3 tasks share a single DistilBERT encoder, enabling knowledge transfer between fraud patterns, sentiment signals, and call quality indicators | |
| - **Attention fusion**: A learned attention mechanism modulates the shared representation per task, allowing each task to attend to different parts of the encoder output while still benefiting from shared features | |
| - **Residual connections**: Fusion output is added to the original representation (gate_init_value=0.5), ensuring stable training and allowing each task to fall back on the base representation | |
| - **Multi-head QA**: Call quality uses 6 independent classification heads with different output sizes (1-5 binary outputs each), weighted by importance during training (resolution: 2.0x, listening: 1.5x, hold: 0.5x) | |
| ## Usage | |
| ### With JengaAI Framework (Recommended) | |
| ```bash | |
| pip install torch transformers pydantic pyyaml huggingface_hub | |
| ``` | |
| ```python | |
| from huggingface_hub import snapshot_download | |
| from jenga_ai.inference import InferencePipeline | |
| # Download model | |
| model_path = snapshot_download( | |
| "Rogendo/JengaAI-multi-task-nlp", | |
| ignore_patterns=["checkpoints/*", "logs/*"], | |
| ) | |
| # Load pipeline | |
| pipeline = InferencePipeline.from_checkpoint( | |
| model_dir=model_path, | |
| config_path=f"{model_path}/experiment_config.yaml", | |
| device="auto", | |
| ) | |
| # Run all 3 tasks at once | |
| result = pipeline.predict("Suspicious M-Pesa transaction from unknown account") | |
| print(result.to_json()) | |
| # Or run a single task | |
| fraud_result = pipeline.predict( | |
| "WARNING: Your Safaricom account has been compromised. Send 5000 KES to unlock.", | |
| task_name="fraud_detection", | |
| ) | |
| fraud = fraud_result.task_results["fraud_detection"].heads["fraud"] | |
| print(f"Fraud: {fraud.prediction} (confidence: {fraud.confidence:.1%})") | |
| # Fraud: 1 (confidence: 96.9%) | |
| ``` | |
| ### Batch Inference | |
| ```python | |
| texts = [ | |
| "Suspicious M-Pesa notification asking me to send money.", | |
| "Normal airtime top-up of 100 KES via M-Pesa.", | |
| "WARNING: Your account has been compromised.", | |
| ] | |
| results = pipeline.predict_batch(texts, task_name="fraud_detection", batch_size=32) | |
| for text, result in zip(texts, results): | |
| fraud = result.task_results["fraud_detection"].heads["fraud"] | |
| label = "FRAUD" if fraud.prediction == 1 else "LEGIT" | |
| print(f"[{label} {fraud.confidence:.1%}] {text}") | |
| ``` | |
| ### CLI | |
| ```bash | |
| # Single text | |
| python -m jenga_ai predict \ | |
| --config experiment_config.yaml \ | |
| --model-dir ./model \ | |
| --text "Suspicious M-Pesa transaction from unknown account" \ | |
| --format report | |
| # Batch from file | |
| python -m jenga_ai predict \ | |
| --config experiment_config.yaml \ | |
| --model-dir ./model \ | |
| --input-file transcripts.jsonl \ | |
| --output predictions.json \ | |
| --batch-size 16 | |
| ``` | |
| ### Call Quality Scoring Example | |
| ```python | |
| result = pipeline.predict( | |
| "Hello, welcome to Safaricom customer care. I understand you're having " | |
| "a billing issue. Let me look into that for you right away. I've found " | |
| "the discrepancy and corrected your balance. Is there anything else?", | |
| task_name="call_quality", | |
| ) | |
| for head_name, head in result.task_results["call_quality"].heads.items(): | |
| print(f"{head_name:16s} {head.prediction} (conf: {head.confidence:.2f})") | |
| ``` | |
| Output: | |
| ``` | |
| opening {'greeting': True} (conf: 0.82) | |
| listening {'acknowledgment': True, 'empathy': True, ...} (conf: 0.75) | |
| proactiveness {'initiative': True, 'follow_up': True, 'suggestions': False} (conf: 0.58) | |
| resolution {'identified_issue': True, 'provided_solution': True, ...} (conf: 0.69) | |
| hold {'asked_permission': False, 'explained_reason': False} (conf: 0.02) | |
| closing {'proper_farewell': True} (conf: 0.52) | |
| ``` | |
| ### Low-Level Usage (Without JengaAI Framework) | |
| If you only need the raw model weights and want to integrate into your own pipeline: | |
| ```python | |
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| # Load components | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| encoder_config = AutoConfig.from_pretrained("./model/encoder_config") | |
| with open("./model/metadata.json") as f: | |
| metadata = json.load(f) | |
| # Load full state dict | |
| state_dict = torch.load("./model/model.pt", map_location="cpu", weights_only=True) | |
| # Extract encoder weights (keys starting with "encoder.") | |
| encoder_state = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")} | |
| encoder = AutoModel.from_config(encoder_config) | |
| encoder.load_state_dict(encoder_state) | |
| encoder.eval() | |
| # Run encoder | |
| inputs = tokenizer("Suspicious transaction", return_tensors="pt", padding="max_length", | |
| truncation=True, max_length=256) | |
| with torch.no_grad(): | |
| outputs = encoder(**inputs) | |
| cls_embedding = outputs.last_hidden_state[:, 0] # [1, 768] | |
| # Extract fraud head weights (task 0, head "fraud") | |
| fraud_weight = state_dict["tasks.0.heads.fraud.1.weight"] # [2, 768] | |
| fraud_bias = state_dict["tasks.0.heads.fraud.1.bias"] # [2] | |
| logits = cls_embedding @ fraud_weight.T + fraud_bias | |
| probs = torch.softmax(logits, dim=-1) | |
| print(f"Fraud probability: {probs[0, 1].item():.4f}") | |
| ``` | |
| ## Intended Use | |
| ### Primary Use Cases | |
| - **M-Pesa Fraud Detection**: Classify M-Pesa transaction descriptions as fraudulent or legitimate. Designed for Safaricom and Kenyan mobile money contexts. | |
| - **Customer Sentiment Monitoring**: Analyze customer feedback and communications for sentiment polarity (negative / neutral / positive). | |
| - **Call Center Quality Assurance**: Score customer service call transcripts across 17 quality sub-metrics in 6 categories, replacing manual QA audits. | |
| - **Multi-Signal Analysis**: Run all 3 tasks simultaneously on the same text to get a comprehensive analysis (is this a fraud attempt? what's the sentiment? how good was the agent's response?). | |
| ### Intended Users | |
| - Kenyan telecommunications companies (Safaricom, Airtel Kenya) | |
| - Financial institutions monitoring mobile money transactions | |
| - Call center operations teams performing quality audits | |
| - Security analysts processing incident reports | |
| - NLP researchers working on African language and context models | |
| ### Downstream Use | |
| The model can be integrated into: | |
| - Real-time fraud alerting systems | |
| - Call center dashboards with automated QA scoring | |
| - Customer feedback analysis pipelines | |
| - Security operations center (SOC) threat triage workflows | |
| - Mobile money transaction monitoring platforms | |
| ## Out-of-Scope Use | |
| - **Not for automated decision-making without human oversight.** This model should support human analysts, not replace them. High-stakes fraud decisions require human review. | |
| - **Not for non-Kenyan contexts without retraining.** Entity names, transaction patterns, and call center norms are Kenyan-specific. | |
| - **Not for languages other than English.** While some Swahili words appear in the training data (M-Pesa, Safaricom, KRA), the model is primarily English. | |
| - **Not for legal evidence.** Model outputs are analytical signals, not forensic evidence. | |
| - **Not for surveillance of individuals.** The model analyzes text content, not identity. | |
| ## Bias, Risks, and Limitations | |
| ### Known Biases | |
| - **Training data imbalance**: Fraud detection was trained on only 20 samples (16 train / 4 eval). The model achieves 1.0 F1 on eval but this is likely due to the tiny eval set and potential overfitting. Real-world fraud patterns are far more diverse. | |
| - **Sentiment data**: Only 15 samples, with accuracy stuck at 33.3% (random baseline for 3 classes). The sentiment head needs significantly more training data to be production-useful. | |
| - **Call quality data**: 4,996 synthetic transcripts. While metrics are strong (0.65-0.97 F1), the synthetic nature means real-world transcripts with noise, code-switching (Swahili-English), and non-standard grammar may perform differently. | |
| - **Geographic bias**: All training data reflects Kenyan contexts. The model may not generalize to other East African countries without adaptation. | |
| ### Risks | |
| - **False positives in fraud detection**: Legitimate transactions flagged as fraud can block real users. Always use this model with human review for enforcement actions. | |
| - **False negatives in fraud detection**: Sophisticated fraud patterns not in the training data will be missed. This model is one signal among many, not a standalone detector. | |
| - **Over-reliance on QA scores**: Call quality scores should augment, not replace, human QA reviewers. Edge cases (cultural nuances, sarcasm, escalation scenarios) may be scored incorrectly. | |
| ### Recommendations | |
| - Use fraud detection as a **triage signal** (flag for review), not an automatic block | |
| - Retrain with production-scale data before deploying to production | |
| - Monitor prediction confidence — route low-confidence predictions to human review using the built-in HITL routing (`enable_hitl=True`) | |
| - Enable PII redaction (`enable_pii=True`) when processing real customer data | |
| - Enable audit logging (`enable_audit=True`) for compliance and accountability | |
| ## Training Details | |
| ### Training Data | |
| | Dataset | Task | Samples | Source | | |
| |:--------|:-----|:--------|:-------| | |
| | `sample_classification.jsonl` | Fraud Detection | 20 | Synthetic M-Pesa transaction descriptions | | |
| | `sample_sentiment.jsonl` | Sentiment Analysis | 15 | Synthetic customer feedback | | |
| | `synthetic_qa_metrics_data_v01x.json` | Call Quality | 4,996 | Synthetic call center transcripts with 17 binary QA labels | | |
| **Train/eval split**: 80/20 random split (seed=42) | |
| All datasets are synthetic, generated to reflect linguistic patterns in Kenyan telecommunications and financial services contexts. They contain English text with occasional Swahili terms and Kenyan-specific entities (M-Pesa, Safaricom, KRA, Kenyan phone numbers). | |
| ### Training Procedure | |
| #### Preprocessing | |
| - Tokenizer: `distilbert-base-uncased` WordPiece tokenizer | |
| - Max sequence length: 256 tokens | |
| - Padding: `max_length` (padded to 256) | |
| - Truncation: enabled | |
| #### Architecture | |
| - **Encoder**: DistilBERT (6 layers, 768 hidden, 12 heads) — 66.4M parameters | |
| - **Fusion**: Attention fusion with residual connections — 1.2M parameters | |
| - **Task heads**: 8 linear heads across 3 tasks — 17K parameters | |
| - **Total**: 67.6M parameters (258MB on disk) | |
| #### Training Hyperparameters | |
| | Parameter | Value | | |
| |:----------|:------| | |
| | Learning rate | 2e-5 | | |
| | Batch size | 16 | | |
| | Epochs | 12 (best checkpoint at epoch 3) | | |
| | Weight decay | 0.01 | | |
| | Warmup steps | 20 | | |
| | Max gradient norm | 1.0 | | |
| | Optimizer | AdamW | | |
| | Precision | FP32 | | |
| | Task sampling | Proportional (temperature=2.0) | | |
| | Early stopping patience | 5 epochs | | |
| | Best model metric | eval_loss | | |
| #### Task Loss Weights | |
| | Head | Weight | Rationale | | |
| |:-----|:-------|:----------| | |
| | fraud | 1.0 | Standard | | |
| | sentiment | 1.0 | Standard | | |
| | opening | 1.0 | Standard | | |
| | listening | 1.5 | Important quality dimension | | |
| | proactiveness | 1.0 | Standard | | |
| | resolution | 2.0 | Most critical quality dimension | | |
| | hold | 0.5 | Less frequent in transcripts | | |
| | closing | 1.0 | Standard | | |
| #### Training Loss Progression | |
| | Epoch | Train Loss | Eval Loss | Status | | |
| |:------|:-----------|:----------|:-------| | |
| | 3 | 1.878 | **1.948** | Best checkpoint | | |
| | 7 | 1.471 | 2.057 | Overfitting begins | | |
| | 8 | 1.403 | 2.068 | Continued overfitting | | |
| The best checkpoint was selected at epoch 3 based on eval_loss. Training continued to epoch 12 but eval loss increased after epoch 3, indicating overfitting — expected given the small fraud and sentiment datasets. | |
| ### Speeds, Sizes, Times | |
| | Metric | Value | | |
| |:-------|:------| | |
| | Model size (disk) | 258 MB | | |
| | Parameters | 67.6M | | |
| | Inference latency (single task, CPU) | ~590 ms | | |
| | Inference latency (all 3 tasks, CPU) | ~1,960 ms | | |
| | Batch throughput (32 texts, single task, CPU) | ~647 ms/sample | | |
| | Training time | ~5 minutes (CPU, 12 epochs) | | |
| ## Evaluation | |
| ### Metrics | |
| All metrics are computed on the 20% held-out eval split. | |
| **Fraud Detection** (binary classification): | |
| | Metric | Value | | |
| |:-------|:------| | |
| | Accuracy | 1.000 | | |
| | Precision | 1.000 | | |
| | Recall | 1.000 | | |
| | F1 | 1.000 | | |
| **Sentiment Analysis** (3-class classification): | |
| | Metric | Value | | |
| |:-------|:------| | |
| | Accuracy | 0.333 | | |
| | Precision | 0.111 | | |
| | Recall | 0.333 | | |
| | F1 | 0.167 | | |
| **Call Quality** (multi-label binary per head): | |
| | Head | Precision | Recall | F1 | | |
| |:-----|:----------|:-------|:---| | |
| | Opening | 0.967 | 0.967 | **0.967** | | |
| | Listening | 0.893 | 0.953 | **0.922** | | |
| | Proactiveness | 0.746 | 0.868 | **0.802** | | |
| | Resolution | 0.918 | 0.898 | **0.908** | | |
| | Hold | 0.856 | 0.519 | **0.647** | | |
| | Closing | 0.881 | 0.881 | **0.881** | | |
| ### Results Summary | |
| - **Fraud detection** achieves perfect metrics on the eval set, but this is a very small eval set (4 samples). Production deployment requires evaluation on a larger, more diverse dataset. | |
| - **Sentiment analysis** performs at random baseline (33.3% accuracy for 3 classes), indicating the 15-sample dataset is insufficient. This head needs retraining with production data. | |
| - **Call quality** shows strong performance across most heads (0.80-0.97 F1), with the "hold" category being the weakest (0.647 F1) due to fewer hold-related examples in the training data. | |
| ## Model Examination | |
| ### Attention Fusion | |
| The attention fusion mechanism learns task-specific attention patterns over the shared encoder output. This allows: | |
| - The fraud head to attend to transaction-related tokens (amounts, account references) | |
| - The sentiment head to attend to opinion-bearing words | |
| - The QA heads to attend to conversational flow patterns | |
| The fusion uses a gated residual connection (initialized at 0.5), meaning each task's representation is a learned blend of the task-specific attended output and the original encoder output. | |
| ### Security Features | |
| When used with the JengaAI inference framework, the model supports: | |
| - **PII Redaction**: Masks Kenyan-specific PII (phone numbers, national IDs, KRA PINs, M-Pesa transaction IDs) before inference | |
| - **Explainability**: Token-level importance scores via attention analysis or gradient methods | |
| - **Human-in-the-Loop**: Automatic routing of low-confidence predictions to human reviewers based on entropy-based uncertainty estimation | |
| - **Audit Trail**: Tamper-evident logging of every inference call with SHA-256 hash chains | |
| ## Technical Specifications | |
| ### Model Architecture and Objective | |
| - **Architecture**: DistilBERT encoder + attention fusion + multi-task heads | |
| - **Encoder**: 6 transformer layers, 768 hidden size, 12 attention heads, 30,522 vocab | |
| - **Fusion**: Single-head attention with residual gating | |
| - **Objectives**: CrossEntropy (fraud, sentiment) + BCEWithLogits (call quality) | |
| ### Compute Infrastructure | |
| #### Hardware | |
| - Training: CPU (Intel/AMD, standard workstation) | |
| - Inference: CPU or CUDA GPU | |
| #### Software | |
| - PyTorch 2.x | |
| - Transformers 5.x | |
| - JengaAI Framework V2 | |
| - Python 3.11+ | |
| ## Environmental Impact | |
| - **Hardware Type**: CPU (standard workstation) | |
| - **Training Time**: ~5 minutes | |
| - **Carbon Emitted**: Negligible (short training run on CPU) | |
| ## Citation | |
| ```bibtex | |
| @software{jengaai2026, | |
| title = {JengaAI: Low-Code Multi-Task NLP for African Security Applications}, | |
| author = {Rogendo}, | |
| year = {2026}, | |
| url = {https://huggingface.co/Rogendo/JengaAI-multi-task-nlp}, | |
| } | |
| ``` | |
| ## Model Card Authors | |
| Rogendo | |
| ## Model Card Contact | |
| For questions, issues, or contributions: [GitHub Issues](https://github.com/Rogendo/JengaAI/issues) |