import gradio as gr from transformers import NllbTokenizerFast, AutoModelForSeq2SeqLM from peft import PeftModel import os import json import tempfile from datetime import datetime import gspread import requests from oauth2client.service_account import ServiceAccountCredentials # --- GOOGLE SHEETS SETUP --- def get_sheet(): # Load credentials from Hugging Face Secrets raw_creds = os.getenv("G_CREDS") sheet_name = os.getenv("SHEET_NAME") if not raw_creds or not sheet_name: raise RuntimeError("Missing G_CREDS or SHEET_NAME secret.") creds_json = json.loads(raw_creds) scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_json, scope) client = gspread.authorize(creds) # Open the sheet return client.open(sheet_name).sheet1 # --- Configuration --- BASE_MODEL_ID = "facebook/nllb-200-distilled-600M" ADAPTER_SWA_KLN = "mutaician/nllb-swahili-kalenjin-v3" ADAPTER_KLN_SWA = "mutaician/nllb-kalenjin-swahili-v1" TTS_API_URL = os.getenv("TTS_API_URL", "").strip() TTS_API_KEY = os.getenv("TTS_API_KEY", "").strip() CUSTOM_CSS = """ :root, body, .gradio-container { font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; letter-spacing: 0 !important; } .gradio-container { max-width: 1180px !important; margin: 0 auto !important; } .app-title h1 { font-size: 2rem !important; line-height: 1.15 !important; margin-bottom: 0.25rem !important; } .app-subtitle { max-width: 850px; color: var(--body-text-color-subdued); } .section-note { color: var(--body-text-color-subdued); margin-bottom: 0.75rem; } .examples-grid { gap: 0.65rem !important; } .example-button { min-height: 46px !important; text-align: left !important; justify-content: flex-start !important; white-space: normal !important; line-height: 1.35 !important; font-size: 0.95rem !important; font-weight: 500 !important; border: 1px solid var(--border-color-primary) !important; background: var(--background-fill-secondary) !important; } .example-button:hover { border-color: var(--color-accent) !important; } .speech-box { border: 1px solid var(--border-color-primary); border-radius: 8px; padding: 0.75rem 0.85rem; background: var(--background-fill-secondary); } .speech-heading { margin: 0 0 0.55rem 0 !important; color: var(--body-text-color); font-weight: 650; } .speech-box .label-wrap { display: none !important; } .speech-box .wrap, .speech-box .wrap.svelte-1cl284s, .speech-box .audio-container, .speech-box .component-wrapper, .speech-box [data-testid="waveform"] { border: 0 !important; box-shadow: none !important; } .speech-box .block, .speech-box .gradio-audio { border: 0 !important; box-shadow: none !important; background: transparent !important; } .speech-box audio { outline: 0 !important; } .status-line textarea { min-height: 42px !important; font-size: 0.95rem !important; text-align: left !important; } """ # --- Load Model (Global / At Startup) --- print(f"Loading Base Model: {BASE_MODEL_ID}...") base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) tokenizer = NllbTokenizerFast.from_pretrained(BASE_MODEL_ID) print("Loading Adapters...") model = PeftModel.from_pretrained(base_model, ADAPTER_SWA_KLN, adapter_name="swa_kln") model.load_adapter(ADAPTER_KLN_SWA, adapter_name="kln_swa") print("Models loaded successfully!") # --- Helper Functions --- def save_to_sheet(input_text, output_text, type_val, correction="", direction=""): try: sheet = get_sheet() timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Append row to Google Sheet # Columns: Timestamp, Input, Output, Type, Correction, Direction # Leaving existing format intact and adding direction at the very end row = [timestamp, input_text, output_text, type_val, correction, direction] sheet.append_row(row) return True except Exception as e: print(f"Google Sheet Error: {e}") return False def translate(text, direction, allow_save): if not text: return "" # 1. Prediction & Routing if direction == "Swahili -> Kalenjin": model.set_adapter("swa_kln") tokenizer.src_lang = "swh_Latn" target_lang_id = tokenizer.convert_tokens_to_ids("luo_Latn") else: # Kalenjin -> Swahili model.set_adapter("kln_swa") tokenizer.src_lang = "luo_Latn" target_lang_id = tokenizer.convert_tokens_to_ids("swh_Latn") inputs = tokenizer(text, return_tensors="pt") generated_tokens = model.generate( **inputs, forced_bos_token_id=target_lang_id, max_length=256, num_beams=5, early_stopping=True ) result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] # 2. Logging if allow_save: save_to_sheet(text, result, "inference", "", direction) return result def submit_feedback(input_text, model_output, direction, user_correction): """Save correction and clear input box""" if not user_correction or not input_text: return "Please ensure there is a translation and a correction.", user_correction success = save_to_sheet(input_text, model_output, "feedback", user_correction, direction) if success: return "Asante / Kongoi! Feedback saved.", "" else: return "Error saving. Try again", user_correction def tts_endpoint_url(): if not TTS_API_URL: return "" if TTS_API_URL.rstrip("/").endswith("/synthesize"): return TTS_API_URL return f"{TTS_API_URL.rstrip('/')}/synthesize" def synthesize_kalenjin_speech(text, allow_save=False): text = " ".join((text or "").strip().split()) if not text: return None, "Enter Kalenjin text first." url = tts_endpoint_url() if not url: return None, "Kalenjin-TTS is not configured yet." headers = {} if TTS_API_KEY: headers["X-TTS-Key"] = TTS_API_KEY try: response = requests.post( url, json={ "text": text, "nfe_step": 32, "cfg_strength": 2.0, "speed": 0.8, "remove_silence": False, "output_format": "mp3", }, headers=headers, timeout=300, ) if response.status_code != 200: try: detail = response.json() except Exception: detail = response.text[:200] print(f"TTS API Error: {response.status_code} {detail}") return None, "Speech generation failed. Please try again." audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") audio_file.write(response.content) audio_file.close() if allow_save: save_to_sheet(text, "", "tts") elapsed = response.headers.get("X-TTS-Elapsed-Seconds") if elapsed: return audio_file.name, f"Speech generated in {elapsed}s." return audio_file.name, "Speech generated." except requests.Timeout: return None, "Speech generation timed out. Please try a shorter sentence." except Exception as e: print(f"TTS Request Error: {e}") return None, "Speech service is unavailable right now." def speak_translation(model_output, direction, allow_save): if direction != "Swahili -> Kalenjin": return None, "Speech is available when the output is Kalenjin." return synthesize_kalenjin_speech(model_output, allow_save) def translation_speech_visibility(direction): return gr.update(visible=direction == "Swahili -> Kalenjin"), None, "" # --- UI Construction --- with gr.Blocks() as demo: gr.Markdown("# 🇰🇪 Kalenjin AI Translator & Speech (Research)", elem_classes=["app-title"]) gr.Markdown( """ Experimental Swahili <-> Kalenjin translation and Kalenjin text-to-speech. This is a research demo, so translations and speech may contain mistakes. Your optional saved examples help improve future models. """, elem_classes=["app-subtitle"], ) with gr.Tabs(): with gr.Tab("Translate"): gr.Markdown( "Translate between Swahili and Kalenjin. If the output is Kalenjin, you can also generate speech for it.", elem_classes=["section-note"], ) with gr.Row(): with gr.Column(): direction_radio = gr.Radio( choices=["Swahili -> Kalenjin", "Kalenjin -> Swahili"], value="Swahili -> Kalenjin", label="Choose translation direction", ) input_text = gr.Textbox( label="Enter text", placeholder="Type the sentence you want to translate...", lines=4, ) allow_save = gr.Checkbox( label="Allow us to save this use to help improve the AI", value=True, ) with gr.Row(): translate_btn = gr.Button("Translate", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(): output_text = gr.Textbox( label="Translation", interactive=False, lines=4, buttons=["copy"], ) with gr.Group(visible=True) as translation_speech_group: speak_output_btn = gr.Button("Listen to Kalenjin") with gr.Group(elem_classes=["speech-box"]): gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"]) translation_audio = gr.Audio(show_label=False, type="filepath") tts_status = gr.Textbox( show_label=False, value="", interactive=False, lines=1, elem_classes=["status-line"], ) gr.Markdown("### Try translation examples") gr.Markdown("Swahili to Kalenjin") with gr.Row(elem_classes=["examples-grid"]): ex_swa_1 = gr.Button("Habari yako", elem_classes=["example-button"]) ex_swa_2 = gr.Button("Ninaenda nyumbani", elem_classes=["example-button"]) ex_swa_3 = gr.Button("Leo mvua imenyesha sana", elem_classes=["example-button"]) gr.Markdown("Kalenjin to Swahili") with gr.Row(elem_classes=["examples-grid"]): ex_kln_1 = gr.Button("Iyamunee", elem_classes=["example-button"]) ex_kln_2 = gr.Button("Awendi gaa", elem_classes=["example-button"]) with gr.Accordion("Help improve the translation", open=False): gr.Markdown( "If the translation is wrong and you know the correct one, you can submit a correction." ) correction_input = gr.Textbox( label="Correct translation", placeholder="Type what the translation should have been...", ) feedback_btn = gr.Button("Submit Correction") feedback_msg = gr.Label(label="Status", value="") with gr.Tab("Kalenjin Text-to-Speech"): gr.Markdown( """ Enter Kalenjin text and generate speech using the experimental Kalenjin-TTS model. First generation may take longer while the speech model starts up. Short text may sound less natural than longer sentences. """, elem_classes=["section-note"], ) direct_tts_text = gr.Textbox( label="Enter Kalenjin text", placeholder="Type Kalenjin text to generate speech...", lines=4, ) direct_tts_allow_save = gr.Checkbox( label="Allow us to save this text to help improve the AI", value=True, ) direct_tts_btn = gr.Button("Generate Speech", variant="primary") with gr.Group(elem_classes=["speech-box"]): gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"]) direct_tts_audio = gr.Audio(show_label=False, type="filepath") direct_tts_status = gr.Textbox( show_label=False, value="", interactive=False, lines=1, elem_classes=["status-line"], ) gr.Markdown("### Try speech examples") with gr.Column(elem_classes=["examples-grid"]): tts_ex_1 = gr.Button("Kenya ko emet ne kong'asis ne bo Afrika.", elem_classes=["example-button"]) tts_ex_2 = gr.Button( "Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.", elem_classes=["example-button"], ) tts_ex_3 = gr.Button("Somani lagok en sugul.", elem_classes=["example-button"]) tts_ex_4 = gr.Button("Bandab sobet kotindoi kaimutik.", elem_classes=["example-button"]) # --- Event Wiring --- translate_btn.click( translate, inputs=[input_text, direction_radio, allow_save], outputs=output_text, ) clear_btn.click( lambda: ("", "", None, ""), outputs=[input_text, output_text, translation_audio, tts_status], ) feedback_btn.click( submit_feedback, inputs=[input_text, output_text, direction_radio, correction_input], outputs=[feedback_msg, correction_input], ) speak_output_btn.click( speak_translation, inputs=[output_text, direction_radio, allow_save], outputs=[translation_audio, tts_status], ) direct_tts_btn.click( synthesize_kalenjin_speech, inputs=[direct_tts_text, direct_tts_allow_save], outputs=[direct_tts_audio, direct_tts_status], ) direction_radio.change( translation_speech_visibility, inputs=[direction_radio], outputs=[translation_speech_group, translation_audio, tts_status], ) ex_swa_1.click( lambda: ("Swahili -> Kalenjin", "Habari yako", gr.update(visible=True), None, ""), outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], ) ex_swa_2.click( lambda: ("Swahili -> Kalenjin", "Ninaenda nyumbani", gr.update(visible=True), None, ""), outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], ) ex_swa_3.click( lambda: ("Swahili -> Kalenjin", "Leo mvua imenyesha sana", gr.update(visible=True), None, ""), outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], ) ex_kln_1.click( lambda: ("Kalenjin -> Swahili", "Iyamunee", gr.update(visible=False), None, ""), outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], ) ex_kln_2.click( lambda: ("Kalenjin -> Swahili", "Awendi gaa", gr.update(visible=False), None, ""), outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], ) tts_ex_1.click(lambda: "Kenya ko emet ne kong'asis ne bo Afrika.", outputs=[direct_tts_text]) tts_ex_2.click( lambda: "Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.", outputs=[direct_tts_text], ) tts_ex_3.click(lambda: "Somani lagok en sugul.", outputs=[direct_tts_text]) tts_ex_4.click(lambda: "Bandab sobet kotindoi kaimutik.", outputs=[direct_tts_text]) if __name__ == "__main__": # ssr_mode=False fixes the blank white screen issue demo.launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), ssr_mode=False, theme=gr.themes.Soft(), css=CUSTOM_CSS, )