"""Backfill lab_color, palette, contrast_axes, scandi_score for existing rows. Reads each scored product's image_path, extracts the LAB palette, derives contrast axes from text fields, scores Scandi-fit, and persists everything to the new columns added in this branch. Idempotent: skips rows that already have a non-NULL `scandi_score` unless `--force` is passed. Usage: uv run --no-project --with numpy --with pillow --with scikit-learn \\ --with "scikit-image>=0.22,<0.25" \\ scripts/backfill_color_contrast.py [--force] """ # /// script # requires-python = ">=3.11" # dependencies = ["numpy", "pillow", "scikit-learn", "scikit-image>=0.22,<0.25"] # /// import argparse import json import sqlite3 import sys import time from pathlib import Path import numpy as np sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from taste.colors import ( # noqa: E402 extract_palette, dominant, palette_to_json, ) from taste.contrast import tag as tag_axes # noqa: E402 from taste.scandi import score_from_palette_and_meta # noqa: E402 DB_PATH = Path.home() / ".taste/taste.db" def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--force", action="store_true", help="Recompute even rows already populated") ap.add_argument("--limit", type=int, default=0, help="Cap number of rows (0 = all)") args = ap.parse_args() if not DB_PATH.exists(): print(f"DB not found: {DB_PATH}", file=sys.stderr) return 1 conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row cols = {r[1] for r in conn.execute("PRAGMA table_info(products)").fetchall()} for col, ddl in [ ("lab_color", "BLOB"), ("palette", "TEXT DEFAULT '[]'"), ("contrast_axes", "TEXT DEFAULT '{}'"), ("scandi_score", "REAL"), ]: if col not in cols: conn.execute(f"ALTER TABLE products ADD COLUMN {col} {ddl}") conn.commit() print(f" added column: {col}", flush=True) where = "image_path != ''" if not args.force: where += " AND (scandi_score IS NULL)" sql = f"SELECT url, brand, name, category, material, color, image_path FROM products WHERE {where}" if args.limit: sql += f" LIMIT {args.limit}" rows = conn.execute(sql).fetchall() total = len(rows) print(f"Backfilling color/contrast/scandi for {total} products...", flush=True) ok = miss = err = 0 batch: list[tuple] = [] BATCH = 200 t0 = time.time() for i, r in enumerate(rows, 1): try: pal = extract_palette(r["image_path"]) dom = dominant(pal) axes = tag_axes(dict(r)) scandi = score_from_palette_and_meta(pal, dict(r)) lab_blob = ( np.asarray(dom, dtype=np.float32).tobytes() if dom is not None else None ) batch.append( ( lab_blob, json.dumps(palette_to_json(pal)), json.dumps(axes), float(scandi), r["url"], ) ) if pal: ok += 1 else: miss += 1 except Exception as e: err += 1 if err < 5: print(f" err on {r['url']}: {e}", file=sys.stderr) if len(batch) >= BATCH: conn.executemany( "UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?", batch, ) conn.commit() batch.clear() elapsed = time.time() - t0 rate = i / elapsed if elapsed > 0 else 0 eta = (total - i) / rate if rate > 0 else 0 print( f" {i}/{total} ok={ok} miss={miss} err={err} " f"rate={rate:.1f}/s eta={eta:.0f}s", flush=True, ) if batch: conn.executemany( "UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?", batch, ) conn.commit() n_with = conn.execute( "SELECT COUNT(*) FROM products WHERE scandi_score IS NOT NULL" ).fetchone()[0] print(f"\nDone in {time.time() - t0:.0f}s. ok={ok} miss={miss} err={err}", flush=True) print(f"products.scandi_score populated for {n_with} rows", flush=True) return 0 if __name__ == "__main__": sys.exit(main())