shalinmehta commited on
Commit
87d1bc9
·
verified ·
1 Parent(s): 88f8ce1

ZeroGPU-adapted DynaCell demo, repointed to biohub repos

Browse files
README.md CHANGED
@@ -1,13 +1,41 @@
1
  ---
2
- title: Dynacell
3
- emoji: 💻
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.16.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DynaCell Virtual Staining Demo
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "5.29.0"
 
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: zero-a10g
11
+ models:
12
+ - biohub/dynacell-checkpoints
13
+ datasets:
14
+ - biohub/dynacell-demo-data
15
  ---
16
 
17
+ # DynaCell Virtual Staining Demo
18
+
19
+ Predict fluorescence channels (membrane, nuclei, or organelle structure) from phase-contrast OME-Zarr using three models:
20
+
21
+ - **CELL-Diff** — flow-matching diffusion model
22
+ - **FNet3D** — 3-D U-Net (FNet architecture)
23
+ - **VSCyto3D** — masked-autoencoder pretrained U-Net
24
+
25
+ ## Quick start
26
+
27
+ 1. Select an organelle from the dropdown.
28
+ 2. Click **Load Demo Data** to fetch the matching A549-cell demo dataset directly into the Space — no download/upload needed.
29
+ 3. Run predictions in **Tab 1** or generate the CELL-Diff ODE trajectory in **Tab 2**.
30
+
31
+ ## Using your own data
32
+
33
+ The input must be an OME-Zarr HCS store zipped into a single `.zip` file, with layout:
34
+
35
+ ```
36
+ your_data.zarr/
37
+ 0/0/fov0000/0 # array shape (T, C, Z, Y, X)
38
+ # C[0] = Phase3D, Z = 16, YX = 512×512
39
+ ```
40
+
41
+ Use [iohub](https://github.com/czbiohub-sf/iohub) to create compatible zarr stores.
app.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DynaCell Virtual Staining Demo — Gradio Space.
2
+
3
+ Upload a zipped OME-Zarr HCS store once; then:
4
+ Tab 1 — run CELL-Diff / FNet3D / VSCyto3D predictions on a selected timepoint,
5
+ view a chosen Z slice, and see Spectral PCC metrics.
6
+ Tab 2 — visualize the CELL-Diff ODE denoising trajectory as an animated GIF,
7
+ with a Phase | Exp reference panel at the selected timepoint and Z slice.
8
+ Changing the Z-slice slider re-renders the GIF instantly from cached data.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import sys
14
+ import tempfile
15
+ import zipfile
16
+ from pathlib import Path
17
+
18
+ import gradio as gr
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ from iohub.ngff import open_ome_zarr
22
+
23
+ sys.path.insert(0, str(Path(__file__).parent))
24
+ from predict_runner import (
25
+ ORGANELLE_LABELS, TARGET_CHANNELS,
26
+ preprocess_zarr, run_prediction,
27
+ compute_trajectory, render_trajectory_gif,
28
+ )
29
+
30
+ from cubic.metrics.bandlimited import spectral_pcc
31
+
32
+ ORGANELLES = ["CAAX", "H2B", "SEC61B", "TOMM20"]
33
+ MODEL_KEYS = ["celldiff", "fnet3d", "vscyto3d"]
34
+ MODEL_LABELS = {"celldiff": "CELL-Diff", "fnet3d": "FNet3D", "vscyto3d": "VSCyto3D"}
35
+ PHASE_CH = 0
36
+ FLUOR_CH = 2
37
+ _DEMO_REPO = "biohub/dynacell-demo-data"
38
+
39
+ PATCH_D = 8 # fixed Z window used by all trajectory models
40
+
41
+ SPACING = [0.174, 0.1494, 0.1494]
42
+ SPECTRAL_KWARGS = dict(bin_delta=1.0, tail_fraction=0.2, apodization="tukey", nbins_low=3)
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Helpers
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def extract_zarr_zip(zip_path: str) -> str:
50
+ """Extract uploaded zip to a fresh temp dir; return the HCS zarr root path."""
51
+ import json
52
+ tmpdir = Path(tempfile.mkdtemp())
53
+ with zipfile.ZipFile(zip_path, "r") as z:
54
+ z.extractall(tmpdir)
55
+ for candidate in sorted(tmpdir.rglob(".zattrs")):
56
+ root = candidate.parent
57
+ try:
58
+ zattrs = json.loads((root / ".zattrs").read_text())
59
+ if "plate" in zattrs:
60
+ return str(root)
61
+ except Exception:
62
+ pass
63
+ for d in sorted(tmpdir.iterdir()):
64
+ if d.is_dir():
65
+ return str(d)
66
+ raise ValueError("No zarr store found in zip.")
67
+
68
+
69
+ def get_data_shape(data_path: str) -> tuple[int, int]:
70
+ """Return (n_timepoints, n_z_slices) from the first position in the plate."""
71
+ with open_ome_zarr(data_path, mode="r") as plate:
72
+ _, pos = next(plate.positions())
73
+ return pos.data.shape[0], pos.data.shape[2]
74
+
75
+
76
+ def percentile_norm(img: np.ndarray, lo: float = 0.5, hi: float = 99.5) -> np.ndarray:
77
+ p_lo, p_hi = np.percentile(img, [lo, hi])
78
+ if p_hi == p_lo:
79
+ return np.zeros_like(img, dtype=np.float32)
80
+ return np.clip((img - p_lo) / (p_hi - p_lo), 0, 1).astype(np.float32)
81
+
82
+
83
+ def compute_spectral_pcc(pred_zarr_path: str, gt_fluor_vol: np.ndarray) -> float | None:
84
+ """Spectral PCC between the prediction (t=0) and the GT fluorescence volume."""
85
+ try:
86
+ with open_ome_zarr(pred_zarr_path, mode="r") as pred_plate:
87
+ _, pred_pos = next(pred_plate.positions())
88
+ pred_vol = np.array(pred_pos.data[0, 0], dtype=np.float32)
89
+ return float(spectral_pcc(pred_vol, gt_fluor_vol, spacing=SPACING, **SPECTRAL_KWARGS))
90
+ except Exception as e:
91
+ print(f"spectral_pcc failed: {e}")
92
+ return None
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Data loaders
97
+ # ---------------------------------------------------------------------------
98
+
99
+ def _make_slider_updates(data_path: str, organelle: str) -> tuple:
100
+ """Read data shape and return slider updates + Phase|Exp figure."""
101
+ n_tp, n_z = get_data_shape(data_path)
102
+ z_mid = n_z // 2
103
+ fig = render_phase_exp_traj(data_path, 0, PATCH_D // 2, organelle)
104
+ return (
105
+ gr.Slider(minimum=0, maximum=n_tp - 1, step=1, value=0), # timepoint_slider
106
+ gr.Slider(minimum=0, maximum=n_z - 1, step=1, value=z_mid), # z_slice_slider
107
+ gr.Slider(minimum=0, maximum=n_tp - 1, step=1, value=0), # traj_timepoint
108
+ fig, # traj_static
109
+ n_tp, n_z,
110
+ )
111
+
112
+
113
+ def load_demo_data(organelle: str, progress=gr.Progress()) -> tuple:
114
+ """Download the demo zarr, extract it, and return updated UI state."""
115
+ from huggingface_hub import hf_hub_download
116
+ filename = f"{organelle}_mock.zarr.zip"
117
+ progress(0.1, desc=f"Downloading {organelle} demo data...")
118
+ zip_path = hf_hub_download(repo_id=_DEMO_REPO, filename=filename, repo_type="dataset")
119
+ progress(0.8, desc="Extracting zarr...")
120
+ data_path = extract_zarr_zip(zip_path)
121
+ tp_sl, z_sl, traj_tp, fig, n_tp, n_z = _make_slider_updates(data_path, organelle)
122
+ progress(1.0, desc="Ready.")
123
+ status = f"**Loaded:** {filename} (A549 cells, {n_tp} timepoints, {n_z} Z slices)"
124
+ return data_path, status, tp_sl, z_sl, traj_tp, fig
125
+
126
+
127
+ def on_upload(file, organelle: str) -> tuple:
128
+ """Handle zarr zip upload: extract, read shape, update UI state."""
129
+ if file is None:
130
+ raise gr.Error("No file uploaded.")
131
+ zip_path = file if isinstance(file, str) else file.name
132
+ data_path = extract_zarr_zip(zip_path)
133
+ tp_sl, z_sl, traj_tp, fig, n_tp, n_z = _make_slider_updates(data_path, organelle)
134
+ status = f"**Uploaded:** {Path(zip_path).name} ({n_tp} timepoints, {n_z} Z slices)"
135
+ return data_path, status, tp_sl, z_sl, traj_tp, fig
136
+
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # Tab 2: Phase | Exp reference panel
140
+ # ---------------------------------------------------------------------------
141
+
142
+ def render_phase_exp(
143
+ zarr_state: str | None,
144
+ timepoint: int,
145
+ z_slice: int,
146
+ organelle: str,
147
+ ) -> plt.Figure | None:
148
+ """Render Phase and Experimental fluorescence side by side at (timepoint, z_slice)."""
149
+ if zarr_state is None:
150
+ return None
151
+ with open_ome_zarr(zarr_state, mode="r") as plate:
152
+ _, pos = next(plate.positions())
153
+ n_tp = pos.data.shape[0]
154
+ n_z = pos.data.shape[2]
155
+ tp = min(timepoint, n_tp - 1)
156
+ z = min(z_slice, n_z - 1)
157
+ phase_img = np.array(pos.data[tp, PHASE_CH, z])
158
+ fluor_img = np.array(pos.data[tp, FLUOR_CH, z])
159
+
160
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.4, 3.2))
161
+ ax1.imshow(percentile_norm(phase_img), cmap="gray")
162
+ ax1.set_title("Phase", fontsize=10)
163
+ ax1.axis("off")
164
+ ax2.imshow(percentile_norm(fluor_img), cmap="gray")
165
+ ax2.set_title(f"Exp ({TARGET_CHANNELS[organelle]})", fontsize=10)
166
+ ax2.axis("off")
167
+ fig.suptitle(
168
+ f"{ORGANELLE_LABELS[organelle]} | t={tp} | z={z}",
169
+ fontsize=11, y=1.01,
170
+ )
171
+ fig.tight_layout()
172
+ return fig
173
+
174
+
175
+ def render_phase_exp_traj(
176
+ zarr_state: str | None,
177
+ timepoint: int,
178
+ z_patch: int,
179
+ organelle: str,
180
+ ) -> plt.Figure | None:
181
+ """Phase | Exp panel for the trajectory tab.
182
+
183
+ z_patch is a patch-relative index (0 … PATCH_D-1); converted to the
184
+ absolute Z using z_start = (n_z - PATCH_D) // 2.
185
+ """
186
+ if zarr_state is None:
187
+ return None
188
+ with open_ome_zarr(zarr_state, mode="r") as plate:
189
+ _, pos = next(plate.positions())
190
+ n_tp = pos.data.shape[0]
191
+ n_z = pos.data.shape[2]
192
+ tp = min(timepoint, n_tp - 1)
193
+ z_start = (n_z - PATCH_D) // 2
194
+ z_abs = z_start + max(0, min(z_patch, PATCH_D - 1))
195
+ phase_img = np.array(pos.data[tp, PHASE_CH, z_abs])
196
+ fluor_img = np.array(pos.data[tp, FLUOR_CH, z_abs])
197
+
198
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.4, 3.2))
199
+ ax1.imshow(percentile_norm(phase_img), cmap="gray")
200
+ ax1.set_title("Phase", fontsize=10)
201
+ ax1.axis("off")
202
+ ax2.imshow(percentile_norm(fluor_img), cmap="gray")
203
+ ax2.set_title(f"Exp ({TARGET_CHANNELS[organelle]})", fontsize=10)
204
+ ax2.axis("off")
205
+ fig.suptitle(
206
+ f"{ORGANELLE_LABELS[organelle]} | t={tp} | z={z_abs} (patch slice {z_patch})",
207
+ fontsize=11, y=1.01,
208
+ )
209
+ fig.tight_layout()
210
+ return fig
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Tab 1: Virtual Staining
215
+ # ---------------------------------------------------------------------------
216
+
217
+ def render_from_z(
218
+ pred_info: dict | None,
219
+ z_slice: int,
220
+ zarr_state: str | None,
221
+ ) -> plt.Figure | None:
222
+ """Re-render the prediction comparison at a different Z slice."""
223
+ if pred_info is None or zarr_state is None:
224
+ return None
225
+
226
+ organelle = pred_info["organelle"]
227
+ timepoint = pred_info["timepoint"]
228
+ selected_models = pred_info["selected_models"]
229
+ pred_paths = pred_info["paths"]
230
+ pred_pccs = pred_info["pccs"]
231
+ n_z = pred_info["n_z"]
232
+ z = min(z_slice, n_z - 1)
233
+
234
+ with open_ome_zarr(zarr_state, mode="r") as gt_plate:
235
+ _, gt_pos = next(gt_plate.positions())
236
+ phase_img = np.array(gt_pos.data[timepoint, PHASE_CH, z])
237
+ fluor_img = np.array(gt_pos.data[timepoint, FLUOR_CH, z])
238
+
239
+ cols = ["Phase", f"Exp ({TARGET_CHANNELS[organelle]})"] + [MODEL_LABELS[m] for m in selected_models]
240
+ fig, axes = plt.subplots(1, len(cols), figsize=(3.0 * len(cols), 3.2))
241
+ if len(cols) == 1:
242
+ axes = [axes]
243
+
244
+ axes[0].imshow(percentile_norm(phase_img), cmap="gray")
245
+ axes[0].set_title("Phase", fontsize=10)
246
+ axes[1].imshow(percentile_norm(fluor_img), cmap="gray")
247
+ axes[1].set_title(f"Exp ({TARGET_CHANNELS[organelle]})", fontsize=10)
248
+
249
+ for col_idx, model_key in enumerate(selected_models, start=2):
250
+ label = MODEL_LABELS[model_key]
251
+ pred_path = pred_paths.get(model_key)
252
+ pcc = pred_pccs.get(model_key)
253
+ if pred_path is not None:
254
+ try:
255
+ with open_ome_zarr(pred_path, mode="r") as pred_plate:
256
+ _, pred_pos = next(pred_plate.positions())
257
+ img = percentile_norm(np.array(pred_pos.data[0, 0, z]))
258
+ title = f"{label}\nSpectral PCC={pcc:.3f}" if pcc is not None else label
259
+ except Exception as e:
260
+ img = np.zeros_like(phase_img, dtype=np.float32)
261
+ title = f"{label}\n(failed)"
262
+ print(f"Render failed for {model_key}: {e}")
263
+ else:
264
+ img = np.zeros_like(phase_img, dtype=np.float32)
265
+ title = f"{label}\n(failed)"
266
+
267
+ axes[col_idx].imshow(img, cmap="gray")
268
+ axes[col_idx].set_title(title, fontsize=9)
269
+
270
+ for ax in axes:
271
+ ax.axis("off")
272
+ fig.suptitle(
273
+ f"{ORGANELLE_LABELS[organelle]} | t={timepoint} | z={z}",
274
+ fontsize=11, y=1.01,
275
+ )
276
+ fig.tight_layout()
277
+ return fig
278
+
279
+
280
+ def run_demo(
281
+ zarr_zip,
282
+ organelle: str,
283
+ selected_models: list[str],
284
+ timepoint: int,
285
+ z_slice: int,
286
+ zarr_state: str | None,
287
+ progress=gr.Progress(),
288
+ ) -> tuple[plt.Figure | None, list[list], str, dict]:
289
+ if zarr_zip is None and not zarr_state:
290
+ raise gr.Error("Please load demo data or upload a zarr zip file.")
291
+ if not selected_models:
292
+ raise gr.Error("Select at least one model.")
293
+
294
+ if zarr_state:
295
+ data_path = zarr_state
296
+ else:
297
+ progress(0.05, desc="Extracting zarr...")
298
+ zip_path = zarr_zip if isinstance(zarr_zip, str) else zarr_zip.name
299
+ data_path = extract_zarr_zip(zip_path)
300
+
301
+ progress(0.10, desc="Computing normalization statistics...")
302
+ preprocess_zarr(data_path)
303
+
304
+ with open_ome_zarr(data_path, mode="r") as gt_plate:
305
+ _, gt_pos = next(gt_plate.positions())
306
+ n_tp, n_z = gt_pos.data.shape[0], gt_pos.data.shape[2]
307
+ tp = min(timepoint, n_tp - 1)
308
+ gt_fluor_vol = np.array(gt_pos.data[tp, FLUOR_CH], dtype=np.float32)
309
+
310
+ pred_paths: dict[str, str | None] = {}
311
+ pred_pccs: dict[str, float | None] = {}
312
+ n_models = len(selected_models)
313
+ for i, model_key in enumerate(selected_models):
314
+ progress(0.15 + 0.60 * i / n_models, desc=f"Running {MODEL_LABELS[model_key]}...")
315
+ try:
316
+ path = run_prediction(model_key, organelle, data_path, tp)
317
+ pred_paths[model_key] = path
318
+ pred_pccs[model_key] = compute_spectral_pcc(path, gt_fluor_vol)
319
+ except Exception as e:
320
+ pred_paths[model_key] = None
321
+ pred_pccs[model_key] = None
322
+ print(f"Prediction failed for {model_key}: {e}")
323
+
324
+ pred_info = {
325
+ "timepoint": tp, "organelle": organelle,
326
+ "selected_models": selected_models,
327
+ "paths": pred_paths, "pccs": pred_pccs, "n_z": n_z,
328
+ }
329
+
330
+ progress(0.80, desc="Rendering figure...")
331
+ fig = render_from_z(pred_info, min(z_slice, n_z - 1), data_path)
332
+
333
+ metrics_rows = [
334
+ [MODEL_LABELS[m], "failed" if pred_paths.get(m) is None
335
+ else (f"{pred_pccs[m]:.4f}" if pred_pccs.get(m) is not None else "N/A")]
336
+ for m in selected_models
337
+ ]
338
+
339
+ progress(1.0, desc="Done.")
340
+ return fig, metrics_rows, data_path, pred_info
341
+
342
+
343
+ # ---------------------------------------------------------------------------
344
+ # Tab 2: CellDiff Trajectory
345
+ # ---------------------------------------------------------------------------
346
+
347
+ def run_trajectory_demo(
348
+ zarr_zip,
349
+ organelle: str,
350
+ timepoint: int,
351
+ num_steps: int,
352
+ z_slice: int,
353
+ zarr_state: str | None,
354
+ progress=gr.Progress(),
355
+ ) -> tuple[str, str, dict]:
356
+ """Run ODE trajectory, render GIF, cache trajectory data for Z-slice re-renders."""
357
+ if zarr_zip is None and not zarr_state:
358
+ raise gr.Error("Please load demo data or upload a zarr zip file.")
359
+
360
+ if zarr_state:
361
+ data_path = zarr_state
362
+ else:
363
+ progress(0.03, desc="Extracting zarr...")
364
+ zip_path = zarr_zip if isinstance(zarr_zip, str) else zarr_zip.name
365
+ data_path = extract_zarr_zip(zip_path)
366
+
367
+ progress(0.08, desc="Computing normalization statistics...")
368
+ preprocess_zarr(data_path)
369
+
370
+ traj_info = compute_trajectory(organelle, data_path, timepoint, num_steps, progress)
371
+ gif_path = render_trajectory_gif(traj_info, z_slice)
372
+ return gif_path, data_path, traj_info
373
+
374
+
375
+ def rerender_gif(traj_info: dict | None, z_slice: int) -> str | None:
376
+ """Re-render the trajectory GIF at a new Z slice without re-running the ODE."""
377
+ if traj_info is None:
378
+ return None
379
+ return render_trajectory_gif(traj_info, z_slice)
380
+
381
+
382
+ # ---------------------------------------------------------------------------
383
+ # Gradio UI
384
+ # ---------------------------------------------------------------------------
385
+
386
+ with gr.Blocks(title="DynaCell Virtual Staining") as demo:
387
+ gr.Markdown("## DynaCell Virtual Staining Demo")
388
+ gr.Markdown(
389
+ "**Tab 1** runs virtual staining predictions (CELL-Diff / FNet3D / VSCyto3D) "
390
+ "on a phase-contrast OME-Zarr for a selected timepoint, and reports Spectral PCC. "
391
+ "**Tab 2** visualizes the CELL-Diff ODE denoising trajectory."
392
+ )
393
+
394
+ zarr_state = gr.State(value=None)
395
+ pred_info_state = gr.State(value=None)
396
+ traj_info_state = gr.State(value=None)
397
+
398
+ # ---- Data source row -------------------------------------------------
399
+ with gr.Row():
400
+ organelle = gr.Dropdown(
401
+ choices=[(ORGANELLE_LABELS[o], o) for o in ORGANELLES],
402
+ value="CAAX", label="Organelle",
403
+ info="Select the target organelle.",
404
+ )
405
+ load_demo_btn = gr.Button("Load Demo Data", variant="secondary", scale=1)
406
+ zarr_upload = gr.File(
407
+ label="Or upload your own zarr (.zip)",
408
+ file_types=[".zip"],
409
+ scale=2,
410
+ )
411
+
412
+ data_status = gr.Markdown("")
413
+
414
+ # ---- Tabs ------------------------------------------------------------
415
+ with gr.Tabs():
416
+
417
+ with gr.Tab("Virtual Staining"):
418
+ with gr.Row():
419
+ model_selector = gr.CheckboxGroup(
420
+ choices=[(MODEL_LABELS[m], m) for m in MODEL_KEYS],
421
+ value=MODEL_KEYS,
422
+ label="Models to run",
423
+ )
424
+ with gr.Row():
425
+ timepoint_slider = gr.Slider(
426
+ minimum=0, maximum=4, step=1, value=0,
427
+ label="Timepoint",
428
+ info="Range updates after loading data.",
429
+ )
430
+ z_slice_slider = gr.Slider(
431
+ minimum=0, maximum=99, step=1, value=15,
432
+ label="Z slice",
433
+ info="Range updates after loading data.",
434
+ )
435
+ run_btn = gr.Button("Run Predictions", variant="primary")
436
+ output_plot = gr.Plot(label="Predictions")
437
+ metrics_table = gr.Dataframe(
438
+ headers=["Model", "Spectral PCC"],
439
+ label="Spectral PCC (volumetric, vs experimental fluorescence)",
440
+ )
441
+
442
+ with gr.Tab("CELL-Diff Trajectory"):
443
+ gr.Markdown(
444
+ "Generate the CELL-Diff ODE denoising trajectory. "
445
+ "T=0 is pure Gaussian noise; T=N is the final predicted fluorescence. "
446
+ "After generating, change **Z slice** to instantly re-render the GIF "
447
+ "at a different slice without re-running the ODE."
448
+ )
449
+ with gr.Row():
450
+ traj_timepoint = gr.Slider(
451
+ minimum=0, maximum=4, step=1, value=0,
452
+ label="Timepoint",
453
+ info="Range updates after loading data.",
454
+ )
455
+ traj_z_slice = gr.Slider(
456
+ minimum=0, maximum=PATCH_D - 1, step=1, value=PATCH_D // 2,
457
+ label=f"Z slice (0–{PATCH_D - 1}, middle {PATCH_D} of full volume)",
458
+ )
459
+ traj_num_steps = gr.Slider(
460
+ minimum=10, maximum=100, step=10, value=50,
461
+ label="ODE steps",
462
+ )
463
+ traj_static = gr.Plot(label="Phase | Exp (reference)")
464
+ traj_btn = gr.Button("Generate Trajectory", variant="primary")
465
+ traj_gif = gr.Image(label="Animated trajectory (GIF)", type="filepath")
466
+
467
+ # ---- Event wiring ----------------------------------------------------
468
+
469
+ _data_outputs = [
470
+ zarr_state, data_status,
471
+ timepoint_slider, z_slice_slider,
472
+ traj_timepoint,
473
+ traj_static,
474
+ ]
475
+
476
+ load_demo_btn.click(
477
+ fn=load_demo_data,
478
+ inputs=[organelle],
479
+ outputs=_data_outputs,
480
+ )
481
+
482
+ zarr_upload.upload(
483
+ fn=on_upload,
484
+ inputs=[zarr_upload, organelle],
485
+ outputs=_data_outputs,
486
+ )
487
+
488
+ run_btn.click(
489
+ fn=run_demo,
490
+ inputs=[zarr_upload, organelle, model_selector, timepoint_slider, z_slice_slider, zarr_state],
491
+ outputs=[output_plot, metrics_table, zarr_state, pred_info_state],
492
+ )
493
+
494
+ z_slice_slider.change(
495
+ fn=render_from_z,
496
+ inputs=[pred_info_state, z_slice_slider, zarr_state],
497
+ outputs=[output_plot],
498
+ )
499
+
500
+ # Phase | Exp panel updates on any slider or organelle change
501
+ for _trigger in (traj_timepoint, traj_z_slice, organelle):
502
+ _trigger.change(
503
+ fn=render_phase_exp_traj,
504
+ inputs=[zarr_state, traj_timepoint, traj_z_slice, organelle],
505
+ outputs=[traj_static],
506
+ )
507
+
508
+ traj_btn.click(
509
+ fn=run_trajectory_demo,
510
+ inputs=[zarr_upload, organelle, traj_timepoint, traj_num_steps, traj_z_slice, zarr_state],
511
+ outputs=[traj_gif, zarr_state, traj_info_state],
512
+ )
513
+
514
+ # Re-render GIF from cached trajectory when Z slice changes (no ODE re-run)
515
+ traj_z_slice.change(
516
+ fn=rerender_gif,
517
+ inputs=[traj_info_state, traj_z_slice],
518
+ outputs=[traj_gif],
519
+ )
520
+
521
+ if __name__ == "__main__":
522
+ demo.launch()
config_templates/celldiff.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ class_path: dynacell.engine.DynacellFlowMatching
3
+ init_args:
4
+ net_config:
5
+ input_spatial_size: [8, 512, 512]
6
+ in_channels: 1
7
+ dims: [64, 128, 256, 256]
8
+ num_res_block: [2, 2, 2]
9
+ hidden_size: 512
10
+ num_heads: 8
11
+ dim_head: 64
12
+ num_hidden_layers: 8
13
+ patch_size: 4
14
+ transport_config:
15
+ path_type: Linear
16
+ prediction: velocity
17
+ num_generate_steps: 100
18
+ predict_method: iterative
19
+ predict_overlap: [4, 256, 256]
20
+ ckpt_path: {ckpt_path}
21
+
22
+ data:
23
+ class_path: viscy_data.hcs.HCSDataModule
24
+ init_args:
25
+ data_path: {data_path}
26
+ source_channel: Phase3D
27
+ target_channel: {target_channel}
28
+ z_window_size: 16
29
+ batch_size: 1
30
+ num_workers: 0
31
+ yx_patch_size: [512, 512]
32
+ normalizations:
33
+ - class_path: viscy_transforms.MinMaxSampled
34
+ init_args:
35
+ keys: [Phase3D]
36
+ level: timepoint_statistics
37
+ augmentations: []
38
+
39
+ trainer:
40
+ accelerator: gpu
41
+ strategy: auto
42
+ devices: 1
43
+ num_nodes: 1
44
+ precision: 32-true
45
+ callbacks:
46
+ - class_path: viscy_utils.callbacks.prediction_writer.HCSPredictionWriter
47
+ init_args:
48
+ output_store: {output_store}
49
+
50
+ return_predictions: false
config_templates/fnet3d.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ class_path: dynacell.engine.DynacellUNet
3
+ init_args:
4
+ architecture: FNet3D
5
+ model_config:
6
+ in_channels: 1
7
+ out_channels: 1
8
+ depth: 4
9
+ mult_chan: 32
10
+ in_stack_depth: 16
11
+ predict_method: full_image
12
+ predict_overlap: [4, 256, 256]
13
+ ckpt_path: {ckpt_path}
14
+
15
+ data:
16
+ class_path: viscy_data.hcs.HCSDataModule
17
+ init_args:
18
+ data_path: {data_path}
19
+ source_channel: Phase3D
20
+ target_channel: {target_channel}
21
+ z_window_size: 16
22
+ batch_size: 1
23
+ num_workers: 0
24
+ yx_patch_size: [512, 512]
25
+ normalizations:
26
+ - class_path: viscy_transforms.NormalizeSampled
27
+ init_args:
28
+ keys: [Phase3D]
29
+ level: fov_statistics
30
+ subtrahend: mean
31
+ divisor: std
32
+ augmentations: []
33
+
34
+ trainer:
35
+ accelerator: gpu
36
+ strategy: auto
37
+ devices: 1
38
+ num_nodes: 1
39
+ precision: 32-true
40
+ callbacks:
41
+ - class_path: viscy_utils.callbacks.prediction_writer.HCSPredictionWriter
42
+ init_args:
43
+ output_store: {output_store}
44
+
45
+ return_predictions: false
config_templates/vscyto3d.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ class_path: dynacell.engine.DynacellUNet
3
+ init_args:
4
+ architecture: fcmae
5
+ model_config:
6
+ in_channels: 1
7
+ out_channels: 1
8
+ in_stack_depth: 15
9
+ decoder_conv_blocks: 2
10
+ dims: [96, 192, 384, 768]
11
+ encoder_blocks: [3, 3, 9, 3]
12
+ encoder_drop_path_rate: 0.1
13
+ pretraining: false
14
+ stem_kernel_size: [5, 4, 4]
15
+ predict_method: full_image
16
+ predict_overlap: [4, 256, 256]
17
+ ckpt_path: {ckpt_path}
18
+
19
+ data:
20
+ class_path: viscy_data.hcs.HCSDataModule
21
+ init_args:
22
+ data_path: {data_path}
23
+ source_channel: Phase3D
24
+ target_channel: {target_channel}
25
+ z_window_size: 15
26
+ batch_size: 1
27
+ num_workers: 0
28
+ yx_patch_size: [512, 512]
29
+ normalizations:
30
+ - class_path: viscy_transforms.NormalizeSampled
31
+ init_args:
32
+ keys: [Phase3D]
33
+ level: fov_statistics
34
+ subtrahend: mean
35
+ divisor: std
36
+ augmentations: []
37
+
38
+ trainer:
39
+ accelerator: gpu
40
+ strategy: auto
41
+ devices: 1
42
+ num_nodes: 1
43
+ precision: 32-true
44
+ callbacks:
45
+ - class_path: viscy_utils.callbacks.prediction_writer.HCSPredictionWriter
46
+ init_args:
47
+ output_store: {output_store}
48
+
49
+ return_predictions: false
predict_runner.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download checkpoints from HF Hub, generate configs, run dynacell predict, and generate trajectories."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import shutil
7
+ import subprocess
8
+ import tempfile
9
+ import uuid
10
+ from pathlib import Path
11
+
12
+ import spaces
13
+ import zarr
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ CHECKPOINT_REPO = "biohub/dynacell-checkpoints"
17
+ TEMPLATE_DIR = Path(__file__).parent / "config_templates"
18
+
19
+ # (model, organelle) → filename in the HF checkpoint repo
20
+ CHECKPOINT_FILES: dict[tuple[str, str], str] = {
21
+ ("celldiff", "CAAX"): "celldiff_caax.ckpt",
22
+ ("celldiff", "H2B"): "celldiff_h2b.ckpt",
23
+ ("celldiff", "SEC61B"): "celldiff_sec61b.ckpt",
24
+ ("celldiff", "TOMM20"): "celldiff_tomm20.ckpt",
25
+ ("fnet3d", "CAAX"): "fnet3d_caax.ckpt",
26
+ ("fnet3d", "H2B"): "fnet3d_h2b.ckpt",
27
+ ("fnet3d", "SEC61B"): "fnet3d_sec61b.ckpt",
28
+ ("fnet3d", "TOMM20"): "fnet3d_tomm20.ckpt",
29
+ ("vscyto3d", "CAAX"): "vscyto3d_caax.ckpt",
30
+ ("vscyto3d", "H2B"): "vscyto3d_h2b.ckpt",
31
+ ("vscyto3d", "SEC61B"): "vscyto3d_sec61b.ckpt",
32
+ ("vscyto3d", "TOMM20"): "vscyto3d_tomm20.ckpt",
33
+ }
34
+
35
+ TARGET_CHANNELS: dict[str, str] = {
36
+ "CAAX": "Membrane",
37
+ "H2B": "Nuclei",
38
+ "SEC61B": "Structure",
39
+ "TOMM20": "Structure",
40
+ }
41
+
42
+ ORGANELLE_LABELS: dict[str, str] = {
43
+ "CAAX": "Membrane (CAAX)",
44
+ "H2B": "Chromatin (H2B)",
45
+ "SEC61B": "ER (SEC61B)",
46
+ "TOMM20": "Mitochondria (TOMM20)",
47
+ }
48
+
49
+ FLUOR_CH = 2 # channel index for fluorescence in the input zarr
50
+
51
+ # Cache downloaded checkpoints in /tmp so the Space doesn't re-download each run
52
+ _ckpt_cache: dict[str, str] = {}
53
+
54
+
55
+ def get_checkpoint(model: str, organelle: str) -> str:
56
+ """Download (or return cached) checkpoint path for a given model + organelle."""
57
+ key = (model, organelle)
58
+ filename = CHECKPOINT_FILES[key]
59
+ if filename not in _ckpt_cache:
60
+ print(f"Downloading {filename} from {CHECKPOINT_REPO} ...")
61
+ local = hf_hub_download(repo_id=CHECKPOINT_REPO, filename=filename)
62
+ _ckpt_cache[filename] = local
63
+ return _ckpt_cache[filename]
64
+
65
+
66
+ def preprocess_zarr(data_path: str) -> None:
67
+ """Compute normalization statistics for the uploaded zarr via viscy preprocess."""
68
+ subprocess.run(
69
+ ["viscy", "preprocess", f"--data_path={data_path}", "--num_workers=1", "--block_size=32"],
70
+ check=True,
71
+ )
72
+
73
+
74
+ def create_single_timepoint_zarr(source_path: str, timepoint: int) -> str:
75
+ """Copy source HCS zarr plate, keeping only the selected timepoint.
76
+
77
+ Remaps timepoint_statistics in .zattrs so index "0" carries the selected
78
+ timepoint's normalization stats (needed by celldiff's MinMaxSampled).
79
+ """
80
+ out_path = Path(tempfile.gettempdir()) / f"dynacell_t{timepoint}_{uuid.uuid4().hex[:8]}.zarr"
81
+ shutil.copytree(source_path, str(out_path))
82
+
83
+ src_store = zarr.open(source_path, mode="r")
84
+ dst_store = zarr.open(str(out_path), mode="r+")
85
+
86
+ def _trim(src_grp: zarr.Group, dst_grp: zarr.Group) -> None:
87
+ for key in list(src_grp.keys()):
88
+ item = src_grp[key]
89
+ if isinstance(item, zarr.Array) and key == "0":
90
+ # Write selected timepoint into index 0, then resize to T=1
91
+ dst_arr = dst_grp[key]
92
+ dst_arr[0] = item[timepoint]
93
+ dst_arr.resize((1,) + item.shape[1:])
94
+ elif isinstance(item, zarr.Group):
95
+ _trim(item, dst_grp[key])
96
+
97
+ _trim(src_store, dst_store)
98
+
99
+ # Remap timepoint_statistics["<timepoint>"] → ["0"] in each FOV's .zattrs
100
+ def _remap_tp_stats(zattrs_path: Path) -> None:
101
+ if not zattrs_path.exists():
102
+ return
103
+ zattrs = json.loads(zattrs_path.read_text())
104
+ norm = zattrs.get("normalization", {})
105
+ changed = False
106
+ for ch_data in norm.values():
107
+ if "timepoint_statistics" in ch_data:
108
+ tp_stats = ch_data["timepoint_statistics"]
109
+ t_key = str(timepoint)
110
+ if t_key in tp_stats:
111
+ ch_data["timepoint_statistics"] = {"0": tp_stats[t_key]}
112
+ changed = True
113
+ if changed:
114
+ zattrs_path.write_text(json.dumps(zattrs))
115
+
116
+ for row in out_path.iterdir():
117
+ if not row.is_dir():
118
+ continue
119
+ for col in row.iterdir():
120
+ if not col.is_dir():
121
+ continue
122
+ for fov in col.iterdir():
123
+ if fov.is_dir():
124
+ _remap_tp_stats(fov / ".zattrs")
125
+
126
+ return str(out_path)
127
+
128
+
129
+ @spaces.GPU(duration=120)
130
+ def run_prediction(model: str, organelle: str, data_path: str, timepoint: int) -> str:
131
+ """Run prediction for a single timepoint; return the output zarr path.
132
+
133
+ Creates a single-timepoint subset of the source zarr, runs prediction on it,
134
+ and returns the path to the output zarr (which has T=1). The `dynacell predict`
135
+ subprocess inherits the ZeroGPU allocation from this decorated frame.
136
+ """
137
+ subset_path = create_single_timepoint_zarr(data_path, timepoint)
138
+
139
+ ckpt_path = get_checkpoint(model, organelle)
140
+ output_dir = Path(tempfile.gettempdir()) / f"dynacell_pred_{uuid.uuid4().hex[:8]}"
141
+ output_store = str(output_dir / f"{organelle}_{model}.zarr")
142
+
143
+ template = (TEMPLATE_DIR / f"{model}.yaml").read_text()
144
+ config_text = template.format(
145
+ ckpt_path=ckpt_path,
146
+ data_path=subset_path,
147
+ output_store=output_store,
148
+ target_channel=TARGET_CHANNELS[organelle],
149
+ )
150
+
151
+ config_path = Path(tempfile.gettempdir()) / f"dynacell_cfg_{uuid.uuid4().hex[:8]}.yaml"
152
+ config_path.write_text(config_text)
153
+
154
+ print(f"Running dynacell predict: {model} / {organelle} / t={timepoint}")
155
+ subprocess.run(["dynacell", "predict", "-c", str(config_path)], check=True)
156
+ config_path.unlink(missing_ok=True)
157
+
158
+ return output_store
159
+
160
+
161
+ @spaces.GPU(duration=120)
162
+ def compute_trajectory(
163
+ organelle: str,
164
+ data_path: str,
165
+ timepoint: int = 0,
166
+ num_steps: int = 50,
167
+ progress=None,
168
+ ) -> dict:
169
+ """Run the CELL-Diff ODE; save trajectory to /tmp as .npy; return metadata dict.
170
+
171
+ The returned dict contains everything needed to call render_trajectory_gif
172
+ without re-running the ODE.
173
+ """
174
+ import numpy as np
175
+ import torch
176
+ from iohub.ngff import open_ome_zarr
177
+ from dynacell.engine import DynacellFlowMatching
178
+ from viscy_data._utils import _read_norm_meta
179
+
180
+ if progress is not None:
181
+ progress(0.05, desc="Downloading CELL-Diff checkpoint...")
182
+ ckpt_path = get_checkpoint("celldiff", organelle)
183
+
184
+ if progress is not None:
185
+ progress(0.15, desc="Loading model...")
186
+ device = "cuda" if torch.cuda.is_available() else "cpu"
187
+ model = DynacellFlowMatching.load_from_checkpoint(ckpt_path, map_location=device)
188
+ model.eval()
189
+ patch_d, patch_h, patch_w = model.model.net.input_spatial_size # (8, 512, 512)
190
+
191
+ if progress is not None:
192
+ progress(0.25, desc="Reading phase data...")
193
+ with open_ome_zarr(data_path, mode="r") as plate:
194
+ _, pos = next(plate.positions())
195
+ phase_ch = pos.get_channel_index("Phase3D")
196
+ phase_raw = np.array(pos.data[timepoint, phase_ch])
197
+ norm_meta = _read_norm_meta(pos)
198
+
199
+ tp_stats = norm_meta["Phase3D"]["timepoint_statistics"][str(timepoint)]
200
+ lo = tp_stats["p1"].item()
201
+ hi = tp_stats["p99"].item()
202
+ phase_norm = np.clip(phase_raw.astype(np.float32), lo, hi)
203
+ phase_norm = 2.0 * (phase_norm - lo) / (hi - lo + 1e-8) - 1.0
204
+
205
+ z_total = phase_norm.shape[0]
206
+ z_start = (z_total - patch_d) // 2
207
+ phase_crop = phase_norm[z_start:z_start + patch_d, :patch_h, :patch_w]
208
+
209
+ if progress is not None:
210
+ progress(0.35, desc=f"Generating {num_steps}-step ODE trajectory...")
211
+ phase_tensor = (
212
+ torch.from_numpy(phase_crop).float()
213
+ .unsqueeze(0).unsqueeze(0)
214
+ .to(device)
215
+ )
216
+ with torch.no_grad():
217
+ trajectory = model.model.generate_trajectory(phase_tensor, num_steps=num_steps)
218
+ traj_np = trajectory[:, 0].cpu().numpy().astype(np.float32) # (num_steps, 1, D, H, W)
219
+
220
+ if progress is not None:
221
+ progress(0.90, desc="Saving trajectory to disk...")
222
+ traj_path = str(Path(tempfile.gettempdir()) / f"traj_np_{uuid.uuid4().hex[:8]}.npy")
223
+ np.save(traj_path, traj_np)
224
+
225
+ if progress is not None:
226
+ progress(1.0, desc="Done.")
227
+
228
+ return {
229
+ "traj_path": traj_path,
230
+ "z_start": z_start,
231
+ "patch_d": patch_d,
232
+ "organelle": organelle,
233
+ "timepoint": timepoint,
234
+ "num_steps": num_steps,
235
+ }
236
+
237
+
238
+ def render_trajectory_gif(traj_info: dict, z_patch: int) -> str:
239
+ """Render a GIF from a cached trajectory at the given patch-relative Z slice.
240
+
241
+ Fast — only loads the saved .npy and runs matplotlib; does not re-run the ODE.
242
+ """
243
+ import numpy as np
244
+ import matplotlib.pyplot as plt
245
+ from matplotlib.animation import FuncAnimation, PillowWriter
246
+
247
+ traj_np = np.load(traj_info["traj_path"])
248
+ z_start = traj_info["z_start"]
249
+ patch_d = traj_info["patch_d"]
250
+ organelle = traj_info["organelle"]
251
+ timepoint = traj_info["timepoint"]
252
+ num_steps = traj_info["num_steps"]
253
+
254
+ z_patch = max(0, min(z_patch, patch_d - 1))
255
+ z_abs = z_start + z_patch
256
+
257
+ def pnorm(img: np.ndarray) -> np.ndarray:
258
+ lo_p, hi_p = np.percentile(img, [0.5, 99.5])
259
+ if hi_p == lo_p:
260
+ return np.zeros_like(img, dtype=np.float32)
261
+ return np.clip((img - lo_p) / (hi_p - lo_p), 0, 1).astype(np.float32)
262
+
263
+ frame_idx = np.linspace(0, num_steps - 1, min(50, num_steps), dtype=int)
264
+ fig_a, ax_a = plt.subplots(figsize=(4, 4))
265
+ ax_a.axis("off")
266
+ im = ax_a.imshow(
267
+ pnorm(traj_np[0, 0, z_patch]), cmap="gray", vmin=0, vmax=1, interpolation="nearest"
268
+ )
269
+ ttl = ax_a.set_title(
270
+ f"{ORGANELLE_LABELS[organelle]} t={timepoint} z={z_abs}\nStep 0 (noise → prediction)",
271
+ fontsize=9,
272
+ )
273
+
274
+ def update(frame: int):
275
+ s = frame_idx[frame]
276
+ im.set_data(pnorm(traj_np[s, 0, z_patch]))
277
+ ttl.set_text(
278
+ f"{ORGANELLE_LABELS[organelle]} t={timepoint} z={z_abs}\nStep {s} (noise → prediction)"
279
+ )
280
+ return im, ttl
281
+
282
+ anim = FuncAnimation(fig_a, update, frames=len(frame_idx), interval=80, blit=True)
283
+ gif_path = str(Path(tempfile.gettempdir()) / f"traj_{uuid.uuid4().hex[:8]}.gif")
284
+ anim.save(gif_path, writer=PillowWriter(fps=12))
285
+ plt.close(fig_a)
286
+ return gif_path
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0
2
+ spaces>=0.30
3
+ zarr>=2.16
4
+ numpy>=1.24
5
+ matplotlib>=3.7
6
+ torch>=2.1
7
+ huggingface_hub>=0.20
8
+ iohub>=0.1
9
+ cubic>=0.7.0a2
10
+ git+https://github.com/mehta-lab/VisCy.git@dynacell-models#subdirectory=packages/viscy-data
11
+ git+https://github.com/mehta-lab/VisCy.git@dynacell-models#subdirectory=packages/viscy-models
12
+ git+https://github.com/mehta-lab/VisCy.git@dynacell-models#subdirectory=packages/viscy-transforms
13
+ git+https://github.com/mehta-lab/VisCy.git@dynacell-models#subdirectory=packages/viscy-utils
14
+ git+https://github.com/mehta-lab/VisCy.git@dynacell-models#subdirectory=applications/dynacell