Maslionok's picture
added a bit more description about each parameter
88774ef
Raw
History Blame Contribute Delete
28.2 kB
from __future__ import annotations
import gc
import json
import os
import re
import sys
import unicodedata
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import boto3
import gradio as gr
import numpy as np
from botocore.config import Config
DEFAULT_ARTIFACT_PREFIX = (
"s3://131-component-staging/"
"multilingual-static-word-embeddings/stage-6/"
)
DEFAULT_LOCAL_SPACE = Path("multilingual_dict_20260603_122323")
DEFAULT_LANGS = ["de", "en", "fr", "lb"]
REQUIRED_FILES = ("aligned_all.faiss", "all_metadata.jsonl", "config.json")
CACHE_DIR = Path(os.getenv("ARTIFACT_CACHE_DIR", "/tmp/multilingual_space_artifacts"))
@dataclass
class LangVectors:
lang: str
ids: np.ndarray
metas: list[dict[str, Any]]
vecs: np.ndarray
@dataclass
class RuntimeOptions:
top_k: int
min_score: float
csls_k: int
candidate_retrieval_k: int
csls_prefetch_k: int
bidirectional: bool
score_method: str
filter_stopwords: bool
filter_bad_tokens: bool
use_surface: bool
@dataclass
class Space:
root: Path
artifact_uri: str
config: dict[str, Any]
languages: list[str]
by_lang: dict[str, LangVectors]
lookup: dict[str, dict[str, list[int]]]
id_to_location: dict[int, tuple[str, int]]
has_surface_forms: bool
def parse_s3_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "s3" or not parsed.netloc:
raise ValueError(f"Expected s3://bucket/key URI, got {uri!r}")
return parsed.netloc, parsed.path.lstrip("/")
def make_s3_client():
access_key = os.getenv("SE_ACCESS_KEY") or os.getenv("AWS_ACCESS_KEY_ID")
secret_key = os.getenv("SE_SECRET_KEY") or os.getenv("AWS_SECRET_ACCESS_KEY")
endpoint_url = os.getenv("SE_HOST_URL") or os.getenv("AWS_ENDPOINT_URL")
region = os.getenv("AWS_DEFAULT_REGION", "us-east-1")
if endpoint_url and not endpoint_url.startswith(("http://", "https://")):
endpoint_url = f"https://{endpoint_url}"
kwargs: dict[str, Any] = {
"service_name": "s3",
"region_name": region,
"config": Config(
signature_version="s3v4",
s3={"addressing_style": "path"},
retries={"max_attempts": 3, "mode": "standard"},
),
}
if endpoint_url:
kwargs["endpoint_url"] = endpoint_url
if access_key and secret_key:
kwargs["aws_access_key_id"] = access_key
kwargs["aws_secret_access_key"] = secret_key
return boto3.client(**kwargs)
def latest_artifact_uri(client) -> str:
explicit = os.getenv("SPACE_ARTIFACT_S3_URI", "").strip().rstrip("/")
if explicit:
return explicit
prefix_override = os.getenv("SPACE_ARTIFACT_S3_PREFIX", "").strip()
prefix_uri = prefix_override or DEFAULT_ARTIFACT_PREFIX
bucket, prefix = parse_s3_uri(prefix_uri)
prefix = prefix.rstrip("/") + "/"
pattern = re.compile(
r"(.*multilingual_(?:dict|space)_(\d{8}_\d{6})(?:\.json)?)/config\.json$"
)
candidates: list[tuple[str, str]] = []
paginator = client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
match = pattern.match(obj["Key"])
if match:
candidates.append((match.group(2), match.group(1)))
if not candidates:
raise FileNotFoundError(
f"No multilingual_dict_*/config.json or multilingual_space_*.json/config.json found under {prefix_uri}"
)
# Run ids are timestamps: YYYYMMDD_HHMMSS. Lexicographic sort gives newest run.
run_id, key = sorted(candidates)[-1]
uri = f"s3://{bucket}/{key}"
print(f"Selected latest stage 6 artifact {run_id}: {uri}", file=sys.stderr)
return uri
def local_cache_for_uri(uri: str) -> Path:
_, key = parse_s3_uri(uri)
return CACHE_DIR / Path(key.rstrip("/")).name
def download_space_from_s3() -> tuple[Path, str]:
client = make_s3_client()
uri = latest_artifact_uri(client)
local_dir = local_cache_for_uri(uri)
local_dir.mkdir(parents=True, exist_ok=True)
bucket, prefix = parse_s3_uri(uri)
prefix = prefix.rstrip("/")
for filename in REQUIRED_FILES:
dst = local_dir / filename
if dst.exists() and dst.stat().st_size > 0:
continue
key = f"{prefix}/{filename}"
print(f"Downloading s3://{bucket}/{key}", file=sys.stderr)
client.download_file(bucket, key, str(dst))
return local_dir, uri
def find_space_dir() -> tuple[Path, str]:
local_override = os.getenv("SPACE_DIR", "").strip()
if local_override:
path = Path(local_override)
if path.exists():
return path, str(path)
if DEFAULT_LOCAL_SPACE.exists():
return DEFAULT_LOCAL_SPACE, str(DEFAULT_LOCAL_SPACE)
local_candidates = sorted(
[*Path(".").glob("multilingual_dict_*"), *Path(".").glob("multilingual_space_*.json")]
)
if local_candidates:
return local_candidates[-1], str(local_candidates[-1])
return download_space_from_s3()
def strip_diacritics(text: str) -> str:
return "".join(
ch for ch in unicodedata.normalize("NFKD", text) if not unicodedata.combining(ch)
)
def lookup_key(text: str) -> str:
text = " ".join(text.strip().casefold().split())
return strip_diacritics(text)
def is_good_token(token: str, min_len: int = 4) -> bool:
if not token or len(token) < min_len or token.isdigit():
return False
alpha = sum(ch.isalpha() for ch in token)
if alpha < 2:
return False
return all(ch.isalnum() or ch in "-'_" for ch in token)
def read_config(space_dir: Path) -> dict[str, Any]:
path = space_dir / "config.json"
if not path.exists():
raise FileNotFoundError(f"Missing config.json in {space_dir}")
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def read_metadata(space_dir: Path) -> tuple[list[dict[str, Any]], dict[str, list[int]]]:
path = space_dir / "all_metadata.jsonl"
if not path.exists():
raise FileNotFoundError(f"Missing all_metadata.jsonl in {space_dir}")
metadata: list[dict[str, Any] | None] = []
ids_by_lang: dict[str, list[int]] = {}
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
meta = json.loads(line)
row_id = int(meta["id"])
while len(metadata) <= row_id:
metadata.append(None)
metadata[row_id] = meta
ids_by_lang.setdefault(str(meta["lang"]), []).append(row_id)
missing = [i for i, meta in enumerate(metadata) if meta is None]
if missing:
raise ValueError(f"Metadata ids are not contiguous; first missing id is {missing[0]}")
return [m for m in metadata if m is not None], ids_by_lang
def metadata_has_surface_forms(metadata: list[dict[str, Any]], config: dict[str, Any]) -> bool:
if config.get("surface_forms_enabled") is False:
return False
return any(
meta.get("surface")
and meta.get("token")
and str(meta["surface"]) != str(meta["token"])
for meta in metadata
)
def reconstruct_range(index: Any, start: int, count: int) -> np.ndarray:
try:
vecs = index.reconstruct_n(start, count)
except TypeError:
vecs = np.empty((count, index.d), dtype=np.float32)
index.reconstruct_n(start, count, vecs)
return np.ascontiguousarray(vecs, dtype=np.float32)
def reconstruct_ids(index: Any, ids: list[int]) -> np.ndarray:
if not ids:
return np.empty((0, index.d), dtype=np.float32)
start = ids[0]
if ids == list(range(start, start + len(ids))):
return reconstruct_range(index, start, len(ids))
vecs = np.empty((len(ids), index.d), dtype=np.float32)
for local_i, row_id in enumerate(ids):
try:
vecs[local_i] = index.reconstruct(int(row_id))
except TypeError:
index.reconstruct(int(row_id), vecs[local_i])
return np.ascontiguousarray(vecs, dtype=np.float32)
def normalize_rows(vecs: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
return (vecs / (norms + 1e-12)).astype(np.float32, copy=False)
def load_vectors_from_faiss(space_dir: Path, ids_by_lang: dict[str, list[int]]) -> dict[str, np.ndarray]:
try:
import faiss # type: ignore
except ImportError as exc:
raise RuntimeError("faiss-cpu is required to read aligned_all.faiss") from exc
faiss_path = space_dir / "aligned_all.faiss"
if not faiss_path.exists():
raise FileNotFoundError(f"Missing aligned_all.faiss in {space_dir}")
print(f"Loading FAISS index: {faiss_path}", file=sys.stderr)
index = faiss.read_index(str(faiss_path))
vectors_by_lang: dict[str, np.ndarray] = {}
for lang, ids in sorted(ids_by_lang.items()):
print(f"Reconstructing {lang}: {len(ids)} vectors", file=sys.stderr)
vectors_by_lang[lang] = normalize_rows(reconstruct_ids(index, ids))
del index
gc.collect()
return vectors_by_lang
def build_lookup(languages: dict[str, LangVectors]) -> dict[str, dict[str, list[int]]]:
lookup: dict[str, dict[str, list[int]]] = {}
for lang, data in languages.items():
lang_lookup: dict[str, list[int]] = {}
for global_id, meta in zip(data.ids.tolist(), data.metas):
for value in (meta.get("token"), meta.get("surface")):
if value:
lang_lookup.setdefault(lookup_key(str(value)), []).append(int(global_id))
lookup[lang] = lang_lookup
return lookup
@lru_cache(maxsize=1)
def load_space() -> Space:
space_dir, artifact_uri = find_space_dir()
config = read_config(space_dir)
metadata, ids_by_lang = read_metadata(space_dir)
vectors_by_lang = load_vectors_from_faiss(space_dir, ids_by_lang)
by_lang: dict[str, LangVectors] = {}
id_to_location: dict[int, tuple[str, int]] = {}
languages = list(config.get("languages") or sorted(ids_by_lang))
for lang in languages:
ids = ids_by_lang.get(lang)
if not ids:
continue
metas = [metadata[row_id] for row_id in ids]
vecs = vectors_by_lang[lang]
by_lang[lang] = LangVectors(
lang=lang,
ids=np.asarray(ids, dtype=np.int64),
metas=metas,
vecs=vecs,
)
for local_i, row_id in enumerate(ids):
id_to_location[int(row_id)] = (lang, local_i)
languages = [lang for lang in languages if lang in by_lang]
return Space(
root=space_dir,
artifact_uri=artifact_uri,
config=config,
languages=languages,
by_lang=by_lang,
lookup=build_lookup(by_lang),
id_to_location=id_to_location,
has_surface_forms=metadata_has_surface_forms(metadata, config),
)
def default_options(config: dict[str, Any]) -> RuntimeOptions:
bidi_config = config.get("bidirectional_consistency") or {}
top_k = int(config.get("top_k", 3))
return RuntimeOptions(
top_k=top_k,
min_score=float(config.get("min_score", 0.15)),
csls_k=int(config.get("csls_k", 10)),
candidate_retrieval_k=int(config.get("candidate_retrieval_k", top_k * 3)),
csls_prefetch_k=int(config.get("csls_prefetch_k", 50)),
bidirectional=bool(bidi_config.get("enabled", True)),
score_method="csls",
filter_stopwords=True,
filter_bad_tokens=True,
use_surface=True,
)
def make_options(
top_k: int,
min_score: float,
csls_k: int,
candidate_retrieval_k: int,
csls_prefetch_k: int,
bidirectional: bool,
score_method: str,
filter_stopwords: bool,
filter_bad_tokens: bool,
use_surface: bool,
) -> RuntimeOptions:
return RuntimeOptions(
top_k=int(top_k),
min_score=float(min_score),
csls_k=int(csls_k),
candidate_retrieval_k=int(candidate_retrieval_k),
csls_prefetch_k=int(csls_prefetch_k),
bidirectional=bool(bidirectional),
score_method=str(score_method).lower(),
filter_stopwords=bool(filter_stopwords),
filter_bad_tokens=bool(filter_bad_tokens),
use_surface=bool(use_surface),
)
def top_indices(values: np.ndarray, k: int) -> np.ndarray:
k = min(max(0, k), values.shape[0])
if k == 0:
return np.empty((0,), dtype=np.int64)
if k >= values.shape[0]:
return np.argsort(-values)
idx = np.argpartition(-values, k - 1)[:k]
return idx[np.argsort(-values[idx])]
def top_mean(values: np.ndarray, k: int) -> float:
k = min(max(1, k), values.shape[0])
idx = top_indices(values, k)
return float(values[idx].mean())
def candidate_allowed(meta: dict[str, Any], lang: str, space: Space, opts: RuntimeOptions) -> bool:
token = str(meta.get("token") or "")
if opts.filter_bad_tokens:
min_len = int((space.config.get("filters") or {}).get("target_is_good_token_min_len", 4))
if not is_good_token(token, min_len):
return False
if opts.filter_stopwords:
stopwords = set((space.config.get("stopwords") or {}).get(lang, []))
if token.lower() in stopwords:
return False
return True
def rank_candidates(
space: Space,
query_vec: np.ndarray,
source_lang: str,
target_lang: str,
opts: RuntimeOptions,
*,
apply_filters: bool = True,
) -> list[dict[str, Any]]:
source = space.by_lang[source_lang]
target = space.by_lang[target_lang]
cosine_all = target.vecs @ query_vec
prefetch_k = max(opts.candidate_retrieval_k, opts.csls_prefetch_k, opts.top_k)
prefetch_ids = top_indices(cosine_all, min(prefetch_k, len(target.metas)))
candidate_cosines = cosine_all[prefetch_ids]
if opts.score_method == "csls":
r_query = top_mean(cosine_all, opts.csls_k)
candidate_vecs = target.vecs[prefetch_ids]
reverse_sims = candidate_vecs @ source.vecs.T
r_targets = np.asarray(
[top_mean(reverse_sims[i], opts.csls_k) for i in range(reverse_sims.shape[0])],
dtype=np.float32,
)
scores = (2.0 * candidate_cosines - r_query - r_targets).astype(np.float32)
else:
scores = candidate_cosines.astype(np.float32)
order = np.argsort(-scores)[: opts.candidate_retrieval_k]
results: list[dict[str, Any]] = []
seen_surfaces: set[str] = set()
dedupe_surfaces = bool(
(space.config.get("filters") or {}).get("duplicate_target_surfaces_removed", True)
)
for rank, pos in enumerate(order, 1):
local_id = int(prefetch_ids[pos])
meta = target.metas[local_id]
score = float(scores[pos])
if score < opts.min_score:
continue
if apply_filters and not candidate_allowed(meta, target_lang, space, opts):
continue
surface = str(meta.get("surface") or meta.get("token") or "")
if dedupe_surfaces and surface in seen_surfaces:
continue
seen_surfaces.add(surface)
results.append(
{
"rank": rank,
"global_id": int(target.ids[local_id]),
"local_id": local_id,
"meta": meta,
"score": score,
"cosine": float(candidate_cosines[pos]),
"bidirectional": None,
}
)
return results
def get_meta(space: Space, global_id: int) -> dict[str, Any]:
lang, local_id = space.id_to_location[int(global_id)]
return space.by_lang[lang].metas[local_id]
def get_vec(space: Space, global_id: int) -> np.ndarray:
lang, local_id = space.id_to_location[int(global_id)]
return space.by_lang[lang].vecs[local_id]
def format_word(meta: dict[str, Any], opts: RuntimeOptions) -> str:
if opts.use_surface:
return str(meta.get("surface") or meta.get("token") or "")
return str(meta.get("token") or meta.get("surface") or "")
def resolve_query(space: Space, lang: str, query: str) -> tuple[int, dict[str, Any], str]:
if lang not in space.by_lang:
raise ValueError(f"Unknown language {lang!r}. Available: {', '.join(space.languages)}")
if not query.strip():
raise ValueError("Enter a query word.")
matches = space.lookup.get(lang, {}).get(lookup_key(query), [])
if not matches:
raise LookupError(f"No exact token/surface match for {lang}:{query!r}")
message = ""
if len(matches) > 1:
preview = []
for row_id in matches[:5]:
meta = get_meta(space, int(row_id))
preview.append(f"{meta.get('surface') or meta.get('token')} (id {row_id})")
message = f"Matched {len(matches)} entries; using the first: {preview[0]}"
row_id = int(matches[0])
return row_id, get_meta(space, row_id), message
def translate_like_terminal(
query: str,
source_lang: str,
top_k: int,
min_score: float,
csls_k: int,
candidate_retrieval_k: int,
csls_prefetch_k: int,
bidirectional: bool,
score_method: str,
filter_stopwords: bool,
filter_bad_tokens: bool,
use_surface: bool,
) -> tuple[str, list[list[Any]]]:
try:
space = load_space()
use_surface = bool(use_surface and space.has_surface_forms)
opts = make_options(
top_k,
min_score,
csls_k,
candidate_retrieval_k,
csls_prefetch_k,
bidirectional,
score_method,
filter_stopwords,
filter_bad_tokens,
use_surface,
)
source_id, source_meta, match_message = resolve_query(space, source_lang, query)
source_vec = get_vec(space, source_id)
source_word = format_word(source_meta, opts)
target_langs = [lang for lang in space.languages if lang != source_lang]
lines = [
f"Query: {source_lang}:{source_word} "
f"(token={source_meta.get('token')}, id={source_id})",
f"Settings: score={opts.score_method}, top_k={opts.top_k}, "
f"min_score={opts.min_score}, csls_k={opts.csls_k}, "
f"candidate_retrieval_k={opts.candidate_retrieval_k}, "
f"bidirectional={opts.bidirectional}",
]
if match_message:
lines.append(match_message)
rows: list[list[Any]] = []
for target_lang in target_langs:
candidates = rank_candidates(space, source_vec, source_lang, target_lang, opts)
kept: list[dict[str, Any]] = []
for cand in candidates:
if opts.bidirectional:
reverse = rank_candidates(
space,
get_vec(space, int(cand["global_id"])),
target_lang,
source_lang,
opts,
)
reverse_ids = {int(item["global_id"]) for item in reverse}
cand["bidirectional"] = source_id in reverse_ids
if not cand["bidirectional"]:
continue
else:
cand["bidirectional"] = False
kept.append(cand)
if len(kept) >= opts.top_k:
break
lines.append("")
lines.append(f"{target_lang}:")
if not kept:
lines.append(" no candidates after filters")
continue
for i, cand in enumerate(kept, 1):
meta = cand["meta"]
word = format_word(meta, opts)
token = meta.get("token")
bidi = "yes" if cand["bidirectional"] else "no"
lines.append(
f" {i}. {word} "
f"(token={token}, score={cand['score']:.4f}, "
f"cosine={cand['cosine']:.4f}, bidi={bidi})"
)
rows.append(
[
target_lang,
i,
word,
token,
round(float(cand["score"]), 6),
round(float(cand["cosine"]), 6),
bidi,
]
)
return "\n".join(lines), rows
except Exception as exc:
return f"Error: {exc}", []
def initialize() -> tuple[Any, ...]:
try:
space = load_space()
opts = default_options(space.config)
source_lang = space.config.get("pivot_lang", "de")
if source_lang not in space.languages:
source_lang = space.languages[0]
status = (
f"Loaded {space.artifact_uri} with "
f"{sum(len(item.metas) for item in space.by_lang.values()):,} vectors."
)
return (
status,
gr.update(choices=space.languages, value=source_lang),
opts.top_k,
opts.min_score,
opts.csls_k,
opts.candidate_retrieval_k,
opts.csls_prefetch_k,
opts.bidirectional,
gr.update(
value=space.has_surface_forms,
interactive=space.has_surface_forms,
label=(
"show surface forms"
if space.has_surface_forms
else "show surface forms (none in this aligned space)"
),
),
)
except Exception as exc:
return (
f"Load error: {exc}",
gr.update(choices=DEFAULT_LANGS, value="de"),
3,
0.15,
10,
9,
50,
True,
gr.update(
value=False,
interactive=False,
label="show surface forms (no aligned space loaded)",
),
)
CSS = """
body { background: #f7f5ef; }
.gradio-container { max-width: 1120px !important; }
.app-title h1 { margin-bottom: 0.15rem; }
.status { color: #5f6b7a; font-size: 0.92rem; }
textarea { font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace; }
"""
with gr.Blocks(title="Multilingual Dictionary Explorer", css=CSS) as demo:
gr.Markdown(
"# Multilingual Dictionary Explorer\n"
"FAISS + CSLS translation lookup from the aligned multilingual space.",
elem_classes=["app-title"],
)
status = gr.Markdown("Loading artifacts...", elem_classes=["status"])
with gr.Row():
with gr.Column(scale=1, min_width=320):
query = gr.Textbox(label="Query word", value="haus")
source_lang = gr.Dropdown(label="Language", choices=DEFAULT_LANGS, value="de")
search = gr.Button("Search", variant="primary")
with gr.Accordion("Parameters", open=False):
top_k = gr.Slider(
1,
20,
value=3,
step=1,
label="top_k",
info=(
"How many final translations to show per target language after "
"scoring and filters."
),
)
min_score = gr.Slider(
-2.0,
2.0,
value=0.15,
step=0.01,
label="min_score",
info=(
"The minimum translation score to show. CSLS is a relative score, "
"so negative values are valid but usually allow weaker matches."
),
)
csls_k = gr.Slider(
1,
50,
value=10,
step=1,
label="csls_k",
info=(
"How many neighbours CSLS compares against to avoid overrating "
"generic words in crowded vector areas."
),
)
candidate_retrieval_k = gr.Slider(
1,
100,
value=9,
step=1,
label="candidate_retrieval_k",
info=(
"How many top candidates to inspect before removing bad tokens, "
"stopwords, low scores, or non-bidirectional matches."
),
)
csls_prefetch_k = gr.Slider(
10,
500,
value=50,
step=1,
label="csls_prefetch_k",
info=(
"How many nearby candidates to fetch first so CSLS can score a "
"larger pool before the final shortlist."
),
)
score_method = gr.Radio(
["csls", "cosine"],
value="csls",
label="score",
info=(
"CSLS adjusts cosine similarity for multilingual lookup; cosine "
"shows plain vector closeness without that correction."
),
)
bidirectional = gr.Checkbox(
value=True,
label="bidirectional_consistency",
info=(
"Keep a translation only when the target word also retrieves the "
"query word back, which is stricter but cleaner."
),
)
filter_stopwords = gr.Checkbox(
value=True,
label="filter stopwords",
info=(
"Remove common function words such as articles, prepositions, and "
"pronouns from the displayed candidates."
),
)
filter_bad_tokens = gr.Checkbox(
value=True,
label="filter bad tokens",
info=(
"Remove candidates that look like noise, for example very short, "
"numeric, or punctuation-heavy tokens."
),
)
use_surface = gr.Checkbox(
value=True,
label="show surface forms",
info=(
"Show readable surface forms while keeping the normalized token "
"visible in the token column."
),
)
with gr.Column(scale=2):
output_text = gr.Textbox(label="Terminal-style output", lines=18)
output_table = gr.Dataframe(
headers=["target_lang", "rank", "word", "token", "score", "cosine", "bidi"],
datatype=["str", "number", "str", "str", "number", "number", "str"],
interactive=False,
wrap=True,
)
inputs = [
query,
source_lang,
top_k,
min_score,
csls_k,
candidate_retrieval_k,
csls_prefetch_k,
bidirectional,
score_method,
filter_stopwords,
filter_bad_tokens,
use_surface,
]
search.click(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table])
query.submit(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table])
demo.load(
initialize,
outputs=[
status,
source_lang,
top_k,
min_score,
csls_k,
candidate_retrieval_k,
csls_prefetch_k,
bidirectional,
use_surface,
],
).then(translate_like_terminal, inputs=inputs, outputs=[output_text, output_table])
if __name__ == "__main__":
demo.queue().launch()