from __future__ import annotations import gc import json import os import re import sys import unicodedata from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any from urllib.parse import urlparse import boto3 import gradio as gr import numpy as np from botocore.config import Config DEFAULT_ARTIFACT_PREFIX = ( "s3://131-component-staging/" "multilingual-static-word-embeddings/stage-6/" ) DEFAULT_LOCAL_SPACE = Path("multilingual_dict_20260603_122323") DEFAULT_LANGS = ["de", "en", "fr", "lb"] REQUIRED_FILES = ("aligned_all.faiss", "all_metadata.jsonl", "config.json") CACHE_DIR = Path(os.getenv("ARTIFACT_CACHE_DIR", "/tmp/multilingual_space_artifacts")) @dataclass class LangVectors: lang: str ids: np.ndarray metas: list[dict[str, Any]] vecs: np.ndarray @dataclass class RuntimeOptions: top_k: int min_score: float csls_k: int candidate_retrieval_k: int csls_prefetch_k: int bidirectional: bool score_method: str filter_stopwords: bool filter_bad_tokens: bool use_surface: bool @dataclass class Space: root: Path artifact_uri: str config: dict[str, Any] languages: list[str] by_lang: dict[str, LangVectors] lookup: dict[str, dict[str, list[int]]] id_to_location: dict[int, tuple[str, int]] has_surface_forms: bool def parse_s3_uri(uri: str) -> tuple[str, str]: parsed = urlparse(uri) if parsed.scheme != "s3" or not parsed.netloc: raise ValueError(f"Expected s3://bucket/key URI, got {uri!r}") return parsed.netloc, parsed.path.lstrip("/") def make_s3_client(): access_key = os.getenv("SE_ACCESS_KEY") or os.getenv("AWS_ACCESS_KEY_ID") secret_key = os.getenv("SE_SECRET_KEY") or os.getenv("AWS_SECRET_ACCESS_KEY") endpoint_url = os.getenv("SE_HOST_URL") or os.getenv("AWS_ENDPOINT_URL") region = os.getenv("AWS_DEFAULT_REGION", "us-east-1") if endpoint_url and not endpoint_url.startswith(("http://", "https://")): endpoint_url = f"https://{endpoint_url}" kwargs: dict[str, Any] = { "service_name": "s3", "region_name": region, "config": Config( signature_version="s3v4", s3={"addressing_style": "path"}, retries={"max_attempts": 3, "mode": "standard"}, ), } if endpoint_url: kwargs["endpoint_url"] = endpoint_url if access_key and secret_key: kwargs["aws_access_key_id"] = access_key kwargs["aws_secret_access_key"] = secret_key return boto3.client(**kwargs) def latest_artifact_uri(client) -> str: explicit = os.getenv("SPACE_ARTIFACT_S3_URI", "").strip().rstrip("/") if explicit: return explicit prefix_override = os.getenv("SPACE_ARTIFACT_S3_PREFIX", "").strip() prefix_uri = prefix_override or DEFAULT_ARTIFACT_PREFIX bucket, prefix = parse_s3_uri(prefix_uri) prefix = prefix.rstrip("/") + "/" pattern = re.compile( r"(.*multilingual_(?:dict|space)_(\d{8}_\d{6})(?:\.json)?)/config\.json$" ) candidates: list[tuple[str, str]] = [] paginator = client.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=bucket, Prefix=prefix): for obj in page.get("Contents", []): match = pattern.match(obj["Key"]) if match: candidates.append((match.group(2), match.group(1))) if not candidates: raise FileNotFoundError( f"No multilingual_dict_*/config.json or multilingual_space_*.json/config.json found under {prefix_uri}" ) # Run ids are timestamps: YYYYMMDD_HHMMSS. Lexicographic sort gives newest run. run_id, key = sorted(candidates)[-1] uri = f"s3://{bucket}/{key}" print(f"Selected latest stage 6 artifact {run_id}: {uri}", file=sys.stderr) return uri def local_cache_for_uri(uri: str) -> Path: _, key = parse_s3_uri(uri) return CACHE_DIR / Path(key.rstrip("/")).name def download_space_from_s3() -> tuple[Path, str]: client = make_s3_client() uri = latest_artifact_uri(client) local_dir = local_cache_for_uri(uri) local_dir.mkdir(parents=True, exist_ok=True) bucket, prefix = parse_s3_uri(uri) prefix = prefix.rstrip("/") for filename in REQUIRED_FILES: dst = local_dir / filename if dst.exists() and dst.stat().st_size > 0: continue key = f"{prefix}/{filename}" print(f"Downloading s3://{bucket}/{key}", file=sys.stderr) client.download_file(bucket, key, str(dst)) return local_dir, uri def find_space_dir() -> tuple[Path, str]: local_override = os.getenv("SPACE_DIR", "").strip() if local_override: path = Path(local_override) if path.exists(): return path, str(path) if DEFAULT_LOCAL_SPACE.exists(): return DEFAULT_LOCAL_SPACE, str(DEFAULT_LOCAL_SPACE) local_candidates = sorted( [*Path(".").glob("multilingual_dict_*"), *Path(".").glob("multilingual_space_*.json")] ) if local_candidates: return local_candidates[-1], str(local_candidates[-1]) return download_space_from_s3() def strip_diacritics(text: str) -> str: return "".join( ch for ch in unicodedata.normalize("NFKD", text) if not unicodedata.combining(ch) ) def lookup_key(text: str) -> str: text = " ".join(text.strip().casefold().split()) return strip_diacritics(text) def is_good_token(token: str, min_len: int = 4) -> bool: if not token or len(token) < min_len or token.isdigit(): return False alpha = sum(ch.isalpha() for ch in token) if alpha < 2: return False return all(ch.isalnum() or ch in "-'_" for ch in token) def read_config(space_dir: Path) -> dict[str, Any]: path = space_dir / "config.json" if not path.exists(): raise FileNotFoundError(f"Missing config.json in {space_dir}") with path.open("r", encoding="utf-8") as f: return json.load(f) def read_metadata(space_dir: Path) -> tuple[list[dict[str, Any]], dict[str, list[int]]]: path = space_dir / "all_metadata.jsonl" if not path.exists(): raise FileNotFoundError(f"Missing all_metadata.jsonl in {space_dir}") metadata: list[dict[str, Any] | None] = [] ids_by_lang: dict[str, list[int]] = {} with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue meta = json.loads(line) row_id = int(meta["id"]) while len(metadata) <= row_id: metadata.append(None) metadata[row_id] = meta ids_by_lang.setdefault(str(meta["lang"]), []).append(row_id) missing = [i for i, meta in enumerate(metadata) if meta is None] if missing: raise ValueError(f"Metadata ids are not contiguous; first missing id is {missing[0]}") return [m for m in metadata if m is not None], ids_by_lang def metadata_has_surface_forms(metadata: list[dict[str, Any]], config: dict[str, Any]) -> bool: if config.get("surface_forms_enabled") is False: return False return any( meta.get("surface") and meta.get("token") and str(meta["surface"]) != str(meta["token"]) for meta in metadata ) def reconstruct_range(index: Any, start: int, count: int) -> np.ndarray: try: vecs = index.reconstruct_n(start, count) except TypeError: vecs = np.empty((count, index.d), dtype=np.float32) index.reconstruct_n(start, count, vecs) return np.ascontiguousarray(vecs, dtype=np.float32) def reconstruct_ids(index: Any, ids: list[int]) -> np.ndarray: if not ids: return np.empty((0, index.d), dtype=np.float32) start = ids[0] if ids == list(range(start, start + len(ids))): return reconstruct_range(index, start, len(ids)) vecs = np.empty((len(ids), index.d), dtype=np.float32) for local_i, row_id in enumerate(ids): try: vecs[local_i] = index.reconstruct(int(row_id)) except TypeError: index.reconstruct(int(row_id), vecs[local_i]) return np.ascontiguousarray(vecs, dtype=np.float32) def normalize_rows(vecs: np.ndarray) -> np.ndarray: norms = np.linalg.norm(vecs, axis=1, keepdims=True) return (vecs / (norms + 1e-12)).astype(np.float32, copy=False) def load_vectors_from_faiss(space_dir: Path, ids_by_lang: dict[str, list[int]]) -> dict[str, np.ndarray]: try: import faiss # type: ignore except ImportError as exc: raise RuntimeError("faiss-cpu is required to read aligned_all.faiss") from exc faiss_path = space_dir / "aligned_all.faiss" if not faiss_path.exists(): raise FileNotFoundError(f"Missing aligned_all.faiss in {space_dir}") print(f"Loading FAISS index: {faiss_path}", file=sys.stderr) index = faiss.read_index(str(faiss_path)) vectors_by_lang: dict[str, np.ndarray] = {} for lang, ids in sorted(ids_by_lang.items()): print(f"Reconstructing {lang}: {len(ids)} vectors", file=sys.stderr) vectors_by_lang[lang] = normalize_rows(reconstruct_ids(index, ids)) del index gc.collect() return vectors_by_lang def build_lookup(languages: dict[str, LangVectors]) -> dict[str, dict[str, list[int]]]: lookup: dict[str, dict[str, list[int]]] = {} for lang, data in languages.items(): lang_lookup: dict[str, list[int]] = {} for global_id, meta in zip(data.ids.tolist(), data.metas): for value in (meta.get("token"), meta.get("surface")): if value: lang_lookup.setdefault(lookup_key(str(value)), []).append(int(global_id)) lookup[lang] = lang_lookup return lookup @lru_cache(maxsize=1) def load_space() -> Space: space_dir, artifact_uri = find_space_dir() config = read_config(space_dir) metadata, ids_by_lang = read_metadata(space_dir) vectors_by_lang = load_vectors_from_faiss(space_dir, ids_by_lang) by_lang: dict[str, LangVectors] = {} id_to_location: dict[int, tuple[str, int]] = {} languages = list(config.get("languages") or sorted(ids_by_lang)) for lang in languages: ids = ids_by_lang.get(lang) if not ids: continue metas = [metadata[row_id] for row_id in ids] vecs = vectors_by_lang[lang] by_lang[lang] = LangVectors( lang=lang, ids=np.asarray(ids, dtype=np.int64), metas=metas, vecs=vecs, ) for local_i, row_id in enumerate(ids): id_to_location[int(row_id)] = (lang, local_i) languages = [lang for lang in languages if lang in by_lang] return Space( root=space_dir, artifact_uri=artifact_uri, config=config, languages=languages, by_lang=by_lang, lookup=build_lookup(by_lang), id_to_location=id_to_location, has_surface_forms=metadata_has_surface_forms(metadata, config), ) def default_options(config: dict[str, Any]) -> RuntimeOptions: bidi_config = config.get("bidirectional_consistency") or {} top_k = int(config.get("top_k", 3)) return RuntimeOptions( top_k=top_k, min_score=float(config.get("min_score", 0.15)), csls_k=int(config.get("csls_k", 10)), candidate_retrieval_k=int(config.get("candidate_retrieval_k", top_k * 3)), csls_prefetch_k=int(config.get("csls_prefetch_k", 50)), bidirectional=bool(bidi_config.get("enabled", True)), score_method="csls", filter_stopwords=True, filter_bad_tokens=True, use_surface=True, ) def make_options( top_k: int, min_score: float, csls_k: int, candidate_retrieval_k: int, csls_prefetch_k: int, bidirectional: bool, score_method: str, filter_stopwords: bool, filter_bad_tokens: bool, use_surface: bool, ) -> RuntimeOptions: return RuntimeOptions( top_k=int(top_k), min_score=float(min_score), csls_k=int(csls_k), candidate_retrieval_k=int(candidate_retrieval_k), csls_prefetch_k=int(csls_prefetch_k), bidirectional=bool(bidirectional), score_method=str(score_method).lower(), filter_stopwords=bool(filter_stopwords), filter_bad_tokens=bool(filter_bad_tokens), use_surface=bool(use_surface), ) def top_indices(values: np.ndarray, k: int) -> np.ndarray: k = min(max(0, k), values.shape[0]) if k == 0: return np.empty((0,), dtype=np.int64) if k >= values.shape[0]: return np.argsort(-values) idx = np.argpartition(-values, k - 1)[:k] return idx[np.argsort(-values[idx])] def top_mean(values: np.ndarray, k: int) -> float: k = min(max(1, k), values.shape[0]) idx = top_indices(values, k) return float(values[idx].mean()) def candidate_allowed(meta: dict[str, Any], lang: str, space: Space, opts: RuntimeOptions) -> bool: token = str(meta.get("token") or "") if opts.filter_bad_tokens: min_len = int((space.config.get("filters") or {}).get("target_is_good_token_min_len", 4)) if not is_good_token(token, min_len): return False if opts.filter_stopwords: stopwords = set((space.config.get("stopwords") or {}).get(lang, [])) if token.lower() in stopwords: return False return True def rank_candidates( space: Space, query_vec: np.ndarray, source_lang: str, target_lang: str, opts: RuntimeOptions, *, apply_filters: bool = True, ) -> list[dict[str, Any]]: source = space.by_lang[source_lang] target = space.by_lang[target_lang] cosine_all = target.vecs @ query_vec prefetch_k = max(opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.top_k) prefetch_ids = top_indices(cosine_all, min(prefetch_k, len(target.metas))) candidate_cosines = cosine_all[prefetch_ids] if opts.score_method == "csls": r_query = top_mean(cosine_all, opts.csls_k) candidate_vecs = target.vecs[prefetch_ids] reverse_sims = candidate_vecs @ source.vecs.T r_targets = np.asarray( [top_mean(reverse_sims[i], opts.csls_k) for i in range(reverse_sims.shape[0])], dtype=np.float32, ) scores = (2.0 * candidate_cosines - r_query - r_targets).astype(np.float32) else: scores = candidate_cosines.astype(np.float32) order = np.argsort(-scores)[: opts.candidate_retrieval_k] results: list[dict[str, Any]] = [] seen_surfaces: set[str] = set() dedupe_surfaces = bool( (space.config.get("filters") or {}).get("duplicate_target_surfaces_removed", True) ) for rank, pos in enumerate(order, 1): local_id = int(prefetch_ids[pos]) meta = target.metas[local_id] score = float(scores[pos]) if score < opts.min_score: continue if apply_filters and not candidate_allowed(meta, target_lang, space, opts): continue surface = str(meta.get("surface") or meta.get("token") or "") if dedupe_surfaces and surface in seen_surfaces: continue seen_surfaces.add(surface) results.append( { "rank": rank, "global_id": int(target.ids[local_id]), "local_id": local_id, "meta": meta, "score": score, "cosine": float(candidate_cosines[pos]), "bidirectional": None, } ) return results def get_meta(space: Space, global_id: int) -> dict[str, Any]: lang, local_id = space.id_to_location[int(global_id)] return space.by_lang[lang].metas[local_id] def get_vec(space: Space, global_id: int) -> np.ndarray: lang, local_id = space.id_to_location[int(global_id)] return space.by_lang[lang].vecs[local_id] def format_word(meta: dict[str, Any], opts: RuntimeOptions) -> str: if opts.use_surface: return str(meta.get("surface") or meta.get("token") or "") return str(meta.get("token") or meta.get("surface") or "") def resolve_query(space: Space, lang: str, query: str) -> tuple[int, dict[str, Any], str]: if lang not in space.by_lang: raise ValueError(f"Unknown language {lang!r}. Available: {', '.join(space.languages)}") if not query.strip(): raise ValueError("Enter a query word.") matches = space.lookup.get(lang, {}).get(lookup_key(query), []) if not matches: raise LookupError(f"No exact token/surface match for {lang}:{query!r}") message = "" if len(matches) > 1: preview = [] for row_id in matches[:5]: meta = get_meta(space, int(row_id)) preview.append(f"{meta.get('surface') or meta.get('token')} (id {row_id})") message = f"Matched {len(matches)} entries; using the first: {preview[0]}" row_id = int(matches[0]) return row_id, get_meta(space, row_id), message def translate_like_terminal( query: str, source_lang: str, top_k: int, min_score: float, csls_k: int, candidate_retrieval_k: int, csls_prefetch_k: int, bidirectional: bool, score_method: str, filter_stopwords: bool, filter_bad_tokens: bool, use_surface: bool, ) -> tuple[str, list[list[Any]]]: try: space = load_space() use_surface = bool(use_surface and space.has_surface_forms) opts = make_options( top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, score_method, filter_stopwords, filter_bad_tokens, use_surface, ) source_id, source_meta, match_message = resolve_query(space, source_lang, query) source_vec = get_vec(space, source_id) source_word = format_word(source_meta, opts) target_langs = [lang for lang in space.languages if lang != source_lang] lines = [ f"Query: {source_lang}:{source_word} " f"(token={source_meta.get('token')}, id={source_id})", f"Settings: score={opts.score_method}, top_k={opts.top_k}, " f"min_score={opts.min_score}, csls_k={opts.csls_k}, " f"candidate_retrieval_k={opts.candidate_retrieval_k}, " f"bidirectional={opts.bidirectional}", ] if match_message: lines.append(match_message) rows: list[list[Any]] = [] for target_lang in target_langs: candidates = rank_candidates(space, source_vec, source_lang, target_lang, opts) kept: list[dict[str, Any]] = [] for cand in candidates: if opts.bidirectional: reverse = rank_candidates( space, get_vec(space, int(cand["global_id"])), target_lang, source_lang, opts, ) reverse_ids = {int(item["global_id"]) for item in reverse} cand["bidirectional"] = source_id in reverse_ids if not cand["bidirectional"]: continue else: cand["bidirectional"] = False kept.append(cand) if len(kept) >= opts.top_k: break lines.append("") lines.append(f"{target_lang}:") if not kept: lines.append(" no candidates after filters") continue for i, cand in enumerate(kept, 1): meta = cand["meta"] word = format_word(meta, opts) token = meta.get("token") bidi = "yes" if cand["bidirectional"] else "no" lines.append( f" {i}. {word} " f"(token={token}, score={cand['score']:.4f}, " f"cosine={cand['cosine']:.4f}, bidi={bidi})" ) rows.append( [ target_lang, i, word, token, round(float(cand["score"]), 6), round(float(cand["cosine"]), 6), bidi, ] ) return "\n".join(lines), rows except Exception as exc: return f"Error: {exc}", [] def initialize() -> tuple[Any, ...]: try: space = load_space() opts = default_options(space.config) source_lang = space.config.get("pivot_lang", "de") if source_lang not in space.languages: source_lang = space.languages[0] status = ( f"Loaded {space.artifact_uri} with " f"{sum(len(item.metas) for item in space.by_lang.values()):,} vectors." ) return ( status, gr.update(choices=space.languages, value=source_lang), opts.top_k, opts.min_score, opts.csls_k, opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.bidirectional, gr.update( value=space.has_surface_forms, interactive=space.has_surface_forms, label=( "show surface forms" if space.has_surface_forms else "show surface forms (none in this aligned space)" ), ), ) except Exception as exc: return ( f"Load error: {exc}", gr.update(choices=DEFAULT_LANGS, value="de"), 3, 0.15, 10, 9, 50, True, gr.update( value=False, interactive=False, label="show surface forms (no aligned space loaded)", ), ) CSS = """ body { background: #f7f5ef; } .gradio-container { max-width: 1120px !important; } .app-title h1 { margin-bottom: 0.15rem; } .status { color: #5f6b7a; font-size: 0.92rem; } textarea { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; } """ with gr.Blocks(title="Multilingual Dictionary Explorer", css=CSS) as demo: gr.Markdown( "# Multilingual Dictionary Explorer\n" "FAISS + CSLS translation lookup from the aligned multilingual space.", elem_classes=["app-title"], ) status = gr.Markdown("Loading artifacts...", elem_classes=["status"]) with gr.Row(): with gr.Column(scale=1, min_width=320): query = gr.Textbox(label="Query word", value="haus") source_lang = gr.Dropdown(label="Language", choices=DEFAULT_LANGS, value="de") search = gr.Button("Search", variant="primary") with gr.Accordion("Parameters", open=False): top_k = gr.Slider( 1, 20, value=3, step=1, label="top_k", info=( "How many final translations to show per target language after " "scoring and filters." ), ) min_score = gr.Slider( -2.0, 2.0, value=0.15, step=0.01, label="min_score", info=( "The minimum translation score to show. CSLS is a relative score, " "so negative values are valid but usually allow weaker matches." ), ) csls_k = gr.Slider( 1, 50, value=10, step=1, label="csls_k", info=( "How many neighbours CSLS compares against to avoid overrating " "generic words in crowded vector areas." ), ) candidate_retrieval_k = gr.Slider( 1, 100, value=9, step=1, label="candidate_retrieval_k", info=( "How many top candidates to inspect before removing bad tokens, " "stopwords, low scores, or non-bidirectional matches." ), ) csls_prefetch_k = gr.Slider( 10, 500, value=50, step=1, label="csls_prefetch_k", info=( "How many nearby candidates to fetch first so CSLS can score a " "larger pool before the final shortlist." ), ) score_method = gr.Radio( ["csls", "cosine"], value="csls", label="score", info=( "CSLS adjusts cosine similarity for multilingual lookup; cosine " "shows plain vector closeness without that correction." ), ) bidirectional = gr.Checkbox( value=True, label="bidirectional_consistency", info=( "Keep a translation only when the target word also retrieves the " "query word back, which is stricter but cleaner." ), ) filter_stopwords = gr.Checkbox( value=True, label="filter stopwords", info=( "Remove common function words such as articles, prepositions, and " "pronouns from the displayed candidates." ), ) filter_bad_tokens = gr.Checkbox( value=True, label="filter bad tokens", info=( "Remove candidates that look like noise, for example very short, " "numeric, or punctuation-heavy tokens." ), ) use_surface = gr.Checkbox( value=True, label="show surface forms", info=( "Show readable surface forms while keeping the normalized token " "visible in the token column." ), ) with gr.Column(scale=2): output_text = gr.Textbox(label="Terminal-style output", lines=18) output_table = gr.Dataframe( headers=["target_lang", "rank", "word", "token", "score", "cosine", "bidi"], datatype=["str", "number", "str", "str", "number", "number", "str"], interactive=False, wrap=True, ) inputs = [ query, source_lang, top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, score_method, filter_stopwords, filter_bad_tokens, use_surface, ] search.click(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) query.submit(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) demo.load( initialize, outputs=[ status, source_lang, top_k, min_score, csls_k, candidate_retrieval_k, csls_prefetch_k, bidirectional, use_surface, ], ).then(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table]) if __name__ == "__main__": demo.queue().launch()