multimodalart HF Staff commited on
Commit
2f54371
Β·
verified Β·
1 Parent(s): bda7caf

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +39 -7
  2. app.py +574 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,13 +1,45 @@
1
  ---
2
- title: Tasker Keyframe Extractor
3
- emoji: πŸš€
4
  colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.19.0
8
- python_version: '3.12'
9
  app_file: app.py
10
- pinned: false
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TASKER Keyframe Extractor
3
+ emoji: πŸ”
4
  colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 6.15.1
 
8
  app_file: app.py
9
+ short_description: VLM-guided tree-search keyframe extraction from videos
10
+ python_version: "3.12"
11
+ startup_duration_timeout: 30m
12
  ---
13
 
14
+ ## TASKER Keyframe Extractor
15
+
16
+ This Space demonstrates **TASKER** (**Ta**sk-driven **a**nd **S**cene-aware **Ke**yframe sea**r**cher), a keyframe extraction algorithm from the ECCV 2026 paper [Bridging VideoQA and Video-Guided Agentic Tasks via Generalized Keyframe Extraction](https://arxiv.org/abs/2606.29445).
17
+
18
+ ### How it works
19
+
20
+ TASKER reformulates keyframe extraction as a **generalized graph-search problem**:
21
+
22
+ 1. The input video is segmented into a tree of segments.
23
+ 2. A Vision-Language Model (Qwen2.5-VL-7B) evaluates which segments likely contain crucial missing actions.
24
+ 3. The selected segments are expanded (split at visual change points).
25
+ 4. Visual deduplication filters near-identical frames.
26
+ 5. The search terminates when the VLM is confident enough (confidence β‰₯ 3) or a frame limit is reached.
27
+
28
+ Four search strategies are available:
29
+ - **A\*** (default): balances goal-relevance and visual state changes
30
+ - **BFS**: broad exploration, can select multiple segments per step
31
+ - **GBFS**: greedy best-first, focuses on goal-critical actions
32
+ - **Dijkstra**: focuses on maximum visual state transitions
33
+
34
+ ### Usage
35
+
36
+ 1. Upload a video file
37
+ 2. Enter a task query (e.g., "How to send an email with an attachment?")
38
+ 3. Select a search strategy
39
+ 4. Click "Extract Keyframes"
40
+
41
+ The model returns a gallery of keyframes with timestamps and frame indices.
42
+
43
+ ### Model
44
+
45
+ Uses [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) as the VLM for segment evaluation, running on ZeroGPU.
app.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Expandable segments to avoid allocator fragmentation under memory spikes
4
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
5
+
6
+ import spaces # MUST be before any torch/CUDA import
7
+
8
+ import cv2
9
+ import re
10
+ import json
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from typing import List, Optional, Tuple
15
+ import tempfile
16
+ import gradio as gr
17
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
18
+
19
+ MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
20
+
21
+ # ── Load model at module scope (ZeroGPU rule 2) ──────────────────────────────
22
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
23
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
24
+ MODEL_ID,
25
+ torch_dtype=torch.bfloat16,
26
+ attn_implementation="sdpa",
27
+ ).to("cuda")
28
+
29
+
30
+ # ── VLM call helper ──────────────────────────────────────────────────────────
31
+
32
+ def vlm_call(images: List[Image.Image], question: str, system_prompt: str = "You are a highly strict UI navigation assistant designed to output JSON.") -> str:
33
+ """Call the local VLM with images and a question, return text response."""
34
+ content = []
35
+ for img in images:
36
+ content.append({"type": "image", "image": img})
37
+ content.append({"type": "text", "text": question})
38
+
39
+ messages = [
40
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
41
+ {"role": "user", "content": content},
42
+ ]
43
+
44
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+ inputs = processor(
46
+ text=[text],
47
+ images=[images] if images else None,
48
+ padding=True,
49
+ return_tensors="pt",
50
+ ).to("cuda")
51
+
52
+ with torch.no_grad():
53
+ output_ids = model.generate(**inputs, max_new_tokens=8192, do_sample=False, temperature=1.0)
54
+
55
+ # Trim the input tokens from output
56
+ input_len = inputs["input_ids"].shape[1]
57
+ output_text = processor.batch_decode(
58
+ output_ids[:, input_len:], skip_special_tokens=True
59
+ )[0]
60
+ return output_text
61
+
62
+
63
+ def parse_json_response(text: str):
64
+ """Extract a JSON object from a text response."""
65
+ try:
66
+ match = re.search(r'\{.*\}', text, re.DOTALL)
67
+ if match:
68
+ return json.loads(match.group(0))
69
+ except Exception:
70
+ pass
71
+ return None
72
+
73
+
74
+ # ── Video utilities ──────────────────────────────────────────────────────────
75
+
76
+ def extract_frame(video_path: str, frame_idx: int) -> Optional[Image.Image]:
77
+ """Extract a single frame from the video as PIL Image."""
78
+ cap = cv2.VideoCapture(video_path)
79
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
80
+ ret, frame = cap.read()
81
+ cap.release()
82
+ if not ret:
83
+ return None
84
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
85
+
86
+
87
+ def compute_color_histogram(img: Image.Image) -> np.ndarray:
88
+ """Compute a normalized 3-channel color histogram."""
89
+ arr = np.array(img)
90
+ hist = cv2.calcHist([arr], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256])
91
+ cv2.normalize(hist, hist)
92
+ return hist
93
+
94
+
95
+ def frame_similarity(hist1: np.ndarray, hist2: np.ndarray) -> float:
96
+ """Compare two color histograms using correlation."""
97
+ return float(cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL))
98
+
99
+
100
+ def is_frame_redundant(new_hist: np.ndarray, existing_hists: List[np.ndarray], threshold: float = 0.985) -> bool:
101
+ """Check if a new frame is too similar to existing ones."""
102
+ for h in existing_hists:
103
+ if frame_similarity(new_hist, h) >= threshold:
104
+ return True
105
+ return False
106
+
107
+
108
+ # ── TASKER core: A* tree search keyframe extraction ─────────────────────────
109
+
110
+ class VideoSeg:
111
+ """A video segment (tree node)."""
112
+ def __init__(self, start: int, end: int):
113
+ self.start = start
114
+ self.end = end
115
+
116
+
117
+ def find_visual_change_split_point(video_path: str, seg_start: int, seg_end: int) -> int:
118
+ """Find the frame with the largest visual change in a segment."""
119
+ midpoint = (seg_start + seg_end) // 2
120
+ try:
121
+ seg_length = seg_end - seg_start
122
+ if seg_length <= 2:
123
+ return midpoint
124
+
125
+ cap = cv2.VideoCapture(video_path)
126
+ num_samples = min(seg_length, 10)
127
+ step = max(1, seg_length // num_samples)
128
+ sample_indices = list(range(seg_start, seg_end, step))
129
+ if sample_indices[-1] != seg_end:
130
+ sample_indices.append(seg_end)
131
+
132
+ frames = {}
133
+ hists = {}
134
+ for idx in sample_indices:
135
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
136
+ ret, frame = cap.read()
137
+ if ret:
138
+ frames[idx] = frame
139
+ hist = cv2.calcHist([frame], [0, 1, 2], None, [50, 50, 50], [0, 256, 0, 256, 0, 256])
140
+ cv2.normalize(hist, hist)
141
+ hists[idx] = hist
142
+
143
+ if len(frames) < 2:
144
+ cap.release()
145
+ return midpoint
146
+
147
+ sorted_indices = sorted(frames.keys())
148
+ max_diff = -1
149
+ best_a, best_b = sorted_indices[0], sorted_indices[-1]
150
+ for i in range(len(sorted_indices) - 1):
151
+ idx_a, idx_b = sorted_indices[i], sorted_indices[i + 1]
152
+ if idx_a in hists and idx_b in hists:
153
+ diff = 1.0 - cv2.compareHist(hists[idx_a], hists[idx_b], cv2.HISTCMP_CORREL)
154
+ if diff > max_diff:
155
+ max_diff = diff
156
+ best_a, best_b = idx_a, idx_b
157
+
158
+ candidate = best_b
159
+ cap.release()
160
+
161
+ # Clamp to valid range
162
+ min_pos = seg_start + int(seg_length * 0.15)
163
+ max_pos = seg_start + int(seg_length * 0.85)
164
+ if candidate < min_pos or candidate > max_pos:
165
+ return midpoint
166
+ return candidate
167
+ except Exception:
168
+ return midpoint
169
+
170
+
171
+ def a_star_select_segment(images: List[Image.Image], goal: str, segment_des: str) -> str:
172
+ """A* strategy: balance goal-relevance and UI state changes."""
173
+ prompt = f"""You are provided with sequential images sampled from a video.
174
+ Each image is labeled with its frame index. The images are shown in chronological order.
175
+ Goal: {goal}
176
+
177
+ Candidate segments (gaps between current frames):
178
+ {segment_des}
179
+
180
+ (A* Strategy - Balance missing goal-relevant info and visual state changes)
181
+ Identify ONE single candidate segment that BEST satisfies BOTH conditions simultaneously:
182
+ 1. GOAL PROXIMITY: The segment likely contains crucial missing actions that are necessary steps toward achieving the Goal.
183
+ 2. STATE CHANGE MAGNITUDE: The segment whose boundary frames show the MOST different visual states is more likely to contain important operations.
184
+
185
+ Return JSON format: {{"frame_descriptions": [{{"segment_id": "1", "description": "Best A* candidate: missing goal step + visual state change"}}]}}
186
+ """
187
+ return vlm_call(images, prompt)
188
+
189
+
190
+ def qa_and_reflect(images: List[Image.Image], goal: str) -> Tuple[str, int]:
191
+ """Evaluate whether current frames are sufficient."""
192
+ prompt_qa = f"Task Goal: {goal}\nLook at these sequential frames. Describe the EXACT step-by-step actions that happen transitioning from one frame to the next."
193
+ answer = vlm_call(images, prompt_qa, system_prompt="You are a helpful video analysis assistant.")
194
+
195
+ prompt_eval = f"""Task Goal: {goal}
196
+ Your sequential analysis: {answer}
197
+
198
+ Evaluate your confidence level strictly:
199
+ 1: Severe Jumps (There are completely missing screens or sudden state changes. MUST expand.)
200
+ 2: Minor Disconnects (The flow makes sense, but some intermediate actions are missing. Should expand.)
201
+ 3: Strong Continuity (The frames capture all important actions and transitions. No key step is skipped.)
202
+
203
+ Output JSON exactly like this: {{"confidence": 3}}
204
+ """
205
+ conf_str = vlm_call(images, prompt_eval)
206
+ conf_json = parse_json_response(conf_str)
207
+ confidence = conf_json.get("confidence", 1) if conf_json else 1
208
+ return answer, int(confidence)
209
+
210
+
211
+ @spaces.GPU(duration=240)
212
+ def extract_keyframes(video_path: str, goal: str, search_strategy: str = "a_star", max_frames: int = 10, min_frames: int = 6, min_steps: int = 3, conf_lower: int = 3, progress=gr.Progress()):
213
+ """
214
+ TASKER keyframe extraction: tree-search with VLM-guided segment selection.
215
+
216
+ Args:
217
+ video_path: Path to the input video.
218
+ goal: Task query describing what the user wants to see.
219
+ search_strategy: One of "a_star", "bfs", "gbfs", "dijkstra".
220
+ max_frames: Maximum number of keyframes to extract.
221
+ min_frames: Minimum number of frames before confidence check can stop.
222
+ min_steps: Minimum expansion steps before confidence check can stop.
223
+ conf_lower: Confidence threshold (1-3) to stop searching.
224
+ Returns:
225
+ List of (PIL Image, caption) tuples for gallery display, plus a summary string.
226
+ """
227
+ cap = cv2.VideoCapture(video_path)
228
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
229
+ fps = cap.get(cv2.CAP_PROP_FPS)
230
+ cap.release()
231
+
232
+ if num_frames <= 0 or fps <= 0:
233
+ return [], "Error: Could not read video file. Please upload a valid video."
234
+
235
+ # ── Initial uniform sampling ─────────────────────────────────────────────
236
+ init_frames = 4
237
+ content_start = 0
238
+ content_end = num_frames - 1
239
+
240
+ if content_end - content_start + 1 <= init_frames:
241
+ sample_idx = list(range(content_start, content_end + 1))
242
+ else:
243
+ interval = max(1, (content_end - content_start + 1) // (init_frames - 1))
244
+ sample_idx = list(range(content_start, content_end + 1, interval))
245
+ if sample_idx[-1] != content_end:
246
+ sample_idx.append(content_end)
247
+
248
+ progress(0.1, desc=f"Initial sampling: {len(sample_idx)} frames from {num_frames} total")
249
+
250
+ video_segments = [VideoSeg(sample_idx[i-1], sample_idx[i]) for i in range(1, len(sample_idx))]
251
+
252
+ # Histogram cache for dedup
253
+ hist_cache = {}
254
+ frozen_segments = set()
255
+ effective_step = 0
256
+ last_confidence = 0
257
+
258
+ max_total_attempts = max_frames + 10
259
+
260
+ for attempt in range(1, max_total_attempts + 1):
261
+ current_frames = len(sample_idx)
262
+ if current_frames >= max_frames:
263
+ break
264
+
265
+ # Extract current frames as images
266
+ images = []
267
+ for idx in sample_idx:
268
+ img = extract_frame(video_path, idx)
269
+ if img is not None:
270
+ images.append(img)
271
+
272
+ if not images:
273
+ break
274
+
275
+ progress(
276
+ 0.1 + 0.6 * (attempt / max_total_attempts),
277
+ desc=f"Step {attempt}: {current_frames} frames, evaluating..."
278
+ )
279
+
280
+ # Confidence check
281
+ if current_frames >= min_frames and effective_step > min_steps:
282
+ _, confidence = qa_and_reflect(images, goal)
283
+ last_confidence = confidence
284
+ if confidence >= conf_lower:
285
+ break
286
+ else:
287
+ if current_frames < min_frames:
288
+ pass # forced expansion
289
+
290
+ # Build segment descriptions
291
+ frame_to_img_idx = {frame: i + 1 for i, frame in enumerate(sample_idx)}
292
+ segment_des_lines = []
293
+ for i, seg in enumerate(video_segments):
294
+ seg_id = i + 1
295
+ if (seg.start, seg.end) in frozen_segments:
296
+ continue
297
+ start_img = frame_to_img_idx.get(seg.start, "?")
298
+ end_img = frame_to_img_idx.get(seg.end, "?")
299
+ segment_des_lines.append(
300
+ f" Segment {seg_id}: frames {seg.start}-{seg.end} (Image #{start_img} -> Image #{end_img})"
301
+ )
302
+ segment_des_str = "\n".join(segment_des_lines)
303
+
304
+ if not segment_des_str:
305
+ break
306
+
307
+ # VLM segment selection
308
+ try:
309
+ if search_strategy == "bfs":
310
+ response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
311
+ Goal: {goal}
312
+ Candidate segments:
313
+ {segment_des_str}
314
+ Select MULTIPLE segments that likely contain crucial missing actions.
315
+ Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
316
+ elif search_strategy == "gbfs":
317
+ response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
318
+ Goal: {goal}
319
+ Candidate segments:
320
+ {segment_des_str}
321
+ Select the SINGLE segment MOST LIKELY to contain crucial missing actions.
322
+ Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
323
+ elif search_strategy == "dijkstra":
324
+ response = vlm_call(images, f"""You are provided with sequential images sampled from a video.
325
+ Candidate segments:
326
+ {segment_des_str}
327
+ Select the SINGLE segment with the MOST significant visual state transition.
328
+ Return JSON: {{"frame_descriptions": [{{"segment_id": "1", "description": "..."}}]}}""")
329
+ else: # a_star
330
+ response = a_star_select_segment(images, goal, segment_des_str)
331
+
332
+ parsed = parse_json_response(response)
333
+ except Exception as e:
334
+ print(f"VLM call error at step {attempt}: {e}")
335
+ parsed = None
336
+
337
+ # Determine selected segment IDs
338
+ selected_seg_ids = set()
339
+ if parsed and "frame_descriptions" in parsed:
340
+ for desc in parsed["frame_descriptions"]:
341
+ for key in desc:
342
+ if key.lower() == "segment_id":
343
+ val = str(desc[key]).strip()
344
+ nums = re.findall(r'\d+', val)
345
+ if nums:
346
+ seg_id = int(nums[0])
347
+ if 1 <= seg_id <= len(video_segments):
348
+ selected_seg_ids.add(seg_id)
349
+ break
350
+
351
+ # Fallback: pick longest segment
352
+ if not selected_seg_ids:
353
+ longest_seg_id = None
354
+ longest_len = 0
355
+ for i, seg in enumerate(video_segments):
356
+ seg_len = seg.end - seg.start
357
+ if seg_len > longest_len and seg_len > 1 and (seg.start, seg.end) not in frozen_segments:
358
+ longest_len = seg_len
359
+ longest_seg_id = i + 1
360
+ if longest_seg_id is not None:
361
+ selected_seg_ids.add(longest_seg_id)
362
+
363
+ if not selected_seg_ids:
364
+ break
365
+
366
+ # BFS quota limit
367
+ if search_strategy == "bfs" and len(selected_seg_ids) > 1:
368
+ remaining_quota = max_frames - len(sample_idx)
369
+ if remaining_quota <= 0:
370
+ break
371
+ if len(selected_seg_ids) > remaining_quota:
372
+ sorted_seg_ids = sorted(selected_seg_ids,
373
+ key=lambda sid: video_segments[sid-1].end - video_segments[sid-1].start,
374
+ reverse=True)
375
+ selected_seg_ids = set(sorted_seg_ids[:remaining_quota])
376
+
377
+ # Split selected segments
378
+ split_origin = {}
379
+ new_segments = []
380
+ seg_counter = 0
381
+ for i, seg in enumerate(video_segments):
382
+ seg_id = i + 1
383
+ if seg_id in selected_seg_ids:
384
+ if seg.end - seg.start <= 1:
385
+ seg_counter += 1
386
+ new_segments.append(VideoSeg(seg.start, seg.end))
387
+ else:
388
+ sp = find_visual_change_split_point(video_path, seg.start, seg.end)
389
+ split_origin[sp] = (seg.start, seg.end)
390
+ seg_counter += 1
391
+ new_segments.append(VideoSeg(seg.start, sp))
392
+ seg_counter += 1
393
+ new_segments.append(VideoSeg(sp, seg.end))
394
+ else:
395
+ seg_counter += 1
396
+ new_segments.append(VideoSeg(seg.start, seg.end))
397
+ video_segments = new_segments
398
+
399
+ # Rebuild sample_idx
400
+ sample_idx_set = set()
401
+ for seg in video_segments:
402
+ sample_idx_set.add(seg.start)
403
+ sample_idx_set.add(seg.end)
404
+ new_sample_idx = sorted(list(sample_idx_set))
405
+
406
+ # Visual deduplication
407
+ new_frames = [idx for idx in new_sample_idx if idx not in set(sample_idx)]
408
+ old_sample_set = set(sample_idx)
409
+
410
+ # Compute histograms for old frames
411
+ old_hists = []
412
+ for idx in sample_idx:
413
+ img = extract_frame(video_path, idx)
414
+ if img is not None:
415
+ old_hists.append(compute_color_histogram(img))
416
+
417
+ frames_to_remove = []
418
+ accepted_new_hists = []
419
+ for new_idx in new_frames:
420
+ new_img = extract_frame(video_path, new_idx)
421
+ if new_img is None:
422
+ continue
423
+ new_hist = compute_color_histogram(new_img)
424
+ all_compare_hists = old_hists + accepted_new_hists
425
+
426
+ if is_frame_redundant(new_hist, all_compare_hists, threshold=0.985):
427
+ frames_to_remove.append(new_idx)
428
+ if new_idx in split_origin:
429
+ frozen_segments.add(split_origin[new_idx])
430
+ else:
431
+ accepted_new_hists.append(new_hist)
432
+
433
+ if frames_to_remove:
434
+ new_sample_idx = [idx for idx in new_sample_idx if idx not in frames_to_remove]
435
+ new_sample_idx = sorted(new_sample_idx)
436
+ video_segments = [VideoSeg(new_sample_idx[i-1], new_sample_idx[i])
437
+ for i in range(1, len(new_sample_idx))]
438
+
439
+ actually_added = len(new_sample_idx) > len(sample_idx)
440
+ sample_idx = new_sample_idx
441
+
442
+ if actually_added:
443
+ effective_step += 1
444
+
445
+ progress(0.85, desc="Finalizing keyframes...")
446
+
447
+ # Force-fill if too few frames
448
+ if len(sample_idx) < min_frames and last_confidence < conf_lower:
449
+ max_force = min_frames + 5
450
+ for _ in range(max_force):
451
+ if len(sample_idx) >= min_frames:
452
+ break
453
+ max_gap = 0
454
+ max_gap_idx = 0
455
+ for i in range(len(sample_idx) - 1):
456
+ if (sample_idx[i], sample_idx[i+1]) in frozen_segments:
457
+ continue
458
+ gap = sample_idx[i+1] - sample_idx[i]
459
+ if gap > max_gap:
460
+ max_gap = gap
461
+ max_gap_idx = i
462
+ if max_gap <= 1:
463
+ break
464
+ sp = find_visual_change_split_point(video_path, sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1])
465
+ sp_img = extract_frame(video_path, sp)
466
+ if sp_img is None:
467
+ break
468
+ sp_hist = compute_color_histogram(sp_img)
469
+ existing_hists = []
470
+ for idx in sample_idx:
471
+ img = extract_frame(video_path, idx)
472
+ if img is not None:
473
+ existing_hists.append(compute_color_histogram(img))
474
+ if is_frame_redundant(sp_hist, existing_hists, threshold=0.985):
475
+ frozen_segments.add((sample_idx[max_gap_idx], sample_idx[max_gap_idx + 1]))
476
+ continue
477
+ sample_idx.insert(max_gap_idx + 1, sp)
478
+
479
+ # Extract final keyframes
480
+ progress(0.95, desc="Extracting final keyframes...")
481
+
482
+ gallery = []
483
+ for i, idx in enumerate(sample_idx):
484
+ img = extract_frame(video_path, idx)
485
+ if img is not None:
486
+ timestamp = idx / fps if fps > 0 else 0
487
+ mins = int(timestamp // 60)
488
+ secs = int(timestamp % 60)
489
+ percent = (idx / max(1, num_frames - 1)) * 100
490
+ caption = f"Frame {i+1}/{len(sample_idx)} | idx={idx} | {mins:02d}:{secs:02d} | {percent:.1f}%"
491
+ gallery.append((img, caption))
492
+
493
+ summary = (
494
+ f"**TASKER {search_strategy.upper()}** extracted **{len(gallery)}** keyframes "
495
+ f"from {num_frames} total frames ({num_frames/fps:.1f}s video).\n\n"
496
+ f"Search stats: {effective_step} effective expansion steps, "
497
+ f"confidence={last_confidence}/3, "
498
+ f"target range {min_frames}-{max_frames} frames."
499
+ )
500
+
501
+ progress(1.0, desc="Done!")
502
+ return gallery, summary
503
+
504
+
505
+ # ── Gradio UI ───────────────────────────────────────────────────────────────
506
+
507
+ CUSTOM_CSS = """
508
+ #header { text-align: center; margin-bottom: 20px; }
509
+ #header h1 { font-size: 2em; margin-bottom: 5px; }
510
+ #header p { color: #666; font-size: 1.1em; }
511
+ """
512
+
513
+ with gr.Blocks(css=CUSTOM_CSS, title="TASKER Keyframe Extractor") as demo:
514
+ gr.HTML("""
515
+ <div id="header">
516
+ <h1>TASKER: Task-driven and Scene-aware Keyframe Search</h1>
517
+ <p>Extract task-relevant keyframes from a video using VLM-guided tree search (A* / BFS / GBFS / Dijkstra)</p>
518
+ </div>
519
+ """)
520
+
521
+ with gr.Row():
522
+ with gr.Column(scale=1):
523
+ video_input = gr.Video(label="Upload Video", sources=["upload"])
524
+ goal_input = gr.Textbox(
525
+ label="Task Query / Goal",
526
+ placeholder="e.g., How to send an email with an attachment?",
527
+ lines=2,
528
+ )
529
+ strategy_input = gr.Dropdown(
530
+ choices=["a_star", "bfs", "gbfs", "dijkstra"],
531
+ value="a_star",
532
+ label="Search Strategy",
533
+ info="A* balances goal-relevance and visual changes. BFS explores broadly. GBFS focuses on goal. Dijkstra focuses on visual changes.",
534
+ )
535
+ with gr.Accordion("Advanced Settings", open=False):
536
+ max_frames_slider = gr.Slider(4, 16, value=10, step=1, label="Max Keyframes")
537
+ min_frames_slider = gr.Slider(2, 8, value=6, step=1, label="Min Keyframes (before confidence check)")
538
+ min_steps_slider = gr.Slider(1, 8, value=3, step=1, label="Min Search Steps")
539
+ conf_slider = gr.Slider(1, 3, value=3, step=1, label="Confidence Threshold (3=strictest)")
540
+
541
+ extract_btn = gr.Button("Extract Keyframes", variant="primary")
542
+
543
+ with gr.Column(scale=2):
544
+ summary_output = gr.Markdown(label="Summary")
545
+ gallery_output = gr.Gallery(
546
+ label="Extracted Keyframes",
547
+ columns=3,
548
+ height=600,
549
+ object_fit="contain",
550
+ )
551
+
552
+ gr.Examples(
553
+ examples=[
554
+ ["https://huggingface.co/datasets/hugging-apps/tasker-demo-videos/resolve/main/cooking_demo.mp4",
555
+ "Show the steps to cook pasta"],
556
+ ],
557
+ inputs=[video_input, goal_input],
558
+ )
559
+
560
+ extract_btn.click(
561
+ fn=extract_keyframes,
562
+ inputs=[
563
+ video_input,
564
+ goal_input,
565
+ strategy_input,
566
+ max_frames_slider,
567
+ min_frames_slider,
568
+ min_steps_slider,
569
+ conf_slider,
570
+ ],
571
+ outputs=[gallery_output, summary_output],
572
+ )
573
+
574
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torchvision
2
+ transformers
3
+ opencv-python-headless
4
+ pillow
5
+ numpy