| import gradio as gr |
| from transformers import NllbTokenizerFast, AutoModelForSeq2SeqLM |
| from peft import PeftModel |
| import os |
| import json |
| import tempfile |
| from datetime import datetime |
| import gspread |
| import requests |
| from oauth2client.service_account import ServiceAccountCredentials |
|
|
| |
| def get_sheet(): |
| |
| raw_creds = os.getenv("G_CREDS") |
| sheet_name = os.getenv("SHEET_NAME") |
| if not raw_creds or not sheet_name: |
| raise RuntimeError("Missing G_CREDS or SHEET_NAME secret.") |
|
|
| creds_json = json.loads(raw_creds) |
| scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_json, scope) |
| client = gspread.authorize(creds) |
| |
| |
| return client.open(sheet_name).sheet1 |
|
|
| |
| BASE_MODEL_ID = "facebook/nllb-200-distilled-600M" |
| ADAPTER_SWA_KLN = "mutaician/nllb-swahili-kalenjin-v3" |
| ADAPTER_KLN_SWA = "mutaician/nllb-kalenjin-swahili-v1" |
| TTS_API_URL = os.getenv("TTS_API_URL", "").strip() |
| TTS_API_KEY = os.getenv("TTS_API_KEY", "").strip() |
|
|
| CUSTOM_CSS = """ |
| :root, body, .gradio-container { |
| font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; |
| letter-spacing: 0 !important; |
| } |
| .gradio-container { |
| max-width: 1180px !important; |
| margin: 0 auto !important; |
| } |
| .app-title h1 { |
| font-size: 2rem !important; |
| line-height: 1.15 !important; |
| margin-bottom: 0.25rem !important; |
| } |
| .app-subtitle { |
| max-width: 850px; |
| color: var(--body-text-color-subdued); |
| } |
| .section-note { |
| color: var(--body-text-color-subdued); |
| margin-bottom: 0.75rem; |
| } |
| .examples-grid { |
| gap: 0.65rem !important; |
| } |
| .example-button { |
| min-height: 46px !important; |
| text-align: left !important; |
| justify-content: flex-start !important; |
| white-space: normal !important; |
| line-height: 1.35 !important; |
| font-size: 0.95rem !important; |
| font-weight: 500 !important; |
| border: 1px solid var(--border-color-primary) !important; |
| background: var(--background-fill-secondary) !important; |
| } |
| .example-button:hover { |
| border-color: var(--color-accent) !important; |
| } |
| .speech-box { |
| border: 1px solid var(--border-color-primary); |
| border-radius: 8px; |
| padding: 0.75rem 0.85rem; |
| background: var(--background-fill-secondary); |
| } |
| .speech-heading { |
| margin: 0 0 0.55rem 0 !important; |
| color: var(--body-text-color); |
| font-weight: 650; |
| } |
| .speech-box .label-wrap { |
| display: none !important; |
| } |
| .speech-box .wrap, |
| .speech-box .wrap.svelte-1cl284s, |
| .speech-box .audio-container, |
| .speech-box .component-wrapper, |
| .speech-box [data-testid="waveform"] { |
| border: 0 !important; |
| box-shadow: none !important; |
| } |
| .speech-box .block, |
| .speech-box .gradio-audio { |
| border: 0 !important; |
| box-shadow: none !important; |
| background: transparent !important; |
| } |
| .speech-box audio { |
| outline: 0 !important; |
| } |
| .status-line textarea { |
| min-height: 42px !important; |
| font-size: 0.95rem !important; |
| text-align: left !important; |
| } |
| """ |
|
|
| |
| print(f"Loading Base Model: {BASE_MODEL_ID}...") |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) |
|
|
| tokenizer = NllbTokenizerFast.from_pretrained(BASE_MODEL_ID) |
|
|
| print("Loading Adapters...") |
| model = PeftModel.from_pretrained(base_model, ADAPTER_SWA_KLN, adapter_name="swa_kln") |
| model.load_adapter(ADAPTER_KLN_SWA, adapter_name="kln_swa") |
|
|
| print("Models loaded successfully!") |
|
|
|
|
| |
|
|
| def save_to_sheet(input_text, output_text, type_val, correction="", direction=""): |
| try: |
| sheet = get_sheet() |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| |
| |
| |
| |
| row = [timestamp, input_text, output_text, type_val, correction, direction] |
| sheet.append_row(row) |
| return True |
| except Exception as e: |
| print(f"Google Sheet Error: {e}") |
| return False |
|
|
| def translate(text, direction, allow_save): |
| if not text: |
| return "" |
| |
| |
| if direction == "Swahili -> Kalenjin": |
| model.set_adapter("swa_kln") |
| tokenizer.src_lang = "swh_Latn" |
| target_lang_id = tokenizer.convert_tokens_to_ids("luo_Latn") |
| else: |
| model.set_adapter("kln_swa") |
| tokenizer.src_lang = "luo_Latn" |
| target_lang_id = tokenizer.convert_tokens_to_ids("swh_Latn") |
| |
| inputs = tokenizer(text, return_tensors="pt") |
| |
| generated_tokens = model.generate( |
| **inputs, |
| forced_bos_token_id=target_lang_id, |
| max_length=256, |
| num_beams=5, |
| early_stopping=True |
| ) |
| result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
| |
| if allow_save: |
| save_to_sheet(text, result, "inference", "", direction) |
| |
| return result |
|
|
| def submit_feedback(input_text, model_output, direction, user_correction): |
| """Save correction and clear input box""" |
| if not user_correction or not input_text: |
| return "Please ensure there is a translation and a correction.", user_correction |
| |
| success = save_to_sheet(input_text, model_output, "feedback", user_correction, direction) |
| if success: |
| return "Asante / Kongoi! Feedback saved.", "" |
| else: |
| return "Error saving. Try again", user_correction |
|
|
| def tts_endpoint_url(): |
| if not TTS_API_URL: |
| return "" |
| if TTS_API_URL.rstrip("/").endswith("/synthesize"): |
| return TTS_API_URL |
| return f"{TTS_API_URL.rstrip('/')}/synthesize" |
|
|
| def synthesize_kalenjin_speech(text, allow_save=False): |
| text = " ".join((text or "").strip().split()) |
| if not text: |
| return None, "Enter Kalenjin text first." |
|
|
| url = tts_endpoint_url() |
| if not url: |
| return None, "Kalenjin-TTS is not configured yet." |
|
|
| headers = {} |
| if TTS_API_KEY: |
| headers["X-TTS-Key"] = TTS_API_KEY |
|
|
| try: |
| response = requests.post( |
| url, |
| json={ |
| "text": text, |
| "nfe_step": 32, |
| "cfg_strength": 2.0, |
| "speed": 0.8, |
| "remove_silence": False, |
| "output_format": "mp3", |
| }, |
| headers=headers, |
| timeout=300, |
| ) |
| if response.status_code != 200: |
| try: |
| detail = response.json() |
| except Exception: |
| detail = response.text[:200] |
| print(f"TTS API Error: {response.status_code} {detail}") |
| return None, "Speech generation failed. Please try again." |
|
|
| audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
| audio_file.write(response.content) |
| audio_file.close() |
| if allow_save: |
| save_to_sheet(text, "", "tts") |
| elapsed = response.headers.get("X-TTS-Elapsed-Seconds") |
| if elapsed: |
| return audio_file.name, f"Speech generated in {elapsed}s." |
| return audio_file.name, "Speech generated." |
| except requests.Timeout: |
| return None, "Speech generation timed out. Please try a shorter sentence." |
| except Exception as e: |
| print(f"TTS Request Error: {e}") |
| return None, "Speech service is unavailable right now." |
|
|
| def speak_translation(model_output, direction, allow_save): |
| if direction != "Swahili -> Kalenjin": |
| return None, "Speech is available when the output is Kalenjin." |
| return synthesize_kalenjin_speech(model_output, allow_save) |
|
|
| def translation_speech_visibility(direction): |
| return gr.update(visible=direction == "Swahili -> Kalenjin"), None, "" |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# 🇰🇪 Kalenjin AI Translator & Speech (Research)", elem_classes=["app-title"]) |
| gr.Markdown( |
| """ |
| Experimental Swahili <-> Kalenjin translation and Kalenjin text-to-speech. |
| This is a research demo, so translations and speech may contain mistakes. |
| Your optional saved examples help improve future models. |
| """, |
| elem_classes=["app-subtitle"], |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Translate"): |
| gr.Markdown( |
| "Translate between Swahili and Kalenjin. If the output is Kalenjin, you can also generate speech for it.", |
| elem_classes=["section-note"], |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| direction_radio = gr.Radio( |
| choices=["Swahili -> Kalenjin", "Kalenjin -> Swahili"], |
| value="Swahili -> Kalenjin", |
| label="Choose translation direction", |
| ) |
| input_text = gr.Textbox( |
| label="Enter text", |
| placeholder="Type the sentence you want to translate...", |
| lines=4, |
| ) |
| allow_save = gr.Checkbox( |
| label="Allow us to save this use to help improve the AI", |
| value=True, |
| ) |
| with gr.Row(): |
| translate_btn = gr.Button("Translate", variant="primary") |
| clear_btn = gr.Button("Clear") |
|
|
| with gr.Column(): |
| output_text = gr.Textbox( |
| label="Translation", |
| interactive=False, |
| lines=4, |
| buttons=["copy"], |
| ) |
| with gr.Group(visible=True) as translation_speech_group: |
| speak_output_btn = gr.Button("Listen to Kalenjin") |
| with gr.Group(elem_classes=["speech-box"]): |
| gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"]) |
| translation_audio = gr.Audio(show_label=False, type="filepath") |
| tts_status = gr.Textbox( |
| show_label=False, |
| value="", |
| interactive=False, |
| lines=1, |
| elem_classes=["status-line"], |
| ) |
|
|
| gr.Markdown("### Try translation examples") |
| gr.Markdown("Swahili to Kalenjin") |
| with gr.Row(elem_classes=["examples-grid"]): |
| ex_swa_1 = gr.Button("Habari yako", elem_classes=["example-button"]) |
| ex_swa_2 = gr.Button("Ninaenda nyumbani", elem_classes=["example-button"]) |
| ex_swa_3 = gr.Button("Leo mvua imenyesha sana", elem_classes=["example-button"]) |
| gr.Markdown("Kalenjin to Swahili") |
| with gr.Row(elem_classes=["examples-grid"]): |
| ex_kln_1 = gr.Button("Iyamunee", elem_classes=["example-button"]) |
| ex_kln_2 = gr.Button("Awendi gaa", elem_classes=["example-button"]) |
|
|
| with gr.Accordion("Help improve the translation", open=False): |
| gr.Markdown( |
| "If the translation is wrong and you know the correct one, you can submit a correction." |
| ) |
| correction_input = gr.Textbox( |
| label="Correct translation", |
| placeholder="Type what the translation should have been...", |
| ) |
| feedback_btn = gr.Button("Submit Correction") |
| feedback_msg = gr.Label(label="Status", value="") |
|
|
| with gr.Tab("Kalenjin Text-to-Speech"): |
| gr.Markdown( |
| """ |
| Enter Kalenjin text and generate speech using the experimental Kalenjin-TTS model. |
| |
| First generation may take longer while the speech model starts up. |
| Short text may sound less natural than longer sentences. |
| """, |
| elem_classes=["section-note"], |
| ) |
| direct_tts_text = gr.Textbox( |
| label="Enter Kalenjin text", |
| placeholder="Type Kalenjin text to generate speech...", |
| lines=4, |
| ) |
| direct_tts_allow_save = gr.Checkbox( |
| label="Allow us to save this text to help improve the AI", |
| value=True, |
| ) |
| direct_tts_btn = gr.Button("Generate Speech", variant="primary") |
| with gr.Group(elem_classes=["speech-box"]): |
| gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"]) |
| direct_tts_audio = gr.Audio(show_label=False, type="filepath") |
| direct_tts_status = gr.Textbox( |
| show_label=False, |
| value="", |
| interactive=False, |
| lines=1, |
| elem_classes=["status-line"], |
| ) |
|
|
| gr.Markdown("### Try speech examples") |
| with gr.Column(elem_classes=["examples-grid"]): |
| tts_ex_1 = gr.Button("Kenya ko emet ne kong'asis ne bo Afrika.", elem_classes=["example-button"]) |
| tts_ex_2 = gr.Button( |
| "Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.", |
| elem_classes=["example-button"], |
| ) |
| tts_ex_3 = gr.Button("Somani lagok en sugul.", elem_classes=["example-button"]) |
| tts_ex_4 = gr.Button("Bandab sobet kotindoi kaimutik.", elem_classes=["example-button"]) |
|
|
| |
| translate_btn.click( |
| translate, |
| inputs=[input_text, direction_radio, allow_save], |
| outputs=output_text, |
| ) |
|
|
| clear_btn.click( |
| lambda: ("", "", None, ""), |
| outputs=[input_text, output_text, translation_audio, tts_status], |
| ) |
|
|
| feedback_btn.click( |
| submit_feedback, |
| inputs=[input_text, output_text, direction_radio, correction_input], |
| outputs=[feedback_msg, correction_input], |
| ) |
|
|
| speak_output_btn.click( |
| speak_translation, |
| inputs=[output_text, direction_radio, allow_save], |
| outputs=[translation_audio, tts_status], |
| ) |
|
|
| direct_tts_btn.click( |
| synthesize_kalenjin_speech, |
| inputs=[direct_tts_text, direct_tts_allow_save], |
| outputs=[direct_tts_audio, direct_tts_status], |
| ) |
|
|
| direction_radio.change( |
| translation_speech_visibility, |
| inputs=[direction_radio], |
| outputs=[translation_speech_group, translation_audio, tts_status], |
| ) |
|
|
| ex_swa_1.click( |
| lambda: ("Swahili -> Kalenjin", "Habari yako", gr.update(visible=True), None, ""), |
| outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], |
| ) |
| ex_swa_2.click( |
| lambda: ("Swahili -> Kalenjin", "Ninaenda nyumbani", gr.update(visible=True), None, ""), |
| outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], |
| ) |
| ex_swa_3.click( |
| lambda: ("Swahili -> Kalenjin", "Leo mvua imenyesha sana", gr.update(visible=True), None, ""), |
| outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], |
| ) |
| ex_kln_1.click( |
| lambda: ("Kalenjin -> Swahili", "Iyamunee", gr.update(visible=False), None, ""), |
| outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], |
| ) |
| ex_kln_2.click( |
| lambda: ("Kalenjin -> Swahili", "Awendi gaa", gr.update(visible=False), None, ""), |
| outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status], |
| ) |
|
|
| tts_ex_1.click(lambda: "Kenya ko emet ne kong'asis ne bo Afrika.", outputs=[direct_tts_text]) |
| tts_ex_2.click( |
| lambda: "Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.", |
| outputs=[direct_tts_text], |
| ) |
| tts_ex_3.click(lambda: "Somani lagok en sugul.", outputs=[direct_tts_text]) |
| tts_ex_4.click(lambda: "Bandab sobet kotindoi kaimutik.", outputs=[direct_tts_text]) |
|
|
| if __name__ == "__main__": |
| |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.getenv("PORT", "7860")), |
| ssr_mode=False, |
| theme=gr.themes.Soft(), |
| css=CUSTOM_CSS, |
| ) |
|
|