| import os |
| from typing import Dict, Tuple, List, Set |
|
|
| import cv2 |
| import gradio as gr |
| import huggingface_hub |
| import numpy as np |
| import onnxruntime as rt |
| import pandas as pd |
| import time |
| from PIL import Image |
|
|
| TITLE = "AI Video Auto-Tagger & Captioner" |
| DESCRIPTION = """ |
| Upload a .mp4 or .mov video, choose how often to sample frames, and generate |
| combined (deduplicated) tags using a selected **tagging/captioning model**. |
| |
| - Extract every N-th frame (e.g., every 10th frame). |
| - Control thresholds for **General Tags** and **Character Tags**. |
| - All tags from all sampled frames are merged into **one unique, comma-separated string**. |
| - Use the **Tag Control** tab to define tag substitutions and exclusions for the final output. |
| |
| **This space is running on the free CPU tier so it can be slow. If you want better speeds, clone the space and host it on more capable hardware.** |
| """ |
|
|
| DEFAULT_MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" |
|
|
| MODEL_OPTIONS = [ |
| "SmilingWolf/wd-eva02-large-tagger-v3", |
| "SmilingWolf/wd-vit-large-tagger-v3", |
| "SmilingWolf/wd-vit-tagger-v3", |
| "SmilingWolf/wd-convnext-tagger-v3", |
| "SmilingWolf/wd-swinv2-tagger-v3", |
| "deepghs/idolsankaku-eva02-large-tagger-v1", |
| "deepghs/idolsankaku-swinv2-tagger-v1", |
| "gokaygokay/Florence-2-SD3-Captioner", |
| "gokaygokay/Florence-2-Flux", |
| "gokaygokay/Florence-2-Flux-Large", |
| "MiaoshouAI/Florence-2-large-PromptGen-v2.0", |
| "thwri/CogFlorence-2.2-Large", |
| "deepghs/deepgelbooru_onnx", |
| ] |
|
|
| MODEL_FILENAME = "model.onnx" |
| LABEL_FILENAME = "selected_tags.csv" |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| |
| kaomojis = [ |
| "0_0", |
| "(o)_(o)", |
| "+_+", |
| "+_-", |
| "._.", |
| "<o>_<o>", |
| "<|>_<|>", |
| "=_=", |
| ">_<", |
| "3_3", |
| "6_9", |
| ">_o", |
| "@_@", |
| "^_^", |
| "o_o", |
| "u_u", |
| "x_x", |
| "|_|", |
| "||_||", |
| ] |
|
|
| css = """ |
| #tagging-tab-button, |
| #tag-control-tab-button { |
| font-weight: 900 !important; |
| } |
| #tagging-tab-button:hover, |
| #tag-control-tab-button:hover { |
| filter: brightness(0.9); |
| } |
| """ |
|
|
| def _format_duration(seconds: float) -> str: |
| """ |
| Format a duration in seconds as MM:SS or HH:MM:SS. |
| """ |
| total_seconds = int(round(seconds)) |
| hours, rem = divmod(total_seconds, 3600) |
| minutes, secs = divmod(rem, 60) |
|
|
| if hours > 0: |
| return f"{hours:02d}:{minutes:02d}:{secs:02d}" |
| else: |
| return f"{minutes:02d}:{secs:02d}" |
|
|
|
|
| def load_labels(df: pd.DataFrame): |
| """ |
| Convert tag dataframe into: |
| - tag_names (str list) |
| - rating_indexes (list[int]) |
| - general_indexes (list[int]) |
| - character_indexes (list[int]) |
| """ |
| name_series = df["name"] |
| name_series = name_series.map( |
| lambda x: x.replace("_", " ") if x not in kaomojis else x |
| ) |
| tag_names = name_series.tolist() |
|
|
| |
| |
| rating_indexes = list(np.where(df["category"] == 9)[0]) |
| general_indexes = list(np.where(df["category"] == 0)[0]) |
| character_indexes = list(np.where(df["category"] == 4)[0]) |
|
|
| return tag_names, rating_indexes, general_indexes, character_indexes |
|
|
|
|
| def add_substitute_row(current): |
| """ |
| Append an empty [original, substitute] row to the substitutes dataframe. |
| Works with type='array' (list of lists). |
| """ |
| if current is None: |
| current = [] |
| |
| current = list(current) |
| current.append(["", ""]) |
| return current |
|
|
|
|
| def add_exclusion_row(current): |
| """ |
| Append an empty [tag] row to the exclusions dataframe. |
| """ |
| if current is None: |
| current = [] |
| current = list(current) |
| current.append([""]) |
| return current |
|
|
| def compute_recommended_batch_size(sampled_frames: int) -> int: |
| """ |
| Heuristic batch-size recommendation based on how many frames |
| will actually be processed (after sampling). |
| |
| Tuned from your measurements: |
| - Small clips -> smaller batches |
| - Medium clips -> medium batches |
| - Larger clips -> larger batches, capped at 32 |
| """ |
| if sampled_frames <= 0: |
| return 8 |
|
|
| if sampled_frames <= 20: |
| rec = 8 |
| elif sampled_frames <= 40: |
| rec = 16 |
| elif sampled_frames <= 80: |
| rec = 24 |
| elif sampled_frames <= 160: |
| rec = 32 |
| else: |
| rec = 32 |
|
|
| |
| return max(1, min(32, rec)) |
|
|
| def update_batch_recommendation(video_path: str, frame_interval: int) -> str: |
| """ |
| Compute a recommended batch size based on the video length |
| and the current frame sampling interval, and return HTML |
| for the UI. |
| """ |
| if not video_path or not os.path.exists(video_path): |
| return "<span>Upload a video to see a recommended batch size.</span>" |
|
|
| try: |
| frame_interval = max(int(frame_interval), 1) |
| except Exception: |
| frame_interval = 1 |
|
|
| try: |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return "<span>Could not read video to estimate batch size.</span>" |
|
|
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 |
| cap.release() |
|
|
| if total_frames <= 0: |
| return "<span>Could not determine video length to recommend batch size.</span>" |
|
|
| sampled_frames = max(1, (total_frames + frame_interval - 1) // frame_interval) |
| rec = compute_recommended_batch_size(sampled_frames) |
|
|
| return ( |
| f"<span>Recommended batch size: <b>{rec}</b> " |
| f"(based on ~{sampled_frames} sampled frames).</span>" |
| ) |
| except Exception as e: |
| return f"<span>Could not compute recommendation: {e}</span>" |
|
|
| def show_batch_loading() -> str: |
| """ |
| Lightweight UI helper: show a pulsing 'calculating' message |
| while we compute the recommended batch size. |
| """ |
| return "<span class='batch-loading'>Calculating recommended batch size...</span>" |
|
|
|
|
| class VideoTagger: |
| """ |
| Wraps a WD-style ONNX model and tag metadata, |
| and exposes helpers to tag PIL images and full videos. |
| """ |
|
|
| def __init__(self, model_repo: str, batch_size: int = 16): |
| self.model_repo = model_repo |
| self.model = None |
| self.model_target_size = None |
| self.tag_names = None |
| self.rating_indexes = None |
| self.general_indexes = None |
| self.character_indexes = None |
| self.batch_size = batch_size |
|
|
| def _download_model_files(self) -> Tuple[str, str]: |
| csv_path = huggingface_hub.hf_hub_download( |
| repo_id=self.model_repo, |
| filename=LABEL_FILENAME, |
| token=HF_TOKEN, |
| ) |
| model_path = huggingface_hub.hf_hub_download( |
| repo_id=self.model_repo, |
| filename=MODEL_FILENAME, |
| token=HF_TOKEN, |
| ) |
| return csv_path, model_path |
|
|
| def _load_model_if_needed(self): |
| if self.model is not None: |
| return |
|
|
| csv_path, model_path = self._download_model_files() |
|
|
| tags_df = pd.read_csv(csv_path) |
| ( |
| self.tag_names, |
| self.rating_indexes, |
| self.general_indexes, |
| self.character_indexes, |
| ) = load_labels(tags_df) |
|
|
| |
| self.model = rt.InferenceSession(model_path) |
|
|
| |
| _, height, width, _ = self.model.get_inputs()[0].shape |
| assert height == width, "Model expects square inputs" |
| self.model_target_size = int(height) |
|
|
| def _prepare_image(self, image: Image.Image) -> np.ndarray: |
| """ |
| Convert a PIL image into the model's expected input tensor: |
| - RGBA composited onto white |
| - padded to square |
| - resized to model_target_size |
| - converted to BGR |
| - shape (1, H, W, 3), float32 |
| """ |
| target_size = self.model_target_size |
|
|
| |
| canvas = Image.new("RGBA", image.size, (255, 255, 255, 255)) |
| canvas.alpha_composite(image) |
| image_rgb = canvas.convert("RGB") |
|
|
| |
| w, h = image_rgb.size |
| max_dim = max(w, h) |
| pad_left = (max_dim - w) // 2 |
| pad_top = (max_dim - h) // 2 |
|
|
| padded = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
| padded.paste(image_rgb, (pad_left, pad_top)) |
|
|
| |
| if max_dim != target_size: |
| padded = padded.resize((target_size, target_size), Image.BICUBIC) |
|
|
| |
| arr = np.asarray(padded, dtype=np.float32) |
| arr = arr[:, :, ::-1] |
|
|
| |
| arr = np.expand_dims(arr, axis=0) |
| return arr |
|
|
| def _prepare_frame_bgr(self, frame_bgr: np.ndarray) -> np.ndarray: |
| """ |
| Fast path for OpenCV frames (BGR uint8). |
| Pads to square, resizes to model_target_size, converts to float32. |
| |
| Returns: (H, W, 3) float32 array in BGR format (no batch dim). |
| """ |
| target_size = self.model_target_size |
|
|
| h, w, _ = frame_bgr.shape |
| max_dim = max(h, w) |
|
|
| |
| pad_vert = max_dim - h |
| pad_horiz = max_dim - w |
| top = pad_vert // 2 |
| bottom = pad_vert - top |
| left = pad_horiz // 2 |
| right = pad_horiz - left |
|
|
| |
| frame_square = cv2.copyMakeBorder( |
| frame_bgr, |
| top, bottom, left, right, |
| borderType=cv2.BORDER_CONSTANT, |
| value=(255, 255, 255), |
| ) |
|
|
| |
| if max_dim != target_size: |
| frame_square = cv2.resize( |
| frame_square, |
| (target_size, target_size), |
| interpolation=cv2.INTER_AREA, |
| ) |
|
|
| |
| arr = frame_square.astype(np.float32) |
| return arr |
|
|
| def _run_batch_and_aggregate( |
| self, |
| batch_tensors: List[np.ndarray], |
| general_thresh: float, |
| character_thresh: float, |
| aggregated_general: Dict[str, float], |
| aggregated_character: Dict[str, float], |
| ) -> int: |
| """ |
| Run ONNX inference on a batch of preprocessed frames and |
| update aggregated_general / aggregated_character with max scores. |
| |
| Returns: number of frames processed in this batch. |
| """ |
| if not batch_tensors: |
| return 0 |
|
|
| input_name = self.model.get_inputs()[0].name |
| output_name = self.model.get_outputs()[0].name |
|
|
| |
| input_tensor = np.stack(batch_tensors, axis=0) |
|
|
| preds_batch = self.model.run([output_name], {input_name: input_tensor})[0] |
| |
|
|
| for preds in preds_batch: |
| general_res, character_res = self._extract_tags_from_scores( |
| preds, |
| general_thresh=general_thresh, |
| character_thresh=character_thresh, |
| ) |
|
|
| |
| for tag, score in general_res.items(): |
| if tag not in aggregated_general or score > aggregated_general[tag]: |
| aggregated_general[tag] = score |
|
|
| for tag, score in character_res.items(): |
| if tag not in aggregated_character or score > aggregated_character[tag]: |
| aggregated_character[tag] = score |
|
|
| return len(batch_tensors) |
|
|
| def tag_image( |
| self, |
| image: Image.Image, |
| general_thresh: float, |
| character_thresh: float, |
| ) -> Tuple[Dict[str, float], Dict[str, float]]: |
| """ |
| Tag a single frame (PIL image). |
| Returns: |
| general_res: {tag -> score} |
| character_res: {tag -> score} |
| """ |
| self._load_model_if_needed() |
|
|
| input_tensor = self._prepare_image(image) |
| input_name = self.model.get_inputs()[0].name |
| output_name = self.model.get_outputs()[0].name |
|
|
| preds = self.model.run([output_name], {input_name: input_tensor})[0] |
| preds = preds[0].astype(float) |
|
|
| labels = list(zip(self.tag_names, preds)) |
|
|
|
|
| |
| general_names = [labels[i] for i in self.general_indexes] |
| general_res = { |
| name: float(score) |
| for name, score in general_names |
| if score > general_thresh |
| } |
|
|
| |
| character_names = [labels[i] for i in self.character_indexes] |
| character_res = { |
| name: float(score) |
| for name, score in character_names |
| if score > character_thresh |
| } |
|
|
| return general_res, character_res |
|
|
| def _extract_tags_from_scores( |
| self, |
| preds: np.ndarray, |
| general_thresh: float, |
| character_thresh: float, |
| ) -> Tuple[Dict[str, float], Dict[str, float]]: |
| """ |
| Given a 1D preds array (num_tags,), return dicts of general/character tags. |
| More efficient than rebuilding label tuples every time. |
| """ |
| |
| preds = preds.astype(float) |
|
|
| general_res: Dict[str, float] = {} |
| character_res: Dict[str, float] = {} |
|
|
| |
| general_scores = preds[self.general_indexes] |
| general_idx_array = np.array(self.general_indexes) |
| general_mask = general_scores > general_thresh |
| for idx, score in zip(general_idx_array[general_mask], general_scores[general_mask]): |
| tag = self.tag_names[idx] |
| general_res[tag] = float(score) |
|
|
| |
| character_scores = preds[self.character_indexes] |
| character_idx_array = np.array(self.character_indexes) |
| character_mask = character_scores > character_thresh |
| for idx, score in zip(character_idx_array[character_mask], character_scores[character_mask]): |
| tag = self.tag_names[idx] |
| character_res[tag] = float(score) |
|
|
| return general_res, character_res |
|
|
| def tag_video( |
| self, |
| video_path: str, |
| frame_interval: int, |
| general_thresh: float, |
| character_thresh: float, |
| tag_substitutes: Dict[str, str], |
| tag_exclusions: Set[str], |
| progress=None, |
| ) -> Tuple[str, Dict]: |
| """ |
| Tag a video by sampling every N-th frame and aggregating tags. |
| """ |
| |
| if not video_path or not os.path.exists(video_path): |
| raise FileNotFoundError("Video file not found.") |
| |
| frame_interval = max(int(frame_interval), 1) |
| is_first_load = self.model is None |
| |
| if progress is not None: |
| progress(0.0, desc="Loading model..." if is_first_load else "Opening video...") |
| |
| |
| self._load_model_if_needed() |
| |
| if progress is not None and is_first_load: |
| progress(0.0, desc="Model loaded. Opening video...") |
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise RuntimeError("Unable to open video file.") |
| |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0 |
| if total_frames <= 0: |
| total_frames = 1 |
| |
| |
| sampled_frames = max(1, (total_frames + frame_interval - 1) // frame_interval) |
| total_batches = max(1, (sampled_frames + self.batch_size - 1) // self.batch_size) |
| recommended_batch = compute_recommended_batch_size(sampled_frames) |
| |
| aggregated_general: Dict[str, float] = {} |
| aggregated_character: Dict[str, float] = {} |
| |
| frame_idx = 0 |
| processed_frames = 0 |
| batch_tensors: List[np.ndarray] = [] |
| current_batch = 1 |
| |
| try: |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| if frame_idx % frame_interval == 0: |
| |
| batch_tensors.append(self._prepare_frame_bgr(frame)) |
| |
| |
| remaining_frames = sampled_frames - processed_frames |
| current_batch_size = min(self.batch_size, remaining_frames) |
| |
| |
| if progress is not None: |
| pct = processed_frames / sampled_frames |
| progress( |
| pct, |
| desc=( |
| f"Preparing batch {current_batch}/{total_batches} " |
| f"({len(batch_tensors)}/{current_batch_size} frames)" |
| ), |
| ) |
| |
| |
| if len(batch_tensors) >= self.batch_size: |
| if progress is not None: |
| beg = processed_frames + 1 |
| end = processed_frames + len(batch_tensors) |
| pct = processed_frames / sampled_frames |
| progress( |
| pct, |
| desc=( |
| f"Processing batch {current_batch}/{total_batches} " |
| f"(frames {beg}-{end}/{sampled_frames})" |
| ), |
| ) |
| |
| done = self._run_batch_and_aggregate( |
| batch_tensors, |
| general_thresh, |
| character_thresh, |
| aggregated_general, |
| aggregated_character, |
| ) |
| |
| processed_frames += done |
| batch_tensors = [] |
| if current_batch < total_batches: |
| current_batch += 1 |
| |
| if progress is not None: |
| pct = processed_frames / sampled_frames |
| progress( |
| pct, |
| desc=( |
| f"Completed batch {current_batch - 1}/{total_batches} " |
| f"({processed_frames}/{sampled_frames} frames processed)" |
| ), |
| ) |
| |
| frame_idx += 1 |
| |
| finally: |
| cap.release() |
| |
| |
| if batch_tensors: |
| if progress is not None: |
| beg = processed_frames + 1 |
| end = processed_frames + len(batch_tensors) |
| pct = processed_frames / sampled_frames |
| progress( |
| pct, |
| desc=( |
| f"Processing final batch {current_batch}/{total_batches} " |
| f"(frames {beg}-{end}/{sampled_frames})" |
| ), |
| ) |
| |
| done = self._run_batch_and_aggregate( |
| batch_tensors, |
| general_thresh, |
| character_thresh, |
| aggregated_general, |
| aggregated_character, |
| ) |
| processed_frames += done |
| |
| if progress is not None: |
| pct = processed_frames / sampled_frames |
| progress( |
| pct, |
| desc=( |
| f"Completed batch {current_batch}/{total_batches} " |
| f"({processed_frames}/{sampled_frames} frames processed)" |
| ), |
| ) |
| |
| if progress is not None: |
| progress(1.0, desc="Finalizing tags...") |
| |
| |
| all_tags_with_scores = {**aggregated_general, **aggregated_character} |
| |
| normalized_subs = {k.strip(): v.strip() for k, v in tag_substitutes.items() if k and v} |
| normalized_exclusions = {t.strip() for t in tag_exclusions if t} |
| |
| adjusted_all_tags: Dict[str, float] = {} |
| for tag, score in all_tags_with_scores.items(): |
| original_tag = tag.strip() |
| |
| if original_tag in normalized_exclusions: |
| continue |
| |
| new_tag = normalized_subs.get(original_tag, original_tag) |
| |
| if new_tag in normalized_exclusions: |
| continue |
| |
| if new_tag not in adjusted_all_tags or score > adjusted_all_tags[new_tag]: |
| adjusted_all_tags[new_tag] = score |
| |
| sorted_tags = sorted( |
| adjusted_all_tags.items(), |
| key=lambda kv: kv[1], |
| reverse=True, |
| ) |
| unique_tags = [tag for tag, _ in sorted_tags] |
| |
| combined_tags_str = ", ".join(unique_tags) |
| |
| debug_info = { |
| "model_repo": self.model_repo, |
| "frames_read": int(frame_idx), |
| "frames_processed": int(processed_frames), |
| "sampled_frames": int(sampled_frames), |
| "total_batches": int(total_batches), |
| "batch_size": int(self.batch_size), |
| "recommended_batch_size": int(recommended_batch), |
| "frame_interval": int(frame_interval), |
| "general_threshold": float(general_thresh), |
| "character_threshold": float(character_thresh), |
| "num_general_tags_raw": len(aggregated_general), |
| "num_character_tags_raw": len(aggregated_character), |
| "total_unique_tags_after_control": len(unique_tags), |
| "num_substitution_rules": len(normalized_subs), |
| "num_exclusions": len(normalized_exclusions), |
| } |
| |
| return combined_tags_str, debug_info |
|
|
|
|
| |
| _tagger_cache: Dict[str, VideoTagger] = {} |
|
|
|
|
| def get_tagger(model_repo: str, batch_size: int | None = None) -> VideoTagger: |
| """ |
| Lazily create and cache a VideoTagger per model repo. |
| Optionally update batch_size on an existing instance. |
| """ |
| tagger = _tagger_cache.get(model_repo) |
| if tagger is None: |
| |
| tagger = VideoTagger(model_repo=model_repo, batch_size=batch_size or 8) |
| _tagger_cache[model_repo] = tagger |
| else: |
| |
| if batch_size is not None: |
| tagger.batch_size = int(batch_size) |
|
|
| return tagger |
|
|
|
|
| def _normalize_tag_substitutes(data) -> Dict[str, str]: |
| """ |
| Convert Dataframe (as array: list[list]) into {original: substitute}. |
| """ |
| mapping: Dict[str, str] = {} |
| if data is None: |
| return mapping |
|
|
| |
| for row in data: |
| if not row or len(row) < 2: |
| continue |
| orig = (row[0] or "").strip() |
| sub = (row[1] or "").strip() |
| if orig and sub: |
| mapping[orig] = sub |
| return mapping |
|
|
|
|
| def _normalize_tag_exclusions(data) -> Set[str]: |
| """ |
| Convert Dataframe (as array: list[list]) into set of tags to exclude. |
| """ |
| exclusions: Set[str] = set() |
| if data is None: |
| return exclusions |
|
|
| |
| for row in data: |
| if row is None: |
| continue |
| if isinstance(row, (list, tuple)): |
| if not row: |
| continue |
| val = row[0] |
| else: |
| val = row |
| val = (val or "").strip() |
| if val: |
| exclusions.add(val) |
| return exclusions |
|
|
|
|
| def tag_video_interface( |
| video_path: str, |
| frame_interval: int, |
| general_thresh: float, |
| character_thresh: float, |
| model_repo: str, |
| tag_substitutes_df, |
| tag_exclusions_df, |
| batch_size: int, |
| progress=gr.Progress(track_tqdm=False), |
| ): |
| if video_path is None: |
| return "", {"error": "Please upload a video file."} |
|
|
| start_time = time.time() |
|
|
| try: |
| |
| |
| tagger = get_tagger(model_repo, batch_size=batch_size) |
|
|
| tag_substitutes = _normalize_tag_substitutes(tag_substitutes_df) |
| tag_exclusions = _normalize_tag_exclusions(tag_exclusions_df) |
|
|
| combined_tags_str, debug_info = tagger.tag_video( |
| video_path=video_path, |
| frame_interval=frame_interval, |
| general_thresh=general_thresh, |
| character_thresh=character_thresh, |
| tag_substitutes=tag_substitutes, |
| tag_exclusions=tag_exclusions, |
| progress=progress, |
| ) |
|
|
| elapsed = time.time() - start_time |
| debug_info["session_duration_seconds"] = round(elapsed, 3) |
| debug_info["session_duration_hms"] = _format_duration(elapsed) |
|
|
| return combined_tags_str, debug_info |
|
|
| except Exception as e: |
| return "", {"error": str(e)} |
|
|
|
|
| with gr.Blocks(title=TITLE) as demo: |
| |
| gr.HTML( |
| """ |
| <style> |
| .batch-loading { |
| animation: batchPulse 1.2s ease-in-out infinite; |
| color: #888888; |
| } |
| @keyframes batchPulse { |
| 0% { color: #666666; } |
| 50% { color: #bbbbbb; } |
| 100% { color: #666666; } |
| } |
| </style> |
| """ |
| ) |
|
|
| gr.Markdown(f"## {TITLE}") |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("Tagging", elem_id="tagging-tab"): |
| with gr.Row(): |
| with gr.Column(): |
| video_input = gr.Video( |
| label="Video (.mp4 or .mov)", |
| sources=["upload"], |
| format="mp4", |
| ) |
| |
| model_choice = gr.Dropdown( |
| choices=MODEL_OPTIONS, |
| value=DEFAULT_MODEL_REPO, |
| label="Tagging Model", |
| ) |
|
|
| general_thresh = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| step=0.01, |
| value=0.35, |
| label="General Tags Threshold", |
| ) |
| |
| character_thresh = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| step=0.01, |
| value=0.85, |
| label="Character Tags Threshold", |
| ) |
| |
| gr.Markdown("### Processing") |
| |
| frame_interval = gr.Slider( |
| minimum=1, |
| maximum=60, |
| step=1, |
| value=10, |
| label="Extract Every N Frames", |
| info="For example, 10 = use every 10th frame.", |
| ) |
| |
| batch_size = gr.Slider( |
| minimum=4, |
| maximum=64, |
| step=4, |
| value=12, |
| label="Batch Size", |
| info=( |
| "Larger batch sizes may increase initial loading time but can significantly " |
| "improve total processing speed, especially for longer videos or high frame counts." |
| ), |
| ) |
|
|
| batch_recommendation = gr.HTML( |
| "<span>Upload a video to see a recommended batch size.</span>" |
| ) |
| |
| run_button = gr.Button("Generate Tags", variant="primary") |
| |
| with gr.Column(): |
| combined_tags = gr.Textbox( |
| label="Combined Unique Tags (All Frames)", |
| lines=6, |
| buttons=["copy"], |
| ) |
| debug_info = gr.JSON( |
| label="Details / Debug Info", |
| ) |
|
|
|
|
| |
| with gr.Tab("Tag Control", elem_id="tag-control-tab"): |
| gr.Markdown("### Tag Substitutes") |
| gr.Markdown( |
| "Add rows where **Original Tag** will be replaced by **Substitute Tag** " |
| "in the final combined output (after all frames are processed)." |
| ) |
| |
| |
| with gr.Column(): |
| tag_substitutes_df = gr.Dataframe( |
| headers=["Original Tag", "Substitute Tag"], |
| datatype=["str", "str"], |
| row_count=1, |
| column_count=2, |
| type="array", |
| label="Tag Substitutes", |
| interactive=True, |
| ) |
| add_sub_row_btn = gr.Button("➕ Add substitute") |
| |
| gr.Markdown("### Tag Exclusions") |
| gr.Markdown( |
| "Add tags that should be **removed entirely** from the final combined output." |
| ) |
| |
| |
| with gr.Column(): |
| tag_exclusions_df = gr.Dataframe( |
| headers=["Tag to Exclude"], |
| datatype=["str"], |
| row_count=1, |
| column_count=1, |
| type="array", |
| label="Tag Exclusions", |
| interactive=True, |
| ) |
| add_ex_row_btn = gr.Button("➕ Add exclusion") |
|
|
|
|
| add_sub_row_btn.click( |
| fn=add_substitute_row, |
| inputs=tag_substitutes_df, |
| outputs=tag_substitutes_df, |
| ) |
| |
| add_ex_row_btn.click( |
| fn=add_exclusion_row, |
| inputs=tag_exclusions_df, |
| outputs=tag_exclusions_df, |
| ) |
|
|
| |
| video_input.change( |
| fn=show_batch_loading, |
| inputs=[], |
| outputs=batch_recommendation, |
| ).then( |
| fn=update_batch_recommendation, |
| inputs=[video_input, frame_interval], |
| outputs=batch_recommendation, |
| ) |
| |
| frame_interval.change( |
| fn=show_batch_loading, |
| inputs=[], |
| outputs=batch_recommendation, |
| ).then( |
| fn=update_batch_recommendation, |
| inputs=[video_input, frame_interval], |
| outputs=batch_recommendation, |
| ) |
|
|
| run_button.click( |
| fn=tag_video_interface, |
| inputs=[ |
| video_input, |
| frame_interval, |
| general_thresh, |
| character_thresh, |
| model_choice, |
| tag_substitutes_df, |
| tag_exclusions_df, |
| batch_size, |
| ], |
| outputs=[combined_tags, debug_info], |
| ) |
|
|
| custom_theme = gr.themes.Default( |
| primary_hue=gr.themes.colors.blue, |
| secondary_hue=gr.themes.colors.slate, |
| radius_size=gr.themes.sizes.radius_xxl, |
| font=[gr.themes.GoogleFont("Raleway")], |
| ) |
|
|
| |
| demo.queue(max_size=4).launch( |
| theme=custom_theme, |
| css=css, |
| ) |
|
|