taste-engine / scripts /backfill_color_contrast.py
mucahitkantepe's picture
add LAB color extraction, contrast axes, scandi scorer
83c6b69
Raw
History Blame Contribute Delete
4.54 kB
"""Backfill lab_color, palette, contrast_axes, scandi_score for existing rows.
Reads each scored product's image_path, extracts the LAB palette, derives
contrast axes from text fields, scores Scandi-fit, and persists everything to
the new columns added in this branch.
Idempotent: skips rows that already have a non-NULL `scandi_score` unless
`--force` is passed.
Usage:
uv run --no-project --with numpy --with pillow --with scikit-learn \\
--with "scikit-image>=0.22,<0.25" \\
scripts/backfill_color_contrast.py [--force]
"""
# /// script
# requires-python = ">=3.11"
# dependencies = ["numpy", "pillow", "scikit-learn", "scikit-image>=0.22,<0.25"]
# ///
import argparse
import json
import sqlite3
import sys
import time
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from taste.colors import ( # noqa: E402
extract_palette,
dominant,
palette_to_json,
)
from taste.contrast import tag as tag_axes # noqa: E402
from taste.scandi import score_from_palette_and_meta # noqa: E402
DB_PATH = Path.home() / ".taste/taste.db"
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--force", action="store_true", help="Recompute even rows already populated")
ap.add_argument("--limit", type=int, default=0, help="Cap number of rows (0 = all)")
args = ap.parse_args()
if not DB_PATH.exists():
print(f"DB not found: {DB_PATH}", file=sys.stderr)
return 1
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cols = {r[1] for r in conn.execute("PRAGMA table_info(products)").fetchall()}
for col, ddl in [
("lab_color", "BLOB"),
("palette", "TEXT DEFAULT '[]'"),
("contrast_axes", "TEXT DEFAULT '{}'"),
("scandi_score", "REAL"),
]:
if col not in cols:
conn.execute(f"ALTER TABLE products ADD COLUMN {col} {ddl}")
conn.commit()
print(f" added column: {col}", flush=True)
where = "image_path != ''"
if not args.force:
where += " AND (scandi_score IS NULL)"
sql = f"SELECT url, brand, name, category, material, color, image_path FROM products WHERE {where}"
if args.limit:
sql += f" LIMIT {args.limit}"
rows = conn.execute(sql).fetchall()
total = len(rows)
print(f"Backfilling color/contrast/scandi for {total} products...", flush=True)
ok = miss = err = 0
batch: list[tuple] = []
BATCH = 200
t0 = time.time()
for i, r in enumerate(rows, 1):
try:
pal = extract_palette(r["image_path"])
dom = dominant(pal)
axes = tag_axes(dict(r))
scandi = score_from_palette_and_meta(pal, dict(r))
lab_blob = (
np.asarray(dom, dtype=np.float32).tobytes() if dom is not None else None
)
batch.append(
(
lab_blob,
json.dumps(palette_to_json(pal)),
json.dumps(axes),
float(scandi),
r["url"],
)
)
if pal:
ok += 1
else:
miss += 1
except Exception as e:
err += 1
if err < 5:
print(f" err on {r['url']}: {e}", file=sys.stderr)
if len(batch) >= BATCH:
conn.executemany(
"UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?",
batch,
)
conn.commit()
batch.clear()
elapsed = time.time() - t0
rate = i / elapsed if elapsed > 0 else 0
eta = (total - i) / rate if rate > 0 else 0
print(
f" {i}/{total} ok={ok} miss={miss} err={err} "
f"rate={rate:.1f}/s eta={eta:.0f}s",
flush=True,
)
if batch:
conn.executemany(
"UPDATE products SET lab_color = ?, palette = ?, contrast_axes = ?, scandi_score = ? WHERE url = ?",
batch,
)
conn.commit()
n_with = conn.execute(
"SELECT COUNT(*) FROM products WHERE scandi_score IS NOT NULL"
).fetchone()[0]
print(f"\nDone in {time.time() - t0:.0f}s. ok={ok} miss={miss} err={err}", flush=True)
print(f"products.scandi_score populated for {n_with} rows", flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())