Spaces:
Sleeping
Sleeping
| """Backfill lab_color, palette, contrast_axes, scandi_score for existing rows. | |
| Reads each scored product's image_path, extracts the LAB palette, derives | |
| contrast axes from text fields, scores Scandi-fit, and persists everything to | |
| the new columns added in this branch. | |
| Idempotent: skips rows that already have a non-NULL `scandi_score` unless | |
| `--force` is passed. | |
| Usage: | |
| uv run --no-project --with numpy --with pillow --with scikit-learn \\ | |
| --with "scikit-image>=0.22,<0.25" \\ | |
| scripts/backfill_color_contrast.py [--force] | |
| """ | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = ["numpy", "pillow", "scikit-learn", "scikit-image>=0.22,<0.25"] | |
| # /// | |
| import argparse | |
| import json | |
| import sqlite3 | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from taste.colors import ( # noqa: E402 | |
| extract_palette, | |
| dominant, | |
| palette_to_json, | |
| ) | |
| from taste.contrast import tag as tag_axes # noqa: E402 | |
| from taste.scandi import score_from_palette_and_meta # noqa: E402 | |
| DB_PATH = Path.home() / ".taste/taste.db" | |
| def main() -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--force", action="store_true", help="Recompute even rows already populated") | |
| ap.add_argument("--limit", type=int, default=0, help="Cap number of rows (0 = all)") | |
| args = ap.parse_args() | |
| if not DB_PATH.exists(): | |
| print(f"DB not found: {DB_PATH}", file=sys.stderr) | |
| return 1 | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cols = {r[1] for r in conn.execute("PRAGMA table_info(products)").fetchall()} | |
| for col, ddl in [ | |
| ("lab_color", "BLOB"), | |
| ("palette", "TEXT DEFAULT '[]'"), | |
| ("contrast_axes", "TEXT DEFAULT '{}'"), | |
| ("scandi_score", "REAL"), | |
| ]: | |
| if col not in cols: | |
| conn.execute(f"ALTER TABLE products ADD COLUMN {col} {ddl}") | |
| conn.commit() | |
| print(f" added column: {col}", flush=True) | |
| where = "image_path != ''" | |
| if not args.force: | |
| where += " AND (scandi_score IS NULL)" | |
| sql = f"SELECT url, brand, name, category, material, color, image_path FROM products WHERE {where}" | |
| if args.limit: | |
| sql += f" LIMIT {args.limit}" | |
| rows = conn.execute(sql).fetchall() | |
| total = len(rows) | |
| print(f"Backfilling color/contrast/scandi for {total} products...", flush=True) | |
| ok = miss = err = 0 | |
| batch: list[tuple] = [] | |
| BATCH = 200 | |
| t0 = time.time() | |
| for i, r in enumerate(rows, 1): | |
| try: | |
| pal = extract_palette(r["image_path"]) | |
| dom = dominant(pal) | |
| axes = tag_axes(dict(r)) | |
| scandi = score_from_palette_and_meta(pal, dict(r)) | |
| lab_blob = ( | |
| np.asarray(dom, dtype=np.float32).tobytes() if dom is not None else None | |
| ) | |
| batch.append( | |
| ( | |
| lab_blob, | |
| json.dumps(palette_to_json(pal)), | |
| json.dumps(axes), | |
| float(scandi), | |
| r["url"], | |
| ) | |
| ) | |
| if pal: | |
| ok += 1 | |
| else: | |
| miss += 1 | |
| except Exception as e: | |
| err += 1 | |
| if err < 5: | |
| print(f" err on {r['url']}: {e}", file=sys.stderr) | |
| if len(batch) >= BATCH: | |
| conn.executemany( | |
| "UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?", | |
| batch, | |
| ) | |
| conn.commit() | |
| batch.clear() | |
| elapsed = time.time() - t0 | |
| rate = i / elapsed if elapsed > 0 else 0 | |
| eta = (total - i) / rate if rate > 0 else 0 | |
| print( | |
| f" {i}/{total} ok={ok} miss={miss} err={err} " | |
| f"rate={rate:.1f}/s eta={eta:.0f}s", | |
| flush=True, | |
| ) | |
| if batch: | |
| conn.executemany( | |
| "UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?", | |
| batch, | |
| ) | |
| conn.commit() | |
| n_with = conn.execute( | |
| "SELECT COUNT(*) FROM products WHERE scandi_score IS NOT NULL" | |
| ).fetchone()[0] | |
| print(f"\nDone in {time.time() - t0:.0f}s. ok={ok} miss={miss} err={err}", flush=True) | |
| print(f"products.scandi_score populated for {n_with} rows", flush=True) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |