| from __future__ import annotations |
|
|
| import gc |
| import json |
| import os |
| import re |
| import sys |
| import unicodedata |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import urlparse |
|
|
| import boto3 |
| import gradio as gr |
| import numpy as np |
| from botocore.config import Config |
|
|
|
|
| DEFAULT_ARTIFACT_PREFIX = ( |
| "s3://131-component-staging/" |
| "multilingual-static-word-embeddings/stage-6/" |
| ) |
| DEFAULT_LOCAL_SPACE = Path("multilingual_dict_20260603_122323") |
| DEFAULT_LANGS = ["de", "en", "fr", "lb"] |
| REQUIRED_FILES = ("aligned_all.faiss", "all_metadata.jsonl", "config.json") |
| CACHE_DIR = Path(os.getenv("ARTIFACT_CACHE_DIR", "/tmp/multilingual_space_artifacts")) |
|
|
|
|
| @dataclass |
| class LangVectors: |
| lang: str |
| ids: np.ndarray |
| metas: list[dict[str, Any]] |
| vecs: np.ndarray |
|
|
|
|
| @dataclass |
| class RuntimeOptions: |
| top_k: int |
| min_score: float |
| csls_k: int |
| candidate_retrieval_k: int |
| csls_prefetch_k: int |
| bidirectional: bool |
| score_method: str |
| filter_stopwords: bool |
| filter_bad_tokens: bool |
| use_surface: bool |
|
|
|
|
| @dataclass |
| class Space: |
| root: Path |
| artifact_uri: str |
| config: dict[str, Any] |
| languages: list[str] |
| by_lang: dict[str, LangVectors] |
| lookup: dict[str, dict[str, list[int]]] |
| id_to_location: dict[int, tuple[str, int]] |
| has_surface_forms: bool |
|
|
|
|
| def parse_s3_uri(uri: str) -> tuple[str, str]: |
| parsed = urlparse(uri) |
| if parsed.scheme != "s3" or not parsed.netloc: |
| raise ValueError(f"Expected s3://bucket/key URI, got {uri!r}") |
| return parsed.netloc, parsed.path.lstrip("/") |
|
|
|
|
| def make_s3_client(): |
| access_key = os.getenv("SE_ACCESS_KEY") or os.getenv("AWS_ACCESS_KEY_ID") |
| secret_key = os.getenv("SE_SECRET_KEY") or os.getenv("AWS_SECRET_ACCESS_KEY") |
| endpoint_url = os.getenv("SE_HOST_URL") or os.getenv("AWS_ENDPOINT_URL") |
| region = os.getenv("AWS_DEFAULT_REGION", "us-east-1") |
|
|
| if endpoint_url and not endpoint_url.startswith(("http://", "https://")): |
| endpoint_url = f"https://{endpoint_url}" |
|
|
| kwargs: dict[str, Any] = { |
| "service_name": "s3", |
| "region_name": region, |
| "config": Config( |
| signature_version="s3v4", |
| s3={"addressing_style": "path"}, |
| retries={"max_attempts": 3, "mode": "standard"}, |
| ), |
| } |
| if endpoint_url: |
| kwargs["endpoint_url"] = endpoint_url |
| if access_key and secret_key: |
| kwargs["aws_access_key_id"] = access_key |
| kwargs["aws_secret_access_key"] = secret_key |
|
|
| return boto3.client(**kwargs) |
|
|
|
|
| def latest_artifact_uri(client) -> str: |
| explicit = os.getenv("SPACE_ARTIFACT_S3_URI", "").strip().rstrip("/") |
| if explicit: |
| return explicit |
|
|
| prefix_override = os.getenv("SPACE_ARTIFACT_S3_PREFIX", "").strip() |
| prefix_uri = prefix_override or DEFAULT_ARTIFACT_PREFIX |
| bucket, prefix = parse_s3_uri(prefix_uri) |
| prefix = prefix.rstrip("/") + "/" |
| pattern = re.compile( |
| r"(.*multilingual_(?:dict|space)_(\d{8}_\d{6})(?:\.json)?)/config\.json$" |
| ) |
| candidates: list[tuple[str, str]] = [] |
|
|
| paginator = client.get_paginator("list_objects_v2") |
| for page in paginator.paginate(Bucket=bucket, Prefix=prefix): |
| for obj in page.get("Contents", []): |
| match = pattern.match(obj["Key"]) |
| if match: |
| candidates.append((match.group(2), match.group(1))) |
|
|
| if not candidates: |
| raise FileNotFoundError( |
| f"No multilingual_dict_*/config.json or multilingual_space_*.json/config.json found under {prefix_uri}" |
| ) |
|
|
| |
| run_id, key = sorted(candidates)[-1] |
| uri = f"s3://{bucket}/{key}" |
| print(f"Selected latest stage 6 artifact {run_id}: {uri}", file=sys.stderr) |
| return uri |
|
|
|
|
| def local_cache_for_uri(uri: str) -> Path: |
| _, key = parse_s3_uri(uri) |
| return CACHE_DIR / Path(key.rstrip("/")).name |
|
|
|
|
| def download_space_from_s3() -> tuple[Path, str]: |
| client = make_s3_client() |
| uri = latest_artifact_uri(client) |
| local_dir = local_cache_for_uri(uri) |
| local_dir.mkdir(parents=True, exist_ok=True) |
|
|
| bucket, prefix = parse_s3_uri(uri) |
| prefix = prefix.rstrip("/") |
| for filename in REQUIRED_FILES: |
| dst = local_dir / filename |
| if dst.exists() and dst.stat().st_size > 0: |
| continue |
| key = f"{prefix}/{filename}" |
| print(f"Downloading s3://{bucket}/{key}", file=sys.stderr) |
| client.download_file(bucket, key, str(dst)) |
|
|
| return local_dir, uri |
|
|
|
|
| def find_space_dir() -> tuple[Path, str]: |
| local_override = os.getenv("SPACE_DIR", "").strip() |
| if local_override: |
| path = Path(local_override) |
| if path.exists(): |
| return path, str(path) |
|
|
| if DEFAULT_LOCAL_SPACE.exists(): |
| return DEFAULT_LOCAL_SPACE, str(DEFAULT_LOCAL_SPACE) |
|
|
| local_candidates = sorted( |
| [*Path(".").glob("multilingual_dict_*"), *Path(".").glob("multilingual_space_*.json")] |
| ) |
| if local_candidates: |
| return local_candidates[-1], str(local_candidates[-1]) |
|
|
| return download_space_from_s3() |
|
|
|
|
| def strip_diacritics(text: str) -> str: |
| return "".join( |
| ch for ch in unicodedata.normalize("NFKD", text) if not unicodedata.combining(ch) |
| ) |
|
|
|
|
| def lookup_key(text: str) -> str: |
| text = " ".join(text.strip().casefold().split()) |
| return strip_diacritics(text) |
|
|
|
|
| def is_good_token(token: str, min_len: int = 4) -> bool: |
| if not token or len(token) < min_len or token.isdigit(): |
| return False |
| alpha = sum(ch.isalpha() for ch in token) |
| if alpha < 2: |
| return False |
| return all(ch.isalnum() or ch in "-'_" for ch in token) |
|
|
|
|
| def read_config(space_dir: Path) -> dict[str, Any]: |
| path = space_dir / "config.json" |
| if not path.exists(): |
| raise FileNotFoundError(f"Missing config.json in {space_dir}") |
| with path.open("r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def read_metadata(space_dir: Path) -> tuple[list[dict[str, Any]], dict[str, list[int]]]: |
| path = space_dir / "all_metadata.jsonl" |
| if not path.exists(): |
| raise FileNotFoundError(f"Missing all_metadata.jsonl in {space_dir}") |
|
|
| metadata: list[dict[str, Any] | None] = [] |
| ids_by_lang: dict[str, list[int]] = {} |
| with path.open("r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| meta = json.loads(line) |
| row_id = int(meta["id"]) |
| while len(metadata) <= row_id: |
| metadata.append(None) |
| metadata[row_id] = meta |
| ids_by_lang.setdefault(str(meta["lang"]), []).append(row_id) |
|
|
| missing = [i for i, meta in enumerate(metadata) if meta is None] |
| if missing: |
| raise ValueError(f"Metadata ids are not contiguous; first missing id is {missing[0]}") |
|
|
| return [m for m in metadata if m is not None], ids_by_lang |
|
|
|
|
| def metadata_has_surface_forms(metadata: list[dict[str, Any]], config: dict[str, Any]) -> bool: |
| if config.get("surface_forms_enabled") is False: |
| return False |
| return any( |
| meta.get("surface") |
| and meta.get("token") |
| and str(meta["surface"]) != str(meta["token"]) |
| for meta in metadata |
| ) |
|
|
|
|
| def reconstruct_range(index: Any, start: int, count: int) -> np.ndarray: |
| try: |
| vecs = index.reconstruct_n(start, count) |
| except TypeError: |
| vecs = np.empty((count, index.d), dtype=np.float32) |
| index.reconstruct_n(start, count, vecs) |
| return np.ascontiguousarray(vecs, dtype=np.float32) |
|
|
|
|
| def reconstruct_ids(index: Any, ids: list[int]) -> np.ndarray: |
| if not ids: |
| return np.empty((0, index.d), dtype=np.float32) |
|
|
| start = ids[0] |
| if ids == list(range(start, start + len(ids))): |
| return reconstruct_range(index, start, len(ids)) |
|
|
| vecs = np.empty((len(ids), index.d), dtype=np.float32) |
| for local_i, row_id in enumerate(ids): |
| try: |
| vecs[local_i] = index.reconstruct(int(row_id)) |
| except TypeError: |
| index.reconstruct(int(row_id), vecs[local_i]) |
| return np.ascontiguousarray(vecs, dtype=np.float32) |
|
|
|
|
| def normalize_rows(vecs: np.ndarray) -> np.ndarray: |
| norms = np.linalg.norm(vecs, axis=1, keepdims=True) |
| return (vecs / (norms + 1e-12)).astype(np.float32, copy=False) |
|
|
|
|
| def load_vectors_from_faiss(space_dir: Path, ids_by_lang: dict[str, list[int]]) -> dict[str, np.ndarray]: |
| try: |
| import faiss |
| except ImportError as exc: |
| raise RuntimeError("faiss-cpu is required to read aligned_all.faiss") from exc |
|
|
| faiss_path = space_dir / "aligned_all.faiss" |
| if not faiss_path.exists(): |
| raise FileNotFoundError(f"Missing aligned_all.faiss in {space_dir}") |
|
|
| print(f"Loading FAISS index: {faiss_path}", file=sys.stderr) |
| index = faiss.read_index(str(faiss_path)) |
| vectors_by_lang: dict[str, np.ndarray] = {} |
| for lang, ids in sorted(ids_by_lang.items()): |
| print(f"Reconstructing {lang}: {len(ids)} vectors", file=sys.stderr) |
| vectors_by_lang[lang] = normalize_rows(reconstruct_ids(index, ids)) |
|
|
| del index |
| gc.collect() |
| return vectors_by_lang |
|
|
|
|
| def build_lookup(languages: dict[str, LangVectors]) -> dict[str, dict[str, list[int]]]: |
| lookup: dict[str, dict[str, list[int]]] = {} |
| for lang, data in languages.items(): |
| lang_lookup: dict[str, list[int]] = {} |
| for global_id, meta in zip(data.ids.tolist(), data.metas): |
| for value in (meta.get("token"), meta.get("surface")): |
| if value: |
| lang_lookup.setdefault(lookup_key(str(value)), []).append(int(global_id)) |
| lookup[lang] = lang_lookup |
| return lookup |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_space() -> Space: |
| space_dir, artifact_uri = find_space_dir() |
| config = read_config(space_dir) |
| metadata, ids_by_lang = read_metadata(space_dir) |
| vectors_by_lang = load_vectors_from_faiss(space_dir, ids_by_lang) |
|
|
| by_lang: dict[str, LangVectors] = {} |
| id_to_location: dict[int, tuple[str, int]] = {} |
| languages = list(config.get("languages") or sorted(ids_by_lang)) |
|
|
| for lang in languages: |
| ids = ids_by_lang.get(lang) |
| if not ids: |
| continue |
| metas = [metadata[row_id] for row_id in ids] |
| vecs = vectors_by_lang[lang] |
| by_lang[lang] = LangVectors( |
| lang=lang, |
| ids=np.asarray(ids, dtype=np.int64), |
| metas=metas, |
| vecs=vecs, |
| ) |
| for local_i, row_id in enumerate(ids): |
| id_to_location[int(row_id)] = (lang, local_i) |
|
|
| languages = [lang for lang in languages if lang in by_lang] |
| return Space( |
| root=space_dir, |
| artifact_uri=artifact_uri, |
| config=config, |
| languages=languages, |
| by_lang=by_lang, |
| lookup=build_lookup(by_lang), |
| id_to_location=id_to_location, |
| has_surface_forms=metadata_has_surface_forms(metadata, config), |
| ) |
|
|
|
|
| def default_options(config: dict[str, Any]) -> RuntimeOptions: |
| bidi_config = config.get("bidirectional_consistency") or {} |
| top_k = int(config.get("top_k", 3)) |
| return RuntimeOptions( |
| top_k=top_k, |
| min_score=float(config.get("min_score", 0.15)), |
| csls_k=int(config.get("csls_k", 10)), |
| candidate_retrieval_k=int(config.get("candidate_retrieval_k", top_k * 3)), |
| csls_prefetch_k=int(config.get("csls_prefetch_k", 50)), |
| bidirectional=bool(bidi_config.get("enabled", True)), |
| score_method="csls", |
| filter_stopwords=True, |
| filter_bad_tokens=True, |
| use_surface=True, |
| ) |
|
|
|
|
| def make_options( |
| top_k: int, |
| min_score: float, |
| csls_k: int, |
| candidate_retrieval_k: int, |
| csls_prefetch_k: int, |
| bidirectional: bool, |
| score_method: str, |
| filter_stopwords: bool, |
| filter_bad_tokens: bool, |
| use_surface: bool, |
| ) -> RuntimeOptions: |
| return RuntimeOptions( |
| top_k=int(top_k), |
| min_score=float(min_score), |
| csls_k=int(csls_k), |
| candidate_retrieval_k=int(candidate_retrieval_k), |
| csls_prefetch_k=int(csls_prefetch_k), |
| bidirectional=bool(bidirectional), |
| score_method=str(score_method).lower(), |
| filter_stopwords=bool(filter_stopwords), |
| filter_bad_tokens=bool(filter_bad_tokens), |
| use_surface=bool(use_surface), |
| ) |
|
|
|
|
| def top_indices(values: np.ndarray, k: int) -> np.ndarray: |
| k = min(max(0, k), values.shape[0]) |
| if k == 0: |
| return np.empty((0,), dtype=np.int64) |
| if k >= values.shape[0]: |
| return np.argsort(-values) |
| idx = np.argpartition(-values, k - 1)[:k] |
| return idx[np.argsort(-values[idx])] |
|
|
|
|
| def top_mean(values: np.ndarray, k: int) -> float: |
| k = min(max(1, k), values.shape[0]) |
| idx = top_indices(values, k) |
| return float(values[idx].mean()) |
|
|
|
|
| def candidate_allowed(meta: dict[str, Any], lang: str, space: Space, opts: RuntimeOptions) -> bool: |
| token = str(meta.get("token") or "") |
| if opts.filter_bad_tokens: |
| min_len = int((space.config.get("filters") or {}).get("target_is_good_token_min_len", 4)) |
| if not is_good_token(token, min_len): |
| return False |
| if opts.filter_stopwords: |
| stopwords = set((space.config.get("stopwords") or {}).get(lang, [])) |
| if token.lower() in stopwords: |
| return False |
| return True |
|
|
|
|
| def rank_candidates( |
| space: Space, |
| query_vec: np.ndarray, |
| source_lang: str, |
| target_lang: str, |
| opts: RuntimeOptions, |
| *, |
| apply_filters: bool = True, |
| ) -> list[dict[str, Any]]: |
| source = space.by_lang[source_lang] |
| target = space.by_lang[target_lang] |
|
|
| cosine_all = target.vecs @ query_vec |
| prefetch_k = max(opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.top_k) |
| prefetch_ids = top_indices(cosine_all, min(prefetch_k, len(target.metas))) |
| candidate_cosines = cosine_all[prefetch_ids] |
|
|
| if opts.score_method == "csls": |
| r_query = top_mean(cosine_all, opts.csls_k) |
| candidate_vecs = target.vecs[prefetch_ids] |
| reverse_sims = candidate_vecs @ source.vecs.T |
| r_targets = np.asarray( |
| [top_mean(reverse_sims[i], opts.csls_k) for i in range(reverse_sims.shape[0])], |
| dtype=np.float32, |
| ) |
| scores = (2.0 * candidate_cosines - r_query - r_targets).astype(np.float32) |
| else: |
| scores = candidate_cosines.astype(np.float32) |
|
|
| order = np.argsort(-scores)[: opts.candidate_retrieval_k] |
| results: list[dict[str, Any]] = [] |
| seen_surfaces: set[str] = set() |
| dedupe_surfaces = bool( |
| (space.config.get("filters") or {}).get("duplicate_target_surfaces_removed", True) |
| ) |
|
|
| for rank, pos in enumerate(order, 1): |
| local_id = int(prefetch_ids[pos]) |
| meta = target.metas[local_id] |
| score = float(scores[pos]) |
| if score < opts.min_score: |
| continue |
| if apply_filters and not candidate_allowed(meta, target_lang, space, opts): |
| continue |
| surface = str(meta.get("surface") or meta.get("token") or "") |
| if dedupe_surfaces and surface in seen_surfaces: |
| continue |
| seen_surfaces.add(surface) |
| results.append( |
| { |
| "rank": rank, |
| "global_id": int(target.ids[local_id]), |
| "local_id": local_id, |
| "meta": meta, |
| "score": score, |
| "cosine": float(candidate_cosines[pos]), |
| "bidirectional": None, |
| } |
| ) |
|
|
| return results |
|
|
|
|
| def get_meta(space: Space, global_id: int) -> dict[str, Any]: |
| lang, local_id = space.id_to_location[int(global_id)] |
| return space.by_lang[lang].metas[local_id] |
|
|
|
|
| def get_vec(space: Space, global_id: int) -> np.ndarray: |
| lang, local_id = space.id_to_location[int(global_id)] |
| return space.by_lang[lang].vecs[local_id] |
|
|
|
|
| def format_word(meta: dict[str, Any], opts: RuntimeOptions) -> str: |
| if opts.use_surface: |
| return str(meta.get("surface") or meta.get("token") or "") |
| return str(meta.get("token") or meta.get("surface") or "") |
|
|
|
|
| def resolve_query(space: Space, lang: str, query: str) -> tuple[int, dict[str, Any], str]: |
| if lang not in space.by_lang: |
| raise ValueError(f"Unknown language {lang!r}. Available: {', '.join(space.languages)}") |
| if not query.strip(): |
| raise ValueError("Enter a query word.") |
|
|
| matches = space.lookup.get(lang, {}).get(lookup_key(query), []) |
| if not matches: |
| raise LookupError(f"No exact token/surface match for {lang}:{query!r}") |
|
|
| message = "" |
| if len(matches) > 1: |
| preview = [] |
| for row_id in matches[:5]: |
| meta = get_meta(space, int(row_id)) |
| preview.append(f"{meta.get('surface') or meta.get('token')} (id {row_id})") |
| message = f"Matched {len(matches)} entries; using the first: {preview[0]}" |
|
|
| row_id = int(matches[0]) |
| return row_id, get_meta(space, row_id), message |
|
|
|
|
| def translate_like_terminal( |
| query: str, |
| source_lang: str, |
| top_k: int, |
| min_score: float, |
| csls_k: int, |
| candidate_retrieval_k: int, |
| csls_prefetch_k: int, |
| bidirectional: bool, |
| score_method: str, |
| filter_stopwords: bool, |
| filter_bad_tokens: bool, |
| use_surface: bool, |
| ) -> tuple[str, list[list[Any]]]: |
| try: |
| space = load_space() |
| use_surface = bool(use_surface and space.has_surface_forms) |
| opts = make_options( |
| top_k, |
| min_score, |
| csls_k, |
| candidate_retrieval_k, |
| csls_prefetch_k, |
| bidirectional, |
| score_method, |
| filter_stopwords, |
| filter_bad_tokens, |
| use_surface, |
| ) |
| source_id, source_meta, match_message = resolve_query(space, source_lang, query) |
| source_vec = get_vec(space, source_id) |
| source_word = format_word(source_meta, opts) |
| target_langs = [lang for lang in space.languages if lang != source_lang] |
|
|
| lines = [ |
| f"Query: {source_lang}:{source_word} " |
| f"(token={source_meta.get('token')}, id={source_id})", |
| f"Settings: score={opts.score_method}, top_k={opts.top_k}, " |
| f"min_score={opts.min_score}, csls_k={opts.csls_k}, " |
| f"candidate_retrieval_k={opts.candidate_retrieval_k}, " |
| f"bidirectional={opts.bidirectional}", |
| ] |
| if match_message: |
| lines.append(match_message) |
|
|
| rows: list[list[Any]] = [] |
| for target_lang in target_langs: |
| candidates = rank_candidates(space, source_vec, source_lang, target_lang, opts) |
| kept: list[dict[str, Any]] = [] |
|
|
| for cand in candidates: |
| if opts.bidirectional: |
| reverse = rank_candidates( |
| space, |
| get_vec(space, int(cand["global_id"])), |
| target_lang, |
| source_lang, |
| opts, |
| ) |
| reverse_ids = {int(item["global_id"]) for item in reverse} |
| cand["bidirectional"] = source_id in reverse_ids |
| if not cand["bidirectional"]: |
| continue |
| else: |
| cand["bidirectional"] = False |
|
|
| kept.append(cand) |
| if len(kept) >= opts.top_k: |
| break |
|
|
| lines.append("") |
| lines.append(f"{target_lang}:") |
| if not kept: |
| lines.append(" no candidates after filters") |
| continue |
|
|
| for i, cand in enumerate(kept, 1): |
| meta = cand["meta"] |
| word = format_word(meta, opts) |
| token = meta.get("token") |
| bidi = "yes" if cand["bidirectional"] else "no" |
| lines.append( |
| f" {i}. {word} " |
| f"(token={token}, score={cand['score']:.4f}, " |
| f"cosine={cand['cosine']:.4f}, bidi={bidi})" |
| ) |
| rows.append( |
| [ |
| target_lang, |
| i, |
| word, |
| token, |
| round(float(cand["score"]), 6), |
| round(float(cand["cosine"]), 6), |
| bidi, |
| ] |
| ) |
|
|
| return "\n".join(lines), rows |
| except Exception as exc: |
| return f"Error: {exc}", [] |
|
|
|
|
| def initialize() -> tuple[Any, ...]: |
| try: |
| space = load_space() |
| opts = default_options(space.config) |
| source_lang = space.config.get("pivot_lang", "de") |
| if source_lang not in space.languages: |
| source_lang = space.languages[0] |
| status = ( |
| f"Loaded {space.artifact_uri} with " |
| f"{sum(len(item.metas) for item in space.by_lang.values()):,} vectors." |
| ) |
| return ( |
| status, |
| gr.update(choices=space.languages, value=source_lang), |
| opts.top_k, |
| opts.min_score, |
| opts.csls_k, |
| opts.candidate_retrieval_k, |
| opts.csls_prefetch_k, |
| opts.bidirectional, |
| gr.update( |
| value=space.has_surface_forms, |
| interactive=space.has_surface_forms, |
| label=( |
| "show surface forms" |
| if space.has_surface_forms |
| else "show surface forms (none in this aligned space)" |
| ), |
| ), |
| ) |
| except Exception as exc: |
| return ( |
| f"Load error: {exc}", |
| gr.update(choices=DEFAULT_LANGS, value="de"), |
| 3, |
| 0.15, |
| 10, |
| 9, |
| 50, |
| True, |
| gr.update( |
| value=False, |
| interactive=False, |
| label="show surface forms (no aligned space loaded)", |
| ), |
| ) |
|
|
|
|
| CSS = """ |
| body { background: #f7f5ef; } |
| .gradio-container { max-width: 1120px !important; } |
| .app-title h1 { margin-bottom: 0.15rem; } |
| .status { color: #5f6b7a; font-size: 0.92rem; } |
| textarea { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; } |
| """ |
|
|
|
|
| with gr.Blocks(title="Multilingual Dictionary Explorer", css=CSS) as demo: |
| gr.Markdown( |
| "# Multilingual Dictionary Explorer\n" |
| "FAISS + CSLS translation lookup from the aligned multilingual space.", |
| elem_classes=["app-title"], |
| ) |
| status = gr.Markdown("Loading artifacts...", elem_classes=["status"]) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=320): |
| query = gr.Textbox(label="Query word", value="haus") |
| source_lang = gr.Dropdown(label="Language", choices=DEFAULT_LANGS, value="de") |
| search = gr.Button("Search", variant="primary") |
|
|
| with gr.Accordion("Parameters", open=False): |
| top_k = gr.Slider( |
| 1, |
| 20, |
| value=3, |
| step=1, |
| label="top_k", |
| info=( |
| "How many final translations to show per target language after " |
| "scoring and filters." |
| ), |
| ) |
| min_score = gr.Slider( |
| -2.0, |
| 2.0, |
| value=0.15, |
| step=0.01, |
| label="min_score", |
| info=( |
| "The minimum translation score to show. CSLS is a relative score, " |
| "so negative values are valid but usually allow weaker matches." |
| ), |
| ) |
| csls_k = gr.Slider( |
| 1, |
| 50, |
| value=10, |
| step=1, |
| label="csls_k", |
| info=( |
| "How many neighbours CSLS compares against to avoid overrating " |
| "generic words in crowded vector areas." |
| ), |
| ) |
| candidate_retrieval_k = gr.Slider( |
| 1, |
| 100, |
| value=9, |
| step=1, |
| label="candidate_retrieval_k", |
| info=( |
| "How many top candidates to inspect before removing bad tokens, " |
| "stopwords, low scores, or non-bidirectional matches." |
| ), |
| ) |
| csls_prefetch_k = gr.Slider( |
| 10, |
| 500, |
| value=50, |
| step=1, |
| label="csls_prefetch_k", |
| info=( |
| "How many nearby candidates to fetch first so CSLS can score a " |
| "larger pool before the final shortlist." |
| ), |
| ) |
| score_method = gr.Radio( |
| ["csls", "cosine"], |
| value="csls", |
| label="score", |
| info=( |
| "CSLS adjusts cosine similarity for multilingual lookup; cosine " |
| "shows plain vector closeness without that correction." |
| ), |
| ) |
| bidirectional = gr.Checkbox( |
| value=True, |
| label="bidirectional_consistency", |
| info=( |
| "Keep a translation only when the target word also retrieves the " |
| "query word back, which is stricter but cleaner." |
| ), |
| ) |
| filter_stopwords = gr.Checkbox( |
| value=True, |
| label="filter stopwords", |
| info=( |
| "Remove common function words such as articles, prepositions, and " |
| "pronouns from the displayed candidates." |
| ), |
| ) |
| filter_bad_tokens = gr.Checkbox( |
| value=True, |
| label="filter bad tokens", |
| info=( |
| "Remove candidates that look like noise, for example very short, " |
| "numeric, or punctuation-heavy tokens." |
| ), |
| ) |
| use_surface = gr.Checkbox( |
| value=True, |
| label="show surface forms", |
| info=( |
| "Show readable surface forms while keeping the normalized token " |
| "visible in the token column." |
| ), |
| ) |
|
|
| with gr.Column(scale=2): |
| output_text = gr.Textbox(label="Terminal-style output", lines=18) |
| output_table = gr.Dataframe( |
| headers=["target_lang", "rank", "word", "token", "score", "cosine", "bidi"], |
| datatype=["str", "number", "str", "str", "number", "number", "str"], |
| interactive=False, |
| wrap=True, |
| ) |
|
|
| inputs = [ |
| query, |
| source_lang, |
| top_k, |
| min_score, |
| csls_k, |
| candidate_retrieval_k, |
| csls_prefetch_k, |
| bidirectional, |
| score_method, |
| filter_stopwords, |
| filter_bad_tokens, |
| use_surface, |
| ] |
| search.click(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) |
| query.submit(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) |
|
|
| demo.load( |
| initialize, |
| outputs=[ |
| status, |
| source_lang, |
| top_k, |
| min_score, |
| csls_k, |
| candidate_retrieval_k, |
| csls_prefetch_k, |
| bidirectional, |
| use_surface, |
| ], |
| ).then(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|