import gradio as gr from sentence_transformers import SentenceTransformer import pickle, json import numpy as np from huggingface_hub import hf_hub_download # ── Load model artifacts ────────────────────────────────────────────── print("Loading classifier...") CLASSIFIER_PATH = hf_hub_download( "AurelPx/hr-conversations-classifier", "setfit_classifier.pkl", repo_type="model" ) LABEL_PATH = hf_hub_download( "AurelPx/hr-conversations-classifier", "setfit_label_config.json", repo_type="model" ) encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') with open(CLASSIFIER_PATH, 'rb') as f: classifier = pickle.load(f) with open(LABEL_PATH) as f: config = json.load(f) LABELS = config['label_names'] print(f"Loaded {len(LABELS)} labels: {LABELS}") # ── Classification function ───────────────────────────────────────────── def classify(text: str, threshold: float): if not text or not text.strip(): return "Please enter a conversation.", "" emb = encoder.encode([text]) proba = classifier.predict_proba(emb) # Build sorted probability list for all labels all_probs = [] for i, p in enumerate(proba): prob = float(p[0][1]) all_probs.append((LABELS[i], prob)) all_probs.sort(key=lambda x: x[1], reverse=True) # Filter by threshold predicted = [(l, p) for l, p in all_probs if p >= threshold] if not predicted: pred_str = f"No labels above threshold {threshold}" else: pred_str = " | ".join([f"**{l}** ({p:.3f})" for l, p in predicted]) probs_str = "\n".join([f"{l}: {p:.3f}" for l, p in all_probs]) return pred_str, probs_str # ── Gradio UI ────────────────────────────────────────────────────────── with gr.Blocks(title="HR Conversations Classifier") as demo: gr.Markdown(""" # 🏢 HR Conversations Multi-Label Classifier Classify HR support conversations into **20 topic labels**. | Metric | Score | |--------|-------| | **F1-micro (5-fold CV)** | **0.7962 ± 0.0098** | | **F1-macro (5-fold CV)** | **0.7721** | **Model**: SETFit (MiniLM-L6-v2 + Logistic Regression) **Training data**: 5,100 conversations (5,000 synthetic + 100 real) """) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Conversation", placeholder="Paste an HR conversation here...\n\nExample:\nUSER: I haven't received my payslip for March yet. Could you please check what's going on?\nAGENT: Good morning. I've checked the payroll system and it appears your March payslip was generated on the 28th but there was a distribution delay.", lines=10 ) threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Classification Threshold") classify_btn = gr.Button("Classify", variant="primary") with gr.Column(scale=1): output_labels = gr.Textbox(label="Predicted Labels", lines=3) output_probs = gr.Textbox(label="All Probabilities (sorted)", lines=22) gr.Examples( examples=[ ["USER: I haven't received my payslip for March yet. Could you please check what's going on?\nAGENT: Good morning. I've checked the payroll system and it appears your March payslip was generated on the 28th but there was a distribution delay. I've resent it to your registered email.", 0.5], ["USER: I need to take sick leave starting today. I woke up with a terrible flu and cant come in. Whats the proceedure?\nAGENT: I'm sorry to hear that. Please rest and take care of yourself. You need to submit a sick leave request in the HR portal and upload your medical certificate within 48 hours.", 0.5], ["USER: I would like to understand the rules around parental leave in France. My partner is expecting and I want to plan ahead.\nAGENT: Congratulations! Under French labor law, the second parent is entitled to 25 calendar days of paternity leave.", 0.5], ["USER: I received an email asking me to complete a GDPR refresher training. Is this mandatory?\nAGENT: Yes, the GDPR refresher is mandatory for all employees and must be completed annually.", 0.5], ["USER: I want to dispute my performance review. My manager gave me a rating that I believe is unfair and biased.\nAGENT: I'm sorry to hear that. You have the right to formally dispute your review. The first step is to submit a written appeal through the HR portal within 15 days.", 0.5], ], inputs=[text_input, threshold], label="Try these examples" ) classify_btn.click( fn=classify, inputs=[text_input, threshold], outputs=[output_labels, output_probs] ) if __name__ == "__main__": demo.launch()