dynacell / predict_runner.py
shalinmehta's picture
Drop upload; organelle-first; equal-height panels; ODE-step slider replaces GIF; titles with colon
c919cd1 verified
Raw
History Blame Contribute Delete
9.13 kB
"""Download checkpoints from HF Hub, generate configs, run dynacell predict, and generate trajectories."""
from __future__ import annotations
import json
import shutil
import subprocess
import tempfile
import uuid
from pathlib import Path
import spaces
import zarr
from huggingface_hub import hf_hub_download
CHECKPOINT_REPO = "biohub/dynacell-checkpoints"
TEMPLATE_DIR = Path(__file__).parent / "config_templates"
# (model, organelle) → filename in the HF checkpoint repo
CHECKPOINT_FILES: dict[tuple[str, str], str] = {
("celldiff", "CAAX"): "celldiff_caax.ckpt",
("celldiff", "H2B"): "celldiff_h2b.ckpt",
("celldiff", "SEC61B"): "celldiff_sec61b.ckpt",
("celldiff", "TOMM20"): "celldiff_tomm20.ckpt",
("fnet3d", "CAAX"): "fnet3d_caax.ckpt",
("fnet3d", "H2B"): "fnet3d_h2b.ckpt",
("fnet3d", "SEC61B"): "fnet3d_sec61b.ckpt",
("fnet3d", "TOMM20"): "fnet3d_tomm20.ckpt",
("vscyto3d", "CAAX"): "vscyto3d_caax.ckpt",
("vscyto3d", "H2B"): "vscyto3d_h2b.ckpt",
("vscyto3d", "SEC61B"): "vscyto3d_sec61b.ckpt",
("vscyto3d", "TOMM20"): "vscyto3d_tomm20.ckpt",
}
TARGET_CHANNELS: dict[str, str] = {
"CAAX": "Membrane",
"H2B": "Nuclei",
"SEC61B": "Structure",
"TOMM20": "Structure",
}
ORGANELLE_LABELS: dict[str, str] = {
"CAAX": "Membrane (CAAX)",
"H2B": "Chromatin (H2B)",
"SEC61B": "ER (SEC61B)",
"TOMM20": "Mitochondria (TOMM20)",
}
FLUOR_CH = 2 # channel index for fluorescence in the input zarr
# Spectral PCC settings (volumetric; shared with app.py).
SPACING = [0.174, 0.1494, 0.1494]
SPECTRAL_KWARGS = dict(bin_delta=1.0, tail_fraction=0.2, apodization="tukey", nbins_low=3)
# Cache downloaded checkpoints in /tmp so the Space doesn't re-download each run
_ckpt_cache: dict[str, str] = {}
def get_checkpoint(model: str, organelle: str) -> str:
"""Download (or return cached) checkpoint path for a given model + organelle."""
key = (model, organelle)
filename = CHECKPOINT_FILES[key]
if filename not in _ckpt_cache:
print(f"Downloading {filename} from {CHECKPOINT_REPO} ...")
local = hf_hub_download(repo_id=CHECKPOINT_REPO, filename=filename)
_ckpt_cache[filename] = local
return _ckpt_cache[filename]
def preprocess_zarr(data_path: str) -> None:
"""Compute normalization statistics for the uploaded zarr via viscy preprocess."""
subprocess.run(
["viscy", "preprocess", f"--data_path={data_path}", "--num_workers=1", "--block_size=32"],
check=True,
)
def create_single_timepoint_zarr(source_path: str, timepoint: int) -> str:
"""Copy source HCS zarr plate, keeping only the selected timepoint.
Remaps timepoint_statistics in .zattrs so index "0" carries the selected
timepoint's normalization stats (needed by celldiff's MinMaxSampled).
"""
out_path = Path(tempfile.gettempdir()) / f"dynacell_t{timepoint}_{uuid.uuid4().hex[:8]}.zarr"
shutil.copytree(source_path, str(out_path))
src_store = zarr.open(source_path, mode="r")
dst_store = zarr.open(str(out_path), mode="r+")
def _trim(src_grp: zarr.Group, dst_grp: zarr.Group) -> None:
for key in list(src_grp.keys()):
item = src_grp[key]
if isinstance(item, zarr.Array) and key == "0":
# Write selected timepoint into index 0, then resize to T=1
dst_arr = dst_grp[key]
dst_arr[0] = item[timepoint]
dst_arr.resize((1,) + item.shape[1:])
elif isinstance(item, zarr.Group):
_trim(item, dst_grp[key])
_trim(src_store, dst_store)
# Remap timepoint_statistics["<timepoint>"] → ["0"] in each FOV's .zattrs
def _remap_tp_stats(zattrs_path: Path) -> None:
if not zattrs_path.exists():
return
zattrs = json.loads(zattrs_path.read_text())
norm = zattrs.get("normalization", {})
changed = False
for ch_data in norm.values():
if "timepoint_statistics" in ch_data:
tp_stats = ch_data["timepoint_statistics"]
t_key = str(timepoint)
if t_key in tp_stats:
ch_data["timepoint_statistics"] = {"0": tp_stats[t_key]}
changed = True
if changed:
zattrs_path.write_text(json.dumps(zattrs))
for row in out_path.iterdir():
if not row.is_dir():
continue
for col in row.iterdir():
if not col.is_dir():
continue
for fov in col.iterdir():
if fov.is_dir():
_remap_tp_stats(fov / ".zattrs")
return str(out_path)
@spaces.GPU(duration=120)
def run_prediction(model: str, organelle: str, data_path: str, timepoint: int) -> str:
"""Run prediction for a single timepoint; return the output zarr path.
Creates a single-timepoint subset of the source zarr, runs prediction on it,
and returns the path to the output zarr (which has T=1). The `dynacell predict`
subprocess inherits the ZeroGPU allocation from this decorated frame.
"""
subset_path = create_single_timepoint_zarr(data_path, timepoint)
ckpt_path = get_checkpoint(model, organelle)
output_dir = Path(tempfile.gettempdir()) / f"dynacell_pred_{uuid.uuid4().hex[:8]}"
output_store = str(output_dir / f"{organelle}_{model}.zarr")
template = (TEMPLATE_DIR / f"{model}.yaml").read_text()
config_text = template.format(
ckpt_path=ckpt_path,
data_path=subset_path,
output_store=output_store,
target_channel=TARGET_CHANNELS[organelle],
)
config_path = Path(tempfile.gettempdir()) / f"dynacell_cfg_{uuid.uuid4().hex[:8]}.yaml"
config_path.write_text(config_text)
print(f"Running dynacell predict: {model} / {organelle} / t={timepoint}")
subprocess.run(["dynacell", "predict", "-c", str(config_path)], check=True)
config_path.unlink(missing_ok=True)
return output_store
@spaces.GPU(duration=120)
def compute_trajectory(
organelle: str,
data_path: str,
timepoint: int = 0,
num_steps: int = 50,
progress=None,
) -> dict:
"""Run the CELL-Diff ODE; save trajectory to /tmp as .npy; return metadata dict.
The returned dict contains everything needed to call render_trajectory_gif
without re-running the ODE.
"""
import numpy as np
import torch
from iohub.ngff import open_ome_zarr
from dynacell.engine import DynacellFlowMatching
from viscy_data._utils import _read_norm_meta
if progress is not None:
progress(0.05, desc="Downloading CELL-Diff checkpoint...")
ckpt_path = get_checkpoint("celldiff", organelle)
if progress is not None:
progress(0.15, desc="Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DynacellFlowMatching.load_from_checkpoint(ckpt_path, map_location=device)
model.eval()
patch_d, patch_h, patch_w = model.model.net.input_spatial_size # (8, 512, 512)
if progress is not None:
progress(0.25, desc="Reading phase data...")
with open_ome_zarr(data_path, mode="r") as plate:
_, pos = next(plate.positions())
phase_ch = pos.get_channel_index("Phase3D")
phase_raw = np.array(pos.data[timepoint, phase_ch])
fluor_raw = np.array(pos.data[timepoint, FLUOR_CH])
norm_meta = _read_norm_meta(pos)
tp_stats = norm_meta["Phase3D"]["timepoint_statistics"][str(timepoint)]
lo = tp_stats["p1"].item()
hi = tp_stats["p99"].item()
phase_norm = np.clip(phase_raw.astype(np.float32), lo, hi)
phase_norm = 2.0 * (phase_norm - lo) / (hi - lo + 1e-8) - 1.0
z_total = phase_norm.shape[0]
z_start = (z_total - patch_d) // 2
phase_crop = phase_norm[z_start:z_start + patch_d, :patch_h, :patch_w]
# Raw phase + experimental fluorescence over the same window, for the
# display panels and the per-step Spectral PCC.
phase_disp = phase_raw[z_start:z_start + patch_d, :patch_h, :patch_w].astype(np.float32)
gt_crop = fluor_raw[z_start:z_start + patch_d, :patch_h, :patch_w].astype(np.float32)
if progress is not None:
progress(0.35, desc=f"Generating {num_steps}-step ODE trajectory...")
phase_tensor = (
torch.from_numpy(phase_crop).float()
.unsqueeze(0).unsqueeze(0)
.to(device)
)
with torch.no_grad():
trajectory = model.model.generate_trajectory(phase_tensor, num_steps=num_steps)
traj_np = trajectory[:, 0].cpu().numpy().astype(np.float32) # (num_steps, 1, D, H, W)
if progress is not None:
progress(0.90, desc="Saving trajectory to disk...")
traj_path = str(Path(tempfile.gettempdir()) / f"traj_{uuid.uuid4().hex[:8]}.npz")
np.savez(traj_path, traj=traj_np, phase=phase_disp, gt=gt_crop)
if progress is not None:
progress(1.0, desc="Done.")
return {
"traj_path": traj_path,
"z_start": z_start,
"patch_d": patch_d,
"organelle": organelle,
"timepoint": timepoint,
"num_steps": num_steps,
}