Maslionok commited on
Commit
245b51c
·
1 Parent(s): e746bd7
Files changed (4) hide show
  1. .gitattributes +0 -35
  2. README.md +0 -55
  3. app.py +0 -1697
  4. requirements.txt +0 -8
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,55 +0,0 @@
1
- ---
2
- title: Multilingual Static Word Embeddings Demo
3
- emoji: 🧭
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: "3.11"
9
- app_file: app.py
10
- pinned: false
11
- ---
12
-
13
- # Multilingual Static Word Embeddings Demo
14
-
15
- Gradio Space for exploring an aligned multilingual static word embedding artifact produced by Stage 6 of `build_multilingual_dictionary.py` when `SAVE_ALLIGNED_SPACE=true`.
16
-
17
- The app loads the newest S3 folder matching:
18
-
19
- ```text
20
- s3://131-component-staging/multilingual-static-word-embeddings/stage-6/multilingual_space_*.json/
21
- ```
22
-
23
- Required files inside the artifact folder:
24
-
25
- - `aligned_all.faiss`
26
- - `all_metadata.jsonl`
27
- - `config.json`
28
-
29
- `aligned_all.vec` is downloaded only if vectors cannot be reconstructed from the FAISS index.
30
-
31
- The selected artifact's S3 `config.json` is used for live UI defaults, stopwords, and metadata display. Changing the artifact dropdown reloads that folder's corresponding config.
32
-
33
- ## Space Secrets
34
-
35
- Set these Hugging Face Space secrets if the S3 bucket is private:
36
-
37
- - `SE_ACCESS_KEY`
38
- - `SE_SECRET_KEY`
39
- - `SE_HOST`
40
-
41
- `SE_HOST` may be either a hostname or a full `https://...` endpoint URL. The app also still supports `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` as fallback names.
42
-
43
- The app lists existing aligned-space folders in a dropdown using the timestamp in names like `multilingual_space_20260521_133953.json/`. The newest timestamp is selected by default.
44
-
45
- To preselect a specific artifact instead of the newest folder, set:
46
-
47
- ```text
48
- SPACE_ARTIFACT_S3_URI=s3://131-component-staging/multilingual-static-word-embeddings/stage-6/multilingual_space_<TIMESTAMP>.json/
49
- ```
50
-
51
- ## What Can Be Changed Live
52
-
53
- The Translate tab allows live changes to retrieval and filtering parameters such as `top_k`, `min_score`, `csls_k`, candidate multiplier, FAISS prefetch, score method, stopword filtering, minimum token length, fuzzy lookup, and bidirectional consistency.
54
-
55
- Alignment/build parameters such as `pivot_lang`, `top_n_vocab`, `out_top`, `align_iters`, `init_pairs`, and `max_pairs` are shown as read-only artifact metadata because changing them requires rebuilding the aligned vector space.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,1697 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import fnmatch
4
- import hashlib
5
- import json
6
- import math
7
- import os
8
- import random
9
- import re
10
- import sys
11
- import threading
12
- import unicodedata
13
- from collections import defaultdict
14
- from dataclasses import dataclass
15
- from pathlib import Path
16
- from typing import Any
17
- from urllib.parse import urlparse
18
-
19
- os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
20
- os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
21
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
22
- os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
23
-
24
- _ORIGINAL_UNRAISABLEHOOK = sys.unraisablehook
25
-
26
-
27
- def _quiet_asyncio_invalid_fd(unraisable):
28
- if (
29
- isinstance(unraisable.exc_value, ValueError)
30
- and "Invalid file descriptor: -1" in str(unraisable.exc_value)
31
- and "BaseEventLoop.__del__" in repr(unraisable.object)
32
- ):
33
- return
34
- _ORIGINAL_UNRAISABLEHOOK(unraisable)
35
-
36
-
37
- sys.unraisablehook = _quiet_asyncio_invalid_fd
38
-
39
- import boto3
40
- import gradio as gr
41
- import numpy as np
42
- import pandas as pd
43
- from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError
44
- from dotenv import load_dotenv
45
-
46
- try:
47
- import faiss
48
- except Exception as exc: # pragma: no cover - shown as a startup error in the UI.
49
- faiss = None
50
- FAISS_IMPORT_ERROR = exc
51
- else:
52
- FAISS_IMPORT_ERROR = None
53
-
54
- try:
55
- from rapidfuzz import fuzz, process
56
- except Exception: # pragma: no cover - rapidfuzz is optional at runtime.
57
- fuzz = None
58
- process = None
59
-
60
-
61
- load_dotenv()
62
-
63
- BASE_ARTIFACT_S3_URI = "s3://131-component-staging/multilingual-static-word-embeddings/stage-6/"
64
- ARTIFACT_ENV_VAR = "SPACE_ARTIFACT_S3_URI"
65
- SE_ACCESS_KEY_ENV = "SE_ACCESS_KEY"
66
- SE_SECRET_KEY_ENV = "SE_SECRET_KEY"
67
- SE_HOST_ENV = "SE_HOST"
68
- CACHE_ROOT = Path("/tmp/multilingual_space_artifacts")
69
- REQUIRED_FILES = ("aligned_all.faiss", "all_metadata.jsonl", "config.json")
70
- OPTIONAL_VEC_FILE = "aligned_all.vec"
71
-
72
-
73
- def _config_get_raw(config: dict[str, Any], keys: tuple[str, ...], default: Any = "") -> Any:
74
- for key in keys:
75
- if key in config:
76
- return config[key]
77
- for section_name in ("config", "params", "args", "stage_6", "alignment", "dictionary", "preprocessing", "filters"):
78
- section = config.get(section_name)
79
- if isinstance(section, dict):
80
- found = _config_get_raw(section, keys, None)
81
- if found is not None:
82
- return found
83
- return default
84
-
85
-
86
- def _as_bool(value: Any, default: bool) -> bool:
87
- if isinstance(value, dict) and "enabled" in value:
88
- return _as_bool(value["enabled"], default)
89
- if isinstance(value, bool):
90
- return value
91
- if isinstance(value, str):
92
- normalized = value.strip().casefold()
93
- if normalized in {"1", "true", "yes", "y", "on"}:
94
- return True
95
- if normalized in {"0", "false", "no", "n", "off"}:
96
- return False
97
- if value is None:
98
- return default
99
- return bool(value)
100
-
101
-
102
- def _as_int(value: Any, default: int) -> int:
103
- try:
104
- return int(value)
105
- except (TypeError, ValueError):
106
- return default
107
-
108
-
109
- def _as_float(value: Any, default: float) -> float:
110
- try:
111
- return float(value)
112
- except (TypeError, ValueError):
113
- return default
114
-
115
-
116
- BASE_DEFAULTS = {
117
- "pivot_lang": "de",
118
- "top_n_vocab": 150000,
119
- "out_top": 50000,
120
- "top_k": 3,
121
- "min_score": 0.15,
122
- "align_iters": 5,
123
- "init_pairs": 5000,
124
- "max_pairs": 15000,
125
- "csls_k": 10,
126
- "candidate_retrieval_k_multiplier": 3,
127
- "csls_prefetch_k": 50,
128
- "bidirectional_consistency": True,
129
- "use_surface_forms": True,
130
- "hide_stopwords": True,
131
- "min_token_length": 4,
132
- }
133
-
134
- INTERACTIVE_DEFAULTS = {
135
- "candidate_retrieval_k_multiplier": 3,
136
- "csls_prefetch_k": 20,
137
- "bidirectional_consistency": False,
138
- "fuzzy_fallback": False,
139
- }
140
-
141
-
142
- def _defaults_from_config(config: dict[str, Any], fallback: dict[str, Any] | None = None) -> dict[str, Any]:
143
- defaults = dict(fallback or BASE_DEFAULTS)
144
- top_k = _as_int(_config_get_raw(config, ("top_k",), defaults["top_k"]), defaults["top_k"])
145
-
146
- candidate_retrieval_k = _config_get_raw(config, ("candidate_retrieval_k",), None)
147
- if candidate_retrieval_k is not None and top_k > 0:
148
- candidate_multiplier = max(1, math.ceil(_as_int(candidate_retrieval_k, top_k * 3) / top_k))
149
- else:
150
- candidate_multiplier = _as_int(
151
- _config_get_raw(config, ("candidate_retrieval_k_multiplier",), defaults["candidate_retrieval_k_multiplier"]),
152
- defaults["candidate_retrieval_k_multiplier"],
153
- )
154
-
155
- defaults.update(
156
- {
157
- "pivot_lang": str(_config_get_raw(config, ("pivot_lang", "pivot_language"), defaults["pivot_lang"])),
158
- "top_n_vocab": _as_int(_config_get_raw(config, ("top_n_vocab",), defaults["top_n_vocab"]), defaults["top_n_vocab"]),
159
- "out_top": _as_int(_config_get_raw(config, ("out_top",), defaults["out_top"]), defaults["out_top"]),
160
- "top_k": top_k,
161
- "min_score": _as_float(_config_get_raw(config, ("min_score",), defaults["min_score"]), defaults["min_score"]),
162
- "align_iters": _as_int(_config_get_raw(config, ("align_iters",), defaults["align_iters"]), defaults["align_iters"]),
163
- "init_pairs": _as_int(_config_get_raw(config, ("init_pairs",), defaults["init_pairs"]), defaults["init_pairs"]),
164
- "max_pairs": _as_int(_config_get_raw(config, ("max_pairs",), defaults["max_pairs"]), defaults["max_pairs"]),
165
- "csls_k": _as_int(_config_get_raw(config, ("csls_k",), defaults["csls_k"]), defaults["csls_k"]),
166
- "candidate_retrieval_k_multiplier": candidate_multiplier,
167
- "csls_prefetch_k": _as_int(
168
- _config_get_raw(config, ("csls_prefetch_k",), defaults["csls_prefetch_k"]),
169
- defaults["csls_prefetch_k"],
170
- ),
171
- "bidirectional_consistency": _as_bool(
172
- _config_get_raw(config, ("bidirectional_consistency", "bidirectional"), defaults["bidirectional_consistency"]),
173
- defaults["bidirectional_consistency"],
174
- ),
175
- "use_surface_forms": _as_bool(
176
- _config_get_raw(config, ("surface_forms_enabled", "use_surface_forms"), defaults["use_surface_forms"]),
177
- defaults["use_surface_forms"],
178
- ),
179
- "hide_stopwords": _as_bool(
180
- _config_get_raw(
181
- config,
182
- ("target_stopwords_filtered_in_translation_candidates", "hide_stopwords"),
183
- defaults["hide_stopwords"],
184
- ),
185
- defaults["hide_stopwords"],
186
- ),
187
- "min_token_length": _as_int(
188
- _config_get_raw(config, ("target_is_good_token_min_len", "min_token_length"), defaults["min_token_length"]),
189
- defaults["min_token_length"],
190
- ),
191
- }
192
- )
193
- return defaults
194
-
195
-
196
- DEFAULTS = dict(BASE_DEFAULTS)
197
-
198
- STOPWORDS = {
199
- "a",
200
- "an",
201
- "and",
202
- "are",
203
- "as",
204
- "at",
205
- "be",
206
- "by",
207
- "das",
208
- "de",
209
- "del",
210
- "der",
211
- "des",
212
- "die",
213
- "du",
214
- "e",
215
- "el",
216
- "en",
217
- "es",
218
- "et",
219
- "for",
220
- "from",
221
- "he",
222
- "het",
223
- "i",
224
- "ich",
225
- "il",
226
- "in",
227
- "is",
228
- "it",
229
- "la",
230
- "las",
231
- "le",
232
- "les",
233
- "lo",
234
- "los",
235
- "of",
236
- "on",
237
- "or",
238
- "que",
239
- "she",
240
- "the",
241
- "to",
242
- "un",
243
- "una",
244
- "und",
245
- "une",
246
- "von",
247
- "was",
248
- "we",
249
- "you",
250
- }
251
-
252
-
253
- class ArtifactError(RuntimeError):
254
- """Raised when the Space cannot resolve, download, or load artifacts."""
255
-
256
-
257
- @dataclass(frozen=True)
258
- class ArtifactPaths:
259
- s3_uri: str
260
- local_dir: Path
261
- faiss_path: Path
262
- metadata_path: Path
263
- config_path: Path
264
-
265
-
266
- @dataclass
267
- class SpaceData:
268
- artifact_uri: str
269
- local_dir: Path
270
- config: dict[str, Any]
271
- id_to_meta: dict[int, dict[str, Any]]
272
- languages: list[str]
273
- lang_to_ids: dict[str, np.ndarray]
274
- lang_to_matrix: dict[str, np.ndarray]
275
- lang_to_index: dict[str, Any]
276
- id_to_lang_local: dict[int, tuple[str, int]]
277
- lookup: dict[str, dict[str, dict[str, list[int]]]]
278
- fuzzy_choices: dict[str, list[str]]
279
- vector_dim: int
280
- vector_source: str
281
- vocab_sizes: dict[str, int]
282
- stopwords: dict[str, set[str]]
283
- csls_avg_cache: dict[tuple[int, str, int], float]
284
-
285
- def vector_for_id(self, vector_id: int) -> np.ndarray:
286
- lang, local_idx = self.id_to_lang_local[int(vector_id)]
287
- return self.lang_to_matrix[lang][local_idx]
288
-
289
-
290
- _SPACE_CACHE: dict[str, SpaceData] = {}
291
- _LOAD_LOCK = threading.Lock()
292
-
293
-
294
- def _progress(progress: gr.Progress | None, value: float, message: str) -> None:
295
- if progress is not None:
296
- progress(value, desc=message)
297
-
298
-
299
- def _normalize_text(text: Any) -> str:
300
- if text is None:
301
- return ""
302
- normalized = unicodedata.normalize("NFKC", str(text))
303
- return " ".join(normalized.strip().casefold().split())
304
-
305
-
306
- def _display_value(value: Any) -> str:
307
- if value is None:
308
- return ""
309
- if isinstance(value, float) and math.isnan(value):
310
- return ""
311
- return str(value)
312
-
313
-
314
- def _parse_s3_uri(uri: str) -> tuple[str, str]:
315
- parsed = urlparse(uri)
316
- if parsed.scheme != "s3" or not parsed.netloc:
317
- raise ArtifactError(f"Expected an S3 URI like s3://bucket/prefix/, got: {uri}")
318
- prefix = parsed.path.lstrip("/")
319
- return parsed.netloc, prefix
320
-
321
-
322
- def _join_s3(base_uri: str, filename: str) -> str:
323
- return f"{base_uri.rstrip('/')}/{filename}"
324
-
325
-
326
- def _normalize_endpoint_url(host: str) -> str | None:
327
- host = host.strip()
328
- if not host:
329
- return None
330
- if host.startswith(("http://", "https://")):
331
- return host
332
- return f"https://{host}"
333
-
334
-
335
- def _s3_client():
336
- region = os.getenv("AWS_DEFAULT_REGION") or "us-east-1"
337
- access_key = os.getenv(SE_ACCESS_KEY_ENV) or os.getenv("AWS_ACCESS_KEY_ID")
338
- secret_key = os.getenv(SE_SECRET_KEY_ENV) or os.getenv("AWS_SECRET_ACCESS_KEY")
339
- endpoint_url = _normalize_endpoint_url(os.getenv(SE_HOST_ENV, ""))
340
-
341
- kwargs: dict[str, Any] = {"region_name": region}
342
- if access_key and secret_key:
343
- kwargs["aws_access_key_id"] = access_key
344
- kwargs["aws_secret_access_key"] = secret_key
345
- if endpoint_url:
346
- kwargs["endpoint_url"] = endpoint_url
347
-
348
- return boto3.session.Session(region_name=region).client("s3", **kwargs)
349
-
350
-
351
- def _credential_hint() -> str:
352
- return (
353
- "Set SE_ACCESS_KEY, SE_SECRET_KEY, and SE_HOST as Hugging Face Space secrets. "
354
- "AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY are also supported as a fallback."
355
- )
356
-
357
-
358
- def _is_multilingual_space_prefix(prefix: str) -> bool:
359
- name = prefix.rstrip("/").split("/")[-1]
360
- return fnmatch.fnmatch(name, "multilingual_space_*.json")
361
-
362
-
363
- def _timestamp_key_from_prefix(prefix: str) -> str:
364
- name = prefix.rstrip("/").split("/")[-1]
365
- match = re.search(r"multilingual_space_(.+)\.json$", name)
366
- return match.group(1) if match else name
367
-
368
-
369
- def _normalize_artifact_uri(uri: str) -> str:
370
- uri = uri.strip()
371
- if uri and not uri.endswith("/"):
372
- uri += "/"
373
- return uri
374
-
375
-
376
- def _artifact_timestamp(uri: str) -> str:
377
- return _timestamp_key_from_prefix(urlparse(uri).path.rstrip("/").split("/")[-1])
378
-
379
-
380
- def _artifact_label(uri: str) -> str:
381
- timestamp = _artifact_timestamp(uri)
382
- match = re.fullmatch(r"(\d{8})_(\d{6})", timestamp)
383
- if not match:
384
- return uri.rstrip("/").split("/")[-1]
385
- date_part, time_part = match.groups()
386
- return (
387
- f"{date_part[:4]}-{date_part[4:6]}-{date_part[6:8]} "
388
- f"{time_part[:2]}:{time_part[2:4]}:{time_part[4:6]}"
389
- )
390
-
391
-
392
- def _artifact_dropdown_choices(uris: list[str]) -> list[tuple[str, str]]:
393
- return [(f"{_artifact_label(uri)} | {uri.rstrip('/').split('/')[-1]}", uri) for uri in uris]
394
-
395
-
396
- def _list_artifact_uris(progress: gr.Progress | None = None) -> list[str]:
397
- bucket, base_prefix = _parse_s3_uri(BASE_ARTIFACT_S3_URI)
398
- if base_prefix and not base_prefix.endswith("/"):
399
- base_prefix += "/"
400
-
401
- _progress(progress, 0.05, "Listing multilingual_space_*.json artifacts")
402
- client = _s3_client()
403
- candidates: dict[str, Any] = {}
404
-
405
- try:
406
- paginator = client.get_paginator("list_objects_v2")
407
- for page in paginator.paginate(Bucket=bucket, Prefix=base_prefix, Delimiter="/"):
408
- for item in page.get("CommonPrefixes", []):
409
- prefix = item.get("Prefix", "")
410
- if _is_multilingual_space_prefix(prefix):
411
- candidates[prefix] = None
412
-
413
- if not candidates:
414
- for page in paginator.paginate(Bucket=bucket, Prefix=base_prefix):
415
- for obj in page.get("Contents", []):
416
- key = obj.get("Key", "")
417
- parts = key.split("/")
418
- for idx, part in enumerate(parts):
419
- if fnmatch.fnmatch(part, "multilingual_space_*.json"):
420
- prefix = "/".join(parts[: idx + 1]) + "/"
421
- last_modified = obj.get("LastModified")
422
- if prefix not in candidates or (
423
- last_modified and candidates[prefix] and last_modified > candidates[prefix]
424
- ):
425
- candidates[prefix] = last_modified
426
- elif prefix not in candidates:
427
- candidates[prefix] = last_modified
428
- break
429
- except (NoCredentialsError, PartialCredentialsError) as exc:
430
- raise ArtifactError(f"S3 credentials are missing or incomplete. {_credential_hint()}") from exc
431
- except ClientError as exc:
432
- code = exc.response.get("Error", {}).get("Code", "unknown")
433
- raise ArtifactError(f"Could not list {BASE_ARTIFACT_S3_URI} ({code}). {_credential_hint()}") from exc
434
-
435
- if not candidates:
436
- raise ArtifactError(f"No artifact folder matching multilingual_space_*.json/ was found under {BASE_ARTIFACT_S3_URI}")
437
-
438
- def sort_key(item: tuple[str, Any]) -> tuple[str, str]:
439
- prefix, last_modified = item
440
- modified_key = last_modified.isoformat() if last_modified else ""
441
- return (_timestamp_key_from_prefix(prefix), modified_key)
442
-
443
- prefixes = [prefix for prefix, _ in sorted(candidates.items(), key=sort_key, reverse=True)]
444
- return [f"s3://{bucket}/{prefix}" for prefix in prefixes]
445
-
446
-
447
- def _find_newest_artifact_uri(progress: gr.Progress | None = None) -> str:
448
- return _list_artifact_uris(progress)[0]
449
-
450
-
451
- def _resolve_artifact_options(progress: gr.Progress | None = None) -> tuple[list[str], str]:
452
- override_uri = _normalize_artifact_uri(os.getenv(ARTIFACT_ENV_VAR, "").strip())
453
- try:
454
- uris = _list_artifact_uris(progress)
455
- except ArtifactError:
456
- if not override_uri:
457
- raise
458
- return [override_uri], override_uri
459
-
460
- if override_uri and override_uri not in uris:
461
- uris.insert(0, override_uri)
462
- selected_uri = override_uri or uris[0]
463
- return uris, selected_uri
464
-
465
-
466
- def _local_cache_dir_for_uri(s3_uri: str) -> Path:
467
- digest = hashlib.sha256(s3_uri.encode("utf-8")).hexdigest()[:16]
468
- name = s3_uri.rstrip("/").split("/")[-1]
469
- safe_name = re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
470
- return CACHE_ROOT / f"{safe_name}_{digest}"
471
-
472
-
473
- def _download_file_if_missing(s3_uri: str, local_path: Path) -> None:
474
- if local_path.exists() and local_path.stat().st_size > 0:
475
- return
476
-
477
- bucket, key = _parse_s3_uri(s3_uri)
478
- local_path.parent.mkdir(parents=True, exist_ok=True)
479
- client = _s3_client()
480
- try:
481
- client.download_file(bucket, key, str(local_path))
482
- except (NoCredentialsError, PartialCredentialsError) as exc:
483
- raise ArtifactError(f"S3 credentials are missing or incomplete while downloading {s3_uri}. {_credential_hint()}") from exc
484
- except ClientError as exc:
485
- code = exc.response.get("Error", {}).get("Code", "unknown")
486
- raise ArtifactError(f"Could not download {s3_uri} ({code}). {_credential_hint()}") from exc
487
-
488
-
489
- def _prepare_artifacts(artifact_uri: str | None = None, progress: gr.Progress | None = None) -> ArtifactPaths:
490
- artifact_uri = _normalize_artifact_uri(artifact_uri or "")
491
- if artifact_uri:
492
- _progress(progress, 0.03, f"Using selected artifact {_artifact_label(artifact_uri)}")
493
- else:
494
- artifact_uri = _normalize_artifact_uri(os.getenv(ARTIFACT_ENV_VAR, "").strip())
495
- if not artifact_uri:
496
- artifact_uri = _find_newest_artifact_uri(progress)
497
- artifact_uri = _normalize_artifact_uri(artifact_uri)
498
-
499
- local_dir = _local_cache_dir_for_uri(artifact_uri)
500
- local_dir.mkdir(parents=True, exist_ok=True)
501
-
502
- for idx, filename in enumerate(REQUIRED_FILES):
503
- _progress(progress, 0.10 + idx * 0.07, f"Checking {filename}")
504
- _download_file_if_missing(_join_s3(artifact_uri, filename), local_dir / filename)
505
-
506
- return ArtifactPaths(
507
- s3_uri=artifact_uri,
508
- local_dir=local_dir,
509
- faiss_path=local_dir / "aligned_all.faiss",
510
- metadata_path=local_dir / "all_metadata.jsonl",
511
- config_path=local_dir / "config.json",
512
- )
513
-
514
-
515
- def _load_json(path: Path) -> dict[str, Any]:
516
- with path.open("r", encoding="utf-8") as handle:
517
- return json.load(handle)
518
-
519
-
520
- def _load_metadata(path: Path) -> list[dict[str, Any]]:
521
- rows: list[dict[str, Any]] = []
522
- with path.open("r", encoding="utf-8") as handle:
523
- for line_no, line in enumerate(handle, start=1):
524
- if not line.strip():
525
- continue
526
- try:
527
- row = json.loads(line)
528
- except json.JSONDecodeError as exc:
529
- raise ArtifactError(f"Invalid JSON in {path.name} at line {line_no}: {exc}") from exc
530
- if "id" not in row or "lang" not in row:
531
- raise ArtifactError(f"Metadata line {line_no} is missing required fields: id/lang")
532
- row["id"] = int(row["id"])
533
- row["lang"] = str(row["lang"])
534
- rows.append(row)
535
- if not rows:
536
- raise ArtifactError(f"{path.name} is empty")
537
- return rows
538
-
539
-
540
- def _reconstruct_vectors(index: Any) -> np.ndarray:
541
- n_vectors = int(index.ntotal)
542
- dim = int(index.d)
543
- if n_vectors <= 0:
544
- raise ArtifactError("FAISS index contains no vectors")
545
-
546
- try:
547
- vectors = index.reconstruct_n(0, n_vectors)
548
- vectors = np.asarray(vectors, dtype=np.float32)
549
- if vectors.shape == (n_vectors, dim):
550
- return vectors
551
- except Exception:
552
- pass
553
-
554
- try:
555
- vectors = np.empty((n_vectors, dim), dtype=np.float32)
556
- index.reconstruct_n(0, n_vectors, vectors)
557
- return vectors
558
- except Exception:
559
- pass
560
-
561
- try:
562
- rows = [np.asarray(index.reconstruct(i), dtype=np.float32) for i in range(n_vectors)]
563
- return np.vstack(rows).astype(np.float32, copy=False)
564
- except Exception as exc:
565
- raise ArtifactError("FAISS vectors could not be reconstructed") from exc
566
-
567
-
568
- def _load_vec_fallback(paths: ArtifactPaths, expected_count: int) -> np.ndarray:
569
- vec_path = paths.local_dir / OPTIONAL_VEC_FILE
570
- _download_file_if_missing(_join_s3(paths.s3_uri, OPTIONAL_VEC_FILE), vec_path)
571
-
572
- vectors: list[np.ndarray] = []
573
- expected_dim: int | None = None
574
- with vec_path.open("r", encoding="utf-8", errors="replace") as handle:
575
- first = handle.readline()
576
- parts = first.strip().split()
577
- has_header = len(parts) == 2 and all(part.isdigit() for part in parts)
578
- if has_header:
579
- expected_dim = int(parts[1])
580
- else:
581
- handle.seek(0)
582
-
583
- for line_no, line in enumerate(handle, start=2 if has_header else 1):
584
- parts = line.rstrip("\n").split()
585
- if not parts:
586
- continue
587
- if expected_dim is None:
588
- expected_dim = len(parts) - 1
589
- values = parts[-expected_dim:]
590
- try:
591
- vectors.append(np.asarray(values, dtype=np.float32))
592
- except ValueError as exc:
593
- raise ArtifactError(f"Could not parse vector values in {vec_path.name} at line {line_no}") from exc
594
-
595
- if len(vectors) != expected_count:
596
- raise ArtifactError(
597
- f"{vec_path.name} contains {len(vectors):,} vectors, but metadata/FAISS expects {expected_count:,}"
598
- )
599
- return np.vstack(vectors).astype(np.float32, copy=False)
600
-
601
-
602
- def _l2_normalize(matrix: np.ndarray) -> np.ndarray:
603
- matrix = np.asarray(matrix, dtype=np.float32)
604
- norms = np.linalg.norm(matrix, axis=1, keepdims=True)
605
- norms[norms == 0.0] = 1.0
606
- matrix /= norms
607
- return matrix
608
-
609
-
610
- def _language_order(metadata_rows: list[dict[str, Any]], config: dict[str, Any]) -> list[str]:
611
- config_languages = _config_get(config, ("languages", "langs", "language_codes"), None)
612
- if isinstance(config_languages, dict):
613
- ordered = [str(key) for key in config_languages.keys()]
614
- elif isinstance(config_languages, list):
615
- ordered = [str(item) for item in config_languages]
616
- else:
617
- ordered = []
618
-
619
- metadata_languages = sorted({row["lang"] for row in metadata_rows})
620
- for lang in metadata_languages:
621
- if lang not in ordered:
622
- ordered.append(lang)
623
- return ordered
624
-
625
-
626
- def _build_lookup(metadata_rows: list[dict[str, Any]]) -> dict[str, dict[str, dict[str, list[int]]]]:
627
- lookup: dict[str, dict[str, dict[str, list[int]]]] = defaultdict(lambda: {"token": defaultdict(list), "surface": defaultdict(list)})
628
- for row in metadata_rows:
629
- lang = row["lang"]
630
- vector_id = int(row["id"])
631
- for field in ("token", "surface"):
632
- value = _normalize_text(row.get(field, ""))
633
- if value:
634
- lookup[lang][field][value].append(vector_id)
635
- return {
636
- lang: {
637
- field: {key: sorted(ids) for key, ids in field_map.items()}
638
- for field, field_map in maps.items()
639
- }
640
- for lang, maps in lookup.items()
641
- }
642
-
643
-
644
- def _stopwords_from_config(config: dict[str, Any], languages: list[str]) -> dict[str, set[str]]:
645
- stopwords: dict[str, set[str]] = {}
646
- raw_stopwords = _config_get(config, ("stopwords",), None)
647
- if isinstance(raw_stopwords, dict):
648
- for lang, values in raw_stopwords.items():
649
- if isinstance(values, list):
650
- stopwords.setdefault(str(lang), set()).update(
651
- _normalize_text(value) for value in values if _normalize_text(value)
652
- )
653
- for lang in languages:
654
- stopwords.setdefault(lang, set()).update(STOPWORDS)
655
- return stopwords
656
-
657
-
658
- def _build_space(paths: ArtifactPaths, progress: gr.Progress | None = None) -> SpaceData:
659
- if faiss is None:
660
- raise ArtifactError(f"faiss-cpu could not be imported: {FAISS_IMPORT_ERROR}")
661
-
662
- _progress(progress, 0.34, "Loading config and metadata")
663
- config = _load_json(paths.config_path)
664
- metadata_rows = _load_metadata(paths.metadata_path)
665
- metadata_rows.sort(key=lambda row: int(row["id"]))
666
-
667
- id_to_meta = {int(row["id"]): row for row in metadata_rows}
668
- if len(id_to_meta) != len(metadata_rows):
669
- raise ArtifactError("Metadata contains duplicate vector ids")
670
- expected_ids = sorted(id_to_meta)
671
- expected_count = len(expected_ids)
672
-
673
- _progress(progress, 0.45, "Loading FAISS index")
674
- index = faiss.read_index(str(paths.faiss_path))
675
- if int(index.ntotal) != expected_count:
676
- raise ArtifactError(
677
- f"FAISS index has {int(index.ntotal):,} vectors but metadata has {expected_count:,} rows"
678
- )
679
-
680
- _progress(progress, 0.56, "Reconstructing aligned vectors from FAISS")
681
- vector_source = "faiss"
682
- try:
683
- vectors = _reconstruct_vectors(index)
684
- except ArtifactError:
685
- _progress(progress, 0.56, "FAISS reconstruction failed; downloading aligned_all.vec fallback")
686
- vector_source = "aligned_all.vec"
687
- vectors = _load_vec_fallback(paths, expected_count)
688
-
689
- if vectors.shape[0] != expected_count:
690
- raise ArtifactError(f"Vector matrix has {vectors.shape[0]:,} rows but metadata has {expected_count:,}")
691
- if expected_ids[0] != 0 or expected_ids[-1] >= vectors.shape[0]:
692
- raise ArtifactError("Metadata ids must be contiguous FAISS vector ids starting at 0")
693
-
694
- _progress(progress, 0.70, "Normalizing vectors and building language indexes")
695
- vectors = _l2_normalize(vectors)
696
- languages = _language_order(metadata_rows, config)
697
- lang_to_ids: dict[str, np.ndarray] = {}
698
- lang_to_matrix: dict[str, np.ndarray] = {}
699
- lang_to_index: dict[str, Any] = {}
700
- id_to_lang_local: dict[int, tuple[str, int]] = {}
701
-
702
- for lang in languages:
703
- ids = np.asarray([int(row["id"]) for row in metadata_rows if row["lang"] == lang], dtype=np.int64)
704
- if ids.size == 0:
705
- continue
706
- lang_matrix = np.ascontiguousarray(vectors[ids], dtype=np.float32)
707
- lang_index = faiss.IndexFlatIP(lang_matrix.shape[1])
708
- lang_index.add(lang_matrix)
709
- lang_to_ids[lang] = ids
710
- lang_to_matrix[lang] = lang_matrix
711
- lang_to_index[lang] = lang_index
712
- for local_idx, vector_id in enumerate(ids.tolist()):
713
- id_to_lang_local[int(vector_id)] = (lang, local_idx)
714
-
715
- languages = [lang for lang in languages if lang in lang_to_ids]
716
- lookup = _build_lookup(metadata_rows)
717
- fuzzy_choices = {
718
- lang: sorted(set(lookup.get(lang, {}).get("token", {})) | set(lookup.get(lang, {}).get("surface", {})))
719
- for lang in languages
720
- }
721
- vocab_sizes = {lang: int(lang_to_ids[lang].size) for lang in languages}
722
- stopwords = _stopwords_from_config(config, languages)
723
- vector_dim = int(next(iter(lang_to_matrix.values())).shape[1])
724
-
725
- del vectors
726
- _progress(progress, 0.92, "Ready")
727
- return SpaceData(
728
- artifact_uri=paths.s3_uri,
729
- local_dir=paths.local_dir,
730
- config=config,
731
- id_to_meta=id_to_meta,
732
- languages=languages,
733
- lang_to_ids=lang_to_ids,
734
- lang_to_matrix=lang_to_matrix,
735
- lang_to_index=lang_to_index,
736
- id_to_lang_local=id_to_lang_local,
737
- lookup=lookup,
738
- fuzzy_choices=fuzzy_choices,
739
- vector_dim=vector_dim,
740
- vector_source=vector_source,
741
- vocab_sizes=vocab_sizes,
742
- stopwords=stopwords,
743
- csls_avg_cache={},
744
- )
745
-
746
-
747
- def get_space(artifact_uri: str | None = None, progress: gr.Progress | None = None) -> SpaceData:
748
- artifact_uri = _normalize_artifact_uri(artifact_uri or os.getenv(ARTIFACT_ENV_VAR, "").strip())
749
- if not artifact_uri:
750
- artifact_uri = _find_newest_artifact_uri(progress)
751
- artifact_uri = _normalize_artifact_uri(artifact_uri)
752
-
753
- if artifact_uri in _SPACE_CACHE:
754
- return _SPACE_CACHE[artifact_uri]
755
-
756
- with _LOAD_LOCK:
757
- if artifact_uri in _SPACE_CACHE:
758
- return _SPACE_CACHE[artifact_uri]
759
- _progress(progress, 0.01, "Preparing multilingual embedding artifacts")
760
- paths = _prepare_artifacts(artifact_uri, progress)
761
- _SPACE_CACHE[artifact_uri] = _build_space(paths, progress)
762
- return _SPACE_CACHE[artifact_uri]
763
-
764
-
765
- def _meta_display(meta: dict[str, Any]) -> dict[str, Any]:
766
- return {
767
- "id": int(meta.get("id", -1)),
768
- "lang": _display_value(meta.get("lang")),
769
- "token": _display_value(meta.get("token")),
770
- "surface": _display_value(meta.get("surface")),
771
- "source_vec_file": _display_value(meta.get("source_vec_file")),
772
- }
773
-
774
-
775
- def _candidate_dataframe(space: SpaceData, ids: list[int], match_type: str) -> pd.DataFrame:
776
- rows = []
777
- for vector_id in ids[:25]:
778
- meta = _meta_display(space.id_to_meta[int(vector_id)])
779
- meta["match_type"] = match_type
780
- rows.append(meta)
781
- return pd.DataFrame(rows, columns=["match_type", "id", "lang", "token", "surface", "source_vec_file"])
782
-
783
-
784
- def _resolve_query(
785
- space: SpaceData,
786
- word: str,
787
- source_lang: str,
788
- use_surface_forms: bool,
789
- fuzzy_fallback: bool,
790
- ) -> tuple[int | None, pd.DataFrame, pd.DataFrame, str]:
791
- normalized = _normalize_text(word)
792
- if not normalized:
793
- return None, pd.DataFrame(), pd.DataFrame(), "Enter a query word."
794
- if source_lang not in space.languages:
795
- return None, pd.DataFrame(), pd.DataFrame(), f"Source language '{source_lang}' is not available."
796
-
797
- exact_ids: list[int] = []
798
- match_type = ""
799
- lang_lookup = space.lookup.get(source_lang, {})
800
-
801
- if use_surface_forms:
802
- exact_ids = list(lang_lookup.get("surface", {}).get(normalized, []))
803
- match_type = "surface"
804
-
805
- if not exact_ids:
806
- exact_ids = list(lang_lookup.get("token", {}).get(normalized, []))
807
- match_type = "token"
808
-
809
- if exact_ids:
810
- candidates = _candidate_dataframe(space, exact_ids, f"exact_{match_type}")
811
- chosen_id = int(exact_ids[0])
812
- chosen = space.id_to_meta[chosen_id]
813
- if len(exact_ids) > 1:
814
- message = (
815
- f"Using exact {match_type} match `{_display_value(chosen.get(match_type, chosen.get('token')))}` "
816
- f"(id {chosen_id}); {len(exact_ids)} exact candidates found."
817
- )
818
- else:
819
- message = f"Using exact {match_type} match `{_display_value(chosen.get(match_type, chosen.get('token')))}`."
820
- return chosen_id, candidates, pd.DataFrame(), message
821
-
822
- suggestions = _fuzzy_suggestions(space, normalized, source_lang) if fuzzy_fallback else pd.DataFrame()
823
- if fuzzy_fallback and suggestions.empty:
824
- message = "No exact match found, and no fuzzy suggestions were available."
825
- elif fuzzy_fallback:
826
- message = "No exact match found. Pick or type one of the fuzzy suggestions."
827
- else:
828
- message = "No exact match found. Enable fuzzy fallback to see suggestions."
829
- return None, pd.DataFrame(), suggestions, message
830
-
831
-
832
- def _fuzzy_suggestions(space: SpaceData, normalized_word: str, lang: str, limit: int = 10) -> pd.DataFrame:
833
- if process is None or fuzz is None:
834
- return pd.DataFrame([{"suggestion": "rapidfuzz is not installed", "score": "", "token": "", "surface": "", "id": ""}])
835
-
836
- choices = space.fuzzy_choices.get(lang, [])
837
- if not choices:
838
- return pd.DataFrame()
839
-
840
- matches = process.extract(normalized_word, choices, scorer=fuzz.WRatio, limit=limit)
841
- rows = []
842
- for suggestion, score, _ in matches:
843
- ids = (
844
- space.lookup.get(lang, {}).get("surface", {}).get(suggestion)
845
- or space.lookup.get(lang, {}).get("token", {}).get(suggestion)
846
- or []
847
- )
848
- meta = space.id_to_meta[ids[0]] if ids else {}
849
- rows.append(
850
- {
851
- "suggestion": suggestion,
852
- "score": round(float(score), 2),
853
- "token": _display_value(meta.get("token")),
854
- "surface": _display_value(meta.get("surface")),
855
- "id": int(meta["id"]) if meta else "",
856
- }
857
- )
858
- return pd.DataFrame(rows, columns=["suggestion", "score", "token", "surface", "id"])
859
-
860
-
861
- def _avg_topk(space: SpaceData, vectors: np.ndarray, lang: str, k: int) -> np.ndarray:
862
- index = space.lang_to_index.get(lang)
863
- if index is None or int(index.ntotal) == 0:
864
- return np.zeros((vectors.shape[0],), dtype=np.float32)
865
- k = max(1, min(int(k), int(index.ntotal)))
866
- distances, _ = index.search(np.ascontiguousarray(vectors, dtype=np.float32), k)
867
- return distances.mean(axis=1).astype(np.float32)
868
-
869
-
870
- def _avg_topk_for_ids(space: SpaceData, vector_ids: list[int], search_lang: str, k: int) -> np.ndarray:
871
- values = np.empty((len(vector_ids),), dtype=np.float32)
872
- missing_positions: list[int] = []
873
- missing_ids: list[int] = []
874
- k = int(k)
875
-
876
- for pos, vector_id in enumerate(vector_ids):
877
- cache_key = (int(vector_id), search_lang, k)
878
- cached = space.csls_avg_cache.get(cache_key)
879
- if cached is None:
880
- missing_positions.append(pos)
881
- missing_ids.append(int(vector_id))
882
- else:
883
- values[pos] = cached
884
-
885
- if missing_ids:
886
- vectors = np.vstack([space.vector_for_id(vector_id) for vector_id in missing_ids]).astype(np.float32, copy=False)
887
- computed = _avg_topk(space, vectors, search_lang, k)
888
- for pos, vector_id, value in zip(missing_positions, missing_ids, computed):
889
- float_value = float(value)
890
- space.csls_avg_cache[(int(vector_id), search_lang, k)] = float_value
891
- values[pos] = float_value
892
-
893
- return values
894
-
895
-
896
- def _raw_candidates(
897
- space: SpaceData,
898
- query_vector: np.ndarray,
899
- source_lang: str,
900
- target_lang: str,
901
- retrieval_k: int,
902
- csls_k: int,
903
- score_method: str,
904
- query_id: int | None = None,
905
- ) -> list[dict[str, Any]]:
906
- if target_lang not in space.lang_to_index:
907
- return []
908
-
909
- index = space.lang_to_index[target_lang]
910
- retrieval_k = max(1, min(int(retrieval_k), int(index.ntotal)))
911
- query_matrix = np.ascontiguousarray(query_vector.reshape(1, -1), dtype=np.float32)
912
- distances, local_indices = index.search(query_matrix, retrieval_k)
913
- local_indices = local_indices[0]
914
- cosines = distances[0].astype(np.float32)
915
- valid = local_indices >= 0
916
- if not np.any(valid):
917
- return []
918
-
919
- local_indices = local_indices[valid].astype(np.int64)
920
- cosines = cosines[valid]
921
- global_ids = space.lang_to_ids[target_lang][local_indices]
922
- candidate_vectors = space.lang_to_matrix[target_lang][local_indices]
923
-
924
- if score_method.casefold() == "csls":
925
- if query_id is None:
926
- r_q = float(_avg_topk(space, query_matrix, target_lang, csls_k)[0])
927
- else:
928
- r_q = float(_avg_topk_for_ids(space, [int(query_id)], target_lang, csls_k)[0])
929
- r_x = _avg_topk_for_ids(space, global_ids.astype(int).tolist(), source_lang, csls_k)
930
- scores = (2.0 * cosines) - r_q - r_x
931
- else:
932
- scores = cosines
933
-
934
- rows = []
935
- for local_idx, vector_id, score, cosine in zip(local_indices, global_ids, scores, cosines):
936
- rows.append(
937
- {
938
- "id": int(vector_id),
939
- "local_idx": int(local_idx),
940
- "score": float(score),
941
- "cosine": float(cosine),
942
- }
943
- )
944
- rows.sort(key=lambda row: row["score"], reverse=True)
945
- return rows
946
-
947
-
948
- def _is_filtered_word(space: SpaceData, meta: dict[str, Any], hide_stopwords: bool, min_token_length: int) -> bool:
949
- token = _normalize_text(meta.get("token", ""))
950
- surface = _normalize_text(meta.get("surface", ""))
951
- candidate = surface or token
952
- compact = candidate.replace(" ", "")
953
- if min_token_length and len(compact) < int(min_token_length):
954
- return True
955
- lang_stopwords = space.stopwords.get(str(meta.get("lang")), STOPWORDS)
956
- if hide_stopwords and (token in lang_stopwords or surface in lang_stopwords or candidate in lang_stopwords):
957
- return True
958
- return False
959
-
960
-
961
- def _reverse_contains_source(
962
- space: SpaceData,
963
- target_id: int,
964
- source_id: int,
965
- source_lang: str,
966
- target_lang: str,
967
- top_k: int,
968
- candidate_multiplier: int,
969
- prefetch_k: int,
970
- min_score: float,
971
- csls_k: int,
972
- score_method: str,
973
- ) -> bool:
974
- target_vector = space.vector_for_id(target_id)
975
- retrieval_k = max(1, int(top_k) * int(candidate_multiplier))
976
- reverse_rows = _raw_candidates(
977
- space=space,
978
- query_vector=target_vector,
979
- source_lang=target_lang,
980
- target_lang=source_lang,
981
- retrieval_k=retrieval_k,
982
- csls_k=csls_k,
983
- score_method=score_method,
984
- query_id=target_id,
985
- )
986
- reverse_rows = [row for row in reverse_rows if row["score"] >= float(min_score)]
987
- reverse_rows = reverse_rows[: max(1, int(top_k) * int(candidate_multiplier))]
988
- reverse_ids = {row["id"] for row in reverse_rows}
989
- if int(source_id) in reverse_ids:
990
- return True
991
-
992
- source_meta = space.id_to_meta[int(source_id)]
993
- source_token = _normalize_text(source_meta.get("token"))
994
- source_surface = _normalize_text(source_meta.get("surface"))
995
- for row in reverse_rows:
996
- meta = space.id_to_meta[row["id"]]
997
- if _normalize_text(meta.get("token")) == source_token:
998
- return True
999
- if source_surface and _normalize_text(meta.get("surface")) == source_surface:
1000
- return True
1001
- return False
1002
-
1003
-
1004
- def _translate_one_target(
1005
- space: SpaceData,
1006
- source_id: int,
1007
- target_lang: str,
1008
- top_k: int,
1009
- min_score: float,
1010
- csls_k: int,
1011
- candidate_multiplier: int,
1012
- prefetch_k: int,
1013
- score_method: str,
1014
- bidirectional_consistency: bool,
1015
- hide_stopwords: bool,
1016
- min_token_length: int,
1017
- ) -> list[dict[str, Any]]:
1018
- source_meta = space.id_to_meta[int(source_id)]
1019
- source_lang = source_meta["lang"]
1020
- source_vector = space.vector_for_id(source_id)
1021
- retrieval_k = max(int(prefetch_k), int(top_k) * int(candidate_multiplier))
1022
- raw_rows = _raw_candidates(
1023
- space=space,
1024
- query_vector=source_vector,
1025
- source_lang=source_lang,
1026
- target_lang=target_lang,
1027
- retrieval_k=retrieval_k,
1028
- csls_k=csls_k,
1029
- score_method=score_method,
1030
- query_id=source_id,
1031
- )
1032
-
1033
- results = []
1034
- candidate_budget = max(int(top_k), int(top_k) * int(candidate_multiplier))
1035
- for row in raw_rows[:candidate_budget]:
1036
- if row["score"] < float(min_score):
1037
- continue
1038
- target_meta = space.id_to_meta[row["id"]]
1039
- if _is_filtered_word(space, target_meta, hide_stopwords, min_token_length):
1040
- continue
1041
- if bidirectional_consistency:
1042
- passed = _reverse_contains_source(
1043
- space=space,
1044
- target_id=row["id"],
1045
- source_id=source_id,
1046
- source_lang=source_lang,
1047
- target_lang=target_lang,
1048
- top_k=top_k,
1049
- candidate_multiplier=candidate_multiplier,
1050
- prefetch_k=prefetch_k,
1051
- min_score=min_score,
1052
- csls_k=csls_k,
1053
- score_method=score_method,
1054
- )
1055
- if not passed:
1056
- continue
1057
- bidirectional_value: Any = True
1058
- else:
1059
- bidirectional_value = "not_checked"
1060
-
1061
- results.append(
1062
- {
1063
- "source word": _display_value(source_meta.get("surface") or source_meta.get("token")),
1064
- "source language": source_lang,
1065
- "target language": target_lang,
1066
- "translated token": _display_value(target_meta.get("token")),
1067
- "translated surface": _display_value(target_meta.get("surface")),
1068
- "score": round(float(row["score"]), 6),
1069
- "cosine": round(float(row["cosine"]), 6),
1070
- "rank": len(results) + 1,
1071
- "bidirectional_passed": bidirectional_value,
1072
- "target source_vec_file": _display_value(target_meta.get("source_vec_file")),
1073
- }
1074
- )
1075
- if len(results) >= int(top_k):
1076
- break
1077
- return results
1078
-
1079
-
1080
- def translate(
1081
- artifact_uri: str,
1082
- word: str,
1083
- source_lang: str,
1084
- target_langs: list[str] | None,
1085
- top_k: int,
1086
- min_score: float,
1087
- csls_k: int,
1088
- candidate_multiplier: int,
1089
- prefetch_k: int,
1090
- score_method: str,
1091
- bidirectional_consistency: bool,
1092
- use_surface_forms: bool,
1093
- hide_stopwords: bool,
1094
- min_token_length: int,
1095
- fuzzy_fallback: bool,
1096
- progress: gr.Progress = gr.Progress(),
1097
- ):
1098
- try:
1099
- space = get_space(artifact_uri, progress)
1100
- source_id, candidates, suggestions, message = _resolve_query(
1101
- space, word, source_lang, use_surface_forms, fuzzy_fallback
1102
- )
1103
- if source_id is None:
1104
- return (
1105
- pd.DataFrame(columns=_translation_columns()),
1106
- "No translation run because the query word was not found.",
1107
- candidates,
1108
- suggestions,
1109
- message,
1110
- )
1111
-
1112
- selected_targets = [lang for lang in (target_langs or []) if lang in space.languages]
1113
- if not selected_targets:
1114
- selected_targets = [lang for lang in space.languages if lang != source_lang]
1115
-
1116
- rows: list[dict[str, Any]] = []
1117
- for target_lang in selected_targets:
1118
- rows.extend(
1119
- _translate_one_target(
1120
- space=space,
1121
- source_id=source_id,
1122
- target_lang=target_lang,
1123
- top_k=top_k,
1124
- min_score=min_score,
1125
- csls_k=csls_k,
1126
- candidate_multiplier=candidate_multiplier,
1127
- prefetch_k=prefetch_k,
1128
- score_method=score_method,
1129
- bidirectional_consistency=bidirectional_consistency,
1130
- hide_stopwords=hide_stopwords,
1131
- min_token_length=min_token_length,
1132
- )
1133
- )
1134
-
1135
- table = pd.DataFrame(rows, columns=_translation_columns())
1136
- grouped = _group_translation_markdown(table)
1137
- return table, grouped, candidates, suggestions, message
1138
- except Exception as exc:
1139
- return (
1140
- pd.DataFrame(columns=_translation_columns()),
1141
- "Translation failed.",
1142
- pd.DataFrame(),
1143
- pd.DataFrame(),
1144
- f"Error: {exc}",
1145
- )
1146
-
1147
-
1148
- def _translation_columns() -> list[str]:
1149
- return [
1150
- "source word",
1151
- "source language",
1152
- "target language",
1153
- "translated token",
1154
- "translated surface",
1155
- "score",
1156
- "cosine",
1157
- "rank",
1158
- "bidirectional_passed",
1159
- "target source_vec_file",
1160
- ]
1161
-
1162
-
1163
- def _group_translation_markdown(table: pd.DataFrame) -> str:
1164
- if table.empty:
1165
- return "No candidates passed the current filters."
1166
-
1167
- lines = []
1168
- for lang, group in table.groupby("target language", sort=False):
1169
- parts = []
1170
- for _, row in group.iterrows():
1171
- label = row["translated surface"] or row["translated token"]
1172
- parts.append(f"{row['rank']}. `{label}` ({row['score']:.3f})")
1173
- lines.append(f"**{lang}**: " + " | ".join(parts))
1174
- return "\n\n".join(lines)
1175
-
1176
-
1177
- def nearest_neighbors(
1178
- artifact_uri: str,
1179
- word: str,
1180
- language: str,
1181
- neighbor_mode: str,
1182
- selected_languages: list[str] | None,
1183
- top_n: int,
1184
- score_method: str,
1185
- min_score: float,
1186
- include_same_language: bool,
1187
- use_surface_forms: bool,
1188
- fuzzy_fallback: bool,
1189
- progress: gr.Progress = gr.Progress(),
1190
- ):
1191
- columns = ["language", "token", "surface", "score", "cosine", "rank", "id", "source_vec_file"]
1192
- try:
1193
- space = get_space(artifact_uri, progress)
1194
- runtime_defaults = _defaults_from_config(space.config)
1195
- source_id, candidates, suggestions, message = _resolve_query(space, word, language, use_surface_forms, fuzzy_fallback)
1196
- if source_id is None:
1197
- hint = suggestions if not suggestions.empty else candidates
1198
- return pd.DataFrame(columns=columns), hint, message
1199
-
1200
- if neighbor_mode == "same language":
1201
- target_languages = [language]
1202
- elif neighbor_mode == "selected languages":
1203
- target_languages = [lang for lang in (selected_languages or []) if lang in space.languages]
1204
- else:
1205
- target_languages = list(space.languages)
1206
-
1207
- if not include_same_language and neighbor_mode != "same language":
1208
- target_languages = [lang for lang in target_languages if lang != language]
1209
- if not target_languages:
1210
- return pd.DataFrame(columns=columns), pd.DataFrame(), "No neighbor languages selected."
1211
-
1212
- source_vector = space.vector_for_id(source_id)
1213
- retrieval_k = max(50, int(top_n) * 3)
1214
- rows = []
1215
- for target_lang in target_languages:
1216
- raw_rows = _raw_candidates(
1217
- space=space,
1218
- query_vector=source_vector,
1219
- source_lang=language,
1220
- target_lang=target_lang,
1221
- retrieval_k=retrieval_k,
1222
- csls_k=runtime_defaults["csls_k"],
1223
- score_method=score_method,
1224
- query_id=source_id,
1225
- )
1226
- for row in raw_rows:
1227
- if row["id"] == int(source_id):
1228
- continue
1229
- if row["score"] < float(min_score):
1230
- continue
1231
- meta = space.id_to_meta[row["id"]]
1232
- rows.append(
1233
- {
1234
- "language": target_lang,
1235
- "token": _display_value(meta.get("token")),
1236
- "surface": _display_value(meta.get("surface")),
1237
- "score": round(float(row["score"]), 6),
1238
- "cosine": round(float(row["cosine"]), 6),
1239
- "rank": 0,
1240
- "id": int(row["id"]),
1241
- "source_vec_file": _display_value(meta.get("source_vec_file")),
1242
- }
1243
- )
1244
- rows.sort(key=lambda row: row["score"], reverse=True)
1245
- rows = rows[: int(top_n)]
1246
- for idx, row in enumerate(rows, start=1):
1247
- row["rank"] = idx
1248
- return pd.DataFrame(rows, columns=columns), candidates, message
1249
- except Exception as exc:
1250
- return pd.DataFrame(columns=columns), pd.DataFrame(), f"Error: {exc}"
1251
-
1252
-
1253
- def browse_vocab(artifact_uri: str, language: str, filter_text: str, limit: int, progress: gr.Progress = gr.Progress()):
1254
- try:
1255
- space = get_space(artifact_uri, progress)
1256
- rows = _browse_rows(space, language, filter_text, int(limit), randomize=False)
1257
- df = pd.DataFrame(rows, columns=_browse_columns())
1258
- return df, df
1259
- except Exception as exc:
1260
- df = pd.DataFrame([{"token": f"Error: {exc}", "surface": "", "id": "", "source_vec_file": ""}])
1261
- return df, df
1262
-
1263
-
1264
- def random_vocab(artifact_uri: str, language: str, limit: int, progress: gr.Progress = gr.Progress()):
1265
- try:
1266
- space = get_space(artifact_uri, progress)
1267
- rows = _browse_rows(space, language, "", int(limit), randomize=True)
1268
- df = pd.DataFrame(rows, columns=_browse_columns())
1269
- return df, df
1270
- except Exception as exc:
1271
- df = pd.DataFrame([{"token": f"Error: {exc}", "surface": "", "id": "", "source_vec_file": ""}])
1272
- return df, df
1273
-
1274
-
1275
- def _browse_columns() -> list[str]:
1276
- return ["token", "surface", "id", "source_vec_file"]
1277
-
1278
-
1279
- def _browse_rows(space: SpaceData, language: str, filter_text: str, limit: int, randomize: bool) -> list[dict[str, Any]]:
1280
- if language not in space.lang_to_ids:
1281
- return []
1282
-
1283
- ids = space.lang_to_ids[language].tolist()
1284
- normalized_filter = _normalize_text(filter_text)
1285
- if randomize and len(ids) > limit:
1286
- ids = random.sample(ids, limit)
1287
-
1288
- rows = []
1289
- for vector_id in ids:
1290
- meta = space.id_to_meta[int(vector_id)]
1291
- token = _display_value(meta.get("token"))
1292
- surface = _display_value(meta.get("surface"))
1293
- if normalized_filter:
1294
- haystack = f"{_normalize_text(token)} {_normalize_text(surface)}"
1295
- if normalized_filter not in haystack:
1296
- continue
1297
- rows.append(
1298
- {
1299
- "token": token,
1300
- "surface": surface,
1301
- "id": int(vector_id),
1302
- "source_vec_file": _display_value(meta.get("source_vec_file")),
1303
- }
1304
- )
1305
- if len(rows) >= limit:
1306
- break
1307
- return rows
1308
-
1309
-
1310
- def use_selected_vocab(table_data: Any, browse_language: str, evt: gr.SelectData):
1311
- try:
1312
- if isinstance(table_data, pd.DataFrame):
1313
- df = table_data
1314
- else:
1315
- df = pd.DataFrame(table_data, columns=_browse_columns())
1316
- row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
1317
- row = df.iloc[int(row_idx)]
1318
- word = row.get("surface") or row.get("token") or ""
1319
- return str(word), gr.update(value=browse_language)
1320
- except Exception:
1321
- return gr.update(), gr.update()
1322
-
1323
-
1324
- def _config_get(config: dict[str, Any], keys: tuple[str, ...], default: Any = "") -> Any:
1325
- return _config_get_raw(config, keys, default)
1326
-
1327
-
1328
- def _format_config_value(value: Any) -> str:
1329
- if value is None:
1330
- return ""
1331
- if isinstance(value, (dict, list)):
1332
- return json.dumps(value, ensure_ascii=False, sort_keys=True)
1333
- return str(value)
1334
-
1335
-
1336
- def artifact_info_markdown(space: SpaceData) -> str:
1337
- runtime_defaults = _defaults_from_config(space.config)
1338
- candidate_retrieval_k = int(runtime_defaults["top_k"]) * int(runtime_defaults["candidate_retrieval_k_multiplier"])
1339
- fields = [
1340
- ("artifact S3 URI", space.artifact_uri),
1341
- ("created_at", _config_get(space.config, ("created_at", "created", "timestamp"), "")),
1342
- ("languages", ", ".join(space.languages)),
1343
- ("pivot_lang", _config_get(space.config, ("pivot_lang", "pivot_language"), runtime_defaults["pivot_lang"])),
1344
- ("vector_dim", _config_get(space.config, ("vector_dim", "dim", "dimension"), space.vector_dim)),
1345
- ("vocab sizes", space.vocab_sizes),
1346
- ("top_n_vocab", runtime_defaults["top_n_vocab"]),
1347
- ("out_top", runtime_defaults["out_top"]),
1348
- ("top_k", runtime_defaults["top_k"]),
1349
- ("min_score", runtime_defaults["min_score"]),
1350
- ("csls_k", runtime_defaults["csls_k"]),
1351
- ("candidate_retrieval_k", _config_get(space.config, ("candidate_retrieval_k",), candidate_retrieval_k)),
1352
- ("candidate multiplier", runtime_defaults["candidate_retrieval_k_multiplier"]),
1353
- ("csls_prefetch_k", runtime_defaults["csls_prefetch_k"]),
1354
- ("align_iters", runtime_defaults["align_iters"]),
1355
- ("init_pairs", runtime_defaults["init_pairs"]),
1356
- ("max_pairs", runtime_defaults["max_pairs"]),
1357
- (
1358
- "bidirectional consistency",
1359
- runtime_defaults["bidirectional_consistency"],
1360
- ),
1361
- ("surface forms enabled", runtime_defaults["use_surface_forms"]),
1362
- ("hide stopwords default", runtime_defaults["hide_stopwords"]),
1363
- ("min token length default", runtime_defaults["min_token_length"]),
1364
- ("vector preprocessing", _config_get(space.config, ("vector_preprocessing", "preprocessing"), "")),
1365
- ("source vec files", _config_get(space.config, ("source_vec_files", "vec_files"), "")),
1366
- ("surface files", _config_get(space.config, ("surface_files",), "")),
1367
- ("local cache", str(space.local_dir)),
1368
- ("vector source", space.vector_source),
1369
- ]
1370
-
1371
- lines = ["| Field | Value |", "| --- | --- |"]
1372
- for field, value in fields:
1373
- lines.append(f"| {field} | {_format_config_value(value)} |")
1374
- return "\n".join(lines)
1375
-
1376
-
1377
- def _empty_artifact_updates(message: str):
1378
- return (
1379
- message,
1380
- gr.update(choices=[], value=None),
1381
- gr.update(choices=[], value=[]),
1382
- gr.update(choices=[], value=None),
1383
- gr.update(choices=[], value=[]),
1384
- gr.update(choices=[], value=None),
1385
- message,
1386
- gr.update(),
1387
- gr.update(),
1388
- gr.update(),
1389
- gr.update(),
1390
- gr.update(),
1391
- gr.update(),
1392
- gr.update(),
1393
- gr.update(),
1394
- gr.update(),
1395
- gr.update(),
1396
- gr.update(),
1397
- )
1398
-
1399
-
1400
- def load_selected_artifact(artifact_uri: str, progress: gr.Progress = gr.Progress()):
1401
- try:
1402
- space = get_space(artifact_uri, progress)
1403
- runtime_defaults = _defaults_from_config(space.config)
1404
- pivot = str(runtime_defaults["pivot_lang"])
1405
- source_default = pivot if pivot in space.languages else space.languages[0]
1406
- targets_default = [lang for lang in space.languages if lang != source_default]
1407
- if not targets_default:
1408
- targets_default = [source_default]
1409
-
1410
- status = (
1411
- f"Loaded {sum(space.vocab_sizes.values()):,} vectors across {len(space.languages)} languages "
1412
- f"from `{space.artifact_uri}`."
1413
- )
1414
- return (
1415
- status,
1416
- gr.update(choices=space.languages, value=source_default),
1417
- gr.update(choices=space.languages, value=targets_default),
1418
- gr.update(choices=space.languages, value=source_default),
1419
- gr.update(choices=space.languages, value=targets_default),
1420
- gr.update(choices=space.languages, value=source_default),
1421
- artifact_info_markdown(space),
1422
- gr.update(value=runtime_defaults["top_k"]),
1423
- gr.update(value=runtime_defaults["min_score"]),
1424
- gr.update(value=runtime_defaults["csls_k"]),
1425
- gr.update(value=min(runtime_defaults["candidate_retrieval_k_multiplier"], INTERACTIVE_DEFAULTS["candidate_retrieval_k_multiplier"])),
1426
- gr.update(value=min(runtime_defaults["csls_prefetch_k"], INTERACTIVE_DEFAULTS["csls_prefetch_k"])),
1427
- gr.update(value=INTERACTIVE_DEFAULTS["bidirectional_consistency"]),
1428
- gr.update(value=runtime_defaults["use_surface_forms"]),
1429
- gr.update(value=runtime_defaults["hide_stopwords"]),
1430
- gr.update(value=runtime_defaults["min_token_length"]),
1431
- gr.update(value=runtime_defaults["min_score"]),
1432
- gr.update(value=runtime_defaults["use_surface_forms"]),
1433
- )
1434
- except Exception as exc:
1435
- return _empty_artifact_updates(f"Startup error: {exc}")
1436
-
1437
-
1438
- def initialize_app(progress: gr.Progress = gr.Progress()):
1439
- try:
1440
- artifact_uris, selected_uri = _resolve_artifact_options(progress)
1441
- artifact_update = gr.update(
1442
- choices=_artifact_dropdown_choices(artifact_uris),
1443
- value=selected_uri,
1444
- interactive=True,
1445
- )
1446
- updates = load_selected_artifact(selected_uri, progress)
1447
- return (updates[0], artifact_update, *updates[1:])
1448
- except Exception as exc:
1449
- message = f"Startup error: {exc}"
1450
- empty_updates = _empty_artifact_updates(message)
1451
- return (empty_updates[0], gr.update(choices=[], value=None, interactive=False), *empty_updates[1:])
1452
-
1453
-
1454
- def update_default_targets(artifact_uri: str, source_language: str):
1455
- try:
1456
- space = get_space(artifact_uri, None)
1457
- targets = [lang for lang in space.languages if lang != source_language]
1458
- return gr.update(choices=space.languages, value=targets or [source_language])
1459
- except Exception:
1460
- return gr.update()
1461
-
1462
-
1463
- CSS = """
1464
- .compact-result table { font-size: 0.92rem; }
1465
- """
1466
-
1467
-
1468
- with gr.Blocks(title="Multilingual Static Word Embeddings") as demo:
1469
- gr.Markdown("## Multilingual Static Word Embeddings Explorer")
1470
- load_status = gr.Markdown("Loading artifacts from S3...")
1471
- artifact_selector = gr.Dropdown(
1472
- label="Aligned space artifact",
1473
- choices=[],
1474
- interactive=True,
1475
- )
1476
-
1477
- with gr.Tabs():
1478
- with gr.Tab("Translate"):
1479
- with gr.Row():
1480
- with gr.Column(scale=1, min_width=280):
1481
- query_word = gr.Textbox(label="Query word", placeholder="Enter a word")
1482
- source_lang = gr.Dropdown(label="Source language", choices=[], interactive=True)
1483
- target_langs = gr.Dropdown(
1484
- label="Target languages",
1485
- choices=[],
1486
- multiselect=True,
1487
- interactive=True,
1488
- )
1489
- translate_button = gr.Button("Translate", variant="primary")
1490
- with gr.Accordion("Retrieval and filters", open=False):
1491
- top_k = gr.Slider(1, 20, value=DEFAULTS["top_k"], step=1, label="top_k")
1492
- min_score = gr.Slider(-2.0, 2.0, value=DEFAULTS["min_score"], step=0.01, label="min_score")
1493
- csls_k = gr.Slider(1, 50, value=DEFAULTS["csls_k"], step=1, label="csls_k")
1494
- candidate_multiplier = gr.Slider(
1495
- 1,
1496
- 10,
1497
- value=DEFAULTS["candidate_retrieval_k_multiplier"],
1498
- step=1,
1499
- label="candidate multiplier",
1500
- )
1501
- prefetch_k = gr.Slider(
1502
- 10,
1503
- 500,
1504
- value=INTERACTIVE_DEFAULTS["csls_prefetch_k"],
1505
- step=10,
1506
- label="FAISS prefetch",
1507
- )
1508
- score_method = gr.Radio(["cosine", "CSLS"], value="cosine", label="score method")
1509
- bidirectional = gr.Checkbox(
1510
- value=INTERACTIVE_DEFAULTS["bidirectional_consistency"],
1511
- label="bidirectional consistency",
1512
- )
1513
- use_surface_forms = gr.Checkbox(value=DEFAULTS["use_surface_forms"], label="use surface forms")
1514
- hide_stopwords = gr.Checkbox(value=DEFAULTS["hide_stopwords"], label="hide stopwords")
1515
- min_token_length = gr.Slider(
1516
- 1,
1517
- 20,
1518
- value=DEFAULTS["min_token_length"],
1519
- step=1,
1520
- label="min token length",
1521
- )
1522
- fuzzy_fallback = gr.Checkbox(value=INTERACTIVE_DEFAULTS["fuzzy_fallback"], label="fuzzy match fallback")
1523
-
1524
- with gr.Column(scale=2, min_width=520):
1525
- translate_message = gr.Markdown()
1526
- grouped_results = gr.Markdown()
1527
- translation_table = gr.Dataframe(
1528
- label="Translations",
1529
- headers=_translation_columns(),
1530
- datatype=["str", "str", "str", "str", "str", "number", "number", "number", "str", "str"],
1531
- wrap=True,
1532
- elem_classes=["compact-result"],
1533
- )
1534
- with gr.Accordion("Source matches and suggestions", open=False):
1535
- match_candidates = gr.Dataframe(label="Exact candidates", wrap=True)
1536
- fuzzy_suggestions = gr.Dataframe(label="Fuzzy suggestions", wrap=True)
1537
-
1538
- with gr.Tab("Nearest Neighbors"):
1539
- with gr.Row():
1540
- with gr.Column(scale=1, min_width=280):
1541
- nn_word = gr.Textbox(label="Word", placeholder="Enter a word")
1542
- nn_language = gr.Dropdown(label="Language", choices=[], interactive=True)
1543
- neighbor_mode = gr.Radio(
1544
- ["same language", "all languages", "selected languages"],
1545
- value="all languages",
1546
- label="Neighbor languages",
1547
- )
1548
- nn_selected_languages = gr.Dropdown(
1549
- label="Selected languages",
1550
- choices=[],
1551
- multiselect=True,
1552
- interactive=True,
1553
- )
1554
- nn_top_n = gr.Slider(1, 100, value=20, step=1, label="top_n")
1555
- nn_score_method = gr.Radio(["cosine", "CSLS"], value="cosine", label="score method")
1556
- nn_min_score = gr.Slider(-2.0, 2.0, value=DEFAULTS["min_score"], step=0.01, label="min score")
1557
- nn_include_same = gr.Checkbox(value=False, label="include same language")
1558
- nn_surface = gr.Checkbox(value=DEFAULTS["use_surface_forms"], label="use surface forms")
1559
- nn_fuzzy = gr.Checkbox(value=INTERACTIVE_DEFAULTS["fuzzy_fallback"], label="fuzzy match fallback")
1560
- nn_button = gr.Button("Find Neighbors", variant="primary")
1561
- with gr.Column(scale=2, min_width=520):
1562
- nn_message = gr.Markdown()
1563
- nn_table = gr.Dataframe(label="Nearest words", wrap=True)
1564
- nn_matches = gr.Dataframe(label="Source match / suggestions", wrap=True)
1565
-
1566
- with gr.Tab("Browse Vocabulary"):
1567
- with gr.Row():
1568
- with gr.Column(scale=1, min_width=280):
1569
- browse_language = gr.Dropdown(label="Language", choices=[], interactive=True)
1570
- browse_filter = gr.Textbox(label="Search/filter", placeholder="token or surface substring")
1571
- browse_limit = gr.Slider(10, 1000, value=100, step=10, label="limit")
1572
- with gr.Row():
1573
- browse_button = gr.Button("Browse", variant="primary")
1574
- random_button = gr.Button("Random Sample")
1575
- with gr.Column(scale=2, min_width=520):
1576
- browse_table = gr.Dataframe(label="Vocabulary", wrap=True)
1577
- browse_state = gr.State(pd.DataFrame(columns=_browse_columns()))
1578
-
1579
- with gr.Tab("Artifact Info"):
1580
- artifact_info = gr.Markdown("Artifact metadata will appear after loading.")
1581
-
1582
- demo.load(
1583
- initialize_app,
1584
- outputs=[
1585
- load_status,
1586
- artifact_selector,
1587
- source_lang,
1588
- target_langs,
1589
- nn_language,
1590
- nn_selected_languages,
1591
- browse_language,
1592
- artifact_info,
1593
- top_k,
1594
- min_score,
1595
- csls_k,
1596
- candidate_multiplier,
1597
- prefetch_k,
1598
- bidirectional,
1599
- use_surface_forms,
1600
- hide_stopwords,
1601
- min_token_length,
1602
- nn_min_score,
1603
- nn_surface,
1604
- ],
1605
- )
1606
-
1607
- artifact_selector.change(
1608
- load_selected_artifact,
1609
- inputs=[artifact_selector],
1610
- outputs=[
1611
- load_status,
1612
- source_lang,
1613
- target_langs,
1614
- nn_language,
1615
- nn_selected_languages,
1616
- browse_language,
1617
- artifact_info,
1618
- top_k,
1619
- min_score,
1620
- csls_k,
1621
- candidate_multiplier,
1622
- prefetch_k,
1623
- bidirectional,
1624
- use_surface_forms,
1625
- hide_stopwords,
1626
- min_token_length,
1627
- nn_min_score,
1628
- nn_surface,
1629
- ],
1630
- )
1631
-
1632
- source_lang.change(update_default_targets, inputs=[artifact_selector, source_lang], outputs=[target_langs])
1633
-
1634
- translate_button.click(
1635
- translate,
1636
- inputs=[
1637
- artifact_selector,
1638
- query_word,
1639
- source_lang,
1640
- target_langs,
1641
- top_k,
1642
- min_score,
1643
- csls_k,
1644
- candidate_multiplier,
1645
- prefetch_k,
1646
- score_method,
1647
- bidirectional,
1648
- use_surface_forms,
1649
- hide_stopwords,
1650
- min_token_length,
1651
- fuzzy_fallback,
1652
- ],
1653
- outputs=[translation_table, grouped_results, match_candidates, fuzzy_suggestions, translate_message],
1654
- )
1655
-
1656
- nn_button.click(
1657
- nearest_neighbors,
1658
- inputs=[
1659
- artifact_selector,
1660
- nn_word,
1661
- nn_language,
1662
- neighbor_mode,
1663
- nn_selected_languages,
1664
- nn_top_n,
1665
- nn_score_method,
1666
- nn_min_score,
1667
- nn_include_same,
1668
- nn_surface,
1669
- nn_fuzzy,
1670
- ],
1671
- outputs=[nn_table, nn_matches, nn_message],
1672
- )
1673
-
1674
- browse_button.click(
1675
- browse_vocab,
1676
- inputs=[artifact_selector, browse_language, browse_filter, browse_limit],
1677
- outputs=[browse_table, browse_state],
1678
- )
1679
- browse_filter.submit(
1680
- browse_vocab,
1681
- inputs=[artifact_selector, browse_language, browse_filter, browse_limit],
1682
- outputs=[browse_table, browse_state],
1683
- )
1684
- random_button.click(
1685
- random_vocab,
1686
- inputs=[artifact_selector, browse_language, browse_limit],
1687
- outputs=[browse_table, browse_state],
1688
- )
1689
- browse_table.select(
1690
- use_selected_vocab,
1691
- inputs=[browse_state, browse_language],
1692
- outputs=[query_word, source_lang],
1693
- )
1694
-
1695
-
1696
- if __name__ == "__main__":
1697
- demo.queue(default_concurrency_limit=4).launch(css=CSS, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- gradio
2
- faiss-cpu
3
- numpy
4
- pandas
5
- boto3
6
- smart_open
7
- python-dotenv
8
- rapidfuzz