""" Rebuild (pre-warm) embeddings for ALL organisations in platform.db. Run this after any bulk import or whenever you want to pre-compute embeddings so the first chatbot query is instant. Usage: python rebuild_embeddings.py python rebuild_embeddings.py --org-id 2 (single org) """ import argparse, json, os, sqlite3, time, re, unicodedata import torch from sentence_transformers import SentenceTransformer BASE_DIR = os.path.abspath(os.path.dirname(__file__)) DB_PATH = os.path.join(BASE_DIR, "platform.db") EMB_DIR = os.path.join(BASE_DIR, "embeddings") os.makedirs(EMB_DIR, exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--org-id", type=int, default=None, help="Rebuild only this org (default: all)") args = parser.parse_args() # ── Text helpers (must match app.py) ───────────────────────────────────────── def normalise(text: str) -> str: text = text.lower().strip() text = unicodedata.normalize("NFD", text) text = "".join(c for c in text if unicodedata.category(c) != "Mn") text = re.sub(r"[^\w\s]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text def clean_question(q: str) -> str: return re.sub(r'\s+\d+\s*$', '', (q or '').strip()) def build_index_text(row) -> str: q = clean_question(row["question"] or "") proc = (row["processus"] or "").strip() procd = (row["procedure"] or "").strip() parts = [q] if proc: parts.append(proc) if procd: parts.append(procd) return " | ".join(parts) # ── Load model ─────────────────────────────────────────────────────────────── print("🔄 Chargement du modèle…") try: model = SentenceTransformer( "OrdalieTech/Solon-embeddings-mini-beta-1.1", device="cpu", trust_remote_code=True ) print("✓ Modèle principal (Solon) chargé") except Exception as e: print(f"⚠ Modèle principal échoué ({e}), bascule sur secours…") model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device="cpu") print("✓ Modèle de secours chargé") # ── DB helpers ──────────────────────────────────────────────────────────────── conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row if args.org_id: orgs = conn.execute( "SELECT id, name FROM organisations WHERE id=?", (args.org_id,) ).fetchall() else: orgs = conn.execute("SELECT id, name FROM organisations").fetchall() if not orgs: print("Aucune organisation trouvée dans la base de données.") raise SystemExit(0) total_built = 0 for org in orgs: oid = org["id"] name = org["name"] rows = conn.execute( "SELECT id, question, processus, procedure, intent, sub_intent, response " "FROM entries WHERE org_id=? ORDER BY id", (oid,) ).fetchall() if not rows: print(f" [{name}] 0 entrées — ignoré") continue row_ids = [r["id"] for r in rows] texts = [build_index_text(r) for r in rows] emb_file = os.path.join(EMB_DIR, f"org_{oid}.pt") meta_file = os.path.join(EMB_DIR, f"org_{oid}_meta.json") # Remove stale files for f in [emb_file, meta_file]: if os.path.exists(f): os.remove(f) print(f" [{name}] {len(texts)} questions → calcul embeddings…", end=" ", flush=True) t0 = time.time() emb = model.encode(texts, convert_to_tensor=True, normalize_embeddings=True, batch_size=64, show_progress_bar=False) torch.save(emb, emb_file) with open(meta_file, "w") as mf: json.dump({"row_ids": row_ids, "org_id": oid, "org_name": name, "v": 4}, mf) elapsed = time.time() - t0 print(f"✓ ({elapsed:.1f}s) → org_{oid}.pt") total_built += 1 conn.close() print(f"\n✅ {total_built} organisation(s) reconstruite(s). Redémarrez Flask pour appliquer.")