mutaician's picture
remove white labels
9d5cf50
Raw
History Blame Contribute Delete
16.7 kB
import gradio as gr
from transformers import NllbTokenizerFast, AutoModelForSeq2SeqLM
from peft import PeftModel
import os
import json
import tempfile
from datetime import datetime
import gspread
import requests
from oauth2client.service_account import ServiceAccountCredentials
# --- GOOGLE SHEETS SETUP ---
def get_sheet():
# Load credentials from Hugging Face Secrets
raw_creds = os.getenv("G_CREDS")
sheet_name = os.getenv("SHEET_NAME")
if not raw_creds or not sheet_name:
raise RuntimeError("Missing G_CREDS or SHEET_NAME secret.")
creds_json = json.loads(raw_creds)
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_json, scope)
client = gspread.authorize(creds)
# Open the sheet
return client.open(sheet_name).sheet1
# --- Configuration ---
BASE_MODEL_ID = "facebook/nllb-200-distilled-600M"
ADAPTER_SWA_KLN = "mutaician/nllb-swahili-kalenjin-v3"
ADAPTER_KLN_SWA = "mutaician/nllb-kalenjin-swahili-v1"
TTS_API_URL = os.getenv("TTS_API_URL", "").strip()
TTS_API_KEY = os.getenv("TTS_API_KEY", "").strip()
CUSTOM_CSS = """
:root, body, .gradio-container {
font-family: Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important;
letter-spacing: 0 !important;
}
.gradio-container {
max-width: 1180px !important;
margin: 0 auto !important;
}
.app-title h1 {
font-size: 2rem !important;
line-height: 1.15 !important;
margin-bottom: 0.25rem !important;
}
.app-subtitle {
max-width: 850px;
color: var(--body-text-color-subdued);
}
.section-note {
color: var(--body-text-color-subdued);
margin-bottom: 0.75rem;
}
.examples-grid {
gap: 0.65rem !important;
}
.example-button {
min-height: 46px !important;
text-align: left !important;
justify-content: flex-start !important;
white-space: normal !important;
line-height: 1.35 !important;
font-size: 0.95rem !important;
font-weight: 500 !important;
border: 1px solid var(--border-color-primary) !important;
background: var(--background-fill-secondary) !important;
}
.example-button:hover {
border-color: var(--color-accent) !important;
}
.speech-box {
border: 1px solid var(--border-color-primary);
border-radius: 8px;
padding: 0.75rem 0.85rem;
background: var(--background-fill-secondary);
}
.speech-heading {
margin: 0 0 0.55rem 0 !important;
color: var(--body-text-color);
font-weight: 650;
}
.speech-box .label-wrap {
display: none !important;
}
.speech-box .wrap,
.speech-box .wrap.svelte-1cl284s,
.speech-box .audio-container,
.speech-box .component-wrapper,
.speech-box [data-testid="waveform"] {
border: 0 !important;
box-shadow: none !important;
}
.speech-box .block,
.speech-box .gradio-audio {
border: 0 !important;
box-shadow: none !important;
background: transparent !important;
}
.speech-box audio {
outline: 0 !important;
}
.status-line textarea {
min-height: 42px !important;
font-size: 0.95rem !important;
text-align: left !important;
}
"""
# --- Load Model (Global / At Startup) ---
print(f"Loading Base Model: {BASE_MODEL_ID}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
tokenizer = NllbTokenizerFast.from_pretrained(BASE_MODEL_ID)
print("Loading Adapters...")
model = PeftModel.from_pretrained(base_model, ADAPTER_SWA_KLN, adapter_name="swa_kln")
model.load_adapter(ADAPTER_KLN_SWA, adapter_name="kln_swa")
print("Models loaded successfully!")
# --- Helper Functions ---
def save_to_sheet(input_text, output_text, type_val, correction="", direction=""):
try:
sheet = get_sheet()
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Append row to Google Sheet
# Columns: Timestamp, Input, Output, Type, Correction, Direction
# Leaving existing format intact and adding direction at the very end
row = [timestamp, input_text, output_text, type_val, correction, direction]
sheet.append_row(row)
return True
except Exception as e:
print(f"Google Sheet Error: {e}")
return False
def translate(text, direction, allow_save):
if not text:
return ""
# 1. Prediction & Routing
if direction == "Swahili -> Kalenjin":
model.set_adapter("swa_kln")
tokenizer.src_lang = "swh_Latn"
target_lang_id = tokenizer.convert_tokens_to_ids("luo_Latn")
else: # Kalenjin -> Swahili
model.set_adapter("kln_swa")
tokenizer.src_lang = "luo_Latn"
target_lang_id = tokenizer.convert_tokens_to_ids("swh_Latn")
inputs = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=256,
num_beams=5,
early_stopping=True
)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# 2. Logging
if allow_save:
save_to_sheet(text, result, "inference", "", direction)
return result
def submit_feedback(input_text, model_output, direction, user_correction):
"""Save correction and clear input box"""
if not user_correction or not input_text:
return "Please ensure there is a translation and a correction.", user_correction
success = save_to_sheet(input_text, model_output, "feedback", user_correction, direction)
if success:
return "Asante / Kongoi! Feedback saved.", ""
else:
return "Error saving. Try again", user_correction
def tts_endpoint_url():
if not TTS_API_URL:
return ""
if TTS_API_URL.rstrip("/").endswith("/synthesize"):
return TTS_API_URL
return f"{TTS_API_URL.rstrip('/')}/synthesize"
def synthesize_kalenjin_speech(text, allow_save=False):
text = " ".join((text or "").strip().split())
if not text:
return None, "Enter Kalenjin text first."
url = tts_endpoint_url()
if not url:
return None, "Kalenjin-TTS is not configured yet."
headers = {}
if TTS_API_KEY:
headers["X-TTS-Key"] = TTS_API_KEY
try:
response = requests.post(
url,
json={
"text": text,
"nfe_step": 32,
"cfg_strength": 2.0,
"speed": 0.8,
"remove_silence": False,
"output_format": "mp3",
},
headers=headers,
timeout=300,
)
if response.status_code != 200:
try:
detail = response.json()
except Exception:
detail = response.text[:200]
print(f"TTS API Error: {response.status_code} {detail}")
return None, "Speech generation failed. Please try again."
audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
audio_file.write(response.content)
audio_file.close()
if allow_save:
save_to_sheet(text, "", "tts")
elapsed = response.headers.get("X-TTS-Elapsed-Seconds")
if elapsed:
return audio_file.name, f"Speech generated in {elapsed}s."
return audio_file.name, "Speech generated."
except requests.Timeout:
return None, "Speech generation timed out. Please try a shorter sentence."
except Exception as e:
print(f"TTS Request Error: {e}")
return None, "Speech service is unavailable right now."
def speak_translation(model_output, direction, allow_save):
if direction != "Swahili -> Kalenjin":
return None, "Speech is available when the output is Kalenjin."
return synthesize_kalenjin_speech(model_output, allow_save)
def translation_speech_visibility(direction):
return gr.update(visible=direction == "Swahili -> Kalenjin"), None, ""
# --- UI Construction ---
with gr.Blocks() as demo:
gr.Markdown("# 🇰🇪 Kalenjin AI Translator & Speech (Research)", elem_classes=["app-title"])
gr.Markdown(
"""
Experimental Swahili <-> Kalenjin translation and Kalenjin text-to-speech.
This is a research demo, so translations and speech may contain mistakes.
Your optional saved examples help improve future models.
""",
elem_classes=["app-subtitle"],
)
with gr.Tabs():
with gr.Tab("Translate"):
gr.Markdown(
"Translate between Swahili and Kalenjin. If the output is Kalenjin, you can also generate speech for it.",
elem_classes=["section-note"],
)
with gr.Row():
with gr.Column():
direction_radio = gr.Radio(
choices=["Swahili -> Kalenjin", "Kalenjin -> Swahili"],
value="Swahili -> Kalenjin",
label="Choose translation direction",
)
input_text = gr.Textbox(
label="Enter text",
placeholder="Type the sentence you want to translate...",
lines=4,
)
allow_save = gr.Checkbox(
label="Allow us to save this use to help improve the AI",
value=True,
)
with gr.Row():
translate_btn = gr.Button("Translate", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
output_text = gr.Textbox(
label="Translation",
interactive=False,
lines=4,
buttons=["copy"],
)
with gr.Group(visible=True) as translation_speech_group:
speak_output_btn = gr.Button("Listen to Kalenjin")
with gr.Group(elem_classes=["speech-box"]):
gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"])
translation_audio = gr.Audio(show_label=False, type="filepath")
tts_status = gr.Textbox(
show_label=False,
value="",
interactive=False,
lines=1,
elem_classes=["status-line"],
)
gr.Markdown("### Try translation examples")
gr.Markdown("Swahili to Kalenjin")
with gr.Row(elem_classes=["examples-grid"]):
ex_swa_1 = gr.Button("Habari yako", elem_classes=["example-button"])
ex_swa_2 = gr.Button("Ninaenda nyumbani", elem_classes=["example-button"])
ex_swa_3 = gr.Button("Leo mvua imenyesha sana", elem_classes=["example-button"])
gr.Markdown("Kalenjin to Swahili")
with gr.Row(elem_classes=["examples-grid"]):
ex_kln_1 = gr.Button("Iyamunee", elem_classes=["example-button"])
ex_kln_2 = gr.Button("Awendi gaa", elem_classes=["example-button"])
with gr.Accordion("Help improve the translation", open=False):
gr.Markdown(
"If the translation is wrong and you know the correct one, you can submit a correction."
)
correction_input = gr.Textbox(
label="Correct translation",
placeholder="Type what the translation should have been...",
)
feedback_btn = gr.Button("Submit Correction")
feedback_msg = gr.Label(label="Status", value="")
with gr.Tab("Kalenjin Text-to-Speech"):
gr.Markdown(
"""
Enter Kalenjin text and generate speech using the experimental Kalenjin-TTS model.
First generation may take longer while the speech model starts up.
Short text may sound less natural than longer sentences.
""",
elem_classes=["section-note"],
)
direct_tts_text = gr.Textbox(
label="Enter Kalenjin text",
placeholder="Type Kalenjin text to generate speech...",
lines=4,
)
direct_tts_allow_save = gr.Checkbox(
label="Allow us to save this text to help improve the AI",
value=True,
)
direct_tts_btn = gr.Button("Generate Speech", variant="primary")
with gr.Group(elem_classes=["speech-box"]):
gr.Markdown("Generated Kalenjin Speech", elem_classes=["speech-heading"])
direct_tts_audio = gr.Audio(show_label=False, type="filepath")
direct_tts_status = gr.Textbox(
show_label=False,
value="",
interactive=False,
lines=1,
elem_classes=["status-line"],
)
gr.Markdown("### Try speech examples")
with gr.Column(elem_classes=["examples-grid"]):
tts_ex_1 = gr.Button("Kenya ko emet ne kong'asis ne bo Afrika.", elem_classes=["example-button"])
tts_ex_2 = gr.Button(
"Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.",
elem_classes=["example-button"],
)
tts_ex_3 = gr.Button("Somani lagok en sugul.", elem_classes=["example-button"])
tts_ex_4 = gr.Button("Bandab sobet kotindoi kaimutik.", elem_classes=["example-button"])
# --- Event Wiring ---
translate_btn.click(
translate,
inputs=[input_text, direction_radio, allow_save],
outputs=output_text,
)
clear_btn.click(
lambda: ("", "", None, ""),
outputs=[input_text, output_text, translation_audio, tts_status],
)
feedback_btn.click(
submit_feedback,
inputs=[input_text, output_text, direction_radio, correction_input],
outputs=[feedback_msg, correction_input],
)
speak_output_btn.click(
speak_translation,
inputs=[output_text, direction_radio, allow_save],
outputs=[translation_audio, tts_status],
)
direct_tts_btn.click(
synthesize_kalenjin_speech,
inputs=[direct_tts_text, direct_tts_allow_save],
outputs=[direct_tts_audio, direct_tts_status],
)
direction_radio.change(
translation_speech_visibility,
inputs=[direction_radio],
outputs=[translation_speech_group, translation_audio, tts_status],
)
ex_swa_1.click(
lambda: ("Swahili -> Kalenjin", "Habari yako", gr.update(visible=True), None, ""),
outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status],
)
ex_swa_2.click(
lambda: ("Swahili -> Kalenjin", "Ninaenda nyumbani", gr.update(visible=True), None, ""),
outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status],
)
ex_swa_3.click(
lambda: ("Swahili -> Kalenjin", "Leo mvua imenyesha sana", gr.update(visible=True), None, ""),
outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status],
)
ex_kln_1.click(
lambda: ("Kalenjin -> Swahili", "Iyamunee", gr.update(visible=False), None, ""),
outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status],
)
ex_kln_2.click(
lambda: ("Kalenjin -> Swahili", "Awendi gaa", gr.update(visible=False), None, ""),
outputs=[direction_radio, input_text, translation_speech_group, translation_audio, tts_status],
)
tts_ex_1.click(lambda: "Kenya ko emet ne kong'asis ne bo Afrika.", outputs=[direct_tts_text])
tts_ex_2.click(
lambda: "Tuitos koriik tugul ak kosherekean kamunget ne bo arawet ab taman ak oeng.",
outputs=[direct_tts_text],
)
tts_ex_3.click(lambda: "Somani lagok en sugul.", outputs=[direct_tts_text])
tts_ex_4.click(lambda: "Bandab sobet kotindoi kaimutik.", outputs=[direct_tts_text])
if __name__ == "__main__":
# ssr_mode=False fixes the blank white screen issue
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
ssr_mode=False,
theme=gr.themes.Soft(),
css=CUSTOM_CSS,
)