from __future__ import annotations import difflib import gc import json import os import re import sys import unicodedata from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any from urllib.parse import urlparse import boto3 import gradio as gr import numpy as np import pandas as pd from botocore.config import Config DEFAULT_ARTIFACT_PREFIX = ( "s3://131-component-staging/" "multilingual-static-word-embeddings/stage-6/" ) ARTIFACT_URI_ENV = "SPACE_ARTIFACT_S3_URI" ARTIFACT_PREFIX_ENV = "SPACE_ARTIFACT_S3_PREFIX" CACHE_ROOT = Path(os.getenv("ARTIFACT_CACHE_DIR", "/tmp/multilingual_space_artifacts")) REQUIRED_FILES = ("aligned_all.faiss", "all_metadata.jsonl", "config.json") DEFAULT_LANGUAGES = ["de", "en", "fr", "lb"] TRANSLATION_COLUMNS = [ "target_lang", "translation", "token", "score", "cosine", "rank", "bidirectional", "id", "source_vec_file", ] NEIGHBOR_COLUMNS = [ "lang", "word", "token", "score", "cosine", "rank", "id", ] VOCAB_COLUMNS = ["id", "lang", "surface", "token", "source_vec_file"] @dataclass class LangVectors: lang: str ids: np.ndarray metas: list[dict[str, Any]] vecs: np.ndarray @dataclass class RuntimeOptions: top_k: int min_score: float csls_k: int candidate_retrieval_k: int csls_prefetch_k: int bidirectional: bool score_method: str filter_stopwords: bool filter_bad_tokens: bool use_surface: bool @dataclass class Space: root: Path artifact_uri: str config: dict[str, Any] languages: list[str] by_lang: dict[str, LangVectors] lookup: dict[str, dict[str, list[int]]] id_to_location: dict[int, tuple[str, int]] def parse_s3_uri(uri: str) -> tuple[str, str]: parsed = urlparse(uri) if parsed.scheme != "s3" or not parsed.netloc: raise ValueError(f"Expected s3://bucket/key URI, got: {uri}") return parsed.netloc, parsed.path.lstrip("/") def make_s3_client(): access_key = os.getenv("SE_ACCESS_KEY") or os.getenv("AWS_ACCESS_KEY_ID") secret_key = os.getenv("SE_SECRET_KEY") or os.getenv("AWS_SECRET_ACCESS_KEY") endpoint_url = os.getenv("SE_HOST_URL") or os.getenv("AWS_ENDPOINT_URL") region = os.getenv("AWS_DEFAULT_REGION", "us-east-1") if endpoint_url and not endpoint_url.startswith(("http://", "https://")): endpoint_url = f"https://{endpoint_url}" kwargs: dict[str, Any] = { "service_name": "s3", "region_name": region, "config": Config( signature_version="s3v4", s3={"addressing_style": "path"}, retries={"max_attempts": 5, "mode": "standard"}, ), } if endpoint_url: kwargs["endpoint_url"] = endpoint_url if access_key and secret_key: kwargs["aws_access_key_id"] = access_key kwargs["aws_secret_access_key"] = secret_key return boto3.client(**kwargs) def find_latest_artifact_uri(client) -> str: explicit_uri = os.getenv(ARTIFACT_URI_ENV, "").strip() if explicit_uri: explicit_uri = explicit_uri.rstrip("/") if "multilingual_space_" in explicit_uri: return explicit_uri bucket, prefix = parse_s3_uri(explicit_uri) return find_latest_artifact_uri_under_prefix(client, bucket, prefix) prefix_uri = os.getenv(ARTIFACT_PREFIX_ENV, DEFAULT_ARTIFACT_PREFIX).strip() bucket, prefix = parse_s3_uri(prefix_uri) return find_latest_artifact_uri_under_prefix(client, bucket, prefix) def find_latest_artifact_uri_under_prefix(client, bucket: str, prefix: str) -> str: prefix = prefix.rstrip("/") + "/" pattern = re.compile(r"(.*multilingual_space_(\d{8}_\d{6})\.json)/config\.json$") candidates: list[tuple[str, str]] = [] paginator = client.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=bucket, Prefix=prefix): for obj in page.get("Contents", []): key = obj["Key"] match = pattern.match(key) if match: candidates.append((match.group(2), match.group(1))) if not candidates: raise FileNotFoundError( f"No multilingual_space_*.json/config.json found under s3://{bucket}/{prefix}" ) _, latest_key = sorted(candidates)[-1] return f"s3://{bucket}/{latest_key}" def artifact_cache_dir(artifact_uri: str) -> Path: _, key = parse_s3_uri(artifact_uri) name = Path(key.rstrip("/")).name return CACHE_ROOT / name def download_artifact() -> tuple[Path, str]: client = make_s3_client() artifact_uri = find_latest_artifact_uri(client) local_dir = artifact_cache_dir(artifact_uri) local_dir.mkdir(parents=True, exist_ok=True) bucket, key_prefix = parse_s3_uri(artifact_uri) key_prefix = key_prefix.rstrip("/") for filename in REQUIRED_FILES: local_path = local_dir / filename if local_path.exists() and local_path.stat().st_size > 0: continue key = f"{key_prefix}/{filename}" print(f"Downloading s3://{bucket}/{key} -> {local_path}", file=sys.stderr) client.download_file(bucket, key, str(local_path)) return local_dir, artifact_uri def strip_diacritics(text: str) -> str: return "".join( ch for ch in unicodedata.normalize("NFKD", text) if not unicodedata.combining(ch) ) def lookup_key(text: str) -> str: text = " ".join(text.strip().casefold().split()) return strip_diacritics(text) def is_good_token(token: str, min_len: int = 4) -> bool: if not token or len(token) < min_len or token.isdigit(): return False alpha = sum(ch.isalpha() for ch in token) if alpha < 2: return False return all(ch.isalnum() or ch in "-'_" for ch in token) def read_config(space_dir: Path) -> dict[str, Any]: with (space_dir / "config.json").open("r", encoding="utf-8") as f: return json.load(f) def read_metadata(space_dir: Path) -> tuple[list[dict[str, Any]], dict[str, list[int]]]: metadata_path = space_dir / "all_metadata.jsonl" metadata: list[dict[str, Any] | None] = [] ids_by_lang: dict[str, list[int]] = {} with metadata_path.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue meta = json.loads(line) row_id = int(meta["id"]) while len(metadata) <= row_id: metadata.append(None) metadata[row_id] = meta ids_by_lang.setdefault(str(meta["lang"]), []).append(row_id) missing = [i for i, meta in enumerate(metadata) if meta is None] if missing: raise ValueError(f"Metadata ids are not contiguous; first missing id is {missing[0]}") return [m for m in metadata if m is not None], ids_by_lang def reconstruct_range(index: Any, start: int, count: int) -> np.ndarray: try: vecs = index.reconstruct_n(start, count) except TypeError: vecs = np.empty((count, index.d), dtype=np.float32) index.reconstruct_n(start, count, vecs) return np.ascontiguousarray(vecs, dtype=np.float32) def reconstruct_ids(index: Any, ids: list[int]) -> np.ndarray: if not ids: return np.empty((0, index.d), dtype=np.float32) start = ids[0] if ids == list(range(start, start + len(ids))): return reconstruct_range(index, start, len(ids)) vecs = np.empty((len(ids), index.d), dtype=np.float32) for local_i, row_id in enumerate(ids): try: vecs[local_i] = index.reconstruct(int(row_id)) except TypeError: index.reconstruct(int(row_id), vecs[local_i]) return np.ascontiguousarray(vecs, dtype=np.float32) def normalize_rows(vecs: np.ndarray) -> np.ndarray: norms = np.linalg.norm(vecs, axis=1, keepdims=True) return (vecs / (norms + 1e-12)).astype(np.float32, copy=False) def load_vectors_from_faiss(space_dir: Path, ids_by_lang: dict[str, list[int]]) -> dict[str, np.ndarray]: faiss_path = space_dir / "aligned_all.faiss" try: import faiss # type: ignore except ImportError as exc: raise RuntimeError( "faiss-cpu is required. The Space must install faiss-cpu from requirements.txt." ) from exc print(f"Loading FAISS index: {faiss_path}", file=sys.stderr) index = faiss.read_index(str(faiss_path)) vectors_by_lang: dict[str, np.ndarray] = {} for lang, ids in sorted(ids_by_lang.items()): print(f"Reconstructing {lang}: {len(ids)} vectors", file=sys.stderr) vectors_by_lang[lang] = normalize_rows(reconstruct_ids(index, ids)) del index gc.collect() return vectors_by_lang def build_lookup(languages: dict[str, LangVectors]) -> dict[str, dict[str, list[int]]]: lookup: dict[str, dict[str, list[int]]] = {} for lang, data in languages.items(): lang_lookup: dict[str, list[int]] = {} for global_id, meta in zip(data.ids.tolist(), data.metas): for value in (meta.get("token"), meta.get("surface")): if not value: continue lang_lookup.setdefault(lookup_key(str(value)), []).append(int(global_id)) lookup[lang] = lang_lookup return lookup @lru_cache(maxsize=1) def load_space() -> Space: space_dir, artifact_uri = download_artifact() config = read_config(space_dir) metadata, ids_by_lang = read_metadata(space_dir) vectors_by_lang = load_vectors_from_faiss(space_dir, ids_by_lang) by_lang: dict[str, LangVectors] = {} id_to_location: dict[int, tuple[str, int]] = {} languages = list(config.get("languages") or sorted(ids_by_lang)) for lang in languages: ids = ids_by_lang.get(lang) if not ids: continue metas = [metadata[row_id] for row_id in ids] vecs = vectors_by_lang[lang] by_lang[lang] = LangVectors( lang=lang, ids=np.asarray(ids, dtype=np.int64), metas=metas, vecs=vecs, ) for local_i, row_id in enumerate(ids): id_to_location[int(row_id)] = (lang, local_i) languages = [lang for lang in languages if lang in by_lang] return Space( root=space_dir, artifact_uri=artifact_uri, config=config, languages=languages, by_lang=by_lang, lookup=build_lookup(by_lang), id_to_location=id_to_location, ) def default_options(config: dict[str, Any]) -> RuntimeOptions: bidi_config = config.get("bidirectional_consistency") or {} return RuntimeOptions( top_k=int(config.get("top_k", 3)), min_score=float(config.get("min_score", 0.15)), csls_k=int(config.get("csls_k", 10)), candidate_retrieval_k=int(config.get("candidate_retrieval_k", 9)), csls_prefetch_k=int(config.get("csls_prefetch_k", 50)), bidirectional=bool(bidi_config.get("enabled", True)), score_method="csls", filter_stopwords=True, filter_bad_tokens=True, use_surface=True, ) def make_options( top_k: int, min_score: float, csls_k: int, candidate_retrieval_k: int, csls_prefetch_k: int, bidirectional: bool, score_method: str, filter_stopwords: bool, filter_bad_tokens: bool, use_surface: bool, ) -> RuntimeOptions: return RuntimeOptions( top_k=int(top_k), min_score=float(min_score), csls_k=int(csls_k), candidate_retrieval_k=int(candidate_retrieval_k), csls_prefetch_k=int(csls_prefetch_k), bidirectional=bool(bidirectional), score_method=str(score_method).lower(), filter_stopwords=bool(filter_stopwords), filter_bad_tokens=bool(filter_bad_tokens), use_surface=bool(use_surface), ) def top_indices(values: np.ndarray, k: int) -> np.ndarray: k = min(max(0, k), values.shape[0]) if k == 0: return np.empty((0,), dtype=np.int64) if k >= values.shape[0]: return np.argsort(-values) idx = np.argpartition(-values, k - 1)[:k] return idx[np.argsort(-values[idx])] def top_mean(values: np.ndarray, k: int) -> float: k = min(max(1, k), values.shape[0]) idx = top_indices(values, k) return float(values[idx].mean()) def candidate_allowed(meta: dict[str, Any], lang: str, space: Space, opts: RuntimeOptions) -> bool: token = str(meta.get("token") or "") if opts.filter_bad_tokens: min_len = int((space.config.get("filters") or {}).get("target_is_good_token_min_len", 4)) if not is_good_token(token, min_len): return False if opts.filter_stopwords: stopwords = set((space.config.get("stopwords") or {}).get(lang, [])) if token.lower() in stopwords: return False return True def rank_candidates( space: Space, query_vec: np.ndarray, source_lang: str, target_lang: str, opts: RuntimeOptions, *, apply_filters: bool = True, ) -> list[dict[str, Any]]: source = space.by_lang[source_lang] target = space.by_lang[target_lang] cosine_all = target.vecs @ query_vec prefetch_k = max(opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.top_k) prefetch_ids = top_indices(cosine_all, min(prefetch_k, len(target.metas))) candidate_cosines = cosine_all[prefetch_ids] if opts.score_method == "csls": r_query = top_mean(cosine_all, opts.csls_k) candidate_vecs = target.vecs[prefetch_ids] reverse_sims = candidate_vecs @ source.vecs.T r_targets = np.asarray( [top_mean(reverse_sims[i], opts.csls_k) for i in range(reverse_sims.shape[0])], dtype=np.float32, ) scores = (2.0 * candidate_cosines - r_query - r_targets).astype(np.float32) else: scores = candidate_cosines.astype(np.float32) order = np.argsort(-scores)[: opts.candidate_retrieval_k] results: list[dict[str, Any]] = [] seen_surfaces: set[str] = set() dedupe_surfaces = bool( (space.config.get("filters") or {}).get("duplicate_target_surfaces_removed", True) ) for rank, pos in enumerate(order, 1): local_id = int(prefetch_ids[pos]) meta = target.metas[local_id] score = float(scores[pos]) if score < opts.min_score: continue if apply_filters and not candidate_allowed(meta, target_lang, space, opts): continue surface = str(meta.get("surface") or meta.get("token") or "") if dedupe_surfaces and surface in seen_surfaces: continue seen_surfaces.add(surface) results.append( { "rank": rank, "global_id": int(target.ids[local_id]), "local_id": local_id, "meta": meta, "score": score, "cosine": float(candidate_cosines[pos]), "bidirectional": None, } ) return results def get_meta(space: Space, global_id: int) -> dict[str, Any]: lang, local_id = space.id_to_location[int(global_id)] return space.by_lang[lang].metas[local_id] def get_vec(space: Space, global_id: int) -> np.ndarray: lang, local_id = space.id_to_location[int(global_id)] return space.by_lang[lang].vecs[local_id] def format_word(meta: dict[str, Any], opts: RuntimeOptions) -> str: if opts.use_surface: return str(meta.get("surface") or meta.get("token") or "") return str(meta.get("token") or meta.get("surface") or "") def suggestions(space: Space, lang: str, query: str, limit: int = 8) -> list[str]: lang_lookup = space.lookup.get(lang, {}) key = lookup_key(query) close_keys = difflib.get_close_matches(key, lang_lookup.keys(), n=limit, cutoff=0.72) labels = [] for close_key in close_keys: row_id = lang_lookup[close_key][0] meta = get_meta(space, row_id) label = str(meta.get("surface") or meta.get("token") or "") if label and label not in labels: labels.append(label) return labels def resolve_query(space: Space, lang: str, query: str) -> tuple[int, dict[str, Any], str]: if lang not in space.by_lang: raise ValueError(f"Unknown language {lang!r}. Available: {', '.join(space.languages)}") query = query.strip() if not query: raise ValueError("Enter a query word.") matches = space.lookup.get(lang, {}).get(lookup_key(query), []) if not matches: hint = suggestions(space, lang, query) if hint: raise LookupError(f"No exact match. Close matches: {', '.join(hint)}") raise LookupError(f"No exact token/surface match for {lang}:{query!r}") row_id = int(matches[0]) message = "" if len(matches) > 1: shown = [] for match_id in matches[:5]: meta = get_meta(space, match_id) shown.append(f"{meta.get('surface') or meta.get('token')} (id {match_id})") message = f"Matched {len(matches)} entries; using {shown[0]}." return row_id, get_meta(space, row_id), message def translation_dataframe() -> pd.DataFrame: return pd.DataFrame(columns=TRANSLATION_COLUMNS) def neighbor_dataframe() -> pd.DataFrame: return pd.DataFrame(columns=NEIGHBOR_COLUMNS) def vocabulary_dataframe() -> pd.DataFrame: return pd.DataFrame(columns=VOCAB_COLUMNS) def translate_ui( query: str, source_lang: str, target_langs: list[str] | None, top_k: int, min_score: float, csls_k: int, candidate_retrieval_k: int, csls_prefetch_k: int, bidirectional: bool, score_method: str, filter_stopwords: bool, filter_bad_tokens: bool, use_surface: bool, ) -> tuple[pd.DataFrame, str]: try: space = load_space() targets = target_langs or [lang for lang in space.languages if lang != source_lang] opts = make_options( top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, score_method, filter_stopwords, filter_bad_tokens, use_surface, ) source_id, source_meta, match_message = resolve_query(space, source_lang, query) source_vec = get_vec(space, source_id) rows: list[dict[str, Any]] = [] grouped: list[str] = [ f"Source: `{source_lang}:{format_word(source_meta, opts)}` " f"(token `{source_meta.get('token')}`, id `{source_id}`)" ] if match_message: grouped.append(match_message) for target_lang in targets: if target_lang == source_lang or target_lang not in space.by_lang: continue candidates = rank_candidates(space, source_vec, source_lang, target_lang, opts) kept: list[dict[str, Any]] = [] for cand in candidates: if opts.bidirectional: reverse = rank_candidates( space, get_vec(space, int(cand["global_id"])), target_lang, source_lang, opts, ) reverse_ids = {int(item["global_id"]) for item in reverse} cand["bidirectional"] = source_id in reverse_ids if not cand["bidirectional"]: continue else: cand["bidirectional"] = False kept.append(cand) if len(kept) >= opts.top_k: break if kept: grouped.append(f"\n{target_lang}:") for i, cand in enumerate(kept, 1): meta = cand["meta"] word = format_word(meta, opts) grouped.append(f"{i}. {word} ({cand['score']:.4f})") rows.append( { "target_lang": target_lang, "translation": word, "token": meta.get("token"), "score": round(float(cand["score"]), 6), "cosine": round(float(cand["cosine"]), 6), "rank": int(cand["rank"]), "bidirectional": bool(cand["bidirectional"]), "id": int(cand["global_id"]), "source_vec_file": meta.get("source_vec_file"), } ) else: grouped.append(f"\n{target_lang}: no candidates after filters") return pd.DataFrame(rows, columns=TRANSLATION_COLUMNS), "\n".join(grouped) except Exception as exc: return translation_dataframe(), f"Error: {exc}" def nearest_ui( query: str, source_lang: str, neighbor_langs: list[str] | None, top_n: int, min_score: float, csls_k: int, score_method: str, include_source_language: bool, use_surface: bool, ) -> tuple[pd.DataFrame, str]: try: space = load_space() opts = make_options( top_n, min_score, csls_k, max(top_n + 5, 20), max(top_n + 5, 50), False, score_method, False, False, use_surface, ) source_id, source_meta, match_message = resolve_query(space, source_lang, query) source_vec = get_vec(space, source_id) targets = neighbor_langs or space.languages if not include_source_language: targets = [lang for lang in targets if lang != source_lang] rows: list[dict[str, Any]] = [] for target_lang in targets: if target_lang not in space.by_lang: continue candidates = rank_candidates( space, source_vec, source_lang, target_lang, opts, apply_filters=False, ) for cand in candidates: if int(cand["global_id"]) == source_id: continue meta = cand["meta"] rows.append( { "lang": target_lang, "word": format_word(meta, opts), "token": meta.get("token"), "score": round(float(cand["score"]), 6), "cosine": round(float(cand["cosine"]), 6), "rank": int(cand["rank"]), "id": int(cand["global_id"]), } ) if len([row for row in rows if row["lang"] == target_lang]) >= top_n: break rows = sorted(rows, key=lambda row: row["score"], reverse=True) status = ( f"Source: `{source_lang}:{format_word(source_meta, opts)}` " f"(token `{source_meta.get('token')}`, id `{source_id}`)" ) if match_message: status += f"\n\n{match_message}" return pd.DataFrame(rows, columns=NEIGHBOR_COLUMNS), status except Exception as exc: return neighbor_dataframe(), f"Error: {exc}" def browse_ui(lang: str, filter_text: str, limit: int) -> pd.DataFrame: try: space = load_space() if lang not in space.by_lang: return vocabulary_dataframe() needle = lookup_key(filter_text or "") rows = [] for row_id, meta in zip(space.by_lang[lang].ids.tolist(), space.by_lang[lang].metas): surface = str(meta.get("surface") or "") token = str(meta.get("token") or "") if needle and needle not in lookup_key(surface) and needle not in lookup_key(token): continue rows.append( { "id": int(row_id), "lang": lang, "surface": surface, "token": token, "source_vec_file": meta.get("source_vec_file"), } ) if len(rows) >= int(limit): break return pd.DataFrame(rows, columns=VOCAB_COLUMNS) except Exception: return vocabulary_dataframe() def config_markdown(space: Space) -> str: config = space.config vocab_sizes = config.get("vocab_sizes") or { lang: len(space.by_lang[lang].metas) for lang in space.languages } bidi = config.get("bidirectional_consistency") or {} lines = [ f"Artifact: `{space.artifact_uri}`", f"Created: `{config.get('created_at', 'unknown')}`", f"Languages: `{', '.join(space.languages)}`", f"Pivot language: `{config.get('pivot_lang', 'unknown')}`", f"Vector dim: `{config.get('vector_dim', 'unknown')}`", f"Top N vocab: `{config.get('top_n_vocab', 'unknown')}`", f"Output top: `{config.get('out_top', 'unknown')}`", f"Default top_k: `{config.get('top_k', 3)}`", f"Default min_score: `{config.get('min_score', 0.15)}`", f"Default csls_k: `{config.get('csls_k', 10)}`", f"Bidirectional consistency: `{bool(bidi.get('enabled', True))}`", "", "Vocabulary sizes:", ] for lang, size in sorted(vocab_sizes.items()): lines.append(f"- `{lang}`: `{size}`") return "\n".join(lines) def initialize_ui(): try: space = load_space() opts = default_options(space.config) source = space.config.get("pivot_lang", "de") if source not in space.languages: source = space.languages[0] targets = [lang for lang in space.languages if lang != source] status = f"Loaded `{space.artifact_uri}` with `{sum(len(v.metas) for v in space.by_lang.values())}` vectors." return ( status, gr.update(choices=space.languages, value=source), gr.update(choices=space.languages, value=targets), opts.top_k, opts.min_score, opts.csls_k, opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.bidirectional, gr.update(choices=space.languages, value=source), gr.update(choices=space.languages, value=space.languages), opts.csls_k, gr.update(choices=space.languages, value=source), config_markdown(space), ) except Exception as exc: status = f"Load error: {exc}" return ( status, gr.update(choices=DEFAULT_LANGUAGES, value="de"), gr.update(choices=DEFAULT_LANGUAGES, value=["en", "fr", "lb"]), 3, 0.15, 10, 9, 50, True, gr.update(choices=DEFAULT_LANGUAGES, value="de"), gr.update(choices=DEFAULT_LANGUAGES, value=DEFAULT_LANGUAGES), 10, gr.update(choices=DEFAULT_LANGUAGES, value="de"), status, ) def update_targets(source_lang: str) -> gr.CheckboxGroup: try: space = load_space() return gr.update( choices=space.languages, value=[lang for lang in space.languages if lang != source_lang], ) except Exception: return gr.update( choices=DEFAULT_LANGUAGES, value=[lang for lang in DEFAULT_LANGUAGES if lang != source_lang], ) def update_neighbor_langs(source_lang: str, include_source: bool) -> gr.CheckboxGroup: try: space = load_space() choices = space.languages except Exception: choices = DEFAULT_LANGUAGES values = choices if include_source else [lang for lang in choices if lang != source_lang] return gr.update(choices=choices, value=values) css = """ .app-title h1 { margin-bottom: 0.15rem; } .status-line { font-size: 0.9rem; color: #475569; } """ with gr.Blocks(title="Multilingual Static Word Embeddings", css=css) as demo: gr.Markdown( "# Multilingual Static Word Embeddings\n" "Search the aligned FAISS space for cross-lingual word neighbors." ) status_md = gr.Markdown("Loading artifacts...", elem_classes=["status-line"]) with gr.Tab("Translate"): with gr.Row(): with gr.Column(scale=1, min_width=320): query = gr.Textbox(label="Query word", value="haus") source_lang = gr.Dropdown( label="Source language", choices=DEFAULT_LANGUAGES, value="de", ) target_langs = gr.CheckboxGroup( label="Target languages", choices=DEFAULT_LANGUAGES, value=["en", "fr", "lb"], ) translate_btn = gr.Button("Search", variant="primary") with gr.Accordion("Retrieval parameters", open=True): top_k = gr.Slider(1, 20, value=3, step=1, label="Top K") min_score = gr.Slider(-2.0, 2.0, value=0.15, step=0.01, label="Min score") score_method = gr.Radio( ["csls", "cosine"], value="csls", label="Score method", ) csls_k = gr.Slider(1, 50, value=10, step=1, label="CSLS K") candidate_retrieval_k = gr.Slider( 1, 100, value=9, step=1, label="Candidate retrieval K", ) csls_prefetch_k = gr.Slider( 10, 500, value=50, step=1, label="CSLS prefetch K", ) bidirectional = gr.Checkbox(value=True, label="Bidirectional consistency") filter_stopwords = gr.Checkbox(value=True, label="Filter stopwords") filter_bad_tokens = gr.Checkbox(value=True, label="Filter noisy tokens") use_surface = gr.Checkbox(value=True, label="Show surface forms") with gr.Column(scale=2): translate_summary = gr.Markdown() translation_results = gr.Dataframe( headers=TRANSLATION_COLUMNS, datatype=["str", "str", "str", "number", "number", "number", "bool", "number", "str"], interactive=False, wrap=True, ) with gr.Tab("Nearest Neighbors"): with gr.Row(): with gr.Column(scale=1, min_width=320): nn_query = gr.Textbox(label="Query word", value="haus") nn_source_lang = gr.Dropdown( label="Source language", choices=DEFAULT_LANGUAGES, value="de", ) nn_langs = gr.CheckboxGroup( label="Neighbor languages", choices=DEFAULT_LANGUAGES, value=DEFAULT_LANGUAGES, ) nn_top_n = gr.Slider(1, 50, value=20, step=1, label="Top N per language") nn_min_score = gr.Slider(-2.0, 2.0, value=-2.0, step=0.01, label="Min score") nn_score_method = gr.Radio(["csls", "cosine"], value="cosine", label="Score method") nn_csls_k = gr.Slider(1, 50, value=10, step=1, label="CSLS K") nn_include_source = gr.Checkbox(value=True, label="Include source language") nn_use_surface = gr.Checkbox(value=True, label="Show surface forms") nn_btn = gr.Button("Find neighbors", variant="primary") with gr.Column(scale=2): nn_summary = gr.Markdown() nn_results = gr.Dataframe( headers=NEIGHBOR_COLUMNS, datatype=["str", "str", "str", "number", "number", "number", "number"], interactive=False, wrap=True, ) with gr.Tab("Browse Vocabulary"): with gr.Row(): vocab_lang = gr.Dropdown(label="Language", choices=DEFAULT_LANGUAGES, value="de") vocab_filter = gr.Textbox(label="Filter", placeholder="Type part of a token or surface form") vocab_limit = gr.Slider(10, 500, value=100, step=10, label="Limit") vocab_results = gr.Dataframe( headers=VOCAB_COLUMNS, datatype=["number", "str", "str", "str", "str"], interactive=False, wrap=True, ) with gr.Tab("Artifact Info"): artifact_info = gr.Markdown("Loading config...") translate_inputs = [ query, source_lang, target_langs, top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, score_method, filter_stopwords, filter_bad_tokens, use_surface, ] translate_btn.click( translate_ui, inputs=translate_inputs, outputs=[translation_results, translate_summary], ) query.submit( translate_ui, inputs=translate_inputs, outputs=[translation_results, translate_summary], ) source_lang.change(update_targets, inputs=source_lang, outputs=target_langs) nn_btn.click( nearest_ui, inputs=[ nn_query, nn_source_lang, nn_langs, nn_top_n, nn_min_score, nn_csls_k, nn_score_method, nn_include_source, nn_use_surface, ], outputs=[nn_results, nn_summary], ) nn_query.submit( nearest_ui, inputs=[ nn_query, nn_source_lang, nn_langs, nn_top_n, nn_min_score, nn_csls_k, nn_score_method, nn_include_source, nn_use_surface, ], outputs=[nn_results, nn_summary], ) nn_source_lang.change( update_neighbor_langs, inputs=[nn_source_lang, nn_include_source], outputs=nn_langs, ) nn_include_source.change( update_neighbor_langs, inputs=[nn_source_lang, nn_include_source], outputs=nn_langs, ) vocab_lang.change(browse_ui, inputs=[vocab_lang, vocab_filter, vocab_limit], outputs=vocab_results) vocab_filter.change(browse_ui, inputs=[vocab_lang, vocab_filter, vocab_limit], outputs=vocab_results) vocab_limit.change(browse_ui, inputs=[vocab_lang, vocab_filter, vocab_limit], outputs=vocab_results) demo.load( initialize_ui, outputs=[ status_md, source_lang, target_langs, top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, nn_source_lang, nn_langs, nn_csls_k, vocab_lang, artifact_info, ], ).then( translate_ui, inputs=translate_inputs, outputs=[translation_results, translate_summary], ).then( browse_ui, inputs=[vocab_lang, vocab_filter, vocab_limit], outputs=vocab_results, ) if __name__ == "__main__": demo.queue().launch()