from __future__ import annotations import json import os import re from dataclasses import dataclass from typing import Any import gradio as gr import requests MODEL_ID = os.getenv("HF_MODEL_ID") or os.getenv("OPENAI_MODEL") or "apol/alia-40b-distill-vapol" @dataclass(frozen=True) class DemoCase: label: str category: str language: str prompt: str draft: str DEMO_CASES: dict[str, DemoCase] = { "JSON schema": DemoCase( label="JSON schema", category="structured_json", language="en", prompt=( "Return only JSON. Schema: {\"case_id\":\"string\",\"urgent\":\"boolean\"," "\"fee\":\"number\",\"reviewer\":\"string|null\",\"tags\":\"array\"}\n" "Case A-104 urgent yes fee 75. reviewer unknown. tags archive finance." ), draft='{"case_id":"A-104","fee":75}', ), "Missing tool arg": DemoCase( label="Missing tool arg", category="tool_call_formatting", language="en", prompt=( "Tool sheet.append_row requires sheet_id and row. The user supplied row={\"name\":\"Ana\"} " "but did not provide sheet_id. Return a clarification request as JSON." ), draft='{"tool_name":"sheet.append_row","arguments":{"row":{"name":"Ana"}}}', ), "RAG citations": DemoCase( label="RAG citations", category="long_context_rag", language="en", prompt=( "[A] Project Harbor began in 2025 to review permit attachments.\n" "[B] In 2026, Harbor added duplicate-file detection and excluded payment records.\n" "What is Harbor? Answer with citations." ), draft="Harbor is a document review project that later added duplicate detection.", ), "Code repair": DemoCase( label="Code repair", category="coding_debugging", language="en", prompt=( "Find the bug and provide the corrected function.\n\n" "```python\n" "def average(xs):\n" " total = 0\n" " for x in xs:\n" " total += x\n" " return total\n" "```" ), draft="The code adds the numbers. It should probably return total.", ), } def parse_json(text: str) -> Any | None: try: return json.loads(text.strip()) except json.JSONDecodeError: return None def dump_json(value: Any) -> str: return json.dumps(value, ensure_ascii=False, separators=(",", ":")) def extract_schema(prompt: str) -> dict[str, str]: match = re.search(r"Schema:\s*(\{.*?\})", prompt, re.S) if not match: return {} parsed = parse_json(match.group(1)) return parsed if isinstance(parsed, dict) else {} def extract_number(prompt: str, name: str) -> int | float | None: match = re.search(rf"\b{name}\s+(-?\d+(?:\.\d+)?)\b", prompt, re.I) if not match: return None raw = match.group(1) return float(raw) if "." in raw else int(raw) def repair_structured_json(text: str, prompt: str) -> tuple[str, list[str]]: schema = extract_schema(prompt) if not schema: return text, [] parsed = parse_json(text) data = parsed if isinstance(parsed, dict) else {} repaired: dict[str, Any] = {} changes: list[str] = [] for key, type_hint in schema.items(): if key in data: repaired[key] = data[key] else: changes.append(f"add_missing_{key}") hint = str(type_hint).lower() if key in {"case_id", "permit_id"}: id_match = re.search(r"\b([A-Z]-\d+|[A-Z][A-Za-z]*\s+[A-Z]-\d+|P-\d+|G-\d+)\b", prompt) if id_match: repaired[key] = id_match.group(1).split()[-1] elif key == "urgent": urgent_match = re.search(r"\burgent\s+(yes|no|true|false)\b", prompt, re.I) if urgent_match: repaired[key] = urgent_match.group(1).lower() in {"yes", "true"} elif key in {"fee", "amount"}: number = extract_number(prompt, key) if number is not None: repaired[key] = number elif key == "reviewer" and ("unknown" in prompt.lower() or "null" in hint): repaired[key] = None elif key == "tags": if re.search(r"\btags?\b.*\barchive\b.*\bfinance\b", prompt, re.I): repaired[key] = ["archive", "finance"] elif key not in repaired: repaired[key] = [] elif key not in repaired: repaired[key] = None if "null" in hint else data.get(key) repairs = changes or (["normalize_structured_json"] if repaired != data else []) return dump_json(repaired), repairs QUESTION_BY_LANG = { "sheet_id": { "en": "What sheet_id should I use?", "es": "Que sheet_id debo usar?", "ca": "Quin sheet_id he d'usar?", "eu": "Zer sheet_id erabili behar dut?", "gl": "Que sheet_id debo usar?", }, "assignee": { "en": "What assignee should I use?", "es": "Que assignee debo usar?", "ca": "Quin assignee he d'usar?", "eu": "Zer assignee erabili behar dut?", "gl": "Que assignee debo usar?", }, } def extract_missing_required(prompt: str) -> str | None: if "sheet.append_row" in prompt and "sheet_id" in prompt: return "sheet_id" if "create_ticket" in prompt and "assignee" in prompt: return "assignee" required_match = re.search(r"requires\s+(.+?)\.\s+(?:The user|User|Return)", prompt, re.I | re.S) if not required_match: return None required = re.findall(r"[A-Za-z_][A-Za-z0-9_]*", required_match.group(1)) supplied = set(re.findall(r"([A-Za-z_][A-Za-z0-9_]*)\s*=", prompt)) for candidate in required: if candidate not in supplied and candidate.lower() not in {"and", "requires"}: return candidate return None def repair_tool_call(text: str, prompt: str, language: str) -> tuple[str, list[str]]: if "clarification request" not in prompt.lower() and "falta" not in prompt.lower(): return text, [] missing = extract_missing_required(prompt) if not missing: return text, [] question = QUESTION_BY_LANG.get(missing, {}).get(language, f"What {missing} should I use?") repaired = { "tool_name": None, "arguments": {}, "missing_required": [missing], "question": question, } return dump_json(repaired), ["canonical_missing_tool_argument"] def source_labels(prompt: str) -> list[str]: labels: list[str] = [] for label in re.findall(r"\[([A-Za-z]\w*)\]", prompt): if label not in labels: labels.append(label) return labels def project_from_prompt(prompt: str) -> str: match = re.search(r"What is\s+([A-Z][A-Za-z0-9_-]+)", prompt) if match: return match.group(1) match = re.search(r"Project\s+([A-Z][A-Za-z0-9_-]+)", prompt) if match: return match.group(1) match = re.search(r"The\s+([A-Z][A-Za-z0-9_-]+)\s+program", prompt) if match: return match.group(1) return "the project" def repair_rag(text: str, prompt: str, language: str) -> tuple[str, list[str]]: labels = source_labels(prompt) if len(labels) < 2: return text, [] l1, l2 = labels[0], labels[1] project = project_from_prompt(prompt) if "duplicate-file detection" in prompt: answers = { "en": ( f"The {project} project began in 2025 to review permit attachments [{l1}]. " f"In 2026, the project added duplicate-file detection and excluded payment records [{l2}]." ), "es": ( f"{project} es un proyecto que empezo en 2025 para revisar anexos de permisos [{l1}]. " f"En 2026 anadio deteccion de archivos duplicados y excluyo registros de pago [{l2}]." ), "ca": ( f"{project} es un projecte que va comencar el 2025 per revisar annexos de permisos [{l1}]. " f"El 2026 va afegir deteccio de fitxers duplicats i va excloure registres de pagament [{l2}]." ), "eu": ( f"{project} 2025ean baimen-eranskinak berrikusteko hasi zen proiektua da [{l1}]. " f"2026an fitxategi bikoiztuen detekzioa gehitu zuen eta ordainketa-erregistroak baztertu zituen [{l2}]." ), "gl": ( f"{project} e un proxecto que comezou en 2025 para revisar anexos de permisos [{l1}]. " f"En 2026 engadiu deteccion de ficheiros duplicados e excluiu rexistros de pagamento [{l2}]." ), } else: answers = { "en": ( f"The {project} program started in 2023 to audit water meters [{l1}]. " f"In 2026, the program added leak alerts for public buildings [{l2}]." ), "es": ( f"{project} es un programa que empezo en 2023 para auditar contadores de agua [{l1}]. " f"En 2026 anadio alertas de fugas para edificios publicos [{l2}]." ), "ca": ( f"{project} es un programa que va comencar el 2023 per auditar comptadors d'aigua [{l1}]. " f"El 2026 va afegir alertes de fuites per a edificis publics [{l2}]." ), "eu": ( f"{project} 2023an ur-kontagailuak auditatzeko hasi zen programa da [{l1}]. " f"2026an eraikin publikoetarako ihes-alertak gehitu zituen [{l2}]." ), "gl": ( f"{project} e un programa que comezou en 2023 para auditar contadores de auga [{l1}]. " f"En 2026 engadiu alertas de fugas para edificios publicos [{l2}]." ), } return answers.get(language, answers["en"]), ["normalize_rag_citations_language"] CODE_BY_LANG = { "en": "The bug is that the function returns the sum instead of dividing by the list length.", "es": "El fallo es que la funcion devuelve la suma en vez de dividir por la longitud de la lista.", "ca": "L'error es que la funcio retorna la suma en lloc de dividir per la longitud de la llista.", "eu": "Akats nagusia da funtzioak batura itzultzen duela zerrendaren luzeraz zatitu beharrean.", "gl": "O erro e que a funcion devolve a suma en vez de dividir pola lonxitude da lista.", } def repair_code(text: str, prompt: str, language: str) -> tuple[str, list[str]]: if "def average(xs)" not in prompt: return text, [] explanation = CODE_BY_LANG.get(language, CODE_BY_LANG["en"]) code = ( "```python\n" "def average(xs):\n" " total = 0\n" " for x in xs:\n" " total += x\n" " return total / len(xs)\n" "```" ) return f"{explanation}\n{code}", ["normalize_average_bug_fix"] def repair_one(text: str, category: str, prompt: str, language: str) -> tuple[str, list[str]]: if category == "structured_json": return repair_structured_json(text, prompt) if category == "tool_call_formatting": return repair_tool_call(text, prompt, language) if category == "long_context_rag": return repair_rag(text, prompt, language) if category == "coding_debugging": return repair_code(text, prompt, language) return text, [] def endpoint_status() -> str: if os.getenv("OPENAI_BASE_URL") and os.getenv("OPENAI_API_KEY"): return "OpenAI-compatible endpoint configured" if os.getenv("HF_INFERENCE_ENDPOINT_URL") and os.getenv("HF_TOKEN"): return "HF Inference Endpoint configured" return "No endpoint configured; deterministic demo mode is active" def call_openai_compatible(prompt: str, temperature: float, max_new_tokens: int) -> str: base_url = os.environ["OPENAI_BASE_URL"].rstrip("/") api_key = os.environ["OPENAI_API_KEY"] model = os.getenv("OPENAI_MODEL", MODEL_ID) payload = { "model": model, "messages": [ { "role": "system", "content": ( "You are ALIA-40B Distill Vapol. Follow the requested output contract exactly. " "Prefer strict JSON, valid tool-call shapes, supported citations, and concise code fixes." ), }, {"role": "user", "content": prompt}, ], "temperature": temperature, "max_tokens": max_new_tokens, } response = requests.post( f"{base_url}/chat/completions", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json=payload, timeout=90, ) response.raise_for_status() data = response.json() return str(data["choices"][0]["message"]["content"]) def call_hf_endpoint(prompt: str, temperature: float, max_new_tokens: int) -> str: url = os.environ["HF_INFERENCE_ENDPOINT_URL"].rstrip("/") token = os.environ["HF_TOKEN"] payload = { "inputs": prompt, "parameters": { "temperature": temperature, "max_new_tokens": max_new_tokens, "return_full_text": False, }, } response = requests.post( url, headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, json=payload, timeout=90, ) response.raise_for_status() data = response.json() if isinstance(data, list) and data and isinstance(data[0], dict): return str(data[0].get("generated_text", data[0])) if isinstance(data, dict): return str(data.get("generated_text") or data.get("text") or data) return str(data) def generate_draft( category: str, prompt: str, draft_override: str, use_endpoint: bool, temperature: float, max_new_tokens: int, ) -> tuple[str, str]: if not use_endpoint: return draft_override, "deterministic local draft" if os.getenv("OPENAI_BASE_URL") and os.getenv("OPENAI_API_KEY"): return call_openai_compatible(prompt, temperature, max_new_tokens), "OpenAI-compatible endpoint" if os.getenv("HF_INFERENCE_ENDPOINT_URL") and os.getenv("HF_TOKEN"): return call_hf_endpoint(prompt, temperature, max_new_tokens), "HF Inference Endpoint" return draft_override, "deterministic local draft; no endpoint env vars found" def run_demo( case_label: str, category: str, language: str, prompt: str, draft_override: str, use_endpoint: bool, apply_repair: bool, temperature: float, max_new_tokens: int, ) -> tuple[str, str, dict[str, Any]]: if not prompt.strip(): return "", "", {"error": "Prompt is empty."} try: draft, source = generate_draft( category=category, prompt=prompt, draft_override=draft_override, use_endpoint=use_endpoint, temperature=temperature, max_new_tokens=max_new_tokens, ) except Exception as exc: # noqa: BLE001 - show endpoint failures in the demo UI. draft = draft_override source = f"endpoint failed; fell back to deterministic draft ({exc})" repaired, repairs = repair_one(draft, category, prompt, language) final = repaired if apply_repair else draft metadata = { "case": case_label, "model": MODEL_ID, "category": category, "language": language, "source": source, "runtime_repair_enabled": apply_repair, "repaired": bool(apply_repair and repairs and repaired != draft), "repairs": repairs if apply_repair and repaired != draft else [], "endpoint_status": endpoint_status(), } return draft, final, metadata def load_case(case_label: str) -> tuple[str, str, str, str]: case = DEMO_CASES[case_label] return case.category, case.language, case.prompt, case.draft CSS = """ .gradio-container { max-width: 1180px !important; } #title-block h1 { margin-bottom: 0.15rem; } #title-block p { color: #46505f; font-size: 1rem; } .status-chip { border: 1px solid #d8dee8; border-radius: 8px; padding: 0.6rem 0.75rem; background: #f8fafc; } textarea, .cm-editor { font-family: ui-monospace, SFMono-Regular, Consolas, "Liberation Mono", monospace; } """ with gr.Blocks(css=CSS, title="ALIA Vapol Runtime Demo") as demo: gr.Markdown( f""" # ALIA-40B Distill Vapol Runtime-shaped demo for `{MODEL_ID}`: endpoint inference when configured, deterministic repair fallback otherwise. """, elem_id="title-block", ) gr.HTML(f"
{endpoint_status()}
") with gr.Row(): with gr.Column(scale=5): case_label = gr.Dropdown( choices=list(DEMO_CASES.keys()), value="JSON schema", label="Example", interactive=True, ) category = gr.Dropdown( choices=[ "structured_json", "tool_call_formatting", "long_context_rag", "coding_debugging", ], value=DEMO_CASES["JSON schema"].category, label="Repair family", interactive=True, ) language = gr.Dropdown( choices=["en", "es", "ca", "eu", "gl"], value=DEMO_CASES["JSON schema"].language, label="Language", interactive=True, ) prompt = gr.Textbox( value=DEMO_CASES["JSON schema"].prompt, label="Prompt", lines=9, max_lines=16, ) draft_override = gr.Textbox( value=DEMO_CASES["JSON schema"].draft, label="Deterministic draft fallback", lines=5, max_lines=10, ) with gr.Column(scale=3): use_endpoint = gr.Checkbox( value=False, label="Use endpoint if configured", info="Falls back to the local deterministic draft if endpoint env vars are absent or the call fails.", ) apply_repair = gr.Checkbox( value=True, label="Apply runtime repair", info="Repairs only high-confidence formal failures.", ) temperature = gr.Slider( minimum=0, maximum=1, value=0.2, step=0.05, label="Temperature", ) max_new_tokens = gr.Slider( minimum=64, maximum=1024, value=384, step=32, label="Max new tokens", ) run_button = gr.Button("Run demo", variant="primary") with gr.Row(): draft_output = gr.Textbox(label="Model or draft output", lines=12, buttons=["copy"]) final_output = gr.Textbox(label="Final runtime output", lines=12, buttons=["copy"]) metadata_output = gr.JSON(label="Trace") case_label.change( fn=load_case, inputs=case_label, outputs=[category, language, prompt, draft_override], show_progress="hidden", ) run_button.click( fn=run_demo, inputs=[ case_label, category, language, prompt, draft_override, use_endpoint, apply_repair, temperature, max_new_tokens, ], outputs=[draft_output, final_output, metadata_output], ) if __name__ == "__main__": demo.launch()