Kapanther commited on
Commit
4c962fa
·
1 Parent(s): 63cf941

Initial deployment with LFS for fonts

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +19 -5
  2. app.py +35 -0
  3. core/__init__.py +43 -0
  4. core/__pycache__/__init__.cpython-311.pyc +0 -0
  5. core/__pycache__/__init__.cpython-314.pyc +0 -0
  6. core/__pycache__/caching.cpython-311.pyc +0 -0
  7. core/__pycache__/caching.cpython-314.pyc +0 -0
  8. core/__pycache__/config.cpython-311.pyc +0 -0
  9. core/__pycache__/llm_defaults.cpython-311.pyc +0 -0
  10. core/__pycache__/outside_text_processor.cpython-311.pyc +0 -0
  11. core/__pycache__/pipeline.cpython-311.pyc +0 -0
  12. core/__pycache__/scaling.cpython-311.pyc +0 -0
  13. core/__pycache__/scaling.cpython-314.pyc +0 -0
  14. core/__pycache__/validation.cpython-311.pyc +0 -0
  15. core/caching.py +584 -0
  16. core/config.py +240 -0
  17. core/image/__init__.py +42 -0
  18. core/image/__pycache__/__init__.cpython-311.pyc +0 -0
  19. core/image/__pycache__/__init__.cpython-314.pyc +0 -0
  20. core/image/__pycache__/cleaning.cpython-311.pyc +0 -0
  21. core/image/__pycache__/cleaning.cpython-314.pyc +0 -0
  22. core/image/__pycache__/detection.cpython-311.pyc +0 -0
  23. core/image/__pycache__/detection.cpython-314.pyc +0 -0
  24. core/image/__pycache__/image_utils.cpython-311.pyc +0 -0
  25. core/image/__pycache__/image_utils.cpython-314.pyc +0 -0
  26. core/image/__pycache__/inpainting.cpython-311.pyc +0 -0
  27. core/image/__pycache__/ocr_detection.cpython-311.pyc +0 -0
  28. core/image/__pycache__/sorting.cpython-311.pyc +0 -0
  29. core/image/cleaning.py +849 -0
  30. core/image/detection.py +914 -0
  31. core/image/image_utils.py +779 -0
  32. core/image/inpainting.py +773 -0
  33. core/image/ocr_detection.py +730 -0
  34. core/image/sorting.py +359 -0
  35. core/llm_defaults.py +31 -0
  36. core/ml/__init__.py +14 -0
  37. core/ml/__pycache__/__init__.cpython-311.pyc +0 -0
  38. core/ml/__pycache__/__init__.cpython-314.pyc +0 -0
  39. core/ml/__pycache__/model_manager.cpython-311.pyc +0 -0
  40. core/ml/__pycache__/model_manager.cpython-314.pyc +0 -0
  41. core/ml/model_manager.py +854 -0
  42. core/outside_text_processor.py +638 -0
  43. core/pipeline.py +1295 -0
  44. core/scaling.py +109 -0
  45. core/services/__init__.py +20 -0
  46. core/services/__pycache__/__init__.cpython-311.pyc +0 -0
  47. core/services/__pycache__/translation.cpython-311.pyc +0 -0
  48. core/services/translation.py +1385 -0
  49. core/text/__init__.py +49 -0
  50. core/text/__pycache__/__init__.cpython-311.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,27 @@
1
  ---
2
  title: Manga Translator
3
- emoji: 🦀
4
  colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Manga Translator
3
+ emoji: 📖
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # MangaTranslator
13
+
14
+ AI-powered manga/comic translation tool. Upload manga pages and get them translated automatically!
15
+
16
+ ## Features
17
+ - Speech bubble detection and cleaning
18
+ - LLM-powered OCR and translation (54 languages)
19
+ - Automatic text rendering with custom fonts
20
+
21
+ ## Usage
22
+ 1. Go to the **Config** tab and enter your LLM API key (Google, OpenRouter, etc.)
23
+ 2. Upload a manga image
24
+ 3. Click **Translate**
25
+
26
+ ## Note
27
+ This is running on CPU (free tier), so translations may take 1-3 minutes per page.
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ # Set environment before importing torch
5
+ os.environ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:512"
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ import core
11
+ from ui import layout
12
+
13
+ # Directories
14
+ MODELS_DIR = Path("./models")
15
+ FONTS_BASE_DIR = Path("./fonts")
16
+
17
+ os.makedirs(MODELS_DIR, exist_ok=True)
18
+ os.makedirs(FONTS_BASE_DIR, exist_ok=True)
19
+
20
+ # Force CPU for HF Spaces free tier
21
+ target_device = torch.device("cpu")
22
+
23
+ print(f"Using device: CPU")
24
+ print(f"PyTorch version: {torch.__version__}")
25
+ print(f"MangaTranslator version: {core.__version__}")
26
+
27
+ # Create and launch the Gradio app
28
+ app = layout.create_layout(
29
+ models_dir=MODELS_DIR,
30
+ fonts_base_dir=FONTS_BASE_DIR,
31
+ target_device=target_device,
32
+ )
33
+
34
+ app.queue()
35
+ app.launch()
core/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MangaTranslator Core Package
3
+
4
+ This package contains the core functionality for translating manga/comic speech bubbles.
5
+ It uses YOLO for speech bubble detection and LLMs for text translation.
6
+ """
7
+
8
+ from .caching import UnifiedCache, get_cache
9
+ from .image.cleaning import clean_speech_bubbles
10
+ from .image.detection import detect_speech_bubbles
11
+ from .image.image_utils import cv2_to_pil, pil_to_cv2, save_image_with_compression
12
+ from .image.inpainting import FluxKontextInpainter
13
+ from .image.ocr_detection import OutsideTextDetector
14
+ from .ml.model_manager import ModelManager, get_model_manager
15
+ from .pipeline import batch_translate_images, translate_and_render
16
+ from .services.translation import call_translation_api_batch
17
+ from .image.sorting import sort_bubbles_by_reading_order
18
+ from .text.text_renderer import render_text_skia
19
+
20
+ __version__ = "1.10.5"
21
+ __version_info__ = (1, 10, 5)
22
+ __author__ = "grinnch"
23
+ __copyright__ = "Copyright 2025-present grinnch"
24
+ __license__ = "Apache-2.0"
25
+ __description__ = "A tool for translating manga pages using AI"
26
+ __all__ = [
27
+ "get_cache",
28
+ "UnifiedCache",
29
+ "translate_and_render",
30
+ "batch_translate_images",
31
+ "render_text_skia",
32
+ "detect_speech_bubbles",
33
+ "clean_speech_bubbles",
34
+ "call_translation_api_batch",
35
+ "sort_bubbles_by_reading_order",
36
+ "pil_to_cv2",
37
+ "cv2_to_pil",
38
+ "save_image_with_compression",
39
+ "get_model_manager",
40
+ "ModelManager",
41
+ "OutsideTextDetector",
42
+ "FluxKontextInpainter",
43
+ ]
core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.77 kB). View file
 
core/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.55 kB). View file
 
core/__pycache__/caching.cpython-311.pyc ADDED
Binary file (28.9 kB). View file
 
core/__pycache__/caching.cpython-314.pyc ADDED
Binary file (31.7 kB). View file
 
core/__pycache__/config.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
core/__pycache__/llm_defaults.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
core/__pycache__/outside_text_processor.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
core/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (44.3 kB). View file
 
core/__pycache__/scaling.cpython-311.pyc ADDED
Binary file (4.56 kB). View file
 
core/__pycache__/scaling.cpython-314.pyc ADDED
Binary file (5.54 kB). View file
 
core/__pycache__/validation.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
core/caching.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import pickle
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from utils.logging import log_message
9
+
10
+
11
+ class UnifiedCache:
12
+ """Unified cache for various MangaTranslator operations."""
13
+
14
+ def __init__(self):
15
+ """Initialize the unified cache."""
16
+ from core.text.font_manager import LRUCache
17
+
18
+ self._yolo_cache = LRUCache(max_size=1)
19
+ self._sam_cache = LRUCache(max_size=1)
20
+ self._translation_cache = LRUCache(max_size=1)
21
+ self._manga_ocr_cache = LRUCache(max_size=20)
22
+ self._upscale_cache = LRUCache(max_size=20)
23
+ self._inpaint_cache = LRUCache(max_size=20)
24
+ self._current_image_hash = None
25
+
26
+ def _hash_image(self, image: Image.Image) -> str:
27
+ """Compute strict SHA256 hash of PIL Image pixel data.
28
+
29
+ Args:
30
+ image: PIL Image to hash
31
+
32
+ Returns:
33
+ str: Hash string (16 chars)
34
+ """
35
+ if image.mode == "RGBA":
36
+ rgb_image = Image.new("RGB", image.size, (255, 255, 255))
37
+ rgb_image.paste(image, mask=image.split()[-1])
38
+ data_image = rgb_image
39
+ elif image.mode == "L":
40
+ data_image = image
41
+ else:
42
+ data_image = image
43
+
44
+ metadata = (
45
+ f"{data_image.mode}_{data_image.size[0]}_{data_image.size[1]}".encode()
46
+ )
47
+ image_bytes = data_image.tobytes()
48
+ digest = hashlib.sha256(metadata + image_bytes).hexdigest()
49
+ return digest[:16]
50
+
51
+ def _hash_numpy(self, array: np.ndarray) -> str:
52
+ """Compute strict SHA256 hash of numpy array contents.
53
+
54
+ Args:
55
+ array: Numpy array to hash
56
+
57
+ Returns:
58
+ str: Hash string (16 chars)
59
+ """
60
+ if array.size == 0:
61
+ return hashlib.sha256(b"empty_array").hexdigest()[:16]
62
+
63
+ metadata = f"{array.shape}_{array.dtype}".encode()
64
+ combined_data = metadata + array.tobytes()
65
+ return hashlib.sha256(combined_data).hexdigest()[:16]
66
+
67
+ def _hash_dict(self, data: Dict) -> str:
68
+ """Compute hash of dictionary.
69
+
70
+ Args:
71
+ data: Dictionary to hash
72
+
73
+ Returns:
74
+ str: Hash string (16 chars)
75
+ """
76
+ data_bytes = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
77
+ return hashlib.sha256(data_bytes).hexdigest()[:16]
78
+
79
+ def get_yolo_cache_key(
80
+ self, image: Image.Image, model_path: str, confidence: float
81
+ ) -> str:
82
+ """Compute cache key for YOLO detection.
83
+
84
+ Args:
85
+ image: Input image
86
+ model_path: Path to YOLO model
87
+ confidence: Confidence threshold
88
+
89
+ Returns:
90
+ str: Cache key
91
+ """
92
+ image_hash = self._hash_image(image)
93
+ model_hash = hashlib.sha256(model_path.encode()).hexdigest()[:16]
94
+ key_string = f"yolo_{image_hash}_{model_hash}_conf{confidence:.3f}"
95
+ return hashlib.sha256(key_string.encode()).hexdigest()
96
+
97
+ def get_yolo_detection(self, cache_key: str) -> Optional[Any]:
98
+ """Get cached YOLO detection result.
99
+
100
+ Args:
101
+ cache_key: Cache key
102
+
103
+ Returns:
104
+ Cached YOLO results or None if not found
105
+ """
106
+ return self._yolo_cache.get(cache_key)
107
+
108
+ def set_yolo_detection(
109
+ self, cache_key: str, results: Any, verbose: bool = False
110
+ ) -> None:
111
+ """Cache YOLO detection result.
112
+
113
+ Args:
114
+ cache_key: Cache key
115
+ results: YOLO detection results to cache
116
+ verbose: Whether to print verbose logging
117
+ """
118
+ self._yolo_cache.put(cache_key, results)
119
+ log_message(
120
+ f" - Cached YOLO detection (cache size: {len(self._yolo_cache.cache)})",
121
+ verbose=verbose,
122
+ )
123
+
124
+ def get_sam_cache_key(
125
+ self,
126
+ image: Image.Image,
127
+ yolo_boxes: Any,
128
+ use_sam2: bool = True,
129
+ conjoined_detection: bool = True,
130
+ conjoined_confidence: float = 0.35,
131
+ ) -> str:
132
+ """Compute cache key for SAM segmentation.
133
+
134
+ Args:
135
+ image: Input image
136
+ yolo_boxes: YOLO detection boxes (tensor or list)
137
+ use_sam2: Whether SAM2 is enabled
138
+ conjoined_detection: Whether conjoined detection is enabled
139
+ conjoined_confidence: Confidence threshold for conjoined detection
140
+
141
+ Returns:
142
+ str: Cache key
143
+ """
144
+ image_hash = self._hash_image(image)
145
+
146
+ if hasattr(yolo_boxes, "cpu"):
147
+ boxes_np = yolo_boxes.cpu().numpy()
148
+ else:
149
+ boxes_np = np.array(yolo_boxes)
150
+ boxes_hash = self._hash_numpy(boxes_np)
151
+
152
+ sam_model_id = "facebook/sam2.1-hiera-large"
153
+ model_hash = hashlib.sha256(sam_model_id.encode()).hexdigest()[:8]
154
+ key_string = (
155
+ f"sam_{image_hash}_{boxes_hash}_{model_hash}_sam2{int(use_sam2)}"
156
+ f"_conjoined{int(conjoined_detection)}"
157
+ f"_conf{conjoined_confidence:.3f}"
158
+ )
159
+ return hashlib.sha256(key_string.encode()).hexdigest()
160
+
161
+ def get_sam_masks(self, cache_key: str) -> Optional[Any]:
162
+ """Get cached SAM masks.
163
+
164
+ Args:
165
+ cache_key: Cache key
166
+
167
+ Returns:
168
+ Cached SAM masks or None if not found
169
+ """
170
+ return self._sam_cache.get(cache_key)
171
+
172
+ def set_sam_masks(self, cache_key: str, masks: Any, verbose: bool = False) -> None:
173
+ """Cache SAM masks.
174
+
175
+ Args:
176
+ cache_key: Cache key
177
+ masks: SAM masks to cache
178
+ verbose: Whether to print verbose logging
179
+ """
180
+ self._sam_cache.put(cache_key, masks)
181
+ log_message(
182
+ f" - Cached SAM masks (cache size: {len(self._sam_cache.cache)})",
183
+ verbose=verbose,
184
+ )
185
+
186
+ def _is_deterministic(self, config) -> bool:
187
+ """Check if translation config is deterministic.
188
+
189
+ Args:
190
+ config: TranslationConfig object
191
+
192
+ Returns:
193
+ bool: True if translation is deterministic
194
+ """
195
+ return config.temperature == 0.0 or config.top_k == 1 or config.top_p == 0.0
196
+
197
+ def get_translation_cache_key(
198
+ self,
199
+ images_b64: list,
200
+ full_image_b64: str,
201
+ config,
202
+ ) -> Optional[str]:
203
+ """Compute cache key for LLM translation.
204
+
205
+ Only returns a key if the config is deterministic.
206
+
207
+ Args:
208
+ images_b64: List of base64 encoded bubble images
209
+ full_image_b64: Base64 encoded full page image
210
+ config: TranslationConfig object
211
+
212
+ Returns:
213
+ str: Cache key, or None if not deterministic
214
+ """
215
+ if not self._is_deterministic(config):
216
+ return None
217
+
218
+ images_hash = hashlib.sha256("".join(images_b64).encode()).hexdigest()[:16]
219
+ full_hash = hashlib.sha256(full_image_b64.encode()).hexdigest()[:16]
220
+
221
+ cache_params = {
222
+ "provider": config.provider,
223
+ "model_name": config.model_name,
224
+ "input_language": config.input_language,
225
+ "output_language": config.output_language,
226
+ "reading_direction": config.reading_direction,
227
+ "translation_mode": config.translation_mode,
228
+ "send_full_page_context": config.send_full_page_context,
229
+ "temperature": config.temperature,
230
+ "top_k": config.top_k,
231
+ "top_p": config.top_p,
232
+ "ocr_method": config.ocr_method,
233
+ "special_instructions": (
234
+ config.special_instructions.strip()
235
+ if config.special_instructions
236
+ else None
237
+ ),
238
+ "max_tokens": config.max_tokens,
239
+ "reasoning_effort": config.reasoning_effort,
240
+ "effort": config.effort,
241
+ "media_resolution": getattr(config, "media_resolution", None),
242
+ "media_resolution_bubbles": getattr(
243
+ config, "media_resolution_bubbles", None
244
+ ),
245
+ "media_resolution_context": getattr(
246
+ config, "media_resolution_context", None
247
+ ),
248
+ "enable_web_search": getattr(config, "enable_web_search", None),
249
+ "upscale_method": getattr(config, "upscale_method", None),
250
+ "bubble_min_side_pixels": getattr(config, "bubble_min_side_pixels", None),
251
+ "context_image_max_side_pixels": getattr(
252
+ config, "context_image_max_side_pixels", None
253
+ ),
254
+ }
255
+ config_hash = self._hash_dict(cache_params)
256
+
257
+ key_string = f"trans_{images_hash}_{full_hash}_{config_hash}"
258
+ return hashlib.sha256(key_string.encode()).hexdigest()
259
+
260
+ def get_translation(self, cache_key: Optional[str]) -> Optional[list]:
261
+ """Get cached translation results.
262
+
263
+ Args:
264
+ cache_key: Cache key (can be None if not deterministic)
265
+
266
+ Returns:
267
+ Cached translations or None if not found
268
+ """
269
+ if cache_key is None:
270
+ return None
271
+ return self._translation_cache.get(cache_key)
272
+
273
+ def set_translation(
274
+ self, cache_key: Optional[str], translations: list, verbose: bool = False
275
+ ) -> None:
276
+ """Cache translation results.
277
+
278
+ Args:
279
+ cache_key: Cache key (can be None if not deterministic)
280
+ translations: Translation results to cache
281
+ verbose: Whether to print verbose logging
282
+ """
283
+ if cache_key is None:
284
+ return
285
+ self._translation_cache.put(cache_key, translations)
286
+ log_message(
287
+ f" - Cached translation (cache size: {len(self._translation_cache.cache)})",
288
+ verbose=verbose,
289
+ )
290
+
291
+ def get_manga_ocr_cache_key(
292
+ self, images_b64: List[str], total_elements: int
293
+ ) -> Optional[str]:
294
+ """Compute cache key for manga-ocr results.
295
+
296
+ Args:
297
+ images_b64: List of base64-encoded cropped images.
298
+ total_elements: Expected number of OCR outputs.
299
+
300
+ Returns:
301
+ str: Cache key (always deterministic)
302
+ """
303
+ images_hash = hashlib.sha256("".join(images_b64).encode()).hexdigest()[:16]
304
+ key_string = f"mocr_{images_hash}_n{total_elements}"
305
+ return hashlib.sha256(key_string.encode()).hexdigest()
306
+
307
+ def get_manga_ocr_result(self, cache_key: Optional[str]) -> Optional[list]:
308
+ """Get cached manga-ocr results."""
309
+ if cache_key is None:
310
+ return None
311
+ return self._manga_ocr_cache.get(cache_key)
312
+
313
+ def set_manga_ocr_result(
314
+ self, cache_key: Optional[str], results: list, verbose: bool = False
315
+ ) -> None:
316
+ """Cache manga-ocr results (including failure markers)."""
317
+ if cache_key is None:
318
+ return
319
+ self._manga_ocr_cache.put(cache_key, results)
320
+ log_message(
321
+ f" - Cached manga-ocr result (cache size: {len(self._manga_ocr_cache.cache)})",
322
+ verbose=verbose,
323
+ )
324
+
325
+ def get_upscale_cache_key(
326
+ self, image: Image.Image, factor: float, model_type: str = "model"
327
+ ) -> str:
328
+ """Compute cache key for image upscaling.
329
+
330
+ Args:
331
+ image: Input image
332
+ factor: Upscaling factor
333
+ model_type: Model type identifier ("model" or "model_lite")
334
+
335
+ Returns:
336
+ str: Cache key
337
+ """
338
+ image_hash = self._hash_image(image)
339
+ key_string = f"upscale_{image_hash}_factor{factor:.3f}_model{model_type}"
340
+ return hashlib.sha256(key_string.encode()).hexdigest()
341
+
342
+ def get_upscale_dimension_cache_key(
343
+ self, image: Image.Image, target: int, mode: str, model_type: str = "model"
344
+ ) -> str:
345
+ """Compute cache key for image upscaling to dimension.
346
+
347
+ Args:
348
+ image: Input image
349
+ target: Target dimension
350
+ mode: Upscaling mode ('max' or 'min')
351
+ model_type: Model type identifier ("model" or "model_lite")
352
+
353
+ Returns:
354
+ str: Cache key
355
+ """
356
+ image_hash = self._hash_image(image)
357
+ key_string = (
358
+ f"upscale_dim_{image_hash}_target{target}_mode{mode}_model{model_type}"
359
+ )
360
+ return hashlib.sha256(key_string.encode()).hexdigest()
361
+
362
+ def get_bubble_processing_cache_key(
363
+ self, image: Image.Image, target: int, mode: str, model_type: str = "model"
364
+ ) -> str:
365
+ """Compute cache key for complete bubble processing (upscale + color match).
366
+
367
+ Args:
368
+ image: Input image
369
+ target: Target dimension
370
+ mode: Upscaling mode ('max' or 'min')
371
+ model_type: Model type identifier ("model" or "model_lite")
372
+
373
+ Returns:
374
+ str: Cache key
375
+ """
376
+ image_hash = self._hash_image(image)
377
+ key_string = (
378
+ f"bubble_proc_{image_hash}_target{target}_mode{mode}_model{model_type}"
379
+ )
380
+ return hashlib.sha256(key_string.encode()).hexdigest()
381
+
382
+ def get_upscaled_image(self, cache_key: str) -> Optional[Image.Image]:
383
+ """Get cached upscaled image.
384
+
385
+ Args:
386
+ cache_key: Cache key
387
+
388
+ Returns:
389
+ Cached upscaled image or None if not found
390
+ """
391
+ return self._upscale_cache.get(cache_key)
392
+
393
+ def set_upscaled_image(
394
+ self, cache_key: str, image: Image.Image, verbose: bool = False
395
+ ) -> None:
396
+ """Cache upscaled image.
397
+
398
+ Args:
399
+ cache_key: Cache key
400
+ image: Upscaled image to cache
401
+ verbose: Whether to print verbose logging
402
+ """
403
+ self._upscale_cache.put(cache_key, image)
404
+ log_message(
405
+ f" - Cached upscaled image (cache size: {len(self._upscale_cache.cache)})",
406
+ verbose=verbose,
407
+ )
408
+
409
+ def get_inpaint_cache_key(
410
+ self,
411
+ image: Image.Image,
412
+ mask: np.ndarray,
413
+ seed: int,
414
+ num_inference_steps: int,
415
+ residual_diff_threshold: float,
416
+ guidance_scale: float,
417
+ prompt: str,
418
+ ocr_params: Optional[Dict] = None,
419
+ ) -> str:
420
+ """Compute cache key for Flux inpainting.
421
+
422
+ Args:
423
+ image: Input image
424
+ mask: Mask array
425
+ seed: Random seed
426
+ num_inference_steps: Number of inference steps
427
+ residual_diff_threshold: Residual diff threshold
428
+ guidance_scale: Guidance scale
429
+ prompt: Inpainting prompt
430
+ ocr_params: Optional OCR parameters dict (e.g., {'min_size': 200})
431
+
432
+ Returns:
433
+ str: Cache key
434
+ """
435
+ image_hash = self._hash_image(image)
436
+ mask_hash = self._hash_numpy(mask)
437
+
438
+ # Include OCR parameters in cache key if provided
439
+ ocr_params_str = ""
440
+ if ocr_params:
441
+ ocr_params_str = "_" + "_".join(
442
+ f"{k}{v}" for k, v in sorted(ocr_params.items())
443
+ )
444
+
445
+ key_string = (
446
+ f"inpaint_{image_hash}_{mask_hash}_"
447
+ f"seed{seed}_steps{num_inference_steps}_"
448
+ f"thresh{residual_diff_threshold:.3f}_"
449
+ f"guide{guidance_scale:.2f}_"
450
+ f"{prompt}{ocr_params_str}"
451
+ )
452
+ return hashlib.sha256(key_string.encode()).hexdigest()
453
+
454
+ def should_use_inpaint_cache(self, seed: int) -> bool:
455
+ """Determine if inpainting caching should be used.
456
+
457
+ Args:
458
+ seed: Random seed value
459
+
460
+ Returns:
461
+ bool: True if caching is enabled (seed != -1)
462
+ """
463
+ return seed != -1
464
+
465
+ def get_inpainted_image(self, cache_key: str) -> Optional[Image.Image]:
466
+ """Get cached inpainted image.
467
+
468
+ Args:
469
+ cache_key: Cache key
470
+
471
+ Returns:
472
+ Cached inpainted image or None if not found
473
+ """
474
+ return self._inpaint_cache.get(cache_key)
475
+
476
+ def set_inpainted_image(
477
+ self, cache_key: str, image: Image.Image, verbose: bool = False
478
+ ) -> None:
479
+ """Cache inpainted image.
480
+
481
+ Args:
482
+ cache_key: Cache key
483
+ image: Inpainted image to cache
484
+ verbose: Whether to print verbose logging
485
+ """
486
+ self._inpaint_cache.put(cache_key, image)
487
+ log_message(
488
+ f" - Cached inpainted image (cache size: {len(self._inpaint_cache.cache)})",
489
+ verbose=verbose,
490
+ )
491
+
492
+ def clear_yolo_cache(self, verbose: bool = False) -> None:
493
+ """Clear YOLO detection cache."""
494
+ self._yolo_cache.cache.clear()
495
+ log_message("YOLO cache cleared", verbose=verbose)
496
+
497
+ def clear_sam_cache(self, verbose: bool = False) -> None:
498
+ """Clear SAM masks cache."""
499
+ self._sam_cache.cache.clear()
500
+ log_message("SAM cache cleared", verbose=verbose)
501
+
502
+ def clear_translation_cache(self, verbose: bool = False) -> None:
503
+ """Clear translation cache."""
504
+ self._translation_cache.cache.clear()
505
+ log_message("Translation cache cleared", verbose=verbose)
506
+
507
+ def clear_manga_ocr_cache(self, verbose: bool = False) -> None:
508
+ """Clear manga-ocr cache."""
509
+ self._manga_ocr_cache.cache.clear()
510
+ log_message("manga-ocr cache cleared", verbose=verbose)
511
+
512
+ def clear_upscale_cache(self, verbose: bool = False) -> None:
513
+ """Clear upscaling cache."""
514
+ self._upscale_cache.cache.clear()
515
+ log_message("Upscale cache cleared", verbose=verbose)
516
+
517
+ def clear_inpaint_cache(self, verbose: bool = False) -> None:
518
+ """Clear inpainting cache."""
519
+ self._inpaint_cache.cache.clear()
520
+ log_message("Inpaint cache cleared", verbose=verbose)
521
+
522
+ def clear_all(self) -> None:
523
+ """Clear all caches."""
524
+ self.clear_yolo_cache(verbose=False)
525
+ self.clear_sam_cache(verbose=False)
526
+ self.clear_translation_cache(verbose=False)
527
+ self.clear_manga_ocr_cache(verbose=False)
528
+ self.clear_upscale_cache(verbose=False)
529
+ self.clear_inpaint_cache(verbose=False)
530
+ log_message("All caches cleared", always_print=True)
531
+
532
+ def set_current_image(self, image: Image.Image, verbose: bool = False) -> None:
533
+ """Set the current image being processed and clear caches if different.
534
+
535
+ Args:
536
+ image: The current image being processed
537
+ verbose: Whether to print verbose logging
538
+ """
539
+ image_hash = self._hash_image(image)
540
+
541
+ if self._current_image_hash is None:
542
+ # First image
543
+ self._current_image_hash = image_hash
544
+ log_message("Cache initialized for new image", verbose=verbose)
545
+ elif self._current_image_hash != image_hash:
546
+ # Different image detected - clear all caches
547
+ log_message(
548
+ "Different image detected - clearing all caches", verbose=verbose
549
+ )
550
+ self.clear_all()
551
+ self._current_image_hash = image_hash
552
+ else:
553
+ # Same image - no action needed
554
+ log_message("Same image detected - reusing caches", verbose=verbose)
555
+
556
+ def get_cache_stats(self) -> dict:
557
+ """Get statistics about cache sizes.
558
+
559
+ Returns:
560
+ dict: Cache statistics
561
+ """
562
+ return {
563
+ "yolo": len(self._yolo_cache.cache),
564
+ "sam": len(self._sam_cache.cache),
565
+ "translation": len(self._translation_cache.cache),
566
+ "manga_ocr": len(self._manga_ocr_cache.cache),
567
+ "upscale": len(self._upscale_cache.cache),
568
+ "inpaint": len(self._inpaint_cache.cache),
569
+ }
570
+
571
+
572
+ _global_cache = None
573
+
574
+
575
+ def get_cache() -> UnifiedCache:
576
+ """Get the global cache instance.
577
+
578
+ Returns:
579
+ UnifiedCache: The global cache instance
580
+ """
581
+ global _global_cache
582
+ if _global_cache is None:
583
+ _global_cache = UnifiedCache()
584
+ return _global_cache
core/config.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from core.llm_defaults import DEFAULT_LLM_PROVIDER, get_provider_sampling_defaults
7
+
8
+
9
+ @dataclass
10
+ class DetectionConfig:
11
+ """Configuration for speech bubble detection."""
12
+
13
+ confidence: float = 0.6
14
+ conjoined_confidence: float = 0.35
15
+ panel_confidence: float = 0.25
16
+ use_sam2: bool = True
17
+ conjoined_detection: bool = True
18
+ use_panel_sorting: bool = True
19
+ use_osb_text_verification: bool = True
20
+
21
+
22
+ @dataclass
23
+ class CleaningConfig:
24
+ """Configuration for speech bubble cleaning."""
25
+
26
+ thresholding_value: int = 190
27
+ use_otsu_threshold: bool = False
28
+ roi_shrink_px: int = 4
29
+ inpaint_colored_bubbles: bool = True
30
+
31
+
32
+ _DEFAULT_TRANSLATION_PROVIDER = DEFAULT_LLM_PROVIDER
33
+ _DEFAULT_SAMPLING = get_provider_sampling_defaults(_DEFAULT_TRANSLATION_PROVIDER)
34
+
35
+
36
+ @dataclass
37
+ class TranslationConfig:
38
+ """Configuration for text translation."""
39
+
40
+ provider: str = _DEFAULT_TRANSLATION_PROVIDER
41
+ google_api_key: str = ""
42
+ openai_api_key: str = ""
43
+ anthropic_api_key: str = ""
44
+ xai_api_key: str = ""
45
+ deepseek_api_key: str = ""
46
+ zai_api_key: str = ""
47
+ moonshot_api_key: str = ""
48
+ openrouter_api_key: str = ""
49
+ openai_compatible_url: str = "http://localhost:1234/v1"
50
+ openai_compatible_api_key: Optional[str] = ""
51
+ model_name: str = "gemini-2.5-flash"
52
+ provider_models: dict[str, Optional[str]] = field(default_factory=dict)
53
+ temperature: float = float(_DEFAULT_SAMPLING["temperature"])
54
+ top_p: float = float(_DEFAULT_SAMPLING["top_p"])
55
+ top_k: int = int(_DEFAULT_SAMPLING["top_k"])
56
+ max_tokens: Optional[int] = (
57
+ None # None = use default logic (16384 for reasoning, 4096 otherwise)
58
+ )
59
+ input_language: str = "Japanese"
60
+ output_language: str = "English"
61
+ reading_direction: str = "rtl"
62
+ translation_mode: str = "one-step"
63
+ reasoning_effort: Optional[str] = (
64
+ None # Default: Google uses "auto", Anthropic uses "none", others use "medium"
65
+ )
66
+ effort: Optional[str] = (
67
+ None # Claude Opus 4.5 only: Controls token spending eagerness (high/medium/low)
68
+ )
69
+ send_full_page_context: bool = True
70
+ upscale_method: str = "model_lite" # "model", "model_lite", "lanczos", or "none"
71
+ enable_web_search: bool = (
72
+ False # Enable model's built-in web search for up-to-date information. OpenRouter uses its own web search tool.
73
+ )
74
+ media_resolution: str = (
75
+ "auto" # Only available via Google provider (auto/high/medium/low)
76
+ )
77
+ media_resolution_bubbles: str = "auto" # Gemini 3 models
78
+ media_resolution_context: str = "auto" # Gemini 3 models
79
+ bubble_min_side_pixels: int = 128
80
+ context_image_max_side_pixels: int = 1024
81
+ osb_min_side_pixels: int = 128
82
+ special_instructions: Optional[str] = None
83
+ ocr_method: str = "LLM" # "LLM" or "manga-ocr"
84
+
85
+
86
+ @dataclass
87
+ class RenderingConfig:
88
+ """Configuration for rendering translated text."""
89
+
90
+ font_dir: str = "./fonts"
91
+ max_font_size: int = 16
92
+ min_font_size: int = 8
93
+ line_spacing_mult: float = 1.0
94
+ use_subpixel_rendering: bool = False
95
+ font_hinting: str = "none"
96
+ use_ligatures: bool = False
97
+ hyphenate_before_scaling: bool = True
98
+ hyphen_penalty: float = 1000.0
99
+ hyphenation_min_word_length: int = 8
100
+ badness_exponent: float = 3.0
101
+ padding_pixels: float = 5.0
102
+ outline_width: float = 0.0
103
+ supersampling_factor: int = 4
104
+
105
+
106
+ @dataclass
107
+ class OutsideTextConfig:
108
+ """Configuration for outside speech bubble text detection and removal."""
109
+
110
+ enabled: bool = False
111
+ enable_page_number_filtering: bool = False
112
+ page_filter_margin_threshold: float = 0.1
113
+ page_filter_min_area_ratio: float = 0.05
114
+ seed: int = 1 # -1 = random
115
+ huggingface_token: str = "" # Required for Flux Kontext model downloads
116
+ force_cv2_inpainting: bool = False
117
+ flux_num_inference_steps: int = 8
118
+ flux_residual_diff_threshold: float = 0.15
119
+ osb_confidence: float = 0.6
120
+ osb_font_name: Optional[str] = None # None = use main font as fallback
121
+ osb_max_font_size: int = 64
122
+ osb_min_font_size: int = 12
123
+ osb_use_ligatures: bool = False
124
+ osb_outline_width: float = 3.0
125
+ osb_line_spacing: float = 1.0
126
+ osb_use_subpixel_rendering: bool = False
127
+ osb_font_hinting: str = "none"
128
+ bbox_expansion_percent: float = 0.1
129
+ text_box_proximity_ratio: float = 0.02 # 2% of image dimension
130
+ flux_guidance_scale: float = 2.5
131
+ flux_prompt: str = "Remove all text."
132
+
133
+
134
+ @dataclass
135
+ class OutputConfig:
136
+ """Configuration for saving output images."""
137
+
138
+ jpeg_quality: int = 95
139
+ png_compression: int = 2
140
+ output_format: str = "auto"
141
+ upscale_final_image: bool = False
142
+ image_upscale_factor: float = 2.0
143
+ image_upscale_model: str = "model_lite" # "model" or "model_lite"
144
+
145
+
146
+ @dataclass
147
+ class MangaTranslatorConfig:
148
+ """Main configuration for the MangaTranslator pipeline."""
149
+
150
+ yolo_model_path: str
151
+ detection: DetectionConfig = field(default_factory=DetectionConfig)
152
+ cleaning: CleaningConfig = field(default_factory=CleaningConfig)
153
+ translation: TranslationConfig = field(default_factory=TranslationConfig)
154
+ rendering: RenderingConfig = field(default_factory=RenderingConfig)
155
+ output: OutputConfig = field(default_factory=OutputConfig)
156
+ outside_text: OutsideTextConfig = field(default_factory=OutsideTextConfig)
157
+ preprocessing: "PreprocessingConfig" = field(
158
+ default_factory=lambda: PreprocessingConfig()
159
+ )
160
+ verbose: bool = False
161
+ device: Optional[torch.device] = None
162
+ cleaning_only: bool = False
163
+ upscaling_only: bool = False
164
+ test_mode: bool = False
165
+ processing_scale: float = 1.0
166
+
167
+ def __post_init__(self):
168
+ # Load API keys from environment variables if not already set
169
+ import os
170
+
171
+ if not self.translation.google_api_key:
172
+ self.translation.google_api_key = os.environ.get("GOOGLE_API_KEY", "")
173
+ if not self.translation.openai_api_key:
174
+ self.translation.openai_api_key = os.environ.get("OPENAI_API_KEY", "")
175
+ if not self.translation.anthropic_api_key:
176
+ self.translation.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
177
+ if not self.translation.xai_api_key:
178
+ self.translation.xai_api_key = os.environ.get("XAI_API_KEY", "")
179
+ if not self.translation.deepseek_api_key:
180
+ self.translation.deepseek_api_key = os.environ.get("DEEPSEEK_API_KEY", "")
181
+ if not self.translation.moonshot_api_key:
182
+ self.translation.moonshot_api_key = os.environ.get("MOONSHOT_API_KEY", "")
183
+ if not self.translation.openrouter_api_key:
184
+ self.translation.openrouter_api_key = os.environ.get(
185
+ "OPENROUTER_API_KEY", ""
186
+ )
187
+ if (
188
+ not self.translation.openai_compatible_api_key
189
+ ): # Check if it's None or empty string
190
+ self.translation.openai_compatible_api_key = os.environ.get(
191
+ "OPENAI_COMPATIBLE_API_KEY", ""
192
+ )
193
+
194
+ # Autodetect device if not specified
195
+ if self.device is None:
196
+ if torch.cuda.is_available():
197
+ self.device = torch.device("cuda")
198
+ elif torch.backends.mps.is_available():
199
+ self.device = torch.device("mps")
200
+ else:
201
+ self.device = torch.device("cpu")
202
+ pass
203
+
204
+
205
+ @dataclass
206
+ class PreprocessingConfig:
207
+ """Configuration for image preprocessing before detection/cleaning."""
208
+
209
+ enabled: bool = False
210
+ factor: float = 2.0
211
+ auto_scale: bool = True
212
+
213
+
214
+ def calculate_reasoning_budget(total_tokens: int, effort_level: str) -> int:
215
+ """
216
+ Calculate reasoning token budget based on effort level.
217
+
218
+ Args:
219
+ total_tokens: Total available tokens (typically max_tokens)
220
+ effort_level: Reasoning effort level ("high", "medium", "low", "minimal", "auto", or "none")
221
+
222
+ Returns:
223
+ int: Calculated budget in tokens
224
+ - "high": 80% of total_tokens
225
+ - "medium": 50% of total_tokens
226
+ - "low": 20% of total_tokens
227
+ - "minimal": 10% of total_tokens
228
+ - "auto" or "none": Returns 0 (caller should handle these cases separately)
229
+ """
230
+ if effort_level == "high":
231
+ return int(total_tokens * 0.8)
232
+ elif effort_level == "medium":
233
+ return int(total_tokens * 0.5)
234
+ elif effort_level == "low":
235
+ return int(total_tokens * 0.2)
236
+ elif effort_level == "minimal":
237
+ return int(total_tokens * 0.1)
238
+ else:
239
+ # "auto" or "none" - return 0, caller should handle these cases
240
+ return 0
core/image/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image processing and analysis modules for MangaTranslator.
3
+
4
+ This subpackage contains modules for:
5
+ - Speech bubble detection (YOLO, SAM)
6
+ - OCR text detection outside bubbles
7
+ - Image cleaning and preprocessing
8
+ - Inpainting for text removal
9
+ - General image utilities
10
+ """
11
+
12
+ from .cleaning import clean_speech_bubbles
13
+ from .detection import detect_speech_bubbles
14
+ from .image_utils import (
15
+ calculate_centroid_expansion_box,
16
+ convert_image_to_target_mode,
17
+ cv2_to_pil,
18
+ pil_to_cv2,
19
+ process_bubble_image_cached,
20
+ resize_to_max_side,
21
+ save_image_with_compression,
22
+ upscale_image,
23
+ upscale_image_to_dimension,
24
+ )
25
+ from .inpainting import FluxKontextInpainter
26
+ from .ocr_detection import OutsideTextDetector
27
+
28
+ __all__ = [
29
+ "clean_speech_bubbles",
30
+ "detect_speech_bubbles",
31
+ "calculate_centroid_expansion_box",
32
+ "convert_image_to_target_mode",
33
+ "cv2_to_pil",
34
+ "pil_to_cv2",
35
+ "process_bubble_image_cached",
36
+ "resize_to_max_side",
37
+ "save_image_with_compression",
38
+ "upscale_image",
39
+ "upscale_image_to_dimension",
40
+ "FluxKontextInpainter",
41
+ "OutsideTextDetector",
42
+ ]
core/image/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.29 kB). View file
 
core/image/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.11 kB). View file
 
core/image/__pycache__/cleaning.cpython-311.pyc ADDED
Binary file (32 kB). View file
 
core/image/__pycache__/cleaning.cpython-314.pyc ADDED
Binary file (30.5 kB). View file
 
core/image/__pycache__/detection.cpython-311.pyc ADDED
Binary file (40.2 kB). View file
 
core/image/__pycache__/detection.cpython-314.pyc ADDED
Binary file (37.5 kB). View file
 
core/image/__pycache__/image_utils.cpython-311.pyc ADDED
Binary file (36.2 kB). View file
 
core/image/__pycache__/image_utils.cpython-314.pyc ADDED
Binary file (35.2 kB). View file
 
core/image/__pycache__/inpainting.cpython-311.pyc ADDED
Binary file (32.5 kB). View file
 
core/image/__pycache__/ocr_detection.cpython-311.pyc ADDED
Binary file (32.4 kB). View file
 
core/image/__pycache__/sorting.cpython-311.pyc ADDED
Binary file (17.6 kB). View file
 
core/image/cleaning.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Optional, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ from core.scaling import scale_area, scale_kernel, scale_scalar
13
+ from utils.exceptions import CleaningError, ImageProcessingError, ValidationError
14
+ from utils.logging import log_message
15
+
16
+ from .detection import detect_speech_bubbles
17
+ from .image_utils import pil_to_cv2
18
+ from .inpainting import FluxKontextInpainter
19
+
20
+ # Cleaning parameters
21
+ GRAYSCALE_MIDPOINT = 128 # Threshold for determining black vs white bubbles
22
+ MIN_CONTOUR_AREA = 50 # Minimum area threshold for filtering small contours
23
+ DILATION_KERNEL_SIZE = (7, 7) # Kernel size for morphological dilation
24
+ EROSION_KERNEL_SIZE = (5, 5) # Kernel size for morphological erosion
25
+ DISTANCE_TRANSFORM_MASK_SIZE = 5 # Mask size for distance transform
26
+
27
+ # Classification thresholds for colored bubbles
28
+ BRIGHT_RATIO_THRESHOLD = 0.50
29
+ DARK_RATIO_THRESHOLD = 0.50
30
+ BRIGHT_DOM_RATIO_MIN = 0.30
31
+ DARK_DOM_RATIO_MIN = 0.30
32
+ BRIGHT_DARK_RATIO_MAX = 0.10
33
+ DARK_BRIGHT_RATIO_MAX = 0.10
34
+
35
+
36
+ def _normalize_mask(mask: np.ndarray) -> np.ndarray:
37
+ """
38
+ Ensure mask is uint8 binary (0/255).
39
+ """
40
+ if mask.dtype != np.uint8:
41
+ mask = mask.astype(np.uint8)
42
+ return np.where(mask > 0, 255, 0).astype(np.uint8)
43
+
44
+
45
+ def process_single_bubble(
46
+ base_mask,
47
+ img_gray,
48
+ img_height,
49
+ img_width,
50
+ thresholding_value,
51
+ use_otsu_threshold,
52
+ roi_shrink_px,
53
+ verbose,
54
+ detection_bbox=None,
55
+ is_sam=False,
56
+ dilation_kernel=None,
57
+ constraint_erosion_kernel=None,
58
+ min_contour_area: float = MIN_CONTOUR_AREA,
59
+ classify_colored: bool = False,
60
+ ):
61
+ """
62
+ Process a single speech bubble mask to extract text regions and determine fill color.
63
+
64
+ Args:
65
+ base_mask (numpy.ndarray): The base mask (SAM or YOLO) for the bubble
66
+ img_gray (numpy.ndarray): Grayscale image
67
+ img_height (int): Image height
68
+ img_width (int): Image width
69
+ thresholding_value (int): Fixed threshold value for text detection
70
+ use_otsu_threshold (bool): Whether to use Otsu's method for thresholding
71
+ roi_shrink_px (int): Pixels to shrink ROI inwards
72
+ verbose (bool): Whether to print verbose messages
73
+ detection_bbox: Bounding box for logging (optional)
74
+ is_sam (bool): Whether this is a SAM mask (for logging)
75
+
76
+ Returns:
77
+ tuple: (final_mask, fill_color_bgr, is_colored, sample_color_bgr, text_bbox)
78
+
79
+ Raises:
80
+ CleaningError: If processing fails
81
+ """
82
+ try:
83
+ base_mask = _normalize_mask(base_mask)
84
+
85
+ if dilation_kernel is None:
86
+ dilation_kernel = cv2.getStructuringElement(
87
+ cv2.MORPH_ELLIPSE, DILATION_KERNEL_SIZE
88
+ )
89
+ if constraint_erosion_kernel is None:
90
+ constraint_erosion_kernel = cv2.getStructuringElement(
91
+ cv2.MORPH_ELLIPSE, EROSION_KERNEL_SIZE
92
+ )
93
+ masked_pixels = img_gray[base_mask == 255]
94
+ if masked_pixels.size == 0:
95
+ log_message(
96
+ f"{'[SAM]' if is_sam else ''}Skipping detection {detection_bbox}: empty mask",
97
+ verbose=verbose,
98
+ )
99
+ raise CleaningError(f"Empty mask for detection {detection_bbox}")
100
+
101
+ mean_pixel_value = np.mean(masked_pixels)
102
+ is_black_bubble = mean_pixel_value < GRAYSCALE_MIDPOINT
103
+ fill_color_bgr = (0, 0, 0) if is_black_bubble else (255, 255, 255)
104
+ is_colored_bubble = False
105
+ sample_color_bgr: tuple[int, int, int] = fill_color_bgr
106
+
107
+ log_message(
108
+ f"{'[SAM]' if is_sam else ''}Detection {detection_bbox}: "
109
+ f"{'Black' if is_black_bubble else 'White'} bubble (mean={mean_pixel_value:.1f})",
110
+ verbose=verbose,
111
+ )
112
+
113
+ roi_mask = cv2.dilate(base_mask, dilation_kernel, iterations=1)
114
+ roi_gray = np.zeros_like(img_gray)
115
+ roi_indices = roi_mask == 255
116
+ roi_gray[roi_indices] = img_gray[roi_indices]
117
+
118
+ # Invert for black bubbles to detect text properly
119
+ roi_for_thresholding = (
120
+ cv2.bitwise_not(roi_gray) if is_black_bubble else roi_gray
121
+ )
122
+ thresholded_roi = np.zeros_like(img_gray)
123
+
124
+ if use_otsu_threshold:
125
+ roi_pixels_for_otsu = roi_for_thresholding[roi_indices]
126
+ thresh_val, _ = cv2.threshold(
127
+ roi_pixels_for_otsu, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
128
+ )
129
+ log_message(
130
+ f"{'[SAM]' if is_sam else ''} Otsu threshold: {thresh_val}",
131
+ verbose=verbose,
132
+ )
133
+ _, thresholded_roi = cv2.threshold(
134
+ roi_for_thresholding, thresh_val, 255, cv2.THRESH_BINARY
135
+ )
136
+ else:
137
+ _, thresholded_roi = cv2.threshold(
138
+ roi_for_thresholding, thresholding_value, 255, cv2.THRESH_BINARY
139
+ )
140
+
141
+ thresholded_roi = cv2.bitwise_and(thresholded_roi, roi_mask)
142
+
143
+ # Shrink ROI to avoid border artifacts
144
+ dist_map = cv2.distanceTransform(
145
+ roi_mask, cv2.DIST_L2, DISTANCE_TRANSFORM_MASK_SIZE
146
+ )
147
+ shrunk_roi_mask = np.where(dist_map >= float(roi_shrink_px), 255, 0).astype(
148
+ np.uint8
149
+ )
150
+ thresholded_roi = cv2.bitwise_and(thresholded_roi, shrunk_roi_mask)
151
+
152
+ # Use eroded mask to avoid erasing bubble outlines
153
+ eroded_constraint_mask = cv2.erode(
154
+ base_mask, constraint_erosion_kernel, iterations=1
155
+ )
156
+
157
+ contours, _ = cv2.findContours(
158
+ thresholded_roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
159
+ )
160
+ valid_contours = []
161
+ for cnt in contours:
162
+ area = cv2.contourArea(cnt)
163
+ if area <= min_contour_area:
164
+ continue
165
+ m = cv2.moments(cnt)
166
+ if m["m00"] == 0:
167
+ continue
168
+ cx = int(m["m10"] / m["m00"])
169
+ cy = int(m["m01"] / m["m00"])
170
+ if (
171
+ 0 <= cx < img_width
172
+ and 0 <= cy < img_height
173
+ and eroded_constraint_mask[cy, cx] == 255
174
+ ):
175
+ valid_contours.append(cnt)
176
+
177
+ log_message(
178
+ f"{'[SAM]' if is_sam else ''}Detection {detection_bbox}: {len(valid_contours)} text fragments found",
179
+ verbose=verbose,
180
+ )
181
+
182
+ text_bbox = None
183
+ if valid_contours:
184
+ validated_mask = np.zeros((img_height, img_width), dtype=np.uint8)
185
+ cv2.drawContours(
186
+ validated_mask, valid_contours, -1, 255, thickness=cv2.FILLED
187
+ )
188
+
189
+ # Re-contour to get clean boundary from validated mask
190
+ boundary_contours, _ = cv2.findContours(
191
+ validated_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
192
+ )
193
+ if boundary_contours:
194
+ largest_contour = max(boundary_contours, key=cv2.contourArea)
195
+ final_mask = np.zeros((img_height, img_width), dtype=np.uint8)
196
+ cv2.drawContours(
197
+ final_mask, [largest_contour], -1, 255, thickness=cv2.FILLED
198
+ )
199
+ x, y, w, h = cv2.boundingRect(largest_contour)
200
+ text_bbox = (x, y, x + w, y + h)
201
+
202
+ if classify_colored:
203
+ # Sample bubble interior excluding text box and outline to determine if colored
204
+ sampling_mask = cv2.erode(
205
+ base_mask, constraint_erosion_kernel, iterations=2
206
+ )
207
+ if text_bbox:
208
+ x1, y1, x2, y2 = text_bbox
209
+ x1 = max(0, x1)
210
+ y1 = max(0, y1)
211
+ x2 = min(img_width, x2)
212
+ y2 = min(img_height, y2)
213
+ sampling_mask[y1:y2, x1:x2] = 0
214
+ sample_pixels = img_gray[sampling_mask == 255]
215
+ if sample_pixels.size == 0:
216
+ sample_pixels = masked_pixels
217
+
218
+ sample_values = sample_pixels.astype(np.uint8).flatten()
219
+ hist = np.bincount(sample_values, minlength=256)
220
+ dominant_val = (
221
+ int(hist.argmax()) if hist.size > 0 else int(mean_pixel_value)
222
+ )
223
+ dominant_count = int(hist.max()) if hist.size > 0 else 0
224
+ total_count = max(int(sample_values.size), 1)
225
+ dominant_ratio = dominant_count / float(total_count)
226
+ bright_ratio = float(
227
+ np.count_nonzero(sample_values >= 245)
228
+ ) / float(total_count)
229
+ dark_ratio = float(np.count_nonzero(sample_values <= 15)) / float(
230
+ total_count
231
+ )
232
+
233
+ log_prefix = "[SAM] " if is_sam else ""
234
+ if bright_ratio >= BRIGHT_RATIO_THRESHOLD or (
235
+ dominant_val >= 245
236
+ and dominant_ratio >= BRIGHT_DOM_RATIO_MIN
237
+ and dark_ratio <= BRIGHT_DARK_RATIO_MAX
238
+ ):
239
+ is_colored_bubble = False
240
+ fill_color_bgr = (255, 255, 255)
241
+ sample_color_bgr = (255, 255, 255)
242
+ log_message(
243
+ f"{log_prefix}Detection {detection_bbox}: white "
244
+ f"(mode={dominant_val}, dom_ratio={dominant_ratio:.2f}, "
245
+ f"bright_ratio={bright_ratio:.2f}, dark_ratio={dark_ratio:.2f})",
246
+ verbose=verbose,
247
+ )
248
+ elif dark_ratio >= DARK_RATIO_THRESHOLD or (
249
+ dominant_val <= 15
250
+ and dominant_ratio >= DARK_DOM_RATIO_MIN
251
+ and bright_ratio <= DARK_BRIGHT_RATIO_MAX
252
+ ):
253
+ is_colored_bubble = False
254
+ fill_color_bgr = (0, 0, 0)
255
+ sample_color_bgr = (0, 0, 0)
256
+ log_message(
257
+ f"{log_prefix}Detection {detection_bbox}: black "
258
+ f"(mode={dominant_val}, dom_ratio={dominant_ratio:.2f}, "
259
+ f"bright_ratio={bright_ratio:.2f}, dark_ratio={dark_ratio:.2f})",
260
+ verbose=verbose,
261
+ )
262
+ else:
263
+ is_colored_bubble = True
264
+ sample_color_bgr = (dominant_val, dominant_val, dominant_val)
265
+ log_message(
266
+ f"{log_prefix}Detection {detection_bbox}: "
267
+ f"colored/gradient (mode={dominant_val}, "
268
+ f"dom_ratio={dominant_ratio:.2f}, "
269
+ f"bright_ratio={bright_ratio:.2f}, "
270
+ f"dark_ratio={dark_ratio:.2f})",
271
+ verbose=verbose,
272
+ )
273
+ return (
274
+ final_mask,
275
+ fill_color_bgr,
276
+ is_colored_bubble,
277
+ sample_color_bgr,
278
+ text_bbox,
279
+ )
280
+
281
+ raise CleaningError("Failed to process bubble mask")
282
+
283
+ except Exception as e:
284
+ log_message(
285
+ f"Failed to process {'SAM' if is_sam else 'YOLO'} mask for {detection_bbox}",
286
+ always_print=True,
287
+ )
288
+ raise CleaningError("Failed to process bubble mask") from e
289
+
290
+
291
+ def clean_speech_bubbles(
292
+ image_input: Union[str, Path, Image.Image],
293
+ model_path,
294
+ confidence=0.6,
295
+ pre_computed_detections=None,
296
+ device=None,
297
+ thresholding_value: int = 190,
298
+ use_otsu_threshold: bool = False,
299
+ roi_shrink_px: int = 4,
300
+ verbose: bool = False,
301
+ processing_scale: float = 1.0,
302
+ conjoined_confidence=0.35,
303
+ inpaint_colored_bubbles: bool = False,
304
+ flux_hf_token: str = "",
305
+ flux_num_inference_steps: int = 10,
306
+ flux_residual_diff_threshold: float = 0.15,
307
+ flux_seed: int = 1,
308
+ osb_text_verification: bool = False,
309
+ osb_text_hf_token: str = "",
310
+ force_cv2_inpainting: bool = False,
311
+ ):
312
+ """
313
+ Clean speech bubbles using YOLO/SAM masks and optional Flux inpainting for colored bubbles.
314
+
315
+ Args:
316
+ image_input (str, Path, or PIL.Image.Image): Path to input image or a PIL Image object.
317
+ model_path (str): Path to YOLO model.
318
+ confidence (float): Confidence threshold for detections.
319
+ pre_computed_detections (list, optional): Pre-computed detections from previous call.
320
+ device (torch.device, optional): The device to run detection model on if needed.
321
+ thresholding_value (int): Fixed threshold value for text detection (0-255). Lower values (e.g., 190)
322
+ are useful for uncleaned text close to bubble's edges.
323
+ use_otsu_threshold (bool): If True, use Otsu's method for thresholding instead of the fixed value.
324
+ roi_shrink_px (int): Number of pixels to shrink the ROI inwards before identification/fill.
325
+ inpaint_colored_bubbles (bool): If True, detect non-white/black bubbles and inpaint text with Flux.
326
+ flux_hf_token (str): Hugging Face token for Flux downloads (shared with outside-text removal).
327
+ flux_num_inference_steps (int): Flux denoising steps for colored bubble inpainting.
328
+ flux_residual_diff_threshold (float): Flux residual diff threshold for caching.
329
+ flux_seed (int): Seed for Flux; -1 enables random per run.
330
+ osb_text_verification (bool): When True, expand bubble boxes to fully cover OSB text detections.
331
+ osb_text_hf_token (str): Optional token for OSB text model downloads.
332
+ force_cv2_inpainting (bool): If True, skip Flux inpainting even for colored bubbles and use standard fill.
333
+
334
+ Returns:
335
+ numpy.ndarray: Cleaned image with text removed.
336
+ list[dict]: A list of dictionaries per bubble containing:
337
+ - 'mask' (np.ndarray): validated text mask (0/255)
338
+ - 'base_mask' (np.ndarray): normalized detection mask used for processing
339
+ - 'color' (tuple BGR): sampled bubble color
340
+ - 'bbox' (tuple): detection bounding box
341
+ - 'is_colored' (bool): whether bubble interior was classified colored
342
+ - 'text_bbox' (tuple|None): bounding box of detected text mask
343
+ - 'is_sam' (bool): whether detection originated from SAM
344
+ Raises:
345
+ ValueError: If the image cannot be loaded or if an image object is passed without pre-computed detections.
346
+ RuntimeError: If model loading or bubble detection fails.
347
+ """
348
+ try:
349
+ if isinstance(image_input, (str, Path)):
350
+ pil_image = Image.open(image_input)
351
+ image_path = image_input
352
+ else:
353
+ pil_image = image_input
354
+ image_path = None # In-memory image has no path
355
+
356
+ image = pil_to_cv2(pil_image)
357
+ img_height, img_width = image.shape[:2]
358
+ img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
359
+
360
+ cleaned_image = image.copy()
361
+
362
+ if pre_computed_detections is not None:
363
+ detections = pre_computed_detections
364
+ elif image_path is not None:
365
+ detection_result = detect_speech_bubbles(
366
+ image_path,
367
+ model_path,
368
+ confidence,
369
+ device=device,
370
+ conjoined_confidence=conjoined_confidence,
371
+ osb_text_verification=osb_text_verification,
372
+ osb_text_hf_token=osb_text_hf_token,
373
+ )
374
+ detections = (
375
+ detection_result[0]
376
+ if isinstance(detection_result, tuple)
377
+ else detection_result
378
+ )
379
+ else:
380
+ raise ValidationError(
381
+ "Bubble detection requires an image path, but an image object "
382
+ "was provided without pre-computed detections."
383
+ )
384
+
385
+ processed_bubbles = []
386
+
387
+ effective_roi_shrink_px = float(
388
+ scale_scalar(
389
+ roi_shrink_px,
390
+ processing_scale,
391
+ minimum=0.0,
392
+ maximum=64.0,
393
+ )
394
+ )
395
+ dilation_kernel = cv2.getStructuringElement(
396
+ cv2.MORPH_ELLIPSE, scale_kernel(DILATION_KERNEL_SIZE, processing_scale)
397
+ )
398
+ constraint_erosion_kernel = cv2.getStructuringElement(
399
+ cv2.MORPH_ELLIPSE, scale_kernel(EROSION_KERNEL_SIZE, processing_scale)
400
+ )
401
+ min_contour_area = scale_area(
402
+ MIN_CONTOUR_AREA,
403
+ processing_scale,
404
+ minimum=MIN_CONTOUR_AREA,
405
+ maximum=5000,
406
+ )
407
+ for detection in detections:
408
+ final_mask = None
409
+ fill_color_bgr = None
410
+ is_colored_bubble = False
411
+ sample_color_bgr: Optional[tuple[int, int, int]] = None
412
+ text_bbox: Optional[tuple[int, int, int, int]] = None
413
+ base_mask = None
414
+ is_sam_mask = False
415
+
416
+ sam_mask = detection.get("sam_mask")
417
+ if sam_mask is not None:
418
+ base_mask = _normalize_mask(sam_mask)
419
+ is_sam_mask = True
420
+ try:
421
+ (
422
+ final_mask,
423
+ fill_color_bgr,
424
+ is_colored_bubble,
425
+ sample_color_bgr,
426
+ text_bbox,
427
+ ) = process_single_bubble(
428
+ base_mask,
429
+ img_gray,
430
+ img_height,
431
+ img_width,
432
+ thresholding_value,
433
+ use_otsu_threshold,
434
+ effective_roi_shrink_px,
435
+ verbose,
436
+ detection.get("bbox"),
437
+ is_sam=True,
438
+ dilation_kernel=dilation_kernel,
439
+ constraint_erosion_kernel=constraint_erosion_kernel,
440
+ min_contour_area=min_contour_area,
441
+ classify_colored=inpaint_colored_bubbles,
442
+ )
443
+ except Exception as e:
444
+ retry_success = False
445
+ if not use_otsu_threshold and base_mask is not None:
446
+ log_message(
447
+ f"Standard cleaning failed for {detection.get('bbox')}, retrying with Otsu...",
448
+ verbose=verbose,
449
+ )
450
+ retry_res = retry_cleaning_with_otsu(
451
+ image,
452
+ {
453
+ "base_mask": base_mask,
454
+ "bbox": detection.get("bbox"),
455
+ "is_sam": True,
456
+ },
457
+ thresholding_value,
458
+ roi_shrink_px,
459
+ processing_scale,
460
+ verbose,
461
+ inpaint_colored_bubbles,
462
+ )
463
+ if retry_res:
464
+ final_mask = retry_res["mask"]
465
+ fill_color_bgr = retry_res["color"]
466
+ sample_color_bgr = retry_res["color"]
467
+ is_colored_bubble = retry_res["is_colored"]
468
+ text_bbox = retry_res["text_bbox"]
469
+ retry_success = True
470
+ log_message(
471
+ f"Otsu retry successful for {detection.get('bbox')}",
472
+ verbose=verbose,
473
+ )
474
+ else:
475
+ log_message(
476
+ f"Otsu retry failed for {detection.get('bbox')}",
477
+ verbose=verbose,
478
+ )
479
+
480
+ if not retry_success:
481
+ error_msg = f"Error processing SAM mask for detection {detection.get('bbox')}: {e}"
482
+ log_message(error_msg, always_print=True)
483
+ continue
484
+ else:
485
+ if "mask_points" not in detection or not detection["mask_points"]:
486
+ log_message(
487
+ f"Skipping detection {detection.get('bbox')}: no mask points",
488
+ verbose=verbose,
489
+ )
490
+ continue
491
+
492
+ try:
493
+ points_list = detection["mask_points"]
494
+ points = np.array(points_list, dtype=np.float32)
495
+
496
+ if len(points.shape) == 3 and points.shape[1] == 1:
497
+ points_int = np.round(points).astype(int)
498
+ elif len(points.shape) == 2 and points.shape[1] == 2:
499
+ points_int = np.round(points).astype(int).reshape((-1, 1, 2))
500
+ else:
501
+ log_message(
502
+ f"Skipping detection {detection.get('bbox')}: invalid mask format",
503
+ verbose=verbose,
504
+ )
505
+ continue
506
+
507
+ yolo_mask = np.zeros((img_height, img_width), dtype=np.uint8)
508
+ cv2.fillPoly(yolo_mask, [points_int], 255)
509
+ base_mask = _normalize_mask(yolo_mask)
510
+
511
+ (
512
+ final_mask,
513
+ fill_color_bgr,
514
+ is_colored_bubble,
515
+ sample_color_bgr,
516
+ text_bbox,
517
+ ) = process_single_bubble(
518
+ base_mask,
519
+ img_gray,
520
+ img_height,
521
+ img_width,
522
+ thresholding_value,
523
+ use_otsu_threshold,
524
+ effective_roi_shrink_px,
525
+ verbose,
526
+ detection.get("bbox"),
527
+ is_sam=False,
528
+ dilation_kernel=dilation_kernel,
529
+ constraint_erosion_kernel=constraint_erosion_kernel,
530
+ min_contour_area=min_contour_area,
531
+ classify_colored=inpaint_colored_bubbles,
532
+ )
533
+
534
+ except Exception as e:
535
+ retry_success = False
536
+ if not use_otsu_threshold and base_mask is not None:
537
+ log_message(
538
+ f"Standard cleaning failed for {detection.get('bbox')}, retrying with Otsu...",
539
+ verbose=verbose,
540
+ )
541
+ retry_res = retry_cleaning_with_otsu(
542
+ image,
543
+ {
544
+ "base_mask": base_mask,
545
+ "bbox": detection.get("bbox"),
546
+ "is_sam": False,
547
+ },
548
+ thresholding_value,
549
+ roi_shrink_px,
550
+ processing_scale,
551
+ verbose,
552
+ inpaint_colored_bubbles,
553
+ )
554
+ if retry_res:
555
+ final_mask = retry_res["mask"]
556
+ fill_color_bgr = retry_res["color"]
557
+ sample_color_bgr = retry_res["color"]
558
+ is_colored_bubble = retry_res["is_colored"]
559
+ text_bbox = retry_res["text_bbox"]
560
+ retry_success = True
561
+ log_message(
562
+ f"Otsu retry successful for {detection.get('bbox')}",
563
+ verbose=verbose,
564
+ )
565
+ else:
566
+ log_message(
567
+ f"Otsu retry failed for {detection.get('bbox')}",
568
+ verbose=verbose,
569
+ )
570
+
571
+ if not retry_success:
572
+ error_msg = f"Error processing YOLO mask for detection {detection.get('bbox')}: {e}"
573
+ log_message(error_msg, always_print=True)
574
+ continue
575
+
576
+ if final_mask is not None and fill_color_bgr is not None:
577
+ processed_bubbles.append(
578
+ {
579
+ "mask": final_mask,
580
+ "base_mask": base_mask,
581
+ "color": (
582
+ sample_color_bgr if sample_color_bgr else fill_color_bgr
583
+ ),
584
+ "bbox": detection.get("bbox"),
585
+ "is_colored": is_colored_bubble,
586
+ "text_bbox": text_bbox,
587
+ "is_sam": is_sam_mask,
588
+ "inpainted": False,
589
+ }
590
+ )
591
+ log_message(
592
+ f"Detection {detection.get('bbox')}: processed successfully",
593
+ verbose=verbose,
594
+ )
595
+
596
+ # Optional Flux inpainting for colored bubbles (text-only mask)
597
+ if inpaint_colored_bubbles:
598
+ colored_bubbles = [
599
+ b for b in processed_bubbles if b.get("is_colored", False)
600
+ ]
601
+ if colored_bubbles and flux_hf_token and not force_cv2_inpainting:
602
+ log_message(
603
+ f"Inpainting {len(colored_bubbles)} colored bubbles with Flux",
604
+ always_print=True,
605
+ )
606
+ pil_working = Image.fromarray(
607
+ cv2.cvtColor(cleaned_image, cv2.COLOR_BGR2RGB)
608
+ )
609
+ base_seed = (
610
+ random.randint(1, 999999)
611
+ if flux_seed == -1
612
+ else max(0, int(flux_seed))
613
+ )
614
+ temp_files = []
615
+ try:
616
+ inpainter = FluxKontextInpainter(
617
+ device=device,
618
+ huggingface_token=flux_hf_token,
619
+ num_inference_steps=int(flux_num_inference_steps),
620
+ residual_diff_threshold=float(flux_residual_diff_threshold),
621
+ )
622
+ for idx, bubble_info in enumerate(colored_bubbles):
623
+ mask_np = bubble_info["mask"]
624
+ mask_bool = mask_np.astype(bool)
625
+ region_seed = base_seed + idx if base_seed > 0 else base_seed
626
+ bbox_tuple = bubble_info.get("bbox")
627
+ ocr_params = {"type": "colored_bubble", "bbox": bbox_tuple}
628
+ try:
629
+ pil_working = inpainter.inpaint_mask(
630
+ pil_working,
631
+ mask_bool,
632
+ seed=region_seed,
633
+ verbose=verbose,
634
+ ocr_params=ocr_params,
635
+ )
636
+ bubble_info["inpainted"] = True
637
+ # Re-sample background brightness after inpaint for accurate text contrast
638
+ cv_after = cv2.cvtColor(
639
+ np.array(pil_working.convert("RGB")), cv2.COLOR_RGB2BGR
640
+ )
641
+ masked_after = cv_after[mask_bool]
642
+ if masked_after.size > 0:
643
+ mean_val = int(np.clip(np.mean(masked_after), 0, 255))
644
+ bubble_info["color"] = (mean_val, mean_val, mean_val)
645
+ except Exception as e:
646
+ log_message(
647
+ f"Flux inpainting failed for bubble {bbox_tuple}: {e}; falling back to standard fill",
648
+ always_print=True,
649
+ )
650
+ continue
651
+
652
+ # Save intermediate result to disk to free memory when multiple regions
653
+ if idx < len(colored_bubbles) - 1:
654
+ temp_file = None
655
+ try:
656
+ temp_fd, temp_file = tempfile.mkstemp(suffix=".png")
657
+ os.close(temp_fd)
658
+ pil_working.save(temp_file, format="PNG")
659
+ log_message(
660
+ "Saved intermediate inpainting result to disk",
661
+ verbose=verbose,
662
+ )
663
+ temp_files.append(temp_file)
664
+ with Image.open(temp_file) as img_tmp:
665
+ img_tmp.load()
666
+ pil_working = img_tmp.copy()
667
+ gc.collect()
668
+ except Exception as e:
669
+ log_message(
670
+ f"Warning: Failed to save intermediate inpainting result: {e}",
671
+ verbose=verbose,
672
+ )
673
+ if temp_file and temp_file in temp_files:
674
+ temp_files.remove(temp_file)
675
+ # fall through with in-memory image
676
+
677
+ cleaned_image = cv2.cvtColor(
678
+ np.array(pil_working.convert("RGB")), cv2.COLOR_RGB2BGR
679
+ )
680
+ except Exception as e:
681
+ log_message(
682
+ f"Flux inpainting aborted; falling back to standard fill: {e}",
683
+ always_print=True,
684
+ )
685
+ finally:
686
+ for temp_file in temp_files:
687
+ if temp_file and os.path.exists(temp_file):
688
+ try:
689
+ os.remove(temp_file)
690
+ except Exception:
691
+ pass
692
+ elif colored_bubbles:
693
+ reason = (
694
+ "forced CV2 inpainting"
695
+ if force_cv2_inpainting
696
+ else "missing Hugging Face token"
697
+ )
698
+ log_message(
699
+ f"Colored bubbles detected but Flux inpainting skipped ({reason}); "
700
+ "falling back to standard fill",
701
+ always_print=True,
702
+ )
703
+
704
+ # Group masks by color for efficient batch processing (skip already inpainted regions)
705
+ if processed_bubbles:
706
+ color_groups = {}
707
+ for bubble_info in processed_bubbles:
708
+ if bubble_info.get("inpainted", False):
709
+ continue
710
+ color_key = bubble_info["color"]
711
+ if color_key not in color_groups:
712
+ color_groups[color_key] = []
713
+ color_groups[color_key].append(bubble_info["mask"])
714
+
715
+ for color_bgr, masks in color_groups.items():
716
+ combined_mask = np.bitwise_or.reduce(masks)
717
+
718
+ if cleaned_image.shape[2] == 4:
719
+ cleaned_image[combined_mask == 255, :3] = (
720
+ color_bgr # Preserve alpha channel
721
+ )
722
+ else:
723
+ cleaned_image[combined_mask == 255] = color_bgr
724
+
725
+ log_message(
726
+ f"Cleaned {len(processed_bubbles)} speech bubbles", always_print=True
727
+ )
728
+ return cleaned_image, processed_bubbles
729
+ except IOError as e:
730
+ raise ImageProcessingError(f"Error loading image {image_input}: {str(e)}")
731
+ except Exception as e:
732
+ raise CleaningError(f"Error cleaning speech bubbles: {str(e)}")
733
+
734
+
735
+ def retry_cleaning_with_otsu(
736
+ image_bgr: np.ndarray,
737
+ bubble_info: dict,
738
+ thresholding_value: int,
739
+ roi_shrink_px: int,
740
+ processing_scale: float = 1.0,
741
+ verbose: bool = False,
742
+ classify_colored: bool = False,
743
+ ) -> Optional[dict]:
744
+ """
745
+ Retry cleaning for a single bubble using Otsu thresholding.
746
+
747
+ Returns a bubble-info dict compatible with clean_speech_bubbles output,
748
+ or None if retry fails.
749
+ """
750
+ base_mask = bubble_info.get("base_mask")
751
+ if base_mask is None:
752
+ log_message(
753
+ f"Otsu retry skipped for {bubble_info.get('bbox')}: missing base_mask",
754
+ verbose=verbose,
755
+ )
756
+ return None
757
+
758
+ try:
759
+ if len(image_bgr.shape) == 3 and image_bgr.shape[2] == 4:
760
+ img_gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGRA2GRAY)
761
+ else:
762
+ img_gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
763
+ except Exception as e:
764
+ log_message(
765
+ f"Otsu retry failed to convert image to grayscale: {e}",
766
+ always_print=True,
767
+ )
768
+ return None
769
+
770
+ img_height, img_width = img_gray.shape[:2]
771
+
772
+ effective_roi_shrink_px = float(
773
+ scale_scalar(
774
+ roi_shrink_px,
775
+ processing_scale,
776
+ minimum=0.0,
777
+ maximum=64.0,
778
+ )
779
+ )
780
+ dilation_kernel = cv2.getStructuringElement(
781
+ cv2.MORPH_ELLIPSE, scale_kernel(DILATION_KERNEL_SIZE, processing_scale)
782
+ )
783
+ constraint_erosion_kernel = cv2.getStructuringElement(
784
+ cv2.MORPH_ELLIPSE, scale_kernel(EROSION_KERNEL_SIZE, processing_scale)
785
+ )
786
+ min_contour_area = scale_area(
787
+ MIN_CONTOUR_AREA,
788
+ processing_scale,
789
+ minimum=MIN_CONTOUR_AREA,
790
+ maximum=5000,
791
+ )
792
+
793
+ try:
794
+ result = process_single_bubble(
795
+ base_mask,
796
+ img_gray,
797
+ img_height,
798
+ img_width,
799
+ thresholding_value,
800
+ True, # force Otsu
801
+ effective_roi_shrink_px,
802
+ verbose,
803
+ bubble_info.get("bbox"),
804
+ bubble_info.get("is_sam", False),
805
+ dilation_kernel=dilation_kernel,
806
+ constraint_erosion_kernel=constraint_erosion_kernel,
807
+ min_contour_area=min_contour_area,
808
+ classify_colored=classify_colored,
809
+ )
810
+ except CleaningError as e:
811
+ log_message(
812
+ f"Otsu retry cleaning failed for {bubble_info.get('bbox')}: {e}",
813
+ always_print=True,
814
+ )
815
+ return None
816
+ except Exception as e:
817
+ log_message(
818
+ f"Otsu retry cleaning unexpected error for {bubble_info.get('bbox')}: {e}",
819
+ always_print=True,
820
+ )
821
+ return None
822
+
823
+ if not result:
824
+ return None
825
+
826
+ (
827
+ final_mask,
828
+ fill_color_bgr,
829
+ is_colored_bubble,
830
+ sample_color_bgr,
831
+ text_bbox,
832
+ ) = result
833
+
834
+ bubble_color = sample_color_bgr if sample_color_bgr else fill_color_bgr
835
+
836
+ log_message(
837
+ f"Otsu retry succeeded for {bubble_info.get('bbox')}",
838
+ verbose=verbose,
839
+ )
840
+
841
+ return {
842
+ "mask": final_mask,
843
+ "base_mask": _normalize_mask(base_mask),
844
+ "color": bubble_color,
845
+ "bbox": bubble_info.get("bbox"),
846
+ "is_colored": is_colored_bubble,
847
+ "text_bbox": text_bbox,
848
+ "is_sam": bubble_info.get("is_sam", False),
849
+ }
core/image/detection.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from core.caching import get_cache
10
+ from core.ml.model_manager import ModelType, get_model_manager
11
+ from utils.exceptions import ImageProcessingError, ModelError
12
+ from utils.logging import log_message
13
+
14
+ # Detection Parameters
15
+ IOA_THRESHOLD = 0.50 # 50% IoA threshold for conjoined bubble detection
16
+ SAM_MASK_THRESHOLD = 0.5 # SAM2 mask binarization threshold
17
+ IOA_OVERLAP_THRESHOLD = 0.5 # IoA threshold for general overlap detection between boxes
18
+ IOU_DUPLICATE_THRESHOLD = 0.7 # IoU threshold for duplicate primary detection
19
+
20
+
21
+ def _box_contains(inner, outer) -> bool:
22
+ """Return True if inner box is fully contained in outer box."""
23
+ ix0, iy0, ix1, iy1 = inner
24
+ ox0, oy0, ox1, oy1 = outer
25
+ return ix0 >= ox0 and iy0 >= oy0 and ix1 <= ox1 and iy1 <= oy1
26
+
27
+
28
+ def _expand_boxes_with_osb_text(
29
+ image_cv,
30
+ image_pil,
31
+ primary_boxes: torch.Tensor,
32
+ cache,
33
+ model_manager,
34
+ device,
35
+ confidence: float,
36
+ hf_token: str,
37
+ verbose: bool,
38
+ ):
39
+ """Expand speech-bubble boxes to fully contain detected OSB text boxes."""
40
+ if primary_boxes is None or len(primary_boxes) == 0:
41
+ return primary_boxes
42
+
43
+ try:
44
+ model_path = str(model_manager.model_paths[ModelType.YOLO_OSBTEXT])
45
+ cache_key = cache.get_yolo_cache_key(image_pil, model_path, confidence)
46
+ cached = cache.get_yolo_detection(cache_key)
47
+
48
+ if cached is not None:
49
+ _, osb_boxes, _ = cached
50
+ else:
51
+ osb_model = model_manager.load_yolo_osbtext(token=hf_token)
52
+ osb_results = osb_model(
53
+ image_cv, conf=confidence, device=device, verbose=False
54
+ )[0]
55
+ osb_boxes = (
56
+ osb_results.boxes.xyxy
57
+ if osb_results.boxes is not None
58
+ else torch.tensor([])
59
+ )
60
+ osb_confs = (
61
+ osb_results.boxes.conf
62
+ if osb_results.boxes is not None
63
+ else torch.tensor([])
64
+ )
65
+ cache.set_yolo_detection(cache_key, (osb_results, osb_boxes, osb_confs))
66
+
67
+ if osb_boxes is None or len(osb_boxes) == 0:
68
+ return primary_boxes
69
+
70
+ pb_np = primary_boxes.detach().cpu().numpy()
71
+ osb_np = osb_boxes.detach().cpu().numpy()
72
+
73
+ for t_box in osb_np:
74
+ tx0, ty0, tx1, ty1 = t_box
75
+ best_idx = None
76
+ best_intersection = 0.0
77
+
78
+ for i, b_box in enumerate(pb_np):
79
+ bx0, by0, bx1, by1 = b_box
80
+ inter_x0 = max(bx0, tx0)
81
+ inter_y0 = max(by0, ty0)
82
+ inter_x1 = min(bx1, tx1)
83
+ inter_y1 = min(by1, ty1)
84
+ inter_w = max(0.0, inter_x1 - inter_x0)
85
+ inter_h = max(0.0, inter_y1 - inter_y0)
86
+ intersection = inter_w * inter_h
87
+ if intersection > best_intersection:
88
+ best_intersection = intersection
89
+ best_idx = i
90
+
91
+ if best_idx is None or best_intersection <= 0.0:
92
+ continue
93
+
94
+ if _box_contains(t_box, pb_np[best_idx]):
95
+ continue
96
+
97
+ bx0, by0, bx1, by1 = pb_np[best_idx]
98
+ pb_np[best_idx] = [
99
+ min(bx0, tx0),
100
+ min(by0, ty0),
101
+ max(bx1, tx1),
102
+ max(by1, ty1),
103
+ ]
104
+
105
+ return torch.tensor(
106
+ pb_np, device=primary_boxes.device, dtype=primary_boxes.dtype
107
+ )
108
+ except Exception as e:
109
+ log_message(f"OSB text verification skipped: {e}", verbose=verbose)
110
+ return primary_boxes
111
+
112
+
113
+ def _calculate_ioa(box_inner, box_outer):
114
+ """Calculate Intersection over Area (IoA) for two bounding boxes.
115
+
116
+ IoA = intersection_area / area_of_inner_box
117
+
118
+ Args:
119
+ box_inner: Tuple or list of (x0, y0, x1, y1) for the inner box
120
+ box_outer: Tuple or list of (x0, y0, x1, y1) for the outer box
121
+
122
+ Returns:
123
+ float: IoA value between 0 and 1
124
+ """
125
+ x_inner_min, y_inner_min, x_inner_max, y_inner_max = box_inner
126
+ x_outer_min, y_outer_min, x_outer_max, y_outer_max = box_outer
127
+
128
+ inter_x_min = max(x_inner_min, x_outer_min)
129
+ inter_y_min = max(y_inner_min, y_outer_min)
130
+ inter_x_max = min(x_inner_max, x_outer_max)
131
+ inter_y_max = min(y_inner_max, y_outer_max)
132
+
133
+ inter_w = max(0, inter_x_max - inter_x_min)
134
+ inter_h = max(0, inter_y_max - inter_y_min)
135
+ intersection = inter_w * inter_h
136
+
137
+ area_inner = (x_inner_max - x_inner_min) * (y_inner_max - y_inner_min)
138
+ return intersection / area_inner if area_inner > 0 else 0.0
139
+
140
+
141
+ def _calculate_iou(box_a, box_b):
142
+ """Calculate Intersection over Union (IoU) for two bounding boxes.
143
+
144
+ IoU = intersection_area / union_area
145
+
146
+ Args:
147
+ box_a: Tuple of (x0, y0, x1, y1)
148
+ box_b: Tuple of (x0, y0, x1, y1)
149
+
150
+ Returns:
151
+ float: IoU value between 0 and 1
152
+ """
153
+ inter_x_min = max(box_a[0], box_b[0])
154
+ inter_y_min = max(box_a[1], box_b[1])
155
+ inter_x_max = min(box_a[2], box_b[2])
156
+ inter_y_max = min(box_a[3], box_b[3])
157
+
158
+ inter_w = max(0, inter_x_max - inter_x_min)
159
+ inter_h = max(0, inter_y_max - inter_y_min)
160
+ intersection = inter_w * inter_h
161
+
162
+ area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
163
+ area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
164
+ union = area_a + area_b - intersection
165
+
166
+ return intersection / union if union > 0 else 0.0
167
+
168
+
169
+ def _deduplicate_primary_boxes(
170
+ boxes: torch.Tensor, confidences: torch.Tensor, threshold: float
171
+ ) -> Tuple[torch.Tensor, List[int]]:
172
+ """Remove duplicate primary detections using IoU-based NMS.
173
+
174
+ When two boxes have IoU > threshold, keeps the one with higher confidence.
175
+
176
+ Args:
177
+ boxes: Tensor of bounding boxes (N, 4)
178
+ confidences: Tensor of confidence scores (N,)
179
+ threshold: IoU threshold above which boxes are considered duplicates
180
+
181
+ Returns:
182
+ Tuple of (deduplicated boxes tensor, indices of kept boxes)
183
+ """
184
+ if len(boxes) <= 1:
185
+ return boxes, list(range(len(boxes)))
186
+
187
+ boxes_list = boxes.tolist()
188
+ confs_list = confidences.tolist()
189
+ n = len(boxes_list)
190
+
191
+ # Sort by confidence (descending)
192
+ indices = sorted(range(n), key=lambda i: confs_list[i], reverse=True)
193
+ keep = []
194
+
195
+ for i in indices:
196
+ is_duplicate = False
197
+ for k in keep:
198
+ if _calculate_iou(boxes_list[i], boxes_list[k]) > threshold:
199
+ is_duplicate = True
200
+ break
201
+ if not is_duplicate:
202
+ keep.append(i)
203
+
204
+ return boxes[keep], keep
205
+
206
+
207
+ def _categorize_detections(primary_boxes, secondary_boxes, ioa_threshold=IOA_THRESHOLD):
208
+ """Categorize detections into simple and conjoined bubbles.
209
+
210
+ Args:
211
+ primary_boxes: Tensor of primary YOLO detection boxes (N, 4)
212
+ secondary_boxes: Tensor of secondary YOLO detection boxes (M, 4)
213
+ ioa_threshold: Threshold for determining if a secondary box is contained in a primary box
214
+
215
+ Returns:
216
+ tuple: (conjoined_indices, simple_indices)
217
+ - conjoined_indices: List of tuples (primary_idx, [secondary_indices])
218
+ - simple_indices: List of primary indices that are simple bubbles
219
+ """
220
+ # Handle cases where one bubble is detected on the page and is conjoined
221
+ if primary_boxes.ndim == 1 and primary_boxes.numel() == 4:
222
+ primary_boxes = primary_boxes.unsqueeze(0)
223
+ if secondary_boxes.ndim == 1 and secondary_boxes.numel() == 4:
224
+ secondary_boxes = secondary_boxes.unsqueeze(0)
225
+
226
+ conjoined_indices = []
227
+ processed_secondary_indices = set()
228
+
229
+ for i, p_box in enumerate(primary_boxes):
230
+ contained_indices = []
231
+ for j, s_box in enumerate(secondary_boxes):
232
+ if j in processed_secondary_indices:
233
+ continue
234
+ ioa = _calculate_ioa(s_box.tolist(), p_box.tolist())
235
+ if ioa > ioa_threshold:
236
+ contained_indices.append(j)
237
+
238
+ if len(contained_indices) >= 2:
239
+ conjoined_indices.append((i, contained_indices))
240
+ processed_secondary_indices.update(contained_indices)
241
+
242
+ primary_simple_indices = []
243
+ conjoined_primary_indices = {c[0] for c in conjoined_indices}
244
+
245
+ for i in range(len(primary_boxes)):
246
+ if i in conjoined_primary_indices:
247
+ continue
248
+
249
+ # Check for duplication against processed secondary bubbles
250
+ is_duplicate = False
251
+ p_box_list = primary_boxes[i].tolist()
252
+
253
+ for s_idx in processed_secondary_indices:
254
+ s_box_list = secondary_boxes[s_idx].tolist()
255
+ if _calculate_ioa(s_box_list, p_box_list) > ioa_threshold:
256
+ is_duplicate = True
257
+ break
258
+
259
+ if not is_duplicate:
260
+ primary_simple_indices.append(i)
261
+
262
+ return conjoined_indices, primary_simple_indices
263
+
264
+
265
+ def _process_simple_bubbles(
266
+ image, primary_boxes, simple_indices, processor, sam_model, device
267
+ ):
268
+ """Process simple (non-conjoined) speech bubbles using SAM2.
269
+
270
+ Args:
271
+ image: PIL Image
272
+ primary_boxes: Tensor of primary YOLO detection boxes
273
+ simple_indices: List of indices for simple bubbles
274
+ processor: SAM2 processor
275
+ sam_model: SAM2 model
276
+ device: PyTorch device
277
+
278
+ Returns:
279
+ list: List of numpy boolean masks for simple bubbles
280
+ """
281
+ if not simple_indices:
282
+ return []
283
+
284
+ simple_boxes_to_sam = primary_boxes[simple_indices].unsqueeze(0).cpu()
285
+ inputs = processor(image, input_boxes=simple_boxes_to_sam, return_tensors="pt")
286
+
287
+ # Cast floating point tensors to model's dtype before moving to device
288
+ for key in inputs:
289
+ if isinstance(inputs[key], torch.Tensor) and inputs[key].is_floating_point():
290
+ inputs[key] = inputs[key].to(sam_model.dtype)
291
+
292
+ inputs = inputs.to(device)
293
+
294
+ with torch.no_grad():
295
+ outputs = sam_model(multimask_output=False, **inputs)
296
+
297
+ masks_tensor = processor.post_process_masks(
298
+ outputs.pred_masks, inputs["original_sizes"]
299
+ )[0][:, 0]
300
+ simple_masks_np = (masks_tensor > SAM_MASK_THRESHOLD).cpu().numpy()
301
+ return [mask for mask in simple_masks_np]
302
+
303
+
304
+ def _fallback_to_yolo_mask(primary_results, i, mask_type="points"):
305
+ """Extract YOLO mask as fallback when SAM2 fails.
306
+
307
+ Args:
308
+ primary_results: YOLO detection results
309
+ i: Detection index
310
+ mask_type: Type of mask to extract ("points" or "binary")
311
+
312
+ Returns:
313
+ Mask data or None if extraction fails
314
+ """
315
+ if getattr(primary_results, "masks", None) is None:
316
+ return None
317
+
318
+ try:
319
+ masks = primary_results.masks
320
+ if len(masks) <= i:
321
+ return None
322
+
323
+ if mask_type == "points":
324
+ mask_points = masks[i].xy[0]
325
+ return (
326
+ mask_points.tolist() if hasattr(mask_points, "tolist") else mask_points
327
+ )
328
+ elif mask_type == "binary":
329
+ mask_tensor = masks.data[i]
330
+ orig_h, orig_w = primary_results.orig_shape
331
+ mask_resized = torch.nn.functional.interpolate(
332
+ mask_tensor.float().unsqueeze(0).unsqueeze(0),
333
+ size=(orig_h, orig_w),
334
+ mode="bilinear",
335
+ align_corners=False,
336
+ ).squeeze()
337
+ binary_mask = (mask_resized > SAM_MASK_THRESHOLD).cpu().numpy()
338
+ return binary_mask.astype(np.uint8) * 255
339
+ else:
340
+ return None
341
+
342
+ except (IndexError, AttributeError) as e:
343
+ log_message(
344
+ f"Could not extract YOLO mask for detection {i}: {e}",
345
+ always_print=True,
346
+ )
347
+ return None
348
+
349
+
350
+ def detect_speech_bubbles(
351
+ image_path: Path,
352
+ model_path,
353
+ confidence=0.6,
354
+ verbose=False,
355
+ device=None,
356
+ use_sam2: bool = True,
357
+ conjoined_detection: bool = True,
358
+ conjoined_confidence=0.35,
359
+ image_override: Optional[Image.Image] = None,
360
+ osb_enabled: bool = False,
361
+ osb_text_verification: bool = False,
362
+ osb_text_hf_token: str = "",
363
+ ):
364
+ """Detect speech bubbles using dual YOLO models and SAM2.
365
+
366
+ For conjoined bubbles detected by the secondary model, uses the inner bounding boxes
367
+ directly and processes each as a separate simple bubble through SAM2.
368
+
369
+ Args:
370
+ image_path (Path): Path to the input image
371
+ model_path (str): Path to the primary YOLO segmentation model
372
+ confidence (float): Confidence threshold for primary YOLO model detections
373
+ verbose (bool): Whether to show detailed processing information
374
+ device (torch.device, optional): The device to run the model on. Autodetects if None.
375
+ use_sam2 (bool): Whether to use SAM2.1 for enhanced segmentation
376
+ conjoined_detection (bool): Whether to enable conjoined bubble detection using secondary YOLO model
377
+ conjoined_confidence (float): Confidence threshold for secondary YOLO model (conjoined bubble detection)
378
+ osb_text_verification (bool): When True, expand bubble boxes to fully cover OSB text detections
379
+ osb_text_hf_token (str): Optional token for gated OSB text model downloads
380
+
381
+ Returns:
382
+ tuple[list, list]: (speech bubble detections, text_free boxes from secondary model)
383
+ """
384
+ detections = []
385
+ text_free_boxes: List[List[float]] = []
386
+
387
+ _device = (
388
+ device
389
+ if device is not None
390
+ else torch.device(
391
+ "cuda"
392
+ if torch.cuda.is_available()
393
+ else "mps" if torch.backends.mps.is_available() else "cpu"
394
+ )
395
+ )
396
+ try:
397
+ if image_override is not None:
398
+ image_pil = (
399
+ image_override
400
+ if image_override.mode == "RGB"
401
+ else image_override.convert("RGB")
402
+ )
403
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
404
+ else:
405
+ image_cv = cv2.imread(str(image_path))
406
+ if image_cv is None:
407
+ raise ImageProcessingError(f"Could not read image at {image_path}")
408
+ image_pil = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
409
+ log_message(
410
+ f"Processing image: {image_path.name} ({image_cv.shape[1]}x{image_cv.shape[0]})",
411
+ verbose=verbose,
412
+ )
413
+ except Exception as e:
414
+ raise ImageProcessingError(f"Error loading image: {e}")
415
+
416
+ model_manager = get_model_manager()
417
+ cache = get_cache()
418
+ try:
419
+ primary_model = model_manager.load_yolo_speech_bubble(model_path)
420
+ log_message(f"Loaded primary YOLO model: {model_path}", verbose=verbose)
421
+ except Exception as e:
422
+ raise ModelError(f"Error loading primary model: {e}")
423
+
424
+ yolo_cache_key = cache.get_yolo_cache_key(image_pil, model_path, confidence)
425
+ cached_yolo = cache.get_yolo_detection(yolo_cache_key)
426
+
427
+ if cached_yolo is not None:
428
+ log_message("Using cached YOLO detections", verbose=verbose)
429
+ primary_results, primary_boxes = cached_yolo
430
+ else:
431
+ primary_results = primary_model(
432
+ image_cv, conf=confidence, device=_device, verbose=False
433
+ )[0]
434
+ primary_boxes = (
435
+ primary_results.boxes.xyxy
436
+ if primary_results.boxes is not None
437
+ else torch.tensor([])
438
+ )
439
+ cache.set_yolo_detection(yolo_cache_key, (primary_results, primary_boxes))
440
+
441
+ # Remove duplicate primary detections using IoU-based NMS
442
+ if len(primary_boxes) > 1:
443
+ original_count = len(primary_boxes)
444
+ primary_boxes, _ = _deduplicate_primary_boxes(
445
+ primary_boxes, primary_results.boxes.conf, IOU_DUPLICATE_THRESHOLD
446
+ )
447
+ if len(primary_boxes) < original_count:
448
+ log_message(
449
+ f"Removed {original_count - len(primary_boxes)} duplicate detections",
450
+ verbose=verbose,
451
+ )
452
+
453
+ if len(primary_boxes) == 0:
454
+ log_message("No detections found", verbose=verbose)
455
+ return detections, text_free_boxes
456
+
457
+ log_message(
458
+ f"Detected {len(primary_boxes)} speech bubbles with YOLO", always_print=True
459
+ )
460
+
461
+ secondary_boxes = torch.tensor([])
462
+ if use_sam2:
463
+ try:
464
+ secondary_model = model_manager.load_yolo_conjoined_bubble()
465
+ log_message(
466
+ "Loaded secondary YOLO model for conjoined/fallback detection",
467
+ verbose=verbose,
468
+ )
469
+
470
+ secondary_results = secondary_model(
471
+ image_cv, conf=conjoined_confidence, device=_device, verbose=False
472
+ )[0]
473
+ secondary_boxes = (
474
+ secondary_results.boxes.xyxy
475
+ if secondary_results.boxes is not None
476
+ else torch.tensor([])
477
+ )
478
+
479
+ # Fallback: Add bubbles detected by secondary model but missed by primary
480
+ if len(secondary_boxes) > 0 and hasattr(secondary_model, "names"):
481
+ text_bubble_id = None
482
+ text_free_id = None
483
+ for cid, cname in secondary_model.names.items():
484
+ if cname == "text_bubble":
485
+ text_bubble_id = cid
486
+ elif cname == "text_free":
487
+ text_free_id = cid
488
+
489
+ secondary_cls = secondary_results.boxes.cls
490
+
491
+ # Collect text_free boxes regardless of OSB setting
492
+ if text_free_id is not None:
493
+ for i, s_box in enumerate(secondary_boxes):
494
+ if int(secondary_cls[i]) == text_free_id:
495
+ text_free_boxes.append(s_box.tolist())
496
+
497
+ if text_bubble_id is not None:
498
+ new_boxes = []
499
+ primary_boxes_list = (
500
+ primary_boxes.tolist() if len(primary_boxes) > 0 else []
501
+ )
502
+
503
+ for i, s_box in enumerate(secondary_boxes):
504
+ if int(secondary_cls[i]) != text_bubble_id:
505
+ continue
506
+
507
+ s_box_list = s_box.tolist()
508
+
509
+ is_covered = False
510
+
511
+ for p_box_list in primary_boxes_list:
512
+ ioa_s_in_p = _calculate_ioa(s_box_list, p_box_list)
513
+ ioa_p_in_s = _calculate_ioa(p_box_list, s_box_list)
514
+
515
+ if (
516
+ ioa_s_in_p > IOA_OVERLAP_THRESHOLD
517
+ or ioa_p_in_s > IOA_OVERLAP_THRESHOLD
518
+ ):
519
+ is_covered = True
520
+ break
521
+
522
+ if not is_covered:
523
+ new_boxes.append(s_box)
524
+
525
+ if new_boxes:
526
+ log_message(
527
+ f"Found {len(new_boxes)} missed bubbles from secondary model",
528
+ always_print=True,
529
+ )
530
+ new_boxes_tensor = torch.stack(new_boxes)
531
+ if len(primary_boxes) > 0:
532
+ primary_boxes = torch.cat(
533
+ (primary_boxes, new_boxes_tensor), dim=0
534
+ )
535
+ else:
536
+ primary_boxes = new_boxes_tensor
537
+
538
+ # Remove text_free detections (route to OSB if enabled, discard otherwise)
539
+ if text_free_boxes and len(primary_boxes) > 0:
540
+ indices_to_remove = []
541
+ primary_boxes_list = primary_boxes.tolist()
542
+
543
+ for i, p_box in enumerate(primary_boxes_list):
544
+ overlaps_text_free = False
545
+ for tf_box in text_free_boxes:
546
+ if (
547
+ _calculate_ioa(p_box, tf_box) > IOA_OVERLAP_THRESHOLD
548
+ or _calculate_ioa(tf_box, p_box) > IOA_OVERLAP_THRESHOLD
549
+ ):
550
+ overlaps_text_free = True
551
+ break
552
+
553
+ if overlaps_text_free:
554
+ indices_to_remove.append(i)
555
+
556
+ if indices_to_remove:
557
+ action = (
558
+ "routing to OSB pipeline"
559
+ if osb_enabled
560
+ else "discarding (OSB disabled)"
561
+ )
562
+ log_message(
563
+ f"Removing {len(indices_to_remove)} bubbles marked text_free ({action})",
564
+ always_print=True,
565
+ )
566
+ keep_indices = [
567
+ i
568
+ for i in range(len(primary_boxes))
569
+ if i not in indices_to_remove
570
+ ]
571
+ if keep_indices:
572
+ primary_boxes = primary_boxes[keep_indices]
573
+ else:
574
+ primary_boxes = torch.tensor([])
575
+
576
+ except Exception as e:
577
+ log_message(
578
+ f"Warning: Could not load/run secondary YOLO model: {e}. "
579
+ "Proceeding without conjoined/fallback detection.",
580
+ verbose=verbose,
581
+ )
582
+ secondary_boxes = torch.tensor([])
583
+
584
+ if osb_text_verification and len(primary_boxes) > 0:
585
+ primary_boxes = _expand_boxes_with_osb_text(
586
+ image_cv,
587
+ image_pil,
588
+ primary_boxes,
589
+ cache,
590
+ model_manager,
591
+ _device,
592
+ confidence,
593
+ osb_text_hf_token,
594
+ verbose,
595
+ )
596
+
597
+ if not use_sam2:
598
+ log_message("SAM2 disabled, using YOLO segmentation masks", verbose=verbose)
599
+ for i, box in enumerate(primary_boxes):
600
+ x0_f, y0_f, x1_f, y1_f = box.tolist()
601
+ conf = float(primary_results.boxes.conf[i])
602
+ cls_id = int(primary_results.boxes.cls[i])
603
+ cls_name = primary_model.names[cls_id]
604
+
605
+ detection = {
606
+ "bbox": (
607
+ int(round(x0_f)),
608
+ int(round(y0_f)),
609
+ int(round(x1_f)),
610
+ int(round(y1_f)),
611
+ ),
612
+ "confidence": conf,
613
+ "class": cls_name,
614
+ }
615
+
616
+ detection["sam_mask"] = _fallback_to_yolo_mask(primary_results, i, "binary")
617
+
618
+ detections.append(detection)
619
+ return detections, text_free_boxes
620
+
621
+ conjoined_indices = []
622
+ simple_indices = list(range(len(primary_boxes)))
623
+ try:
624
+ log_message("Applying SAM2.1 segmentation refinement", verbose=verbose)
625
+ sam_cache_key = cache.get_sam_cache_key(
626
+ image_pil,
627
+ primary_boxes,
628
+ use_sam2,
629
+ conjoined_detection,
630
+ conjoined_confidence,
631
+ )
632
+ cached_sam = cache.get_sam_masks(sam_cache_key)
633
+
634
+ if cached_sam is not None:
635
+ log_message("Using cached SAM masks", verbose=verbose)
636
+ detections = cached_sam
637
+ return detections, text_free_boxes
638
+
639
+ processor, sam_model = model_manager.load_sam2()
640
+ if len(secondary_boxes) > 0 and conjoined_detection:
641
+ log_message(
642
+ "Categorizing detections (simple vs conjoined)...", verbose=verbose
643
+ )
644
+ conjoined_indices, simple_indices = _categorize_detections(
645
+ primary_boxes, secondary_boxes, ioa_threshold=IOA_THRESHOLD
646
+ )
647
+ log_message(
648
+ f"Found {len(simple_indices)} simple bubbles and {len(conjoined_indices)} conjoined groups",
649
+ verbose=verbose,
650
+ )
651
+ if len(conjoined_indices) > 0:
652
+ log_message(
653
+ f"Detected {len(conjoined_indices)} conjoined speech bubbles with second YOLO",
654
+ always_print=True,
655
+ )
656
+ else:
657
+ conjoined_indices = []
658
+ simple_indices = list(range(len(primary_boxes)))
659
+ log_message(
660
+ f"No secondary detections, processing all {len(simple_indices)} as simple bubbles",
661
+ verbose=verbose,
662
+ )
663
+ boxes_to_process = []
664
+
665
+ for idx in simple_indices:
666
+ boxes_to_process.append(primary_boxes[idx])
667
+
668
+ for _, s_indices in conjoined_indices:
669
+ for s_idx in s_indices:
670
+ boxes_to_process.append(secondary_boxes[s_idx])
671
+
672
+ if boxes_to_process:
673
+ all_boxes_tensor = torch.stack(boxes_to_process)
674
+ all_masks = _process_simple_bubbles(
675
+ image_pil,
676
+ all_boxes_tensor,
677
+ list(range(len(boxes_to_process))),
678
+ processor,
679
+ sam_model,
680
+ _device,
681
+ )
682
+ all_boxes = boxes_to_process
683
+
684
+ total_boxes = len(boxes_to_process)
685
+ simple_count = len(simple_indices)
686
+ conjoined_count = sum(len(s_indices) for _, s_indices in conjoined_indices)
687
+
688
+ if conjoined_indices:
689
+ log_message(
690
+ f"Processing {total_boxes} bubbles ({simple_count} simple + "
691
+ f"{conjoined_count} from conjoined groups)...",
692
+ verbose=verbose,
693
+ )
694
+ else:
695
+ log_message(
696
+ f"Processing {total_boxes} simple bubbles...", verbose=verbose
697
+ )
698
+ else:
699
+ all_masks = []
700
+ all_boxes = []
701
+
702
+ log_message(f"Refined {len(all_masks)} masks with SAM2", always_print=True)
703
+ log_message(f"Total masks generated: {len(all_masks)}", verbose=verbose)
704
+ img_h, img_w = image_cv.shape[:2]
705
+ for i, (mask, box) in enumerate(zip(all_masks, all_boxes)):
706
+ x0_f, y0_f, x1_f, y1_f = box.tolist()
707
+
708
+ x0 = int(np.floor(max(0, min(x0_f, img_w))))
709
+ y0 = int(np.floor(max(0, min(y0_f, img_h))))
710
+ x1 = int(np.ceil(max(0, min(x1_f, img_w))))
711
+ y1 = int(np.ceil(max(0, min(y1_f, img_h))))
712
+
713
+ if x1 <= x0 or y1 <= y0:
714
+ continue
715
+ bbox_mask = np.zeros((img_h, img_w), dtype=bool)
716
+ bbox_mask[y0:y1, x0:x1] = True
717
+ clipped_mask = np.logical_and(mask, bbox_mask)
718
+
719
+ detection = {
720
+ "bbox": (x0, y0, x1, y1),
721
+ "confidence": 1.0, # Masks from SAM are high confidence
722
+ "class": "speech bubble",
723
+ "sam_mask": clipped_mask.astype(np.uint8) * 255,
724
+ }
725
+ detections.append(detection)
726
+
727
+ log_message("SAM2.1 segmentation completed successfully", verbose=verbose)
728
+ cache.set_sam_masks(sam_cache_key, detections)
729
+
730
+ except Exception as e:
731
+ log_message(
732
+ f"SAM2.1 segmentation failed: {e}. Falling back to YOLO segmentation masks.",
733
+ always_print=True,
734
+ )
735
+ detections = []
736
+
737
+ # Process primary boxes first in fallback to avoid duplicating secondary splits
738
+ fallback_boxes = []
739
+ if conjoined_detection and len(secondary_boxes) > 0 and conjoined_indices:
740
+ for idx in simple_indices:
741
+ fallback_boxes.append(("primary", idx, primary_boxes[idx]))
742
+ for _, s_indices in conjoined_indices:
743
+ for s_idx in s_indices:
744
+ fallback_boxes.append(("secondary", s_idx, secondary_boxes[s_idx]))
745
+ elif len(primary_boxes) > 0:
746
+ for idx in range(len(primary_boxes)):
747
+ fallback_boxes.append(("primary", idx, primary_boxes[idx]))
748
+
749
+ img_h, img_w = image_cv.shape[:2]
750
+ primary_fallback_count = 0
751
+ secondary_fallback_count = 0
752
+
753
+ for _, (source, orig_idx, box) in enumerate(fallback_boxes):
754
+ x0_f, y0_f, x1_f, y1_f = box.tolist()
755
+
756
+ if source == "primary" and len(primary_results.boxes) > 0:
757
+ safe_idx = min(orig_idx, len(primary_results.boxes.conf) - 1)
758
+ conf = float(primary_results.boxes.conf[safe_idx])
759
+ cls_id = int(primary_results.boxes.cls[safe_idx])
760
+ cls_name = primary_model.names[cls_id]
761
+ sam_mask = _fallback_to_yolo_mask(primary_results, safe_idx, "binary")
762
+ primary_fallback_count += 1
763
+ elif source == "secondary" and "secondary_results" in locals():
764
+ try:
765
+ safe_idx = min(orig_idx, len(secondary_results.boxes.conf) - 1)
766
+ conf = float(secondary_results.boxes.conf[safe_idx])
767
+ except Exception:
768
+ conf = conjoined_confidence
769
+ cls_name = "speech_bubble"
770
+ x0 = int(max(0, min(x0_f, img_w)))
771
+ y0 = int(max(0, min(y0_f, img_h)))
772
+ x1 = int(max(0, min(x1_f, img_w)))
773
+ y1 = int(max(0, min(y1_f, img_h)))
774
+ mask = np.zeros((img_h, img_w), dtype=np.uint8)
775
+ mask[y0:y1, x0:x1] = 255
776
+ sam_mask = mask
777
+ secondary_fallback_count += 1
778
+ else:
779
+ conf = conjoined_confidence
780
+ cls_name = "speech_bubble"
781
+ x0 = int(max(0, min(x0_f, img_w)))
782
+ y0 = int(max(0, min(y0_f, img_h)))
783
+ x1 = int(max(0, min(x1_f, img_w)))
784
+ y1 = int(max(0, min(y1_f, img_h)))
785
+ mask = np.zeros((img_h, img_w), dtype=np.uint8)
786
+ mask[y0:y1, x0:x1] = 255
787
+ sam_mask = mask
788
+
789
+ detection = {
790
+ "bbox": (
791
+ int(round(x0_f)),
792
+ int(round(y0_f)),
793
+ int(round(x1_f)),
794
+ int(round(y1_f)),
795
+ ),
796
+ "confidence": conf,
797
+ "class": cls_name,
798
+ }
799
+ detection["sam_mask"] = sam_mask
800
+
801
+ detections.append(detection)
802
+
803
+ log_message(
804
+ f"Fallback segmentation used {len(detections)} boxes "
805
+ f"(primary: {primary_fallback_count}, secondary splits: {secondary_fallback_count})",
806
+ verbose=verbose,
807
+ )
808
+
809
+ return detections, text_free_boxes
810
+ return detections, text_free_boxes
811
+
812
+
813
+ def detect_panels(
814
+ image_path: Path,
815
+ confidence: float = 0.25,
816
+ device=None,
817
+ verbose=False,
818
+ image_override: Optional[Image.Image] = None,
819
+ ) -> List[Tuple[int, int, int, int]]:
820
+ """Detect manga/comic panels using YOLO model.
821
+
822
+ Args:
823
+ image_path (Path): Path to the input image
824
+ confidence (float): Confidence threshold for panel YOLO detections
825
+ device (torch.device, optional): The device to run the model on. Autodetects if None.
826
+ verbose (bool): Whether to show detailed processing information
827
+ image_override (Image.Image, optional): PIL Image to use instead of loading from path
828
+
829
+ Returns:
830
+ list: List of tuples (x1, y1, x2, y2) representing panel bounding boxes.
831
+ Only includes detections with class "frame".
832
+ """
833
+ _device = (
834
+ device
835
+ if device is not None
836
+ else torch.device(
837
+ "cuda"
838
+ if torch.cuda.is_available()
839
+ else "mps" if torch.backends.mps.is_available() else "cpu"
840
+ )
841
+ )
842
+
843
+ try:
844
+ if image_override is not None:
845
+ image_pil = (
846
+ image_override
847
+ if image_override.mode == "RGB"
848
+ else image_override.convert("RGB")
849
+ )
850
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
851
+ else:
852
+ image_cv = cv2.imread(str(image_path))
853
+ if image_cv is None:
854
+ raise ImageProcessingError(f"Could not read image at {image_path}")
855
+ image_pil = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
856
+ log_message(
857
+ f"Processing image for panel detection: {image_path.name if image_path else 'override'} "
858
+ f"({image_cv.shape[1]}x{image_cv.shape[0]})",
859
+ verbose=verbose,
860
+ )
861
+ except Exception as e:
862
+ raise ImageProcessingError(f"Error loading image: {e}")
863
+
864
+ model_manager = get_model_manager()
865
+ try:
866
+ panel_model = model_manager.load_yolo_panel(verbose=verbose)
867
+ except Exception as e:
868
+ raise ModelError(f"Error loading panel model: {e}")
869
+
870
+ try:
871
+ results = panel_model(image_cv, conf=confidence, device=_device, verbose=False)[
872
+ 0
873
+ ]
874
+ boxes = results.boxes.xyxy if results.boxes is not None else torch.tensor([])
875
+ classes = results.boxes.cls if results.boxes is not None else torch.tensor([])
876
+
877
+ if len(boxes) == 0:
878
+ log_message("No panels detected", verbose=verbose)
879
+ return []
880
+
881
+ # Filter for "frame" class (panel class)
882
+ frame_class_id = None
883
+ if hasattr(panel_model, "names"):
884
+ for class_id, class_name in panel_model.names.items():
885
+ if class_name.lower() == "frame":
886
+ frame_class_id = class_id
887
+ break
888
+
889
+ panel_boxes = []
890
+ for i, box in enumerate(boxes):
891
+ # If we found a frame class ID, only include detections of that class
892
+ # Otherwise, include all detections (fallback)
893
+ if frame_class_id is not None:
894
+ if int(classes[i]) != frame_class_id:
895
+ continue
896
+
897
+ x0_f, y0_f, x1_f, y1_f = box.tolist()
898
+ panel_boxes.append(
899
+ (
900
+ int(round(x0_f)),
901
+ int(round(y0_f)),
902
+ int(round(x1_f)),
903
+ int(round(y1_f)),
904
+ )
905
+ )
906
+
907
+ return panel_boxes
908
+
909
+ except Exception as e:
910
+ log_message(
911
+ f"Panel detection failed: {e}. Proceeding without panel information.",
912
+ always_print=True,
913
+ )
914
+ return []
core/image/image_utils.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+ try:
11
+ import oxipng
12
+ OXIPNG_AVAILABLE = True
13
+ except ImportError:
14
+ oxipng = None
15
+ OXIPNG_AVAILABLE = False
16
+ import torch
17
+ from PIL import Image
18
+
19
+ from core.caching import get_cache
20
+ from core.ml.model_manager import get_model_manager
21
+ from utils.exceptions import ImageProcessingError
22
+ from utils.logging import log_message
23
+
24
+
25
+ def pil_to_cv2(pil_image):
26
+ """
27
+ Convert PIL Image to OpenCV format (numpy array)
28
+
29
+ Args:
30
+ pil_image (PIL.Image): PIL Image object
31
+
32
+ Returns:
33
+ numpy.ndarray: OpenCV image in BGR format
34
+ """
35
+ rgb_image = np.array(pil_image)
36
+ if len(rgb_image.shape) == 3:
37
+ if rgb_image.shape[2] == 3: # RGB
38
+ return cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
39
+ elif rgb_image.shape[2] == 4: # RGBA
40
+ return cv2.cvtColor(rgb_image, cv2.COLOR_RGBA2BGRA)
41
+ return rgb_image
42
+
43
+
44
+ def cv2_to_pil(cv2_image):
45
+ """
46
+ Convert OpenCV image to PIL Image
47
+
48
+ Args:
49
+ cv2_image (numpy.ndarray): OpenCV image in BGR or BGRA format
50
+
51
+ Returns:
52
+ PIL.Image: PIL Image object
53
+ """
54
+ if len(cv2_image.shape) == 3:
55
+ if cv2_image.shape[2] == 3: # BGR
56
+ rgb_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
57
+ return Image.fromarray(rgb_image)
58
+ elif cv2_image.shape[2] == 4: # BGRA
59
+ rgba_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGRA2RGBA)
60
+ return Image.fromarray(rgba_image)
61
+ return Image.fromarray(cv2_image)
62
+
63
+
64
+ def save_image_with_compression(
65
+ image, output_path, jpeg_quality=95, png_compression=2, verbose=False
66
+ ):
67
+ """
68
+ Save an image with specified compression settings.
69
+
70
+ Args:
71
+ image (PIL.Image): Image to save
72
+ output_path (str or Path): Path to save the image
73
+ jpeg_quality (int): JPEG quality (1-100, higher is better quality)
74
+ png_compression (int): PNG compression level (0-6, higher is more compression)
75
+ verbose (bool): Whether to print verbose logging
76
+
77
+ Raises:
78
+ ImageProcessingError: If image saving fails
79
+ """
80
+ output_path = (
81
+ Path(output_path) if not isinstance(output_path, Path) else output_path
82
+ )
83
+
84
+ extension = output_path.suffix.lower()
85
+ output_format = None
86
+ save_options = {}
87
+
88
+ if extension in [".jpg", ".jpeg"]:
89
+ output_format = "JPEG"
90
+ # JPEG doesn't support transparency - composite on white background
91
+ if image.mode in ["RGBA", "LA"]:
92
+ log_message(
93
+ f"Converting {image.mode} to RGB for JPEG output", verbose=verbose
94
+ )
95
+ background = Image.new("RGB", image.size, (255, 255, 255))
96
+ alpha_channel = image.split()[-1] if image.mode in ["RGBA", "LA"] else None
97
+ background.paste(image, mask=alpha_channel)
98
+ image = background
99
+ elif image.mode == "P": # Handle Palette mode
100
+ log_message("Converting P mode to RGB for JPEG output", verbose=verbose)
101
+ image = image.convert("RGB")
102
+ elif image.mode != "RGB":
103
+ log_message(
104
+ f"Converting {image.mode} mode to RGB for JPEG output", verbose=verbose
105
+ )
106
+ image = image.convert("RGB")
107
+ save_options["quality"] = max(1, min(jpeg_quality, 100))
108
+ log_message(
109
+ f"Saving JPEG image with quality {save_options['quality']} to {output_path}",
110
+ verbose=verbose,
111
+ )
112
+
113
+ elif extension == ".png":
114
+ output_format = "PNG"
115
+ oxipng_level = min(6, max(0, int(png_compression)))
116
+ log_message(
117
+ f"Saving PNG image with compression level {oxipng_level} to {output_path}",
118
+ verbose=verbose,
119
+ )
120
+
121
+ elif extension == ".webp":
122
+ output_format = "WEBP"
123
+ save_options["lossless"] = True
124
+ log_message(
125
+ f"Saving WEBP image with lossless quality to {output_path}", verbose=verbose
126
+ )
127
+
128
+ else:
129
+ log_message(
130
+ f"Warning: Unknown output extension '{extension}'. Saving as PNG.",
131
+ verbose=verbose,
132
+ always_print=True,
133
+ )
134
+ output_format = "PNG"
135
+ output_path = output_path.with_suffix(".png")
136
+ oxipng_level = min(6, max(0, int(png_compression)))
137
+ log_message(
138
+ f"Saving PNG image with compression level {oxipng_level} to {output_path}",
139
+ verbose=verbose,
140
+ )
141
+
142
+ try:
143
+ os.makedirs(output_path.parent, exist_ok=True)
144
+
145
+ if output_format == "PNG":
146
+ if OXIPNG_AVAILABLE:
147
+ buffer = io.BytesIO()
148
+ image.save(buffer, format="PNG")
149
+ png_data = buffer.getvalue()
150
+
151
+ try:
152
+ optimized_data = oxipng.optimize_from_memory(
153
+ png_data, level=oxipng_level, optimize_alpha=True
154
+ )
155
+ with open(output_path, "wb") as f:
156
+ f.write(optimized_data)
157
+ except oxipng.PngError as e:
158
+ log_message(
159
+ f"oxipng optimization failed: {e}. Falling back to Pillow save.",
160
+ verbose=verbose,
161
+ always_print=True,
162
+ )
163
+ # Fallback to Pillow if oxipng fails
164
+ image.save(
165
+ str(output_path),
166
+ format="PNG",
167
+ compress_level=max(0, min(png_compression, 6)),
168
+ optimize=True,
169
+ )
170
+ else:
171
+ # oxipng not available, use Pillow directly
172
+ image.save(
173
+ str(output_path),
174
+ format="PNG",
175
+ compress_level=max(0, min(png_compression, 6)),
176
+ optimize=True,
177
+ )
178
+ else:
179
+ # Use Pillow for non-PNG formats
180
+ image.save(str(output_path), format=output_format, **save_options)
181
+ return True
182
+ except Exception as e:
183
+ log_message(f"Error saving image to {output_path}: {e}", always_print=True)
184
+ raise ImageProcessingError(f"Failed to save image to {output_path}") from e
185
+
186
+
187
+ def calculate_centroid_expansion_box(
188
+ cleaned_mask: np.ndarray, padding_pixels: float = 5.0, verbose: bool = False
189
+ ) -> Tuple[Tuple[int, int, int, int], Tuple[float, float]]:
190
+ """
191
+ Calculates guaranteed safe rendering box using the 5-step Distance Transform Insetting Method.
192
+
193
+ This function implements a sophisticated algorithm to find the optimal text placement area
194
+ within a speech bubble, ensuring text never touches the bubble boundaries. The method uses
195
+ computer vision techniques to create a safe zone for text rendering.
196
+
197
+ Algorithm Overview:
198
+ The 5-step Distance Transform Insetting Method works as follows:
199
+
200
+ 1. Establish Safe Zone:
201
+ - Uses cv2.distanceTransform() to compute the distance from each pixel to the nearest
202
+ bubble edge (0 pixels)
203
+ - Creates a safe_area_mask where distance >= padding_pixels
204
+ - This ensures all pixels in the safe zone are at least padding_pixels away from edges
205
+
206
+ 2. Find Unbiased Anchor:
207
+ - Calculates the centroid (geometric center) of the safe_area_mask using cv2.moments()
208
+ - This provides an unbiased starting point for text placement
209
+ - The centroid represents the "center of mass" of the safe area
210
+
211
+ 3. Measure Available Space:
212
+ - Performs ray-casting from the centroid in four cardinal directions (left, right, up, down)
213
+ - Measures distances to the nearest safe area boundary in each direction
214
+ - Uses numpy array operations for efficient distance calculation
215
+
216
+ 4. Calculate Symmetrical Dimensions:
217
+ - Takes the minimum distance in each axis to ensure the box fits in all directions
218
+ - Multiplies by 2 to create symmetrical width and height around the centroid
219
+ - Subtracts 1 pixel margin for safety
220
+
221
+ 5. Construct Final Box:
222
+ - Creates a centered rectangle within the safe zone
223
+ - Ensures the box is completely contained within the original mask bounds
224
+ - Returns both the box coordinates and the true centroid for precise text positioning
225
+
226
+ Why This Approach Works:
227
+ - Distance Transform provides accurate edge detection and safe zone calculation
228
+ - Ray-casting ensures the text box never touches bubble boundaries
229
+ - Centroid-based approach provides natural, visually appealing text placement
230
+ - Symmetrical dimensions prevent text from appearing off-center
231
+ - The method handles complex bubble shapes (ovals, irregular polygons, etc.)
232
+
233
+ Args:
234
+ cleaned_mask: Binary mask (0/255) of the cleaned speech bubble where 255 represents
235
+ the bubble interior and 0 represents the background
236
+ padding_pixels: Minimum distance in pixels that text must maintain from bubble edges.
237
+ Higher values create more padding but smaller text areas.
238
+ verbose: Whether to print detailed processing information for debugging
239
+
240
+ Returns:
241
+ Tuple containing:
242
+ - Tuple[int, int, int, int]: Safe box coordinates as [x, y, width, height] where
243
+ (x, y) is the top-left corner.
244
+ - Tuple[float, float]: True geometric center (centroid) of the safe area as (cx, cy).
245
+
246
+ Raises:
247
+ ImageProcessingError: If mask is invalid or calculation fails
248
+
249
+ Example:
250
+ >>> mask = np.zeros((100, 100), dtype=np.uint8)
251
+ >>> cv2.ellipse(mask, (50, 50), (40, 30), 0, 0, 360, 255, -1)
252
+ >>> box, centroid = calculate_centroid_expansion_box(mask, padding_pixels=10.0)
253
+ >>> log_message(f"Safe box: {box}, Centroid: {centroid}", verbose=True)
254
+ Safe box: (20, 30, 60, 40), Centroid: (50.0, 50.0)
255
+ """
256
+ if cleaned_mask is None or not np.any(cleaned_mask):
257
+ raise ImageProcessingError("Invalid or empty mask provided")
258
+
259
+ try:
260
+ # Create safe area using distance transform
261
+ distance_map = cv2.distanceTransform(
262
+ cleaned_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE
263
+ )
264
+ safe_area_mask = (distance_map >= padding_pixels).astype(np.uint8) * 255
265
+
266
+ if not np.any(safe_area_mask):
267
+ log_message(
268
+ f"Safe area calculation failed: padding {padding_pixels:.0f}px too large",
269
+ verbose=verbose,
270
+ always_print=True,
271
+ )
272
+ raise ImageProcessingError("Failed to create safe area mask")
273
+
274
+ # Find centroid of safe area
275
+ moments = cv2.moments(safe_area_mask)
276
+
277
+ if moments["m00"] == 0:
278
+ raise ImageProcessingError("Safe area mask has no area")
279
+
280
+ centroid_x = moments["m10"] / moments["m00"]
281
+ centroid_y = moments["m01"] / moments["m00"]
282
+
283
+ # Check if centroid is in a constricted region (dual/conjoined bubbles)
284
+ _, max_val, _, max_loc = cv2.minMaxLoc(distance_map)
285
+
286
+ cx_int, cy_int = int(round(centroid_x)), int(round(centroid_y))
287
+ mask_h, mask_w = safe_area_mask.shape
288
+
289
+ cx_int = max(0, min(cx_int, mask_w - 1))
290
+ cy_int = max(0, min(cy_int, mask_h - 1))
291
+
292
+ dist_at_centroid = distance_map[cy_int, cx_int]
293
+
294
+ if dist_at_centroid < max_val * 0.70:
295
+ log_message(
296
+ f"Centroid in constricted region (dist={dist_at_centroid:.1f} vs max={max_val:.1f}). "
297
+ "Moving anchor to pole of inaccessibility.",
298
+ verbose=verbose,
299
+ )
300
+ centroid_x, centroid_y = float(max_loc[0]), float(max_loc[1])
301
+
302
+ centroid = (centroid_x, centroid_y)
303
+
304
+ # Ray-cast from centroid to find maximum safe dimensions
305
+ cx, cy = int(round(centroid_x)), int(round(centroid_y))
306
+ mask_h, mask_w = safe_area_mask.shape
307
+
308
+ # Verify centroid is within safe area, adjust if needed
309
+ if (
310
+ cy < 0
311
+ or cy >= mask_h
312
+ or cx < 0
313
+ or cx >= mask_w
314
+ or safe_area_mask[cy, cx] != 255
315
+ ):
316
+ # Centroid is outside safe area, find nearest safe pixel
317
+ safe_pixels = np.argwhere(safe_area_mask == 255)
318
+ if safe_pixels.size == 0:
319
+ raise ImageProcessingError("No safe pixels found in safe_area_mask")
320
+ # Find nearest safe pixel to calculated centroid
321
+ distances = np.sqrt(
322
+ (safe_pixels[:, 0] - centroid_y) ** 2
323
+ + (safe_pixels[:, 1] - centroid_x) ** 2
324
+ )
325
+ nearest_idx = np.argmin(distances)
326
+ cy, cx = safe_pixels[nearest_idx]
327
+ # Update centroid to the adjusted position
328
+ centroid_x, centroid_y = float(cx), float(cy)
329
+ centroid = (centroid_x, centroid_y)
330
+
331
+ left_zeros = np.where(safe_area_mask[cy, 0:cx] == 0)[0]
332
+ dist_to_left_edge = cx - (left_zeros.max() if left_zeros.size > 0 else 0)
333
+
334
+ right_zeros = np.where(safe_area_mask[cy, cx:] == 0)[0]
335
+ dist_to_right_edge = right_zeros.min() if right_zeros.size > 0 else mask_w - cx
336
+
337
+ up_zeros = np.where(safe_area_mask[0:cy, cx] == 0)[0]
338
+ dist_to_top_edge = cy - (up_zeros.max() if up_zeros.size > 0 else 0)
339
+
340
+ down_zeros = np.where(safe_area_mask[cy:, cx] == 0)[0]
341
+ dist_to_bottom_edge = down_zeros.min() if down_zeros.size > 0 else mask_h - cy
342
+
343
+ # Only subtract 1 if distance > 1, otherwise use the distance directly
344
+ # This prevents collapsing 1-pixel safe areas to 0x0
345
+ min_width_dist = min(dist_to_left_edge, dist_to_right_edge)
346
+ min_height_dist = min(dist_to_top_edge, dist_to_bottom_edge)
347
+ safe_width_base = min_width_dist - 1 if min_width_dist > 1 else min_width_dist
348
+ safe_height_base = (
349
+ min_height_dist - 1 if min_height_dist > 1 else min_height_dist
350
+ )
351
+ max_safe_width = 2 * max(0, safe_width_base)
352
+ max_safe_height = 2 * max(0, safe_height_base)
353
+
354
+ if max_safe_width <= 0 or max_safe_height <= 0:
355
+ log_message(
356
+ f"Invalid safe area dimensions: {max_safe_width:.0f}x{max_safe_height:.0f}",
357
+ verbose=verbose,
358
+ always_print=True,
359
+ )
360
+ raise ImageProcessingError("Failed to create safe area mask")
361
+
362
+ box_x_float = centroid_x - max_safe_width / 2.0
363
+ box_y_float = centroid_y - max_safe_height / 2.0
364
+
365
+ box_x = int(round(box_x_float))
366
+ box_y = int(round(box_y_float))
367
+
368
+ guaranteed_box = (box_x, box_y, max_safe_width, max_safe_height)
369
+
370
+ if (
371
+ box_x >= 0
372
+ and box_y >= 0
373
+ and box_x + max_safe_width <= mask_w
374
+ and box_y + max_safe_height <= mask_h
375
+ ):
376
+ log_message(
377
+ f"Safe area: {max_safe_width:.0f}x{max_safe_height:.0f} at ({centroid_x:.0f}, {centroid_y:.0f})",
378
+ verbose=verbose,
379
+ )
380
+ return guaranteed_box, centroid
381
+ else:
382
+ log_message(
383
+ f"Safe area validation failed: exceeds bounds {mask_w}x{mask_h}",
384
+ verbose=verbose,
385
+ always_print=True,
386
+ )
387
+ raise ImageProcessingError("Failed to create safe area mask")
388
+
389
+ except (cv2.error, ValueError, IndexError, ZeroDivisionError, OverflowError) as e:
390
+ log_message(
391
+ f"Safe area calculation error: {e}", verbose=verbose, always_print=True
392
+ )
393
+ except Exception as e:
394
+ log_message(
395
+ f"Safe area calculation failed: {e}", verbose=verbose, always_print=True
396
+ )
397
+
398
+ raise ImageProcessingError("Safe area calculation failed")
399
+
400
+
401
+ def image_to_tensor(image: Image.Image, device: torch.device) -> torch.Tensor:
402
+ """Converts a PIL Image to a PyTorch tensor."""
403
+ if image.mode != "RGB":
404
+ image = image.convert("RGB")
405
+ img_np = np.array(image).astype(np.float32) / 255.0
406
+ if img_np.ndim == 2: # Grayscale to RGB
407
+ img_np = np.stack((img_np,) * 3, axis=-1)
408
+ return torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
409
+
410
+
411
+ def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
412
+ """Converts a PyTorch tensor to a PIL Image."""
413
+ img_np = (
414
+ tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() * 255
415
+ ).astype(np.uint8)
416
+ return Image.fromarray(img_np)
417
+
418
+
419
+ def _upscale_image(model, image: Image.Image, device: torch.device) -> Image.Image:
420
+ """Upscales a PIL image using the provided model."""
421
+ tensor_in = image_to_tensor(image, device)
422
+ with torch.no_grad():
423
+ tensor_out = model(tensor_in)
424
+ return tensor_to_image(tensor_out)
425
+
426
+
427
+ def upscale_image_to_dimension(
428
+ model,
429
+ image: Image.Image,
430
+ target: int,
431
+ device: torch.device,
432
+ mode: str,
433
+ model_type: str = "model",
434
+ verbose: bool = False,
435
+ ) -> Image.Image:
436
+ """
437
+ Upscale until a dimensional target is reached.
438
+
439
+ Args:
440
+ mode: 'max' ensures max(width, height) >= target, 'min' ensures min(width, height) >= target
441
+ model_type: Model type identifier ("model" or "model_lite")
442
+ """
443
+ if mode not in {"max", "min"}:
444
+ raise ImageProcessingError("mode must be 'max' or 'min'")
445
+
446
+ # Validate input image dimensions
447
+ if image.width <= 0 or image.height <= 0:
448
+ log_message(
449
+ f"Invalid image dimensions: {image.width}x{image.height}. Cannot upscale 0x0 images.",
450
+ always_print=True,
451
+ )
452
+ raise ImageProcessingError(
453
+ f"Invalid image dimensions: {image.width}x{image.height}. Cannot upscale 0x0 images."
454
+ )
455
+
456
+ cache = get_cache()
457
+ cache_key = cache.get_upscale_dimension_cache_key(image, target, mode, model_type)
458
+ cached_result = cache.get_upscaled_image(cache_key)
459
+ if cached_result is not None:
460
+ log_message(" - Using cached upscaled image", verbose=verbose)
461
+ return cached_result
462
+
463
+ current_image = image
464
+
465
+ def met(w: int, h: int) -> bool:
466
+ return (max(w, h) >= target) if mode == "max" else (min(w, h) >= target)
467
+
468
+ if met(current_image.width, current_image.height):
469
+ cache.set_upscaled_image(cache_key, current_image, verbose)
470
+ return current_image
471
+
472
+ log_message(
473
+ f"Upscaling from {current_image.width}x{current_image.height}...",
474
+ verbose=verbose,
475
+ )
476
+ current_image = _upscale_image(model, current_image, device)
477
+ log_message(f"...to {current_image.width}x{current_image.height}", verbose=verbose)
478
+
479
+ # Save intermediate image to disk if more passes will be needed
480
+ if not met(current_image.width, current_image.height):
481
+ temp_file = None
482
+ try:
483
+ temp_fd, temp_file = tempfile.mkstemp(suffix=".png")
484
+ os.close(temp_fd)
485
+ current_image.save(temp_file, format="PNG")
486
+
487
+ with Image.open(temp_file) as img_tmp:
488
+ img_tmp.load()
489
+ new_image = img_tmp.copy()
490
+
491
+ del current_image
492
+ gc.collect()
493
+ current_image = new_image
494
+ log_message(
495
+ "Saved and reloaded intermediate image before additional passes",
496
+ verbose=verbose,
497
+ )
498
+ except Exception as e:
499
+ log_message(
500
+ f"Warning: Failed to save intermediate image to disk: {e}. Continuing with in-memory processing.",
501
+ verbose=verbose,
502
+ )
503
+ finally:
504
+ if temp_file and os.path.exists(temp_file):
505
+ try:
506
+ os.remove(temp_file)
507
+ except Exception:
508
+ pass
509
+
510
+ while not met(current_image.width, current_image.height):
511
+ log_message(
512
+ f"Upscaling from {current_image.width}x{current_image.height} (additional pass)...",
513
+ verbose=verbose,
514
+ )
515
+ current_image = _upscale_image(model, current_image, device)
516
+ log_message(
517
+ f"...to {current_image.width}x{current_image.height}", verbose=verbose
518
+ )
519
+
520
+ # Save intermediate image to disk to free memory
521
+ temp_file = None
522
+ try:
523
+ temp_fd, temp_file = tempfile.mkstemp(suffix=".png")
524
+ os.close(temp_fd)
525
+ current_image.save(temp_file, format="PNG")
526
+ del current_image
527
+ gc.collect()
528
+
529
+ with Image.open(temp_file) as img_tmp:
530
+ img_tmp.load()
531
+ current_image = img_tmp.copy()
532
+
533
+ log_message(
534
+ "Saved and reloaded intermediate image to free memory",
535
+ verbose=verbose,
536
+ )
537
+ except Exception as e:
538
+ log_message(
539
+ f"Warning: Failed to save intermediate image to disk: {e}. Continuing with in-memory processing.",
540
+ verbose=verbose,
541
+ )
542
+ finally:
543
+ if temp_file and os.path.exists(temp_file):
544
+ try:
545
+ os.remove(temp_file)
546
+ except Exception:
547
+ pass # Ignore errors during cleanup
548
+
549
+ cache.set_upscaled_image(cache_key, current_image, verbose)
550
+ return current_image
551
+
552
+
553
+ def upscale_image(
554
+ image: Image.Image, factor: float, model_type: str = "model", verbose: bool = False
555
+ ) -> Image.Image:
556
+ """Upscales an image by a given factor.
557
+
558
+ Args:
559
+ image: Image to upscale
560
+ factor: Upscaling factor
561
+ model_type: Model type to use - "model" or "model_lite"
562
+ verbose: Whether to print verbose logging
563
+ """
564
+ if factor == 1.0:
565
+ return image
566
+
567
+ cache = get_cache()
568
+ cache_key = cache.get_upscale_cache_key(image, factor, model_type)
569
+ cached_upscale = cache.get_upscaled_image(cache_key)
570
+ if cached_upscale is not None:
571
+ log_message(" - Using cached upscaled image", verbose=verbose)
572
+ return cached_upscale
573
+
574
+ model_manager = get_model_manager()
575
+ if model_type == "model_lite":
576
+ upscale_model = model_manager.load_upscale_lite()
577
+ log_message(f"Upscaling image by {factor}x with lite model...", verbose=verbose)
578
+ else:
579
+ upscale_model = model_manager.load_upscale()
580
+ log_message(f"Upscaling image by {factor}x...", verbose=verbose)
581
+ device = model_manager.device
582
+
583
+ target_width = int(image.width * factor)
584
+ target_height = int(image.height * factor)
585
+
586
+ upscaled_image = upscale_image_to_dimension(
587
+ upscale_model,
588
+ image,
589
+ max(target_width, target_height),
590
+ device,
591
+ "max",
592
+ model_type,
593
+ verbose,
594
+ )
595
+ result = upscaled_image.resize((target_width, target_height), Image.LANCZOS)
596
+
597
+ cache.set_upscaled_image(cache_key, result)
598
+ return result
599
+
600
+
601
+ def resize_to_max_side(
602
+ image: Image.Image, max_side: int, verbose: bool = False
603
+ ) -> Image.Image:
604
+ """Resize so that the largest side equals max_side (aspect ratio preserved)."""
605
+ width, height = image.size
606
+ current_max = max(width, height)
607
+ if current_max == max_side:
608
+ return image
609
+ scale = max_side / current_max
610
+ new_width = max(1, int(round(width * scale)))
611
+ new_height = max(1, int(round(height * scale)))
612
+ log_message(
613
+ f"Resizing to max-side {max_side}: {width}x{height} -> {new_width}x{new_height}",
614
+ verbose=verbose,
615
+ )
616
+ return image.resize((new_width, new_height), Image.LANCZOS)
617
+
618
+
619
+ def resize_to_min_side(
620
+ image: Image.Image, min_side: int, verbose: bool = False
621
+ ) -> Image.Image:
622
+ """Resize so that the smallest side equals min_side (aspect ratio preserved)."""
623
+ width, height = image.size
624
+
625
+ # Validate input image dimensions
626
+ if width <= 0 or height <= 0:
627
+ log_message(
628
+ f"Invalid image dimensions: {width}x{height}. Cannot resize 0x0 images.",
629
+ always_print=True,
630
+ )
631
+ raise ImageProcessingError(
632
+ f"Invalid image dimensions: {width}x{height}. Cannot resize 0x0 images."
633
+ )
634
+
635
+ current_min = min(width, height)
636
+ if current_min == min_side:
637
+ return image
638
+ scale = min_side / current_min
639
+ new_width = max(1, int(round(width * scale)))
640
+ new_height = max(1, int(round(height * scale)))
641
+ log_message(
642
+ f"Resizing to min-side {min_side}: {width}x{height} -> {new_width}x{new_height}",
643
+ verbose=verbose,
644
+ )
645
+ return image.resize((new_width, new_height), Image.LANCZOS)
646
+
647
+
648
+ def convert_image_to_target_mode(
649
+ pil_image: Image.Image, target_mode: str, verbose: bool = False
650
+ ) -> Image.Image:
651
+ """
652
+ Convert a PIL image to the target color mode (RGB or RGBA).
653
+
654
+ Handles complex transparency flattening and mode conversion with multiple
655
+ fallback strategies to ensure robust image processing.
656
+
657
+ Args:
658
+ pil_image: The PIL image to convert
659
+ target_mode: Target mode ("RGB" or "RGBA")
660
+ verbose: Whether to print detailed logging
661
+
662
+ Returns:
663
+ PIL.Image: The converted image in the target mode
664
+ """
665
+ if pil_image.mode == target_mode:
666
+ return pil_image
667
+
668
+ if target_mode == "RGB":
669
+ if (
670
+ pil_image.mode == "RGBA"
671
+ or pil_image.mode == "LA"
672
+ or (pil_image.mode == "P" and "transparency" in pil_image.info)
673
+ ):
674
+ log_message(
675
+ f"Converting {pil_image.mode} to RGB (flattening transparency)",
676
+ verbose=verbose,
677
+ )
678
+ background = Image.new("RGB", pil_image.size, (255, 255, 255))
679
+ try:
680
+ mask = None
681
+ if pil_image.mode == "RGBA":
682
+ mask = pil_image.split()[3]
683
+ elif pil_image.mode == "LA":
684
+ mask = pil_image.split()[1]
685
+ elif pil_image.mode == "P" and "transparency" in pil_image.info:
686
+ temp_rgba = pil_image.convert("RGBA")
687
+ mask = temp_rgba.split()[3]
688
+
689
+ if mask:
690
+ background.paste(pil_image, mask=mask)
691
+ pil_image = background
692
+ else:
693
+ pil_image = pil_image.convert("RGB")
694
+ except Exception as paste_err:
695
+ log_message(
696
+ f"Warning: Paste failed, trying alpha_composite: {paste_err}",
697
+ verbose=verbose,
698
+ )
699
+ try:
700
+ background_comp = Image.new("RGB", pil_image.size, (255, 255, 255))
701
+ img_rgba_for_composite = (
702
+ pil_image
703
+ if pil_image.mode == "RGBA"
704
+ else pil_image.convert("RGBA")
705
+ )
706
+ pil_image = Image.alpha_composite(
707
+ background_comp.convert("RGBA"), img_rgba_for_composite
708
+ ).convert("RGB")
709
+ log_message(
710
+ "Alpha composite conversion successful", verbose=verbose
711
+ )
712
+ except Exception as composite_err:
713
+ log_message(
714
+ f"Warning: Alpha composite failed, using simple convert: {composite_err}",
715
+ verbose=verbose,
716
+ )
717
+ pil_image = pil_image.convert("RGB") # Final fallback conversion
718
+ else: # Non-transparent conversion to RGB
719
+ log_message(f"Converting {pil_image.mode} to RGB", verbose=verbose)
720
+ pil_image = pil_image.convert("RGB")
721
+ elif target_mode == "RGBA":
722
+ log_message(f"Converting {pil_image.mode} to RGBA", verbose=verbose)
723
+ pil_image = pil_image.convert("RGBA")
724
+
725
+ return pil_image
726
+
727
+
728
+ def process_bubble_image_cached(
729
+ bubble_image_pil: Image.Image,
730
+ upscale_model,
731
+ device: torch.device,
732
+ target_min_side: int = 200,
733
+ mode: str = "min",
734
+ model_type: str = "model",
735
+ verbose: bool = False,
736
+ ) -> Image.Image:
737
+ """
738
+ Process a bubble image with upscaling, using cache for the complete pipeline.
739
+
740
+ This function handles the complete bubble processing pipeline:
741
+ 1. Upscales the bubble to meet minimum size requirements
742
+ 2. Resizes to exact minimum side length
743
+ 3. Caches the final result
744
+
745
+ Args:
746
+ bubble_image_pil: The bubble image to process
747
+ upscale_model: The upscaling model to use
748
+ device: PyTorch device for model inference
749
+ target_min_side: Target minimum side length
750
+ mode: Upscaling mode ('max' or 'min')
751
+ model_type: Model type identifier ("model" or "model_lite")
752
+ verbose: Whether to print detailed logging
753
+
754
+ Returns:
755
+ Image.Image: The processed bubble image
756
+ """
757
+ cache = get_cache()
758
+ cache_key = cache.get_bubble_processing_cache_key(
759
+ bubble_image_pil, target_min_side, mode, model_type
760
+ )
761
+ cached_result = cache.get_upscaled_image(cache_key)
762
+ if cached_result is not None:
763
+ log_message(" - Using cached bubble processing result", verbose=verbose)
764
+ return cached_result
765
+
766
+ upscaled_bubble = upscale_image_to_dimension(
767
+ upscale_model,
768
+ bubble_image_pil,
769
+ target_min_side,
770
+ device,
771
+ mode,
772
+ model_type,
773
+ verbose,
774
+ )
775
+
776
+ resized_bubble = resize_to_min_side(upscaled_bubble, target_min_side, verbose)
777
+
778
+ cache.set_upscaled_image(cache_key, resized_bubble, verbose)
779
+ return resized_bubble
core/image/inpainting.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from scipy.ndimage import distance_transform_edt
9
+
10
+ from core.caching import get_cache
11
+ from core.ml.model_manager import get_model_manager
12
+ from utils.logging import log_message
13
+
14
+ # Blur Parameters
15
+ BLUR_SCALE_FACTOR = (
16
+ 0.1 # Multiplier for bounding box dimensions to calculate blur radius
17
+ )
18
+ MIN_BLUR_RADIUS = 1 # Minimum blur radius in pixels
19
+ MAX_BLUR_RADIUS = 10 # Maximum blur radius in pixels
20
+
21
+ # Inpainting Parameters
22
+ FLUX_GUIDANCE_SCALE = 2.5 # Flux Kontext guidance scale
23
+ CONTEXT_PADDING_RATIO = 0.5 # Context padding is 50% of detection size
24
+ MAX_CONTEXT_PADDING = 80 # Context padding capped at 80 pixels
25
+
26
+
27
+ class FluxKontextInpainter:
28
+ """Inpainter using Flux Kontext models for text removal."""
29
+
30
+ def __init__(
31
+ self,
32
+ device: Optional[torch.device] = None,
33
+ huggingface_token: str = "",
34
+ num_inference_steps: int = 15,
35
+ residual_diff_threshold: float = 0.15,
36
+ ):
37
+ """Initialize the Flux Kontext Inpaint class.
38
+
39
+ Args:
40
+ device: PyTorch device to use. Auto-detects if None.
41
+ huggingface_token: HuggingFace token for model downloads.
42
+ num_inference_steps: Number of denoising steps for inference.
43
+ residual_diff_threshold: Residual diff threshold for Flux caching (0.0-1.0).
44
+ """
45
+ self.DEVICE = (
46
+ device
47
+ if device is not None
48
+ else torch.device(
49
+ "cuda"
50
+ if torch.cuda.is_available()
51
+ else "mps"
52
+ if torch.backends.mps.is_available()
53
+ else "cpu"
54
+ )
55
+ )
56
+ self.DTYPE = (
57
+ torch.bfloat16
58
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
59
+ else torch.float16
60
+ if self.DEVICE.type == "mps"
61
+ else torch.float32
62
+ )
63
+ self.huggingface_token = huggingface_token
64
+ self.num_inference_steps = num_inference_steps
65
+ self.residual_diff_threshold = residual_diff_threshold
66
+ self.manager = get_model_manager()
67
+ self.cache = get_cache()
68
+
69
+ # Preferred resolutions for optimal Flux performance
70
+ self.PREFERED_KONTEXT_RESOLUTIONS = [
71
+ (672, 1568),
72
+ (688, 1504),
73
+ (720, 1456),
74
+ (752, 1392),
75
+ (800, 1328),
76
+ (832, 1248),
77
+ (880, 1184),
78
+ (944, 1104),
79
+ (1024, 1024),
80
+ (1104, 944),
81
+ (1184, 880),
82
+ (1248, 832),
83
+ (1328, 800),
84
+ (1392, 752),
85
+ (1456, 720),
86
+ (1504, 688),
87
+ (1568, 672),
88
+ ]
89
+
90
+ self.pipeline = None
91
+ self.transformer = None
92
+ self.text_encoder_2 = None
93
+
94
+ # Fixed parameters optimized for text removal
95
+ self.guidance_scale = FLUX_GUIDANCE_SCALE
96
+ self.prompt = "Remove all text."
97
+ self.context_padding_ratio = CONTEXT_PADDING_RATIO
98
+ self.max_context_padding = MAX_CONTEXT_PADDING
99
+
100
+ def load_models(self):
101
+ """Load Flux Kontext models via model manager."""
102
+ if self.pipeline is not None:
103
+ return
104
+
105
+ if self.huggingface_token:
106
+ self.manager.set_flux_hf_token(self.huggingface_token)
107
+
108
+ self.manager.set_flux_residual_diff_threshold(self.residual_diff_threshold)
109
+
110
+ self.transformer, self.text_encoder_2, self.pipeline = (
111
+ self.manager.load_flux_models()
112
+ )
113
+
114
+ def unload_models(self):
115
+ """Unload Flux models via model manager to free up memory."""
116
+ self.pipeline = None
117
+ self.transformer = None
118
+ self.text_encoder_2 = None
119
+ self.manager.unload_flux_models()
120
+
121
+ def convert_mask_to_tensor(self, mask_np):
122
+ """Convert a numpy mask to the tensor format expected by the pipeline.
123
+
124
+ Args:
125
+ mask_np: Numpy mask array (H, W) with True/False values
126
+
127
+ Returns:
128
+ torch.Tensor: Mask tensor in CHW format (1.0 for areas to keep, 0.0 for areas to inpaint)
129
+ """
130
+ # Invert mask: True = inpaint (0.0), False = keep (1.0)
131
+ mask_float = mask_np.astype(np.float32)
132
+ mask_inverted = 1.0 - mask_float
133
+ mask_tensor = torch.from_numpy(mask_inverted).unsqueeze(0)
134
+
135
+ return mask_tensor
136
+
137
+ def flux_kontext_image_scale(self, image_pil):
138
+ """Find the closest preferred resolution and resize the image.
139
+
140
+ Args:
141
+ image_pil (PIL.Image): Input image to scale
142
+
143
+ Returns:
144
+ PIL.Image: Scaled image at the closest preferred resolution
145
+ """
146
+ w_in, h_in = image_pil.size
147
+ if w_in == 0 or h_in == 0:
148
+ return image_pil
149
+
150
+ ar = w_in / h_in
151
+ # Find resolution with minimum aspect ratio difference
152
+ _, w_opt, h_opt = min(
153
+ (abs(ar - w / h), w, h) for (w, h) in self.PREFERED_KONTEXT_RESOLUTIONS
154
+ )
155
+
156
+ log_message(
157
+ f" - Original image size: {w_in}x{h_in} (AR: {ar:.2f})", always_print=True
158
+ )
159
+ log_message(
160
+ f" - Scaling to nearest preferred resolution: {w_opt}x{h_opt}",
161
+ always_print=True,
162
+ )
163
+
164
+ if (w_in, h_in) == (w_opt, h_opt):
165
+ return image_pil
166
+
167
+ # Use LANCZOS for high-quality downscaling
168
+ image_scaled = image_pil.resize((w_opt, h_opt), Image.Resampling.LANCZOS)
169
+
170
+ return image_scaled
171
+
172
+ def compute_mask_bbox_aspect_ratio(
173
+ self,
174
+ mask_chw,
175
+ padding,
176
+ blur_radius,
177
+ target_ar=None,
178
+ transpose=False,
179
+ preferred_resolutions=None,
180
+ verbose=False,
181
+ ):
182
+ """Compute an optimized bounding box for the mask with aspect ratio adjustment.
183
+
184
+ Args:
185
+ mask_chw (torch.Tensor): Input mask tensor in CHW format
186
+ padding (int): Padding around the mask bounding box
187
+ blur_radius (int): Radius for edge blur effect
188
+ target_ar (float, optional): Target aspect ratio
189
+ transpose (bool): Whether to transpose the aspect ratio logic
190
+ preferred_resolutions (list, optional): List of preferred resolutions
191
+ verbose (bool): Whether to print verbose output
192
+
193
+ Returns:
194
+ tuple: (mask_for_composite, x, y, width, height)
195
+ """
196
+ if mask_chw.dim() == 4:
197
+ mask = mask_chw[0, 0]
198
+ else:
199
+ mask = mask_chw[0]
200
+
201
+ H, W = mask.shape[0], mask.shape[1]
202
+ hard = mask.clone().unsqueeze(0)
203
+ if blur_radius > 0:
204
+ # Create smooth falloff at mask edges for better blending
205
+ m_bool = hard[0].cpu().to(torch.float32).numpy().astype(bool)
206
+ d_out = distance_transform_edt(~m_bool)
207
+ d_in = distance_transform_edt(m_bool)
208
+ alpha = np.zeros_like(d_out, np.float32)
209
+ alpha[d_in > 0] = 1.0
210
+ ramp = np.clip(1.0 - (d_out / blur_radius), 0.0, 1.0)
211
+ alpha[d_out > 0] = ramp[d_out > 0]
212
+ mask_blur_full = torch.from_numpy(alpha)[None, ...].to(hard.device)
213
+ else:
214
+ mask_blur_full = hard.clone()
215
+
216
+ ys, xs = torch.where(hard[0] > 0)
217
+ if len(ys) == 0:
218
+ return (
219
+ torch.zeros((1, H, W), device=mask_chw.device, dtype=mask_chw.dtype),
220
+ 0,
221
+ 0,
222
+ W,
223
+ H,
224
+ )
225
+
226
+ x1 = max(0, int(xs.min()) - padding)
227
+ x2 = min(W, int(xs.max()) + 1 + padding)
228
+ y1 = max(0, int(ys.min()) - padding)
229
+ y2 = min(H, int(ys.max()) + 1 + padding)
230
+ w0 = x2 - x1
231
+ h0 = y2 - y1
232
+
233
+ if preferred_resolutions:
234
+ if h0 == 0:
235
+ initial_ar = W / H
236
+ else:
237
+ initial_ar = w0 / h0
238
+ log_message(
239
+ f" - Initial mask bounding box AR: {initial_ar:.2f}",
240
+ verbose=verbose,
241
+ )
242
+
243
+ # Snap to closest preferred aspect ratio
244
+ _, w_opt, h_opt = min(
245
+ (abs(initial_ar - w / h), w, h) for (w, h) in preferred_resolutions
246
+ )
247
+ ar = w_opt / h_opt
248
+ log_message(
249
+ f" - Snapping to closest preferred AR: {ar:.2f} ({w_opt}x{h_opt})",
250
+ verbose=verbose,
251
+ )
252
+ else:
253
+ ar = target_ar
254
+
255
+ req_w = math.ceil(h0 * ar)
256
+ req_h = math.floor(w0 / ar)
257
+
258
+ new_x1, new_x2 = x1, x2
259
+ new_y1, new_y2 = y1, y2
260
+
261
+ flush_left = x1 == 0
262
+ flush_right = x2 == W
263
+ flush_top = y1 == 0
264
+ flush_bot = y2 == H
265
+
266
+ if not transpose:
267
+ if req_w > w0:
268
+ target_w = min(W, req_w)
269
+ delta = target_w - w0
270
+ if flush_right:
271
+ new_x1, new_x2 = W - target_w, W
272
+ elif flush_left:
273
+ new_x1, new_x2 = 0, target_w
274
+ else:
275
+ off = delta // 2
276
+ new_x1 = max(0, x1 - off)
277
+ new_x2 = new_x1 + target_w
278
+ if new_x2 > W:
279
+ new_x2 = W
280
+ new_x1 = W - target_w
281
+
282
+ elif req_h > h0:
283
+ target_h = min(H, req_h)
284
+ delta = target_h - h0
285
+ if flush_bot:
286
+ new_y1, new_y2 = H - target_h, H
287
+ elif flush_top:
288
+ new_y1, new_y2 = 0, target_h
289
+ else:
290
+ off = delta // 2
291
+ new_y1 = max(0, y1 - off)
292
+ new_y2 = new_y1 + target_h
293
+ if new_y2 > H:
294
+ new_y2 = H
295
+ new_y1 = H - target_h
296
+
297
+ else: # Transpose logic
298
+ if req_h > h0:
299
+ target_h = min(H, req_h)
300
+ delta = target_h - h0
301
+ if flush_bot:
302
+ new_y1, new_y2 = H - target_h, H
303
+ elif flush_top:
304
+ new_y1, new_y2 = 0, target_h
305
+ else:
306
+ off = delta // 2
307
+ new_y1 = max(0, y1 - off)
308
+ new_y2 = new_y1 + target_h
309
+ if new_y2 > H:
310
+ new_y2 = H
311
+ new_y1 = H - target_h
312
+
313
+ elif req_w > w0:
314
+ target_w = min(W, req_w)
315
+ delta = target_w - w0
316
+ if flush_right:
317
+ new_x1, new_x2 = W - target_w, W
318
+ elif flush_left:
319
+ new_x1, new_x2 = 0, target_w
320
+ else:
321
+ off = delta // 2
322
+ new_x1 = max(0, x1 - off)
323
+ new_x2 = new_x1 + target_w
324
+ if new_x2 > W:
325
+ new_x2 = W
326
+ new_x1 = W - target_w
327
+
328
+ final_w = new_x2 - new_x1
329
+ final_h = new_y2 - new_y1
330
+
331
+ # Return cropped mask for compositing
332
+ mask_for_composite = mask_blur_full[:, new_y1:new_y2, new_x1:new_x2]
333
+
334
+ return (
335
+ mask_for_composite.to(mask_chw.device, dtype=mask_chw.dtype),
336
+ int(new_x1),
337
+ int(new_y1),
338
+ int(final_w),
339
+ int(final_h),
340
+ )
341
+
342
+ def image_alpha_fix(self, destination, source):
343
+ """Ensure destination and source tensors have compatible channel dimensions.
344
+
345
+ Args:
346
+ destination (torch.Tensor): Destination tensor
347
+ source (torch.Tensor): Source tensor
348
+
349
+ Returns:
350
+ tuple: (destination, source) with compatible dimensions
351
+ """
352
+ dest_channels = destination.shape[-1]
353
+ source_channels = source.shape[-1]
354
+
355
+ if dest_channels == source_channels:
356
+ return destination, source
357
+
358
+ if dest_channels > source_channels:
359
+ # Pad source to match destination's channel count
360
+ padding = torch.ones(
361
+ (*source.shape[:-1], dest_channels - source_channels),
362
+ device=source.device,
363
+ dtype=source.dtype,
364
+ )
365
+ source = torch.cat([source, padding], dim=-1)
366
+ else: # source_channels > dest_channels
367
+ # Truncate source to match destination's channel count
368
+ source = source[..., :dest_channels]
369
+
370
+ return destination, source
371
+
372
+ def repeat_to_batch_size(self, tensor, batch_size):
373
+ """Adjust tensor batch size by repeating or truncating as needed.
374
+
375
+ Args:
376
+ tensor (torch.Tensor): Input tensor
377
+ batch_size (int): Target batch size
378
+
379
+ Returns:
380
+ torch.Tensor: Tensor with the specified batch size
381
+ """
382
+ if tensor.shape[0] > batch_size:
383
+ return tensor[:batch_size]
384
+ elif tensor.shape[0] < batch_size:
385
+ return tensor.repeat(batch_size, 1, 1, 1)
386
+ return tensor
387
+
388
+ def composite(
389
+ self, destination, source, x, y, mask=None, multiplier=1, resize_source=False
390
+ ):
391
+ """Composite source image onto destination at specified coordinates.
392
+
393
+ Args:
394
+ destination (torch.Tensor): Destination image tensor
395
+ source (torch.Tensor): Source image tensor
396
+ x (int): X coordinate for placement
397
+ y (int): Y coordinate for placement
398
+ mask (torch.Tensor, optional): Alpha mask for blending
399
+ multiplier (int): Coordinate multiplier
400
+ resize_source (bool): Whether to resize source to match destination
401
+
402
+ Returns:
403
+ torch.Tensor: Composited image tensor
404
+ """
405
+ source = source.to(destination.device)
406
+ if resize_source:
407
+ source = torch.nn.functional.interpolate(
408
+ source,
409
+ size=(destination.shape[2], destination.shape[3]),
410
+ mode="bilinear",
411
+ )
412
+
413
+ source = self.repeat_to_batch_size(source, destination.shape[0])
414
+
415
+ x = max(
416
+ -source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)
417
+ )
418
+ y = max(
419
+ -source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)
420
+ )
421
+
422
+ left, top = (x // multiplier, y // multiplier)
423
+
424
+ if mask is None:
425
+ mask = torch.ones_like(source)
426
+ else:
427
+ mask = mask.to(destination.device, copy=True)
428
+ mask = torch.nn.functional.interpolate(
429
+ mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])),
430
+ size=(source.shape[2], source.shape[3]),
431
+ mode="bilinear",
432
+ )
433
+ mask = self.repeat_to_batch_size(mask, source.shape[0])
434
+
435
+ visible_width = max(0, min(source.shape[3], destination.shape[3] - left))
436
+ visible_height = max(0, min(source.shape[2], destination.shape[2] - top))
437
+
438
+ if visible_width == 0 or visible_height == 0:
439
+ return destination
440
+
441
+ source_portion = source[:, :, :visible_height, :visible_width]
442
+ mask_portion = mask[:, :, :visible_height, :visible_width]
443
+ inverse_mask_portion = torch.ones_like(mask_portion) - mask_portion
444
+
445
+ destination_portion = destination[
446
+ :, :, top : top + visible_height, left : left + visible_width
447
+ ]
448
+ # Alpha blend source and destination using mask
449
+ blended_portion = (source_portion * mask_portion) + (
450
+ destination_portion * inverse_mask_portion
451
+ )
452
+ destination[:, :, top : top + visible_height, left : left + visible_width] = (
453
+ blended_portion
454
+ )
455
+
456
+ return destination
457
+
458
+ def image_composite_masked(
459
+ self, destination, source, x, y, resize_source, mask=None
460
+ ):
461
+ """Wrapper function that handles channel dimension compatibility.
462
+
463
+ Args:
464
+ destination (torch.Tensor): Destination image tensor
465
+ source (torch.Tensor): Source image tensor
466
+ x (int): X coordinate for placement
467
+ y (int): Y coordinate for placement
468
+ resize_source (bool): Whether to resize source to match destination
469
+ mask (torch.Tensor, optional): Alpha mask for blending
470
+
471
+ Returns:
472
+ torch.Tensor: Composited image tensor
473
+ """
474
+ destination, source = self.image_alpha_fix(destination, source)
475
+ destination = destination.clone().movedim(-1, 1)
476
+ output = self.composite(
477
+ destination, source.movedim(-1, 1), x, y, mask, 1, resize_source
478
+ ).movedim(1, -1)
479
+ return output
480
+
481
+ def inpaint_mask(
482
+ self,
483
+ image_pil: Image.Image,
484
+ mask_np: np.ndarray,
485
+ seed: int = 1,
486
+ verbose: bool = False,
487
+ ocr_params: Optional[Dict] = None,
488
+ strict_mask_clipping: bool = False,
489
+ composite_clip_bbox: Optional[Tuple[int, int, int, int]] = None,
490
+ ) -> Image.Image:
491
+ """Inpaint a specific mask region in the image.
492
+
493
+ Args:
494
+ image_pil: PIL Image to inpaint
495
+ mask_np: Numpy mask array (H, W) with True for areas to inpaint
496
+ seed: Random seed for inference
497
+ verbose: Whether to print verbose output
498
+ ocr_params: Optional OCR parameters dict for cache key generation
499
+ strict_mask_clipping: When True, ensure compositing is limited to the
500
+ original mask extent (no bleed from padding/blur)
501
+ composite_clip_bbox: Optional (x1, y1, x2, y2) bbox to clip the final
502
+ composite mask to, in original image coordinates.
503
+
504
+ Returns:
505
+ PIL.Image: The inpainted image
506
+ """
507
+ mask_np = np.asarray(mask_np)
508
+ if mask_np.dtype != bool:
509
+ mask_np = mask_np.astype(bool)
510
+
511
+ if not np.any(mask_np):
512
+ return image_pil
513
+
514
+ log_message(
515
+ " - Computing optimized mask bounding box with blur and aspect ratio...",
516
+ verbose=verbose,
517
+ )
518
+
519
+ ys, xs = np.where(mask_np)
520
+ if len(ys) == 0 or len(xs) == 0:
521
+ return image_pil
522
+
523
+ x_min, x_max = int(xs.min()), int(xs.max())
524
+ y_min, y_max = int(ys.min()), int(ys.max())
525
+
526
+ bbox_width = x_max - x_min
527
+ bbox_height = y_max - y_min
528
+
529
+ padding_pixels = int(max(bbox_width, bbox_height) * self.context_padding_ratio)
530
+ padding = min(padding_pixels, self.max_context_padding)
531
+ log_message(
532
+ f" - Proportional context padding: {padding_pixels}px, capped to: {padding}px",
533
+ verbose=verbose,
534
+ )
535
+
536
+ blur_radius = int(max(bbox_width, bbox_height) * BLUR_SCALE_FACTOR)
537
+ blur_radius = max(
538
+ MIN_BLUR_RADIUS, min(blur_radius, MAX_BLUR_RADIUS)
539
+ ) # clamp between MIN and MAX
540
+ log_message(f" - Dynamic blur radius set to: {blur_radius}", verbose=verbose)
541
+
542
+ mask_tensor = (
543
+ torch.from_numpy(mask_np.astype(np.float32)).unsqueeze(0).unsqueeze(0)
544
+ )
545
+
546
+ mask_for_composite, x, y, width, height = self.compute_mask_bbox_aspect_ratio(
547
+ mask_chw=mask_tensor,
548
+ padding=padding,
549
+ blur_radius=blur_radius,
550
+ preferred_resolutions=self.PREFERED_KONTEXT_RESOLUTIONS,
551
+ transpose=False,
552
+ verbose=verbose,
553
+ )
554
+
555
+ # Quantize bbox to improve cache stability against minor detection jitter
556
+ quant = 2
557
+ img_h, img_w = mask_np.shape
558
+ qx1 = max(0, min(img_w, int(round(x / quant) * quant)))
559
+ qy1 = max(0, min(img_h, int(round(y / quant) * quant)))
560
+ qx2 = max(qx1 + 1, min(img_w, int(round((x + width) / quant) * quant)))
561
+ qy2 = max(qy1 + 1, min(img_h, int(round((y + height) / quant) * quant)))
562
+ qwidth = max(1, qx2 - qx1)
563
+ qheight = max(1, qy2 - qy1)
564
+
565
+ # Adjust mask_for_composite to the quantized bbox via pad/crop
566
+ dx_left = x - qx1
567
+ dy_top = y - qy1
568
+ dx_right = (qx1 + qwidth) - (x + width)
569
+ dy_bottom = (qy1 + qheight) - (y + height)
570
+
571
+ if dx_left > 0 or dx_right > 0 or dy_top > 0 or dy_bottom > 0:
572
+ pad_l = max(dx_left, 0)
573
+ pad_r = max(dx_right, 0)
574
+ pad_t = max(dy_top, 0)
575
+ pad_b = max(dy_bottom, 0)
576
+ mask_for_composite = torch.nn.functional.pad(
577
+ mask_for_composite, (pad_l, pad_r, pad_t, pad_b)
578
+ )
579
+
580
+ if dx_left < 0:
581
+ mask_for_composite = mask_for_composite[:, :, -dx_left:]
582
+ if dy_top < 0:
583
+ mask_for_composite = mask_for_composite[:, -dy_top:, :]
584
+ if mask_for_composite.shape[-1] > qwidth:
585
+ mask_for_composite = mask_for_composite[:, :, :qwidth]
586
+ if mask_for_composite.shape[-2] > qheight:
587
+ mask_for_composite = mask_for_composite[:, :qheight, :]
588
+
589
+ x, y, width, height = qx1, qy1, qwidth, qheight
590
+
591
+ if strict_mask_clipping:
592
+ original_mask_crop = mask_tensor[0, 0, y : y + height, x : x + width]
593
+ mask_for_composite = mask_for_composite * original_mask_crop
594
+
595
+ if composite_clip_bbox is not None:
596
+ clip_x1, clip_y1, clip_x2, clip_y2 = composite_clip_bbox
597
+
598
+ img_h, img_w = mask_np.shape
599
+ clip_x1 = max(0, min(img_w, clip_x1))
600
+ clip_x2 = max(0, min(img_w, clip_x2))
601
+ clip_y1 = max(0, min(img_h, clip_y1))
602
+ clip_y2 = max(0, min(img_h, clip_y2))
603
+
604
+ start_x = max(0, clip_x1 - x)
605
+ end_x = min(width, clip_x2 - x)
606
+ start_y = max(0, clip_y1 - y)
607
+ end_y = min(height, clip_y2 - y)
608
+
609
+ if end_x <= start_x or end_y <= start_y:
610
+ mask_for_composite = torch.zeros_like(mask_for_composite)
611
+ else:
612
+ clipped_mask = torch.zeros_like(mask_for_composite)
613
+ clipped_mask[:, start_y:end_y, start_x:end_x] = mask_for_composite[
614
+ :, start_y:end_y, start_x:end_x
615
+ ]
616
+ mask_for_composite = clipped_mask
617
+
618
+ log_message(
619
+ f" - Optimized bbox found at ({x}, {y}) with size {width}x{height}",
620
+ verbose=verbose,
621
+ )
622
+
623
+ image_cropped_pil = image_pil.crop((x, y, x + width, y + height))
624
+ mask_crop_np = mask_np[y : y + height, x : x + width]
625
+
626
+ cache_params = {
627
+ "bbox": (x, y, width, height),
628
+ "padding": padding,
629
+ "blur": blur_radius,
630
+ }
631
+ if strict_mask_clipping:
632
+ cache_params["strict_clip"] = True
633
+ if composite_clip_bbox is not None:
634
+ cache_params["clip_bbox"] = tuple(composite_clip_bbox)
635
+ if ocr_params:
636
+ cache_params.update(ocr_params)
637
+
638
+ cache_key = None
639
+ cached_patch = None
640
+ if self.cache.should_use_inpaint_cache(seed):
641
+ # Downsample mask signature to reduce sensitivity to minor jitter
642
+ if mask_crop_np.size > 0:
643
+ sig_h = min(64, max(4, mask_crop_np.shape[0]))
644
+ sig_w = min(64, max(4, mask_crop_np.shape[1]))
645
+ mask_sig = (
646
+ torch.from_numpy(mask_crop_np.astype(np.float32))
647
+ .unsqueeze(0)
648
+ .unsqueeze(0)
649
+ )
650
+ mask_sig = torch.nn.functional.interpolate(
651
+ mask_sig, size=(sig_h, sig_w), mode="bilinear", align_corners=False
652
+ )
653
+ mask_sig_np = (mask_sig > 0.5).cpu().numpy().astype(np.uint8)[0, 0]
654
+ else:
655
+ mask_sig_np = mask_crop_np
656
+
657
+ cache_key = self.cache.get_inpaint_cache_key(
658
+ image_cropped_pil,
659
+ mask_sig_np,
660
+ seed,
661
+ self.num_inference_steps,
662
+ self.residual_diff_threshold,
663
+ self.guidance_scale,
664
+ self.prompt,
665
+ cache_params,
666
+ )
667
+ cached_patch = self.cache.get_inpainted_image(cache_key)
668
+ if cached_patch is not None:
669
+ log_message(" - Using cached inpainting patch", verbose=verbose)
670
+
671
+ patch_pil = cached_patch
672
+
673
+ if patch_pil is None:
674
+ self.load_models()
675
+
676
+ if self.pipeline is None:
677
+ log_message(
678
+ "Warning: Flux Kontext pipeline not available. Skipping inpainting.",
679
+ always_print=True,
680
+ )
681
+ return image_pil
682
+
683
+ image_scaled_for_inference_pil = self.flux_kontext_image_scale(
684
+ image_cropped_pil
685
+ )
686
+ inference_width, inference_height = image_scaled_for_inference_pil.size
687
+
688
+ if image_scaled_for_inference_pil.mode == "RGBA":
689
+ image_scaled_for_inference_pil = image_scaled_for_inference_pil.convert(
690
+ "RGB"
691
+ )
692
+
693
+ log_message(" - Running inference...", verbose=verbose)
694
+
695
+ self.pipeline.text_encoder_2.to(self.DEVICE)
696
+
697
+ prompt_embeds, pooled_prompt_embeds, _ = self.pipeline.encode_prompt(
698
+ prompt=self.prompt,
699
+ prompt_2=None,
700
+ device=self.DEVICE,
701
+ )
702
+
703
+ self.pipeline.text_encoder_2.to("cpu")
704
+ gc.collect()
705
+ torch.cuda.empty_cache()
706
+
707
+ self.pipeline.transformer.to(self.DEVICE)
708
+
709
+ required_area = inference_width * inference_height
710
+ with torch.inference_mode():
711
+ gen = torch.Generator(device=self.DEVICE).manual_seed(seed)
712
+ out = self.pipeline(
713
+ prompt_embeds=prompt_embeds,
714
+ pooled_prompt_embeds=pooled_prompt_embeds,
715
+ image=image_scaled_for_inference_pil,
716
+ width=inference_width,
717
+ height=inference_height,
718
+ num_inference_steps=self.num_inference_steps,
719
+ guidance_scale=self.guidance_scale,
720
+ generator=gen,
721
+ output_type="pt",
722
+ max_area=required_area,
723
+ )
724
+ img = out.images[0]
725
+ torch.nan_to_num_(img, nan=0.0, posinf=1.0, neginf=0.0)
726
+ img.clamp_(0, 1)
727
+ generated_patch_pil = Image.fromarray(
728
+ (
729
+ img.mul(255)
730
+ .round()
731
+ .to(torch.uint8)
732
+ .permute(1, 2, 0)
733
+ .cpu()
734
+ .numpy()
735
+ )
736
+ )
737
+
738
+ self.pipeline.transformer.to("cpu")
739
+ gc.collect()
740
+ torch.cuda.empty_cache()
741
+
742
+ patch_pil = generated_patch_pil.resize(
743
+ (width, height), Image.Resampling.LANCZOS
744
+ )
745
+
746
+ dest_tensor = torch.from_numpy(
747
+ np.asarray(image_pil, dtype=np.float32) / 255.0
748
+ ).unsqueeze(0)
749
+ src_tensor = torch.from_numpy(
750
+ np.asarray(patch_pil, dtype=np.float32) / 255.0
751
+ ).unsqueeze(0)
752
+
753
+ composited_tensor = self.image_composite_masked(
754
+ destination=dest_tensor,
755
+ source=src_tensor,
756
+ x=x,
757
+ y=y,
758
+ resize_source=False,
759
+ mask=mask_for_composite,
760
+ )
761
+
762
+ composited_pil = Image.fromarray(
763
+ (composited_tensor[0].cpu().numpy() * 255).astype("uint8")
764
+ )
765
+
766
+ if (
767
+ self.cache.should_use_inpaint_cache(seed)
768
+ and cache_key is not None
769
+ and cached_patch is None
770
+ ):
771
+ self.cache.set_inpainted_image(cache_key, patch_pil)
772
+
773
+ return composited_pil
core/image/ocr_detection.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from core.caching import get_cache
10
+ from core.ml.model_manager import ModelType, get_model_manager
11
+ from utils.exceptions import ImageProcessingError
12
+ from utils.logging import log_message
13
+
14
+
15
+ class OutsideTextDetector:
16
+ """Detects text outside speech bubbles to isolate SFX/captions from dialogue."""
17
+
18
+ def __init__(
19
+ self,
20
+ device: Optional[torch.device] = None,
21
+ hf_token: Optional[str] = None,
22
+ ):
23
+ """Initialize the outside text detector.
24
+
25
+ Args:
26
+ device: PyTorch device to use. Auto-detects if None.
27
+ hf_token: Hugging Face token for gated repo access.
28
+ """
29
+ self.device = (
30
+ device
31
+ if device is not None
32
+ else torch.device(
33
+ "cuda"
34
+ if torch.cuda.is_available()
35
+ else "mps"
36
+ if torch.backends.mps.is_available()
37
+ else "cpu"
38
+ )
39
+ )
40
+ self.hf_token = hf_token
41
+ self.manager = get_model_manager()
42
+ self.cache = get_cache()
43
+
44
+ def boxes_overlap(self, box1, box2):
45
+ """Check if two bounding boxes overlap (have non-zero intersection).
46
+
47
+ Args:
48
+ box1: Bounding box in [x_min, y_min, x_max, y_max] format.
49
+ box2: Bounding box in YOLO format [x_min, y_min, x_max, y_max].
50
+
51
+ Returns:
52
+ bool: True if boxes overlap, False otherwise.
53
+ """
54
+ x1_min, y1_min, x1_max, y1_max = box1
55
+ x2_min, y2_min, x2_max, y2_max = box2
56
+
57
+ return not (
58
+ x1_max <= x2_min or x2_max <= x1_min or y1_max <= y2_min or y2_max <= y1_min
59
+ )
60
+
61
+ def box_is_inside(self, box1, box2):
62
+ """Check if box1 is completely inside box2.
63
+
64
+ Args:
65
+ box1: Bounding box in [x1, y1, x2, y2] format.
66
+ box2: Bounding box in [x1, y1, x2, y2] format.
67
+
68
+ Returns:
69
+ bool: True if box1 is completely inside box2, False otherwise.
70
+ """
71
+ x1_min, y1_min, x1_max, y1_max = box1
72
+ x2_min, y2_min, x2_max, y2_max = box2
73
+
74
+ return (
75
+ x1_min >= x2_min
76
+ and x1_max <= x2_max
77
+ and y1_min >= y2_min
78
+ and y1_max <= y2_max
79
+ )
80
+
81
+ def filter_nested_detections(self, results):
82
+ """Remove detections fully contained in larger ones to avoid duplicates.
83
+
84
+ Args:
85
+ results: List of detection results (bbox, text, confidence).
86
+
87
+ Returns:
88
+ list: Filtered results with nested detections removed.
89
+ """
90
+ if len(results) <= 1:
91
+ return results
92
+
93
+ # Prioritize larger detections to avoid removing important text
94
+ def get_area(result):
95
+ bbox = result[0]
96
+ x_min, y_min, x_max, y_max = bbox
97
+ return (x_max - x_min) * (y_max - y_min)
98
+
99
+ sorted_results = sorted(results, key=get_area, reverse=True)
100
+ filtered_results = []
101
+
102
+ for i, current_result in enumerate(sorted_results):
103
+ is_nested = False
104
+ current_bbox = current_result[0]
105
+
106
+ for kept_result in filtered_results:
107
+ kept_bbox = kept_result[0]
108
+ if self.box_is_inside(current_bbox, kept_bbox):
109
+ is_nested = True
110
+ break
111
+
112
+ if not is_nested:
113
+ filtered_results.append(current_result)
114
+
115
+ return filtered_results
116
+
117
+ def unload_models(self):
118
+ """Unload OCR models via model manager to free GPU/CPU memory."""
119
+ self.manager.unload_ocr_models()
120
+
121
+ def detect_outside_text(
122
+ self,
123
+ image_path: str,
124
+ yolo_model_path: Optional[str] = None,
125
+ confidence: float = 0.6,
126
+ conjoined_confidence: float = 0.35,
127
+ verbose: bool = False,
128
+ image_override: Optional[Image.Image] = None,
129
+ existing_bubbles: Optional[List] = None,
130
+ text_free_boxes: Optional[List] = None,
131
+ ):
132
+ """Detect non-dialogue text by subtracting YOLO speech bubbles from OCR results.
133
+
134
+ Args:
135
+ image_path: Path to the input image.
136
+ yolo_model_path: Optional custom YOLO model path.
137
+ confidence: Confidence threshold for primary YOLO model detections.
138
+ conjoined_confidence: Confidence threshold for secondary YOLO model (conjoined bubble detection).
139
+ verbose: If True, logs intermediate steps.
140
+ text_free_boxes: Optional list of text_free regions to use as fallback OSB detections.
141
+
142
+ Returns:
143
+ list: Detected regions outside bubbles as (bbox, confidence).
144
+ """
145
+ if image_override is None and not os.path.exists(image_path):
146
+ raise FileNotFoundError(f"Error: The file '{image_path}' was not found.")
147
+
148
+ try:
149
+ if image_override is not None:
150
+ image_pil = (
151
+ image_override
152
+ if image_override.mode == "RGB"
153
+ else image_override.convert("RGB")
154
+ )
155
+ image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
156
+ else:
157
+ image_cv = cv2.imread(str(image_path))
158
+ if image_cv is None:
159
+ raise ImageProcessingError(f"Could not read image at {image_path}")
160
+ image_pil = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
161
+ image_name = image_path if image_override is None else "override"
162
+ log_message(
163
+ f"Processing image: {image_name} "
164
+ f"({image_cv.shape[1]}x{image_cv.shape[0]})",
165
+ verbose=verbose,
166
+ )
167
+ except Exception as e:
168
+ raise ImageProcessingError(f"Error loading image: {e}")
169
+
170
+ provided_bubble_boxes = None
171
+ if existing_bubbles is not None:
172
+ try:
173
+ provided_bubble_boxes = []
174
+ for b in existing_bubbles:
175
+ bbox = b.get("bbox") if isinstance(b, dict) else b
176
+ if bbox is None or len(bbox) != 4:
177
+ continue
178
+ x0, y0, x1, y1 = bbox
179
+ provided_bubble_boxes.append(
180
+ [float(x0), float(y0), float(x1), float(y1)]
181
+ )
182
+ if provided_bubble_boxes:
183
+ log_message(
184
+ f"Using {len(provided_bubble_boxes)} provided bubble boxes for OSB filtering",
185
+ verbose=verbose,
186
+ )
187
+ except Exception as e:
188
+ log_message(
189
+ f"Warning: Failed to parse provided bubbles: {e}. Falling back to YOLO.",
190
+ always_print=True,
191
+ )
192
+ provided_bubble_boxes = None
193
+
194
+ text_free_boxes = list(text_free_boxes) if text_free_boxes else []
195
+
196
+ if provided_bubble_boxes:
197
+ yolo_boxes = torch.tensor(
198
+ provided_bubble_boxes, device=self.device, dtype=torch.float32
199
+ )
200
+ num_yolo_boxes = len(yolo_boxes)
201
+ log_message(
202
+ f"Skipping YOLO; using provided bubbles ({num_yolo_boxes})",
203
+ verbose=verbose,
204
+ )
205
+ else:
206
+ log_message("Running YOLO detection for speech bubbles...", verbose=verbose)
207
+
208
+ sb_model_path = (
209
+ str(self.manager.model_paths[ModelType.YOLO_SPEECH_BUBBLE])
210
+ if yolo_model_path is None
211
+ else yolo_model_path
212
+ )
213
+ sb_cache_key = self.cache.get_yolo_cache_key(
214
+ image_pil, sb_model_path, confidence
215
+ )
216
+ cached_sb = self.cache.get_yolo_detection(sb_cache_key)
217
+
218
+ if cached_sb is not None:
219
+ log_message("Using cached Speech Bubble detections", verbose=verbose)
220
+ yolo_results, yolo_boxes = cached_sb
221
+ else:
222
+ yolo_model = self.manager.load_yolo_speech_bubble(yolo_model_path)
223
+ yolo_results = yolo_model(
224
+ image_cv, conf=confidence, device=self.device, verbose=False
225
+ )[0]
226
+ yolo_boxes = (
227
+ yolo_results.boxes.xyxy
228
+ if yolo_results.boxes is not None
229
+ else torch.tensor([])
230
+ )
231
+ self.cache.set_yolo_detection(sb_cache_key, (yolo_results, yolo_boxes))
232
+
233
+ num_yolo_boxes = len(yolo_boxes) if yolo_boxes.nelement() > 0 else 0
234
+ log_message(
235
+ f"YOLO detected {num_yolo_boxes} speech bubbles", verbose=verbose
236
+ )
237
+
238
+ log_message(
239
+ "Running Secondary YOLO to catch missed bubbles...", verbose=verbose
240
+ )
241
+ try:
242
+ sec_model = self.manager.load_yolo_conjoined_bubble()
243
+ sec_results = sec_model(
244
+ image_cv,
245
+ conf=conjoined_confidence,
246
+ device=self.device,
247
+ verbose=False,
248
+ )[0]
249
+
250
+ sec_boxes = (
251
+ sec_results.boxes.xyxy
252
+ if sec_results.boxes is not None
253
+ else torch.tensor([])
254
+ )
255
+ sec_cls = (
256
+ sec_results.boxes.cls
257
+ if sec_results.boxes is not None
258
+ else torch.tensor([])
259
+ )
260
+
261
+ # Find text_bubble and text_free classes
262
+ tb_id = None
263
+ tf_id = None
264
+ if hasattr(sec_model, "names"):
265
+ for cid, cname in sec_model.names.items():
266
+ if cname == "text_bubble":
267
+ tb_id = cid
268
+ elif cname == "text_free":
269
+ tf_id = cid
270
+
271
+ if tf_id is not None and len(sec_boxes) > 0:
272
+ for i, cls_id in enumerate(sec_cls):
273
+ if int(cls_id) == tf_id:
274
+ text_free_boxes.append(sec_boxes[i].detach().cpu().numpy())
275
+
276
+ if tb_id is not None and len(sec_boxes) > 0:
277
+ boxes_to_add = []
278
+ for i, cls_id in enumerate(sec_cls):
279
+ if int(cls_id) == tb_id:
280
+ boxes_to_add.append(sec_boxes[i])
281
+
282
+ if boxes_to_add:
283
+ log_message(
284
+ f"Secondary YOLO found {len(boxes_to_add)} potential bubbles",
285
+ verbose=verbose,
286
+ )
287
+ boxes_to_add_tensor = torch.stack(boxes_to_add)
288
+ if yolo_boxes.nelement() > 0:
289
+ yolo_boxes = torch.cat(
290
+ (yolo_boxes, boxes_to_add_tensor), dim=0
291
+ )
292
+ else:
293
+ yolo_boxes = boxes_to_add_tensor
294
+ except Exception as e:
295
+ log_message(f"Secondary YOLO failed: {e}", verbose=verbose)
296
+
297
+ log_message("Running YOLO OSB Text...", always_print=True)
298
+
299
+ osbtext_boxes = None
300
+ osbtext_confs = None
301
+ try:
302
+ osbtext_model_path = str(self.manager.model_paths[ModelType.YOLO_OSBTEXT])
303
+ osbtext_cache_key = self.cache.get_yolo_cache_key(
304
+ image_pil, osbtext_model_path, confidence
305
+ )
306
+
307
+ cached_osbtext = self.cache.get_yolo_detection(osbtext_cache_key)
308
+
309
+ if cached_osbtext is not None:
310
+ log_message("Using cached OSBText detections", verbose=verbose)
311
+ osbtext_results, osbtext_boxes, osbtext_confs = cached_osbtext
312
+ else:
313
+ osbtext_model = self.manager.load_yolo_osbtext(token=self.hf_token)
314
+ osbtext_results = osbtext_model(
315
+ image_cv, conf=confidence, device=self.device, verbose=False
316
+ )[0]
317
+ osbtext_boxes = (
318
+ osbtext_results.boxes.xyxy
319
+ if osbtext_results.boxes is not None
320
+ else None
321
+ )
322
+ osbtext_confs = (
323
+ osbtext_results.boxes.conf
324
+ if osbtext_results.boxes is not None
325
+ else None
326
+ )
327
+ self.cache.set_yolo_detection(
328
+ osbtext_cache_key, (osbtext_results, osbtext_boxes, osbtext_confs)
329
+ )
330
+ except Exception as e:
331
+ log_message(
332
+ f"OSB text model unavailable: {e}. Using text_free fallback if available.",
333
+ always_print=True,
334
+ )
335
+ if text_free_boxes:
336
+ log_message(
337
+ f"Using {len(text_free_boxes)} text_free boxes as OSB fallback",
338
+ always_print=True,
339
+ )
340
+ osbtext_boxes = torch.tensor(
341
+ text_free_boxes, device=self.device, dtype=torch.float32
342
+ )
343
+ osbtext_confs = torch.ones(
344
+ len(text_free_boxes), device=self.device, dtype=torch.float32
345
+ )
346
+ else:
347
+ log_message(
348
+ "No text_free fallback available; skipping OSB text detections",
349
+ always_print=True,
350
+ )
351
+
352
+ base_results = []
353
+ if osbtext_boxes is not None:
354
+ boxes_np = osbtext_boxes.detach().cpu().numpy()
355
+ confs_np = osbtext_confs.detach().cpu().numpy()
356
+
357
+ for i, box in enumerate(boxes_np):
358
+ conf = confs_np[i]
359
+ base_results.append((box, float(conf)))
360
+
361
+ final_results = list(base_results)
362
+
363
+ log_message("Filtering out nested detections...", verbose=verbose)
364
+ before_nested_filter = len(final_results)
365
+ final_results = self.filter_nested_detections(final_results)
366
+ after_nested_filter = len(final_results)
367
+ nested_removed = before_nested_filter - after_nested_filter
368
+ log_message(
369
+ f"Nested detections removed: {nested_removed}. Remaining detections: {after_nested_filter}.",
370
+ verbose=verbose,
371
+ )
372
+
373
+ if yolo_boxes is not None and yolo_boxes.nelement() > 0:
374
+ log_message(
375
+ "Filtering OCR results to keep text outside speech bubbles...",
376
+ verbose=verbose,
377
+ )
378
+ filtered_results = []
379
+ yolo_boxes_np = yolo_boxes.detach().cpu().numpy()
380
+
381
+ for ocr_result in final_results:
382
+ bbox, _ = ocr_result
383
+
384
+ overlaps_any_bubble = False
385
+
386
+ for yolo_box in yolo_boxes_np:
387
+ if self.boxes_overlap(bbox, yolo_box):
388
+ # Check if this bubble is actually a text_free region
389
+ is_text_free_bubble = False
390
+ if text_free_boxes:
391
+ for tf_box in text_free_boxes:
392
+ # We check if the YOLO bubble overlaps with a text_free detection
393
+ if self.boxes_overlap(yolo_box, tf_box):
394
+ is_text_free_bubble = True
395
+ break
396
+
397
+ if not is_text_free_bubble:
398
+ overlaps_any_bubble = True
399
+ break
400
+
401
+ if not overlaps_any_bubble:
402
+ filtered_results.append(ocr_result)
403
+
404
+ filtered_out = len(final_results) - len(filtered_results)
405
+ log_message(
406
+ f"Filtered out {filtered_out} OCR results that overlapped with speech bubbles",
407
+ verbose=verbose,
408
+ )
409
+ final_results = filtered_results
410
+
411
+ log_message(
412
+ f"Found {len(final_results)} outside text regions", always_print=True
413
+ )
414
+
415
+ return final_results
416
+
417
+ def get_text_masks(
418
+ self,
419
+ image_path: str,
420
+ bbox_expansion_percent: float = 0.0,
421
+ text_box_proximity_ratio: float = 0.02,
422
+ verbose: bool = False,
423
+ image_override: Optional[Image.Image] = None,
424
+ existing_results: Optional[List] = None,
425
+ ) -> Tuple[Optional[List], Optional[Image.Image]]:
426
+ """Create rectangular masks from OCR bounding boxes for inpainting.
427
+
428
+ Args:
429
+ image_path: Path to the input image.
430
+ bbox_expansion_percent: Percentage to expand bounding boxes.
431
+ text_box_proximity_ratio: Ratio for grouping nearby text boxes (as fraction of image dimension).
432
+ verbose: Whether to print verbose output.
433
+
434
+ Returns:
435
+ tuple: (groups, image_pil) where groups is a list of dicts with:
436
+ {
437
+ 'combined_mask': np.array[H,W,bool],
438
+ 'bbox': dict,
439
+ 'individual_masks': [np.array],
440
+ 'mask_indices': [int],
441
+ 'confidence': float,
442
+ }.
443
+ """
444
+ results = (
445
+ existing_results
446
+ if existing_results is not None
447
+ else self.detect_outside_text(
448
+ image_path,
449
+ verbose=verbose,
450
+ image_override=image_override,
451
+ )
452
+ )
453
+
454
+ if not results:
455
+ return None, None
456
+
457
+ if image_override is not None:
458
+ image_pil = (
459
+ image_override.convert("RGB")
460
+ if image_override.mode != "RGB"
461
+ else image_override
462
+ )
463
+ else:
464
+ image_pil = Image.open(image_path).convert("RGB")
465
+ img_w, img_h = image_pil.size
466
+
467
+ log_message("Converting OCR results to axis-aligned boxes...", verbose=verbose)
468
+ boxes = [[int(c) for c in result[0]] for result in results]
469
+
470
+ expanded_boxes = []
471
+ for box in boxes:
472
+ x0, y0, x1, y1 = box
473
+ width = x1 - x0
474
+ height = y1 - y0
475
+ expand_x = width * bbox_expansion_percent
476
+ expand_y = height * bbox_expansion_percent
477
+ x0e = int(np.floor(max(0, x0 - expand_x)))
478
+ y0e = int(np.floor(max(0, y0 - expand_y)))
479
+ x1e = int(np.ceil(min(img_w, x1 + expand_x)))
480
+ y1e = int(np.ceil(min(img_h, y1 + expand_y)))
481
+ if x1e > x0e and y1e > y0e:
482
+ expanded_boxes.append([x0e, y0e, x1e, y1e])
483
+
484
+ log_message(
485
+ f"Grouping {len(expanded_boxes)} text boxes spatially...",
486
+ verbose=verbose,
487
+ )
488
+
489
+ grouped_boxes = self._group_text_boxes_spatially(
490
+ expanded_boxes, results, img_w, img_h, text_box_proximity_ratio, verbose
491
+ )
492
+
493
+ groups = []
494
+ for group_boxes, group_results in grouped_boxes:
495
+ combined_mask = np.zeros((img_h, img_w), dtype=bool)
496
+ individual_masks = []
497
+ mask_indices = []
498
+ avg_confidence = 0.0
499
+
500
+ min_x = min(box[0] for box in group_boxes)
501
+ min_y = min(box[1] for box in group_boxes)
502
+ max_x = max(box[2] for box in group_boxes)
503
+ max_y = max(box[3] for box in group_boxes)
504
+
505
+ # Ensure combined region doesn't exceed Flux Kontext preferred resolutions
506
+ max_dimension = 1568
507
+ if max_x - min_x > max_dimension or max_y - min_y > max_dimension:
508
+ log_message(
509
+ f" - Group too large ({max_x - min_x}x{max_y - min_y}), splitting...",
510
+ verbose=verbose,
511
+ )
512
+ for i, (box, result) in enumerate(zip(group_boxes, group_results)):
513
+ x0, y0, x1, y1 = box
514
+ mask = np.zeros((img_h, img_w), dtype=bool)
515
+ mask[y0:y1, x0:x1] = True
516
+
517
+ bbox = {
518
+ "x": int(x0),
519
+ "y": int(y0),
520
+ "width": int(x1 - x0),
521
+ "height": int(y1 - y0),
522
+ }
523
+
524
+ raw_box = [int(c) for c in result[0]]
525
+ raw_x0, raw_y0, raw_x1, raw_y1 = raw_box
526
+ original_bbox = {
527
+ "x": raw_x0,
528
+ "y": raw_y0,
529
+ "width": raw_x1 - raw_x0,
530
+ "height": raw_y1 - raw_y0,
531
+ }
532
+
533
+ _, conf = result
534
+
535
+ groups.append(
536
+ {
537
+ "combined_mask": mask,
538
+ "bbox": bbox,
539
+ "original_bbox": original_bbox,
540
+ "individual_masks": [mask],
541
+ "mask_indices": [i],
542
+ "confidence": conf,
543
+ }
544
+ )
545
+ continue
546
+
547
+ raw_boxes = [[int(c) for c in res[0]] for res in group_results]
548
+
549
+ for i, (box, result, raw_box) in enumerate(
550
+ zip(group_boxes, group_results, raw_boxes)
551
+ ):
552
+ x0, y0, x1, y1 = box
553
+ mask = np.zeros((img_h, img_w), dtype=bool)
554
+ mask[y0:y1, x0:x1] = True
555
+
556
+ combined_mask |= mask
557
+ individual_masks.append(mask)
558
+ mask_indices.append(i)
559
+
560
+ _, conf = result
561
+ avg_confidence += conf
562
+
563
+ raw_min_x = min(box[0] for box in raw_boxes)
564
+ raw_min_y = min(box[1] for box in raw_boxes)
565
+ raw_max_x = max(box[2] for box in raw_boxes)
566
+ raw_max_y = max(box[3] for box in raw_boxes)
567
+
568
+ bbox = {
569
+ "x": int(min_x),
570
+ "y": int(min_y),
571
+ "width": int(max_x - min_x),
572
+ "height": int(max_y - min_y),
573
+ }
574
+
575
+ original_bbox = {
576
+ "x": int(raw_min_x),
577
+ "y": int(raw_min_y),
578
+ "width": int(raw_max_x - raw_min_x),
579
+ "height": int(raw_max_y - raw_min_y),
580
+ }
581
+
582
+ groups.append(
583
+ {
584
+ "combined_mask": combined_mask,
585
+ "bbox": bbox,
586
+ "original_bbox": original_bbox,
587
+ "individual_masks": individual_masks,
588
+ "mask_indices": mask_indices,
589
+ "confidence": avg_confidence / len(group_results),
590
+ }
591
+ )
592
+
593
+ log_message(
594
+ f"Created {len(groups)} grouped text regions for inpainting",
595
+ verbose=verbose,
596
+ )
597
+
598
+ return groups, image_pil
599
+
600
+ def _group_text_boxes_spatially(
601
+ self, boxes, results, img_w, img_h, text_box_proximity_ratio=0.02, verbose=False
602
+ ):
603
+ """
604
+ Group nearby text boxes based on spatial proximity.
605
+
606
+ Args:
607
+ boxes: List of bounding boxes [x0, y0, x1, y1]
608
+ results: List of OCR results corresponding to boxes
609
+ img_w: Image width
610
+ img_h: Image height
611
+ text_box_proximity_ratio: Ratio for grouping nearby text boxes (as fraction of image dimension).
612
+ verbose: Whether to print detailed logs
613
+
614
+ Returns:
615
+ List of tuples (group_boxes, group_results) where each group contains
616
+ spatially related text boxes
617
+ """
618
+ if not boxes:
619
+ return []
620
+
621
+ proximity_threshold = min(img_w, img_h) * text_box_proximity_ratio
622
+
623
+ parent = list(range(len(boxes)))
624
+
625
+ def find(x):
626
+ if parent[x] != x:
627
+ parent[x] = find(parent[x])
628
+ return parent[x]
629
+
630
+ def union(x, y):
631
+ px, py = find(x), find(y)
632
+ if px != py:
633
+ parent[px] = py
634
+
635
+ for i in range(len(boxes)):
636
+ for j in range(i + 1, len(boxes)):
637
+ if self._boxes_are_nearby(boxes[i], boxes[j], proximity_threshold):
638
+ union(i, j)
639
+
640
+ groups = {}
641
+ for i in range(len(boxes)):
642
+ root = find(i)
643
+ if root not in groups:
644
+ groups[root] = ([], [])
645
+ groups[root][0].append(boxes[i])
646
+ groups[root][1].append(results[i])
647
+
648
+ grouped_boxes = list(groups.values())
649
+
650
+ log_message(
651
+ f" - Grouped {len(boxes)} boxes into {len(grouped_boxes)} spatial groups",
652
+ verbose=verbose,
653
+ )
654
+
655
+ return grouped_boxes
656
+
657
+ def _boxes_are_nearby(self, box1, box2, threshold):
658
+ """
659
+ Check if two bounding boxes are spatially close enough to be grouped.
660
+
661
+ Args:
662
+ box1: First bounding box [x0, y0, x1, y1]
663
+ box2: Second bounding box [x0, y0, x1, y1]
664
+ threshold: Maximum distance for boxes to be considered nearby
665
+
666
+ Returns:
667
+ True if boxes are nearby, False otherwise
668
+ """
669
+ x1_min, y1_min, x1_max, y1_max = box1
670
+ x2_min, y2_min, x2_max, y2_max = box2
671
+
672
+ cx1 = (x1_min + x1_max) / 2
673
+ cy1 = (y1_min + y1_max) / 2
674
+ cx2 = (x2_min + x2_max) / 2
675
+ cy2 = (y2_min + y2_max) / 2
676
+
677
+ distance = np.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2)
678
+
679
+ return distance <= threshold
680
+
681
+
682
+ def extract_text_with_manga_ocr(
683
+ images: List[Image.Image], verbose: bool = False
684
+ ) -> List[str]:
685
+ """Extract text from images using manga-ocr library.
686
+
687
+ Args:
688
+ images: List of PIL Images to process
689
+ verbose: Whether to print verbose output
690
+
691
+ Returns:
692
+ List of extracted text strings (one per image). Returns [OCR FAILED] on errors.
693
+ """
694
+ if not images:
695
+ return []
696
+
697
+ try:
698
+ model_manager = get_model_manager()
699
+ manga_ocr_instance = model_manager.get_manga_ocr(verbose=verbose)
700
+
701
+ extracted_texts = []
702
+ for i, img in enumerate(images):
703
+ try:
704
+ if img is None:
705
+ log_message(
706
+ f"Image {i + 1} is None (decode failure), skipping",
707
+ always_print=True,
708
+ )
709
+ extracted_texts.append("[OCR FAILED]")
710
+ continue
711
+
712
+ log_message(
713
+ f"Processing image {i + 1}/{len(images)} with manga-ocr",
714
+ verbose=verbose,
715
+ )
716
+ text = manga_ocr_instance(img)
717
+
718
+ extracted_texts.append(text.strip() if text else "")
719
+
720
+ except Exception as e:
721
+ log_message(
722
+ f"manga-ocr failed for image {i + 1}: {e}", always_print=True
723
+ )
724
+ extracted_texts.append("[OCR FAILED]")
725
+
726
+ return extracted_texts
727
+
728
+ except Exception as e:
729
+ log_message(f"Error with manga-ocr: {e}", always_print=True)
730
+ return ["[OCR FAILED]"] * len(images)
core/image/sorting.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+
4
+ def sort_bubbles_by_reading_order(detections, reading_direction="rtl", panels=None):
5
+ """
6
+ Hybrid Algorithm (veto system):
7
+ - Macro: graph sort with ceiling + right-neighbor veto to enforce Z flow.
8
+ - Micro: tuned spatial banding with looser thresholds for offset bubbles.
9
+ """
10
+
11
+ if not detections:
12
+ return []
13
+
14
+ rtl = (reading_direction or "rtl").lower() == "rtl"
15
+
16
+ # Micro layout: keep slightly offset bubbles grouped into lines/columns.
17
+ def _get_features(bbox):
18
+ x1, y1, x2, y2 = bbox
19
+ w = max(1.0, float(x2 - x1))
20
+ h = max(1.0, float(y2 - y1))
21
+ cx = (x1 + x2) / 2.0
22
+ cy = (y1 + y2) / 2.0
23
+ return x1, y1, x2, y2, w, h, cx, cy
24
+
25
+ def _spatial_sort(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
26
+ """Robust spatial sort for bubbles (vertical columns + horizontal rows)."""
27
+ if not items:
28
+ return []
29
+
30
+ # Tuned thresholds to keep slightly offset bubbles in the same line.
31
+ y_overlap_ratio_threshold = 0.25
32
+ y_center_band_factor = 0.5
33
+ x_overlap_ratio_threshold = 0.2
34
+ x_center_band_factor = 0.5
35
+
36
+ enriched = []
37
+ for item in items:
38
+ x1, y1, x2, y2, w, h, cx, cy = _get_features(item["bbox"])
39
+ enriched.append(
40
+ {
41
+ "item": item,
42
+ "x1": x1,
43
+ "y1": y1,
44
+ "x2": x2,
45
+ "y2": y2,
46
+ "w": w,
47
+ "h": h,
48
+ "cx": cx,
49
+ "cy": cy,
50
+ }
51
+ )
52
+
53
+ enriched.sort(key=lambda e: e["cy"])
54
+
55
+ bands = []
56
+ for e in enriched:
57
+ y1, y2, h = e["y1"], e["y2"], e["h"]
58
+ best_band_idx = -1
59
+ best_score = -1.0
60
+
61
+ for i, band in enumerate(bands):
62
+ band_h = max(1.0, float(band["y_max"] - band["y_min"]))
63
+ overlap_v = max(0.0, min(y2, band["y_max"]) - max(y1, band["y_min"]))
64
+ overlap_ratio = overlap_v / min(h, band_h)
65
+ center_delta_y = abs(e["cy"] - (band["y_min"] + band["y_max"]) / 2.0)
66
+
67
+ same_row = (overlap_ratio >= y_overlap_ratio_threshold) or (
68
+ center_delta_y <= y_center_band_factor * min(h, band_h)
69
+ )
70
+
71
+ if same_row:
72
+ score = overlap_ratio - (center_delta_y / (h + band_h)) * 0.1
73
+ if score > best_score:
74
+ best_score = score
75
+ best_band_idx = i
76
+
77
+ if best_band_idx == -1:
78
+ bands.append({"y_min": y1, "y_max": y2, "items": [e]})
79
+ else:
80
+ band = bands[best_band_idx]
81
+ band["items"].append(e)
82
+ band["y_min"] = min(band["y_min"], y1)
83
+ band["y_max"] = max(band["y_max"], y2)
84
+
85
+ bands.sort(key=lambda b: b["y_min"])
86
+
87
+ ordered_items = []
88
+ for band in bands:
89
+ items_in_band = band["items"]
90
+ columns = []
91
+
92
+ for e in items_in_band:
93
+ x1, x2, w = e["x1"], e["x2"], e["w"]
94
+ best_col_idx = -1
95
+ best_score = -1.0
96
+
97
+ for i, col in enumerate(columns):
98
+ col_w = max(1.0, float(col["x_max"] - col["x_min"]))
99
+ overlap_h = max(0.0, min(x2, col["x_max"]) - max(x1, col["x_min"]))
100
+ overlap_ratio = overlap_h / min(w, col_w)
101
+ col_center_x = (col["x_min"] + col["x_max"]) / 2.0
102
+ center_delta_x = abs(e["cx"] - col_center_x)
103
+
104
+ same_col = (overlap_ratio >= x_overlap_ratio_threshold) or (
105
+ center_delta_x <= x_center_band_factor * min(w, col_w)
106
+ )
107
+
108
+ if same_col:
109
+ score = overlap_ratio - (center_delta_x / (w + col_w)) * 0.1
110
+ if score > best_score:
111
+ best_score = score
112
+ best_col_idx = i
113
+
114
+ if best_col_idx == -1:
115
+ columns.append({"x_min": x1, "x_max": x2, "items": [e]})
116
+ else:
117
+ col = columns[best_col_idx]
118
+ col["items"].append(e)
119
+ col["x_min"] = min(col["x_min"], x1)
120
+ col["x_max"] = max(col["x_max"], x2)
121
+
122
+ if rtl:
123
+ columns.sort(key=lambda c: -((c["x_min"] + c["x_max"]) / 2.0))
124
+ else:
125
+ columns.sort(key=lambda c: ((c["x_min"] + c["x_max"]) / 2.0))
126
+
127
+ for col in columns:
128
+ col["items"].sort(key=lambda e: e["cy"])
129
+ ordered_items.extend([e["item"] for e in col["items"]])
130
+
131
+ return ordered_items
132
+
133
+ # Macro layout: panel graph with root detection and dual veto for Z-flow.
134
+ def _iou_x(boxA, boxB):
135
+ xa1, _, xa2, _ = boxA
136
+ xb1, _, xb2, _ = boxB
137
+ inter = max(0, min(xa2, xb2) - max(xa1, xb1))
138
+ union = (xa2 - xa1) + (xb2 - xb1) - inter
139
+ return inter / union if union > 0 else 0
140
+
141
+ def _iou_y_overlap(boxA, boxB):
142
+ _, ya1, _, ya2 = boxA
143
+ _, yb1, _, yb2 = boxB
144
+ inter = max(0, min(ya2, yb2) - max(ya1, yb1))
145
+ min_h = min(ya2 - ya1, yb2 - yb1)
146
+ return inter / min_h if min_h > 0 else 0
147
+
148
+ def sort_panels_strict(panels_list, rtl=True):
149
+ if not panels_list:
150
+ return []
151
+
152
+ nodes = []
153
+ for i, bbox in enumerate(panels_list):
154
+ nodes.append(
155
+ {
156
+ "id": i,
157
+ "bbox": bbox,
158
+ "center": ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2),
159
+ "visited": False,
160
+ }
161
+ )
162
+
163
+ sorted_indices = []
164
+
165
+ # Roots: panels with no panel above in the same column.
166
+ root_nodes = []
167
+ for n in nodes:
168
+ is_root = True
169
+ for parent in nodes:
170
+ if n["id"] == parent["id"]:
171
+ continue
172
+ is_above = parent["bbox"][3] <= (n["bbox"][1] + 50)
173
+ x_overlap = _iou_x(parent["bbox"], n["bbox"])
174
+ if is_above and x_overlap > 0.2:
175
+ is_root = False
176
+ break
177
+ if is_root:
178
+ root_nodes.append(n)
179
+
180
+ if root_nodes:
181
+ start_node = (
182
+ max(root_nodes, key=lambda n: n["bbox"][2])
183
+ if rtl
184
+ else min(root_nodes, key=lambda n: n["bbox"][0])
185
+ )
186
+ else:
187
+ start_node = min(nodes, key=lambda n: n["bbox"][1])
188
+
189
+ current = start_node
190
+ current["visited"] = True
191
+ sorted_indices.append(current["id"])
192
+
193
+ while len(sorted_indices) < len(nodes):
194
+ c_box = current["bbox"]
195
+ candidates = [n for n in nodes if not n["visited"]]
196
+ if not candidates:
197
+ break
198
+
199
+ col_cand = None
200
+ col_candidates = []
201
+ for cand in candidates:
202
+ cand_box = cand["bbox"]
203
+ overlap = _iou_x(c_box, cand_box)
204
+ is_below = cand_box[1] >= (c_box[1] + (c_box[3] - c_box[1]) * 0.5)
205
+ if overlap > 0.2 and is_below:
206
+ dist_y = max(0, cand_box[1] - c_box[3])
207
+ col_candidates.append((dist_y, cand))
208
+
209
+ if col_candidates:
210
+ col_candidates.sort(
211
+ key=lambda x: (
212
+ int(x[0] / 50),
213
+ -x[1]["center"][0] if rtl else x[1]["center"][0],
214
+ )
215
+ )
216
+ col_cand = col_candidates[0][1]
217
+
218
+ row_cand = None
219
+ row_candidates = []
220
+ for cand in candidates:
221
+ cand_box = cand["bbox"]
222
+ if rtl:
223
+ is_row_neighbor = cand_box[2] <= (c_box[0] + 50)
224
+ dist_x = c_box[0] - cand_box[2]
225
+ else:
226
+ is_row_neighbor = cand_box[0] >= (c_box[2] - 50)
227
+ dist_x = cand_box[0] - c_box[2]
228
+
229
+ if is_row_neighbor:
230
+ y_inter = max(
231
+ 0, min(c_box[3], cand_box[3]) - max(c_box[1], cand_box[1])
232
+ )
233
+ if y_inter > 0:
234
+ row_candidates.append((dist_x, cand))
235
+
236
+ if row_candidates:
237
+ row_candidates.sort(key=lambda x: x[0])
238
+ row_cand = row_candidates[0][1]
239
+
240
+ # Dual veto: ceiling (topological) + right-neighbor (row start).
241
+ if col_cand:
242
+ is_blocked = False
243
+ for other in candidates:
244
+ if other["id"] == col_cand["id"]:
245
+ continue
246
+ is_above = other["bbox"][3] <= (col_cand["bbox"][1] + 50)
247
+ x_overlap = _iou_x(other["bbox"], col_cand["bbox"])
248
+ if is_above and x_overlap > 0.2:
249
+ is_blocked = True
250
+ break
251
+ if rtl:
252
+ has_block_neighbor = other["bbox"][0] > (
253
+ col_cand["bbox"][0] + 20
254
+ )
255
+ else:
256
+ has_block_neighbor = other["bbox"][2] < (
257
+ col_cand["bbox"][2] - 20
258
+ )
259
+ y_overlap_ratio = _iou_y_overlap(col_cand["bbox"], other["bbox"])
260
+ if has_block_neighbor and y_overlap_ratio > 0.3:
261
+ is_blocked = True
262
+ break
263
+ if is_blocked:
264
+ col_cand = None
265
+
266
+ next_node = None
267
+ if row_cand and not col_cand:
268
+ next_node = row_cand
269
+ elif col_cand and not row_cand:
270
+ next_node = col_cand
271
+ elif row_cand and col_cand:
272
+ curr_h = c_box[3] - c_box[1]
273
+ bottom_diff = abs(c_box[3] - row_cand["bbox"][3])
274
+ is_row_aligned = bottom_diff < (curr_h * 0.25)
275
+ next_node = row_cand if is_row_aligned else col_cand
276
+
277
+ if not next_node:
278
+ # Recompute roots among remaining nodes to find a new entry.
279
+ sub_roots = []
280
+ for n in candidates:
281
+ is_root = True
282
+ for parent in candidates:
283
+ if n["id"] == parent["id"]:
284
+ continue
285
+ is_above = parent["bbox"][3] <= (n["bbox"][1] + 50)
286
+ x_overlap = _iou_x(parent["bbox"], n["bbox"])
287
+ if is_above and x_overlap > 0.2:
288
+ is_root = False
289
+ break
290
+ if is_root:
291
+ sub_roots.append(n)
292
+
293
+ if sub_roots:
294
+ next_node = (
295
+ max(sub_roots, key=lambda n: n["bbox"][2])
296
+ if rtl
297
+ else min(sub_roots, key=lambda n: n["bbox"][0])
298
+ )
299
+ else:
300
+ next_node = min(candidates, key=lambda n: n["bbox"][1])
301
+
302
+ current = next_node
303
+ current["visited"] = True
304
+ sorted_indices.append(current["id"])
305
+
306
+ return sorted_indices
307
+
308
+ if not panels:
309
+ return _spatial_sort(detections)
310
+
311
+ sorted_panel_indices = sort_panels_strict(panels, rtl)
312
+ if not sorted_panel_indices:
313
+ sorted_panel_indices = list(range(len(panels)))
314
+
315
+ panel_bins = {pid: [] for pid in sorted_panel_indices}
316
+ unassigned = []
317
+
318
+ for detection in detections:
319
+ bx1, by1, bx2, by2 = detection["bbox"]
320
+ bcx, bcy = (bx1 + bx2) / 2.0, (by1 + by2) / 2.0
321
+ assigned = False
322
+
323
+ for i, pbbox in enumerate(panels):
324
+ px1, py1, px2, py2 = pbbox
325
+ if px1 <= bcx <= px2 and py1 <= bcy <= py2:
326
+ panel_bins.setdefault(i, []).append(detection)
327
+ detection["panel_id"] = i
328
+ assigned = True
329
+ break
330
+
331
+ if not assigned:
332
+ best_dist = float("inf")
333
+ best_pid = -1
334
+ for i, pbbox in enumerate(panels):
335
+ px1, py1, px2, py2 = pbbox
336
+ dx = max(px1 - bcx, 0, bcx - px2)
337
+ dy = max(py1 - bcy, 0, bcy - py2)
338
+ dist = (dx**2 + dy**2) ** 0.5
339
+ if dist < best_dist:
340
+ best_dist = dist
341
+ best_pid = i
342
+
343
+ if best_dist < 300:
344
+ panel_bins.setdefault(best_pid, []).append(detection)
345
+ detection["panel_id"] = best_pid
346
+ assigned = True
347
+
348
+ if not assigned:
349
+ detection["panel_id"] = None
350
+ unassigned.append(detection)
351
+
352
+ final_order = []
353
+ for pid in sorted_panel_indices:
354
+ final_order.extend(_spatial_sort(panel_bins.get(pid, [])))
355
+
356
+ if unassigned:
357
+ final_order.extend(_spatial_sort(unassigned))
358
+
359
+ return final_order
core/llm_defaults.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provider-specific default sampling parameters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Optional
6
+
7
+ # Canonical provider names used across the app
8
+ DEFAULT_LLM_PROVIDER = "Google"
9
+
10
+ _PROVIDER_SAMPLING_DEFAULTS: Dict[str, Dict[str, float | int]] = {
11
+ "Google": {"temperature": 0.1, "top_p": 0.95, "top_k": 64},
12
+ "OpenAI": {"temperature": 0.1, "top_p": 1.0, "top_k": 0},
13
+ "Anthropic": {"temperature": 0.1, "top_p": 1.0, "top_k": 0},
14
+ "xAI": {"temperature": 0.1, "top_p": 1.0, "top_k": 0},
15
+ "DeepSeek": {"temperature": 0.1, "top_p": 0.95, "top_k": 0},
16
+ "Z.ai": {"temperature": 0.1, "top_p": 0.95, "top_k": 40},
17
+ "Moonshot AI": {"temperature": 0.1, "top_p": 1.0, "top_k": 0},
18
+ "OpenRouter": {"temperature": 0.1, "top_p": 0.95, "top_k": 64},
19
+ "OpenAI-Compatible": {"temperature": 0.1, "top_p": 0.95, "top_k": 40},
20
+ }
21
+
22
+
23
+ def get_provider_sampling_defaults(provider: Optional[str]) -> Dict[str, float | int]:
24
+ """Return a copy of the sampling defaults for the specified provider."""
25
+ fallback = _PROVIDER_SAMPLING_DEFAULTS[DEFAULT_LLM_PROVIDER]
26
+ if not provider:
27
+ return fallback.copy()
28
+ return _PROVIDER_SAMPLING_DEFAULTS.get(provider, fallback).copy()
29
+
30
+
31
+ __all__ = ["DEFAULT_LLM_PROVIDER", "get_provider_sampling_defaults"]
core/ml/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Machine learning model management for MangaTranslator.
3
+
4
+ This subpackage contains modules for:
5
+ - Centralized ML model loading and caching
6
+ - Model manager for YOLO, SAM, Flux, and other models
7
+ """
8
+
9
+ from .model_manager import ModelManager, get_model_manager
10
+
11
+ __all__ = [
12
+ "ModelManager",
13
+ "get_model_manager",
14
+ ]
core/ml/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (542 Bytes). View file
 
core/ml/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (507 Bytes). View file
 
core/ml/__pycache__/model_manager.cpython-311.pyc ADDED
Binary file (44.7 kB). View file
 
core/ml/__pycache__/model_manager.cpython-314.pyc ADDED
Binary file (45.6 kB). View file
 
core/ml/model_manager.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import shutil
3
+ import threading
4
+ import urllib.request
5
+ from contextlib import contextmanager
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
+ from spandrel import ModelLoader
13
+ from transformers import Sam2Model, Sam2Processor
14
+ from ultralytics import YOLO
15
+
16
+ from utils.exceptions import ModelError
17
+ from utils.logging import log_message
18
+
19
+
20
+ class ModelType(Enum):
21
+ """Enumeration of available model types."""
22
+
23
+ UPSCALE = "upscale"
24
+ UPSCALE_LITE = "upscale_lite"
25
+ YOLO_SPEECH_BUBBLE = "yolo_speech_bubble"
26
+ YOLO_CONJOINED_BUBBLE = "yolo_conjoined_bubble"
27
+ YOLO_OSBTEXT = "yolo_osbtext"
28
+ YOLO_PANEL = "yolo_panel"
29
+ SAM2 = "sam2"
30
+ MANGA_OCR = "manga_ocr"
31
+ FLUX_TRANSFORMER = "flux_transformer"
32
+ FLUX_TEXT_ENCODER = "flux_text_encoder"
33
+ FLUX_PIPELINE = "flux_pipeline"
34
+
35
+
36
+ class ModelManager:
37
+ """Singleton model manager for MangaTranslator."""
38
+
39
+ _instance = None
40
+ _lock = threading.RLock()
41
+
42
+ def __new__(cls):
43
+ if cls._instance is None:
44
+ with cls._lock:
45
+ if cls._instance is None:
46
+ cls._instance = super().__new__(cls)
47
+ cls._instance._initialized = False
48
+ return cls._instance
49
+
50
+ def __init__(self):
51
+ """Initialize the model manager (only once due to singleton pattern)."""
52
+ with self._lock:
53
+ if self._initialized:
54
+ return
55
+
56
+ self.device = torch.device(
57
+ "cuda"
58
+ if torch.cuda.is_available()
59
+ else "mps" if torch.backends.mps.is_available() else "cpu"
60
+ )
61
+ self.dtype = (
62
+ torch.bfloat16
63
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
64
+ else torch.float16 if self.device.type == "mps" else torch.float32
65
+ )
66
+
67
+ # Model storage
68
+ self.models = {}
69
+ self.model_paths = self._init_model_paths()
70
+ self.model_urls = self._init_model_urls()
71
+ self.model_hf_repos = self._init_hf_repos()
72
+
73
+ # Flux-specific configuration
74
+ self.flux_cache_dir = Path("./models/flux")
75
+ self.flux_hf_token = None
76
+ self.flux_residual_diff_threshold = 0.15
77
+
78
+ self._initialized = True
79
+ log_message(
80
+ f"Model Manager initialized on device: {self.device}", always_print=True
81
+ )
82
+
83
+ def _init_model_paths(self):
84
+ """Initialize model file paths."""
85
+ model_dir = Path("./models").resolve()
86
+ return {
87
+ ModelType.UPSCALE: (
88
+ model_dir / "upscale" / "2x-AnimeSharpV4_RCAN.safetensors"
89
+ ),
90
+ ModelType.UPSCALE_LITE: (
91
+ model_dir / "upscale" / "2x-AnimeSharpV4_Fast_RCAN_PU.safetensors"
92
+ ),
93
+ ModelType.YOLO_SPEECH_BUBBLE: (
94
+ model_dir / "yolo" / "yolov8m_seg-speech-bubble.pt"
95
+ ),
96
+ ModelType.YOLO_CONJOINED_BUBBLE: (
97
+ model_dir / "yolo" / "comic-speech-bubble-detector-yolov8m.pt"
98
+ ),
99
+ ModelType.YOLO_OSBTEXT: (model_dir / "yolo" / "animetext_yolov12x.pt"),
100
+ ModelType.YOLO_PANEL: (
101
+ model_dir / "yolo" / "manga109_v2023.12.07_l_yolov11.pt"
102
+ ),
103
+ ModelType.MANGA_OCR: (model_dir / "manga-ocr-base"),
104
+ }
105
+
106
+ def _init_model_urls(self):
107
+ """Initialize model download URLs."""
108
+ return {
109
+ ModelType.UPSCALE: (
110
+ "https://huggingface.co/Kim2091/2x-AnimeSharpV4/resolve/main/"
111
+ "2x-AnimeSharpV4_RCAN.safetensors"
112
+ ),
113
+ ModelType.UPSCALE_LITE: (
114
+ "https://huggingface.co/Kim2091/2x-AnimeSharpV4/resolve/main/"
115
+ "2x-AnimeSharpV4_Fast_RCAN_PU.safetensors"
116
+ ),
117
+ }
118
+
119
+ def _init_hf_repos(self):
120
+ """Initialize Hugging Face repository information."""
121
+ repos = {
122
+ ModelType.UPSCALE: {
123
+ "repo_id": "Kim2091/2x-AnimeSharpV4",
124
+ "filename": "2x-AnimeSharpV4_RCAN.safetensors",
125
+ },
126
+ ModelType.UPSCALE_LITE: {
127
+ "repo_id": "Kim2091/2x-AnimeSharpV4",
128
+ "filename": "2x-AnimeSharpV4_Fast_RCAN_PU.safetensors",
129
+ },
130
+ ModelType.YOLO_SPEECH_BUBBLE: {
131
+ "repo_id": "kitsumed/yolov8m_seg-speech-bubble",
132
+ "filename": "model.pt",
133
+ },
134
+ ModelType.YOLO_CONJOINED_BUBBLE: {
135
+ "repo_id": "ogkalu/comic-speech-bubble-detector-yolov8m",
136
+ "filename": "comic-speech-bubble-detector.pt",
137
+ },
138
+ ModelType.YOLO_OSBTEXT: {
139
+ "repo_id": "deepghs/AnimeText_yolo",
140
+ "filename": "yolo12x_animetext/model.pt",
141
+ },
142
+ ModelType.YOLO_PANEL: {
143
+ "repo_id": "deepghs/manga109_yolo",
144
+ "filename": "v2023.12.07_l_yv11/model.pt",
145
+ },
146
+ ModelType.SAM2: {
147
+ "repo_id": "facebook/sam2.1-hiera-large",
148
+ },
149
+ ModelType.FLUX_PIPELINE: {
150
+ "repo_id": "black-forest-labs/FLUX.1-Kontext-dev",
151
+ "filename": None, # Pipeline loaded via from_pretrained
152
+ },
153
+ }
154
+
155
+ repos[ModelType.FLUX_TRANSFORMER] = {
156
+ "repo_id": "nunchaku-tech/nunchaku-flux.1-kontext-dev",
157
+ "filename": None, # Will be constructed dynamically in load_flux_models()
158
+ }
159
+ repos[ModelType.FLUX_TEXT_ENCODER] = {
160
+ "repo_id": "nunchaku-tech/nunchaku-t5",
161
+ "filename": "awq-int4-flux.1-t5xxl.safetensors",
162
+ }
163
+ repos[ModelType.MANGA_OCR] = {
164
+ "repo_id": "kha-white/manga-ocr-base",
165
+ }
166
+
167
+ return repos
168
+
169
+ def _ensure_file(self, path: Path, url: str, verbose: bool = False) -> None:
170
+ """Download file from URL if it doesn't exist.
171
+
172
+ Args:
173
+ path: Path where file should be saved
174
+ url: URL to download from
175
+ verbose: Whether to print verbose logging
176
+ """
177
+ if path.exists():
178
+ return
179
+ path.parent.mkdir(parents=True, exist_ok=True)
180
+ log_message(f"Downloading {path.name}...", verbose=verbose)
181
+ try:
182
+ req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
183
+ with urllib.request.urlopen(req) as response, open(path, "wb") as f:
184
+ shutil.copyfileobj(response, f)
185
+ log_message(f"Downloaded {path.name} successfully.", verbose=verbose)
186
+ except Exception as e:
187
+ if path.exists():
188
+ path.unlink()
189
+ raise ModelError(f"Failed to download {path.name}: {e}")
190
+
191
+ def _ensure_hf_file(
192
+ self,
193
+ repo_id: str,
194
+ filename: str,
195
+ target: Path,
196
+ token: Optional[str] = None,
197
+ verbose: bool = False,
198
+ ) -> Path:
199
+ """Download file from Hugging Face if it doesn't exist.
200
+
201
+ Args:
202
+ repo_id: Hugging Face repository ID
203
+ filename: Name of file to download
204
+ target: Path where file should be saved
205
+ token: Optional Hugging Face token
206
+ verbose: Whether to print verbose logging
207
+ """
208
+ if target.exists():
209
+ return target
210
+ target.parent.mkdir(parents=True, exist_ok=True)
211
+ log_message(
212
+ f"Downloading {target.name} from Hugging Face ({repo_id})...",
213
+ verbose=verbose,
214
+ )
215
+ downloaded = hf_hub_download(
216
+ repo_id=repo_id,
217
+ filename=filename,
218
+ local_dir=str(target.parent),
219
+ token=token,
220
+ )
221
+ downloaded_path = Path(downloaded)
222
+ if downloaded_path != target:
223
+ downloaded_parent = downloaded_path.parent
224
+ try:
225
+ downloaded_path.replace(target)
226
+ except Exception:
227
+ shutil.copyfile(downloaded_path, target)
228
+ try:
229
+ downloaded_path.unlink()
230
+ except Exception:
231
+ pass
232
+
233
+ # Clean up empty directory if it was created by hf_hub_download
234
+ if downloaded_parent != target.parent and downloaded_parent.exists():
235
+ try:
236
+ if not any(downloaded_parent.iterdir()):
237
+ downloaded_parent.rmdir()
238
+ except (OSError, PermissionError):
239
+ pass
240
+ log_message(f"Downloaded {target.name} successfully.", verbose=verbose)
241
+ return target
242
+
243
+ def _ensure_hf_repo(
244
+ self,
245
+ repo_id: str,
246
+ target_dir: Path,
247
+ token: Optional[str] = None,
248
+ verbose: bool = False,
249
+ ) -> Path:
250
+ """Download entire repository from Hugging Face if it doesn't exist.
251
+
252
+ Args:
253
+ repo_id: Hugging Face repository ID
254
+ target_dir: Directory where repository should be saved
255
+ token: Optional Hugging Face token
256
+ verbose: Whether to print verbose logging
257
+
258
+ Returns:
259
+ Path to the downloaded repository directory
260
+ """
261
+ # Check for larger model file to ensure download is complete
262
+ critical_file = target_dir / "pytorch_model.bin"
263
+ if target_dir.exists() and critical_file.exists():
264
+ return target_dir
265
+ target_dir.mkdir(parents=True, exist_ok=True)
266
+ log_message(
267
+ f"Downloading repository {repo_id} from Hugging Face...",
268
+ verbose=verbose,
269
+ )
270
+ try:
271
+ snapshot_download(
272
+ repo_id=repo_id,
273
+ local_dir=str(target_dir),
274
+ token=token,
275
+ )
276
+ log_message(
277
+ f"Downloaded repository {repo_id} successfully.", verbose=verbose
278
+ )
279
+ except Exception as e:
280
+ if target_dir.exists():
281
+ models_dir = Path("./models").resolve()
282
+ target_resolved = target_dir.resolve()
283
+ try:
284
+ target_resolved.relative_to(models_dir)
285
+ # Safe to delete
286
+ try:
287
+ shutil.rmtree(target_dir)
288
+ except Exception:
289
+ pass
290
+ except ValueError:
291
+ # target_dir is not within models/, skip deletion for safety
292
+ log_message(
293
+ f"Warning: Skipping deletion of {target_dir} as it is outside models/ directory",
294
+ always_print=True,
295
+ )
296
+ raise ModelError(f"Failed to download repository {repo_id}: {e}") from e
297
+ return target_dir
298
+
299
+ def is_loaded(self, model_type: ModelType) -> bool:
300
+ """Check if a model is currently loaded."""
301
+ with self._lock:
302
+ return model_type in self.models and self.models[model_type] is not None
303
+
304
+ def load_upscale(self, verbose: bool = False):
305
+ """Load upscale model (AnimeSharpV4 RCAN)."""
306
+ with self._lock:
307
+ if self.is_loaded(ModelType.UPSCALE):
308
+ return self.models[ModelType.UPSCALE]
309
+
310
+ log_message(
311
+ "Loading upscale model (2x-AnimeSharpV4_RCAN)...", verbose=verbose
312
+ )
313
+ path = self.model_paths[ModelType.UPSCALE]
314
+
315
+ # Try HF download first, fallback to direct URL
316
+ try:
317
+ hf_info = self.model_hf_repos[ModelType.UPSCALE]
318
+ self._ensure_hf_file(
319
+ hf_info["repo_id"], hf_info["filename"], path, verbose=verbose
320
+ )
321
+ except Exception:
322
+ self._ensure_file(
323
+ path, self.model_urls[ModelType.UPSCALE], verbose=verbose
324
+ )
325
+
326
+ # Load model
327
+ if path.suffix == ".safetensors":
328
+ from safetensors import safe_open
329
+
330
+ state_dict = {}
331
+ with safe_open(path, framework="pt", device=str(self.device)) as f:
332
+ for key in f.keys():
333
+ state_dict[key] = f.get_tensor(key)
334
+ else:
335
+ state_dict = torch.load(
336
+ path, map_location=self.device, weights_only=False
337
+ )
338
+
339
+ model = (
340
+ ModelLoader().load_from_state_dict(state_dict).to(self.device).eval()
341
+ )
342
+ self.models[ModelType.UPSCALE] = model
343
+ log_message("Upscale model loaded.", verbose=verbose)
344
+ return model
345
+
346
+ def load_upscale_lite(self, verbose: bool = False):
347
+ """Load upscale lite model (AnimeSharpV4 Fast RCAN PU)."""
348
+ with self._lock:
349
+ if self.is_loaded(ModelType.UPSCALE_LITE):
350
+ return self.models[ModelType.UPSCALE_LITE]
351
+
352
+ log_message(
353
+ "Loading upscale lite model (2x-AnimeSharpV4_Fast_RCAN_PU)...",
354
+ verbose=verbose,
355
+ )
356
+ path = self.model_paths[ModelType.UPSCALE_LITE]
357
+
358
+ # Try HF download first, fallback to direct URL
359
+ try:
360
+ hf_info = self.model_hf_repos[ModelType.UPSCALE_LITE]
361
+ self._ensure_hf_file(
362
+ hf_info["repo_id"], hf_info["filename"], path, verbose=verbose
363
+ )
364
+ except Exception:
365
+ self._ensure_file(
366
+ path, self.model_urls[ModelType.UPSCALE_LITE], verbose=verbose
367
+ )
368
+
369
+ # Load model
370
+ if path.suffix == ".safetensors":
371
+ from safetensors import safe_open
372
+
373
+ state_dict = {}
374
+ with safe_open(path, framework="pt", device=str(self.device)) as f:
375
+ for key in f.keys():
376
+ state_dict[key] = f.get_tensor(key)
377
+ else:
378
+ state_dict = torch.load(
379
+ path, map_location=self.device, weights_only=False
380
+ )
381
+
382
+ model = (
383
+ ModelLoader().load_from_state_dict(state_dict).to(self.device).eval()
384
+ )
385
+ self.models[ModelType.UPSCALE_LITE] = model
386
+ log_message("Upscale lite model loaded.", verbose=verbose)
387
+ return model
388
+
389
+ def load_yolo_speech_bubble(
390
+ self, model_path: Optional[str] = None, verbose: bool = False
391
+ ):
392
+ """Load YOLO model for speech bubble detection.
393
+
394
+ Args:
395
+ model_path: Optional custom path to YOLO model. If None, uses default.
396
+ verbose: Whether to print verbose logging
397
+ """
398
+ with self._lock:
399
+ if self.is_loaded(ModelType.YOLO_SPEECH_BUBBLE):
400
+ return self.models[ModelType.YOLO_SPEECH_BUBBLE]
401
+
402
+ log_message(
403
+ "Loading YOLO speech bubble detection model...", verbose=verbose
404
+ )
405
+
406
+ path = (
407
+ self.model_paths[ModelType.YOLO_SPEECH_BUBBLE]
408
+ if model_path is None
409
+ else Path(model_path)
410
+ )
411
+
412
+ if path == self.model_paths[ModelType.YOLO_SPEECH_BUBBLE]:
413
+ hf_info = self.model_hf_repos[ModelType.YOLO_SPEECH_BUBBLE]
414
+ self._ensure_hf_file(
415
+ hf_info["repo_id"], hf_info["filename"], path, verbose=verbose
416
+ )
417
+
418
+ model = YOLO(str(path))
419
+ self.models[ModelType.YOLO_SPEECH_BUBBLE] = model
420
+ log_message("YOLO model loaded.", verbose=verbose)
421
+ return model
422
+
423
+ def load_yolo_conjoined_bubble(self, verbose: bool = False):
424
+ """Load YOLO model for conjoined speech bubble detection."""
425
+ with self._lock:
426
+ if self.is_loaded(ModelType.YOLO_CONJOINED_BUBBLE):
427
+ return self.models[ModelType.YOLO_CONJOINED_BUBBLE]
428
+
429
+ log_message(
430
+ "Loading YOLO conjoined bubble detection model...", verbose=verbose
431
+ )
432
+ path = self.model_paths[ModelType.YOLO_CONJOINED_BUBBLE]
433
+
434
+ # Try HF download
435
+ hf_info = self.model_hf_repos[ModelType.YOLO_CONJOINED_BUBBLE]
436
+ self._ensure_hf_file(
437
+ hf_info["repo_id"], hf_info["filename"], path, verbose=verbose
438
+ )
439
+
440
+ model = YOLO(str(path))
441
+ self.models[ModelType.YOLO_CONJOINED_BUBBLE] = model
442
+ log_message("YOLO conjoined bubble model loaded.", verbose=verbose)
443
+ return model
444
+
445
+ def load_yolo_osbtext(self, token: Optional[str] = None, verbose: bool = False):
446
+ """Load YOLO model for outside text detection.
447
+
448
+ Args:
449
+ token: Hugging Face token for gated repo access.
450
+ verbose: Whether to print verbose logging
451
+ """
452
+ with self._lock:
453
+ if self.is_loaded(ModelType.YOLO_OSBTEXT):
454
+ return self.models[ModelType.YOLO_OSBTEXT]
455
+
456
+ log_message("Loading YOLO OSB Text detection model...", verbose=verbose)
457
+
458
+ path = self.model_paths[ModelType.YOLO_OSBTEXT]
459
+ hf_info = self.model_hf_repos[ModelType.YOLO_OSBTEXT]
460
+
461
+ self._ensure_hf_file(
462
+ hf_info["repo_id"],
463
+ hf_info["filename"],
464
+ path,
465
+ token=token,
466
+ verbose=verbose,
467
+ )
468
+
469
+ model = YOLO(str(path))
470
+ self.models[ModelType.YOLO_OSBTEXT] = model
471
+ log_message("YOLO OSB Text model loaded.", verbose=verbose)
472
+ return model
473
+
474
+ def load_yolo_panel(self, verbose: bool = False):
475
+ """Load YOLO model for panel detection.
476
+
477
+ Args:
478
+ verbose: Whether to print verbose logging
479
+ """
480
+ with self._lock:
481
+ if self.is_loaded(ModelType.YOLO_PANEL):
482
+ return self.models[ModelType.YOLO_PANEL]
483
+
484
+ log_message("Loading YOLO panel detection model...", verbose=verbose)
485
+ path = self.model_paths[ModelType.YOLO_PANEL]
486
+ hf_info = self.model_hf_repos[ModelType.YOLO_PANEL]
487
+
488
+ self._ensure_hf_file(
489
+ hf_info["repo_id"],
490
+ hf_info["filename"],
491
+ path,
492
+ verbose=verbose,
493
+ )
494
+
495
+ model = YOLO(str(path))
496
+ self.models[ModelType.YOLO_PANEL] = model
497
+ log_message("YOLO panel model loaded.", verbose=verbose)
498
+ return model
499
+
500
+ def load_manga_ocr(self, verbose: bool = False) -> Path:
501
+ """Ensure manga-ocr model repository is downloaded.
502
+
503
+ Args:
504
+ verbose: Whether to print verbose logging
505
+
506
+ Returns:
507
+ Path to the downloaded manga-ocr model directory
508
+ """
509
+ with self._lock:
510
+ model_path = self.model_paths[ModelType.MANGA_OCR]
511
+ hf_info = self.model_hf_repos[ModelType.MANGA_OCR]
512
+ self._ensure_hf_repo(hf_info["repo_id"], model_path, verbose=verbose)
513
+ log_message("manga-ocr model repository ready.", verbose=verbose)
514
+ return model_path
515
+
516
+ def get_manga_ocr(self, verbose: bool = False):
517
+ """Get manga-ocr instance, loading it if necessary.
518
+
519
+ Args:
520
+ verbose: Whether to print verbose logging
521
+
522
+ Returns:
523
+ MangaOcr instance
524
+ """
525
+ with self._lock:
526
+ if self.is_loaded(ModelType.MANGA_OCR):
527
+ return self.models[ModelType.MANGA_OCR]
528
+
529
+ log_message("Initializing manga-ocr...", verbose=verbose)
530
+
531
+ # Fix for MeCab/Fugashi on non-Windows systems
532
+ try:
533
+ import os
534
+
535
+ import unidic_lite
536
+
537
+ os.environ["MECABRC"] = os.path.join(unidic_lite.DICDIR, "mecabrc")
538
+ except ImportError:
539
+ log_message(
540
+ "Warning: unidic_lite not found, skipping MeCab fix",
541
+ verbose=verbose,
542
+ )
543
+ except Exception as e:
544
+ log_message(f"Warning: Failed to apply MeCab fix: {e}", verbose=verbose)
545
+
546
+ from manga_ocr import MangaOcr
547
+
548
+ # Ensure model is downloaded
549
+ model_path = self.load_manga_ocr(verbose=verbose)
550
+ manga_ocr_instance = MangaOcr(pretrained_model_name_or_path=str(model_path))
551
+ self.models[ModelType.MANGA_OCR] = manga_ocr_instance
552
+ log_message("manga-ocr initialized", verbose=verbose)
553
+ return manga_ocr_instance
554
+
555
+ def load_sam2(self, verbose: bool = False):
556
+ """Load SAM 2.1 model and processor.
557
+
558
+ Returns:
559
+ tuple: (processor, model) - SAM2 processor and model instances
560
+ """
561
+ with self._lock:
562
+ if self.is_loaded(ModelType.SAM2):
563
+ return self.models[ModelType.SAM2]
564
+
565
+ log_message("Loading SAM 2.1 model...", verbose=verbose)
566
+ hf_info = self.model_hf_repos[ModelType.SAM2]
567
+ cache_dir = "models/sam"
568
+
569
+ processor = Sam2Processor.from_pretrained(
570
+ hf_info["repo_id"], cache_dir=cache_dir
571
+ )
572
+ model = Sam2Model.from_pretrained(
573
+ hf_info["repo_id"], torch_dtype=self.dtype, cache_dir=cache_dir
574
+ ).to(self.device)
575
+ model.eval()
576
+
577
+ # Store as tuple
578
+ self.models[ModelType.SAM2] = (processor, model)
579
+ log_message("SAM 2.1 model loaded.", verbose=verbose)
580
+ return self.models[ModelType.SAM2]
581
+
582
+ def set_flux_hf_token(self, token: str):
583
+ """Set the HuggingFace token for Flux model downloads.
584
+
585
+ Args:
586
+ token: HuggingFace API token
587
+ """
588
+ self.flux_hf_token = token if token else None
589
+
590
+ def set_flux_residual_diff_threshold(self, threshold: float):
591
+ """Set the residual diff threshold for Flux caching.
592
+
593
+ Args:
594
+ threshold: Residual diff threshold (0.0-1.0)
595
+ """
596
+ self.flux_residual_diff_threshold = max(0.0, min(1.0, threshold))
597
+
598
+ def load_flux_models(self, verbose: bool = False):
599
+ """Load all Flux Kontext inpainting models (transformer, text encoder, pipeline).
600
+
601
+ Returns:
602
+ tuple: (transformer, text_encoder, pipeline)
603
+ """
604
+ with self._lock:
605
+ if self.is_loaded(ModelType.FLUX_PIPELINE):
606
+ return (
607
+ self.models[ModelType.FLUX_TRANSFORMER],
608
+ self.models[ModelType.FLUX_TEXT_ENCODER],
609
+ self.models[ModelType.FLUX_PIPELINE],
610
+ )
611
+
612
+ log_message("Loading Flux Kontext inpainting models...", verbose=verbose)
613
+ try:
614
+ # Lazy imports for Nunchaku and diffusers
615
+ from diffusers import FluxKontextPipeline
616
+ from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
617
+ from nunchaku.models.text_encoders.t5_encoder import (
618
+ NunchakuT5EncoderModel,
619
+ )
620
+ from nunchaku.models.transformers.transformer_flux import (
621
+ NunchakuFluxTransformer2dModel,
622
+ )
623
+ from nunchaku.utils import get_precision
624
+
625
+ hf_info = self.model_hf_repos[ModelType.FLUX_TRANSFORMER]
626
+ if hf_info["filename"] is None:
627
+ hf_info["filename"] = (
628
+ f"svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
629
+ )
630
+ transformer_path = self._ensure_hf_file(
631
+ hf_info["repo_id"],
632
+ hf_info["filename"],
633
+ self.flux_cache_dir / hf_info["filename"],
634
+ verbose=verbose,
635
+ )
636
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
637
+ str(transformer_path),
638
+ torch_dtype=self.dtype,
639
+ offload=True,
640
+ precision="int4",
641
+ set_attention_impl="nunchaku-fp16",
642
+ )
643
+ self.models[ModelType.FLUX_TRANSFORMER] = transformer
644
+
645
+ # Load text encoder
646
+ hf_info = self.model_hf_repos[ModelType.FLUX_TEXT_ENCODER]
647
+ text_encoder_path = self._ensure_hf_file(
648
+ hf_info["repo_id"],
649
+ hf_info["filename"],
650
+ self.flux_cache_dir / hf_info["filename"],
651
+ verbose=verbose,
652
+ )
653
+ text_encoder = NunchakuT5EncoderModel.from_pretrained(
654
+ str(text_encoder_path),
655
+ torch_dtype=self.dtype,
656
+ )
657
+ self.models[ModelType.FLUX_TEXT_ENCODER] = text_encoder
658
+
659
+ # Load pipeline
660
+ pipeline_repo = self.model_hf_repos[ModelType.FLUX_PIPELINE]["repo_id"]
661
+ pipeline = FluxKontextPipeline.from_pretrained(
662
+ pipeline_repo,
663
+ transformer=transformer,
664
+ text_encoder_2=text_encoder,
665
+ torch_dtype=self.dtype,
666
+ cache_dir=str(self.flux_cache_dir),
667
+ token=self.flux_hf_token,
668
+ ).to(self.device)
669
+
670
+ # Apply caching for faster inference
671
+ apply_cache_on_pipe(
672
+ pipeline, residual_diff_threshold=self.flux_residual_diff_threshold
673
+ )
674
+ self.models[ModelType.FLUX_PIPELINE] = pipeline
675
+
676
+ log_message("Flux Kontext models loaded successfully.", verbose=verbose)
677
+ return transformer, text_encoder, pipeline
678
+ except ImportError as e:
679
+ raise ModelError(
680
+ "Nunchaku not installed or incompatible. Inpainting requires Nunchaku."
681
+ ) from e
682
+ except Exception as e:
683
+ raise ModelError(
684
+ f"Failed to load Flux/Nunchaku inpainting models: {e}"
685
+ ) from e
686
+
687
+ def unload_model(
688
+ self, model_type: ModelType, force_gc: bool = True, verbose: bool = False
689
+ ):
690
+ """Unload a specific model and free memory.
691
+
692
+ Args:
693
+ model_type: Type of model to unload
694
+ force_gc: Whether to force garbage collection
695
+ verbose: Whether to print verbose logging
696
+ """
697
+ with self._lock:
698
+ if not self.is_loaded(model_type):
699
+ return
700
+
701
+ log_message(f"Unloading {model_type.value}...", verbose=verbose)
702
+ del self.models[model_type]
703
+ self.models[model_type] = None
704
+
705
+ if force_gc and torch.cuda.is_available():
706
+ gc.collect()
707
+ torch.cuda.empty_cache()
708
+
709
+ def unload_upscale_models(self, verbose: bool = False):
710
+ """Unload upscale models (both regular and lite)."""
711
+ self.unload_model(ModelType.UPSCALE, force_gc=False, verbose=verbose)
712
+ self.unload_model(ModelType.UPSCALE_LITE, force_gc=False, verbose=verbose)
713
+ if torch.cuda.is_available():
714
+ gc.collect()
715
+ torch.cuda.empty_cache()
716
+ log_message("Upscale models unloaded.", verbose=verbose)
717
+
718
+ def unload_ocr_models(self, verbose: bool = False):
719
+ """Unload OCR-related models (YOLO, SAM2, and manga-ocr)."""
720
+ models_unloaded = []
721
+ if self.is_loaded(ModelType.YOLO_SPEECH_BUBBLE):
722
+ models_unloaded.append("yolo_speech_bubble")
723
+ if self.is_loaded(ModelType.YOLO_CONJOINED_BUBBLE):
724
+ models_unloaded.append("yolo_conjoined_bubble")
725
+ if self.is_loaded(ModelType.SAM2):
726
+ models_unloaded.append("sam2")
727
+ if self.is_loaded(ModelType.YOLO_OSBTEXT):
728
+ models_unloaded.append("yolo_osbtext")
729
+ if self.is_loaded(ModelType.YOLO_PANEL):
730
+ models_unloaded.append("yolo_panel")
731
+ if self.is_loaded(ModelType.MANGA_OCR):
732
+ models_unloaded.append("manga_ocr")
733
+
734
+ self.unload_model(ModelType.YOLO_SPEECH_BUBBLE, force_gc=False, verbose=verbose)
735
+ self.unload_model(
736
+ ModelType.YOLO_CONJOINED_BUBBLE, force_gc=False, verbose=verbose
737
+ )
738
+ self.unload_model(ModelType.SAM2, force_gc=False, verbose=verbose)
739
+ self.unload_model(ModelType.YOLO_OSBTEXT, force_gc=False, verbose=verbose)
740
+ self.unload_model(ModelType.YOLO_PANEL, force_gc=False, verbose=verbose)
741
+ self.unload_model(ModelType.MANGA_OCR, force_gc=True, verbose=verbose)
742
+
743
+ if models_unloaded:
744
+ log_message("OCR models unloaded.", verbose=verbose)
745
+
746
+ def unload_flux_models(self, verbose: bool = False):
747
+ """Unload all Flux Kontext models."""
748
+ models_unloaded = []
749
+ if self.is_loaded(ModelType.FLUX_TRANSFORMER):
750
+ models_unloaded.append("flux_transformer")
751
+ if self.is_loaded(ModelType.FLUX_TEXT_ENCODER):
752
+ models_unloaded.append("flux_text_encoder")
753
+ if self.is_loaded(ModelType.FLUX_PIPELINE):
754
+ models_unloaded.append("flux_pipeline")
755
+
756
+ self.unload_model(ModelType.FLUX_TRANSFORMER, force_gc=False, verbose=verbose)
757
+ self.unload_model(ModelType.FLUX_TEXT_ENCODER, force_gc=False, verbose=verbose)
758
+ self.unload_model(ModelType.FLUX_PIPELINE, force_gc=True, verbose=verbose)
759
+
760
+ if models_unloaded:
761
+ log_message("Flux Kontext models unloaded.", verbose=verbose)
762
+
763
+ def unload_all(self, verbose: bool = False):
764
+ """Unload all models and free all GPU memory."""
765
+ with self._lock:
766
+ log_message("Unloading all models...", verbose=verbose)
767
+ for model_type in list(self.models.keys()):
768
+ if self.is_loaded(model_type):
769
+ del self.models[model_type]
770
+ self.models[model_type] = None
771
+
772
+ if torch.cuda.is_available():
773
+ gc.collect()
774
+ torch.cuda.empty_cache()
775
+ log_message("All models unloaded.", verbose=verbose)
776
+
777
+ def get_memory_stats(self):
778
+ """Get current GPU memory usage statistics."""
779
+ if not torch.cuda.is_available():
780
+ return {"device": "cpu", "memory": "N/A"}
781
+
782
+ allocated = torch.cuda.memory_allocated() / 1024**3
783
+ reserved = torch.cuda.memory_reserved() / 1024**3
784
+ return {
785
+ "device": torch.cuda.get_device_name(0),
786
+ "allocated_gb": f"{allocated:.2f}",
787
+ "reserved_gb": f"{reserved:.2f}",
788
+ }
789
+
790
+ def print_memory_stats(self):
791
+ """Print current GPU memory usage."""
792
+ stats = self.get_memory_stats()
793
+ if stats["memory"] == "N/A":
794
+ log_message(f"Device: {stats['device']}", always_print=True)
795
+ else:
796
+ log_message(
797
+ f"GPU Memory - Allocated: {stats['allocated_gb']} GB, "
798
+ f"Reserved: {stats['reserved_gb']} GB",
799
+ always_print=True,
800
+ )
801
+
802
+ @contextmanager
803
+ def upscale_context(self, verbose: bool = False):
804
+ """Context manager for upscale model - auto-loads and unloads."""
805
+ try:
806
+ self.load_upscale(verbose=verbose)
807
+ yield self.models[ModelType.UPSCALE]
808
+ finally:
809
+ self.unload_upscale_models(verbose=verbose)
810
+
811
+ @contextmanager
812
+ def upscale_lite_context(self, verbose: bool = False):
813
+ """Context manager for upscale lite model - auto-loads and unloads."""
814
+ try:
815
+ self.load_upscale_lite(verbose=verbose)
816
+ yield self.models[ModelType.UPSCALE_LITE]
817
+ finally:
818
+ self.unload_upscale_models(verbose=verbose)
819
+
820
+ @contextmanager
821
+ def ocr_context(self, hf_token=None, verbose: bool = False):
822
+ """Context manager for OCR models - auto-loads and unloads.
823
+
824
+ Args:
825
+ hf_token: Hugging Face token for gated repo access
826
+ verbose: Whether to print verbose logging
827
+ """
828
+ try:
829
+ yolo = self.load_yolo_speech_bubble(verbose=verbose)
830
+ yolo_osbtext = self.load_yolo_osbtext(token=hf_token, verbose=verbose)
831
+ yield yolo, yolo_osbtext
832
+ finally:
833
+ self.unload_ocr_models(verbose=verbose)
834
+
835
+ @contextmanager
836
+ def flux_context(self, verbose: bool = False):
837
+ """Context manager for Flux models - auto-loads and unloads."""
838
+ try:
839
+ transformer, text_encoder, pipeline = self.load_flux_models(verbose=verbose)
840
+ yield transformer, text_encoder, pipeline
841
+ finally:
842
+ self.unload_flux_models(verbose=verbose)
843
+
844
+
845
+ # Global singleton instance
846
+ _model_manager = None
847
+
848
+
849
+ def get_model_manager() -> ModelManager:
850
+ """Get the global model manager instance."""
851
+ global _model_manager
852
+ if _model_manager is None:
853
+ _model_manager = ModelManager()
854
+ return _model_manager
core/outside_text_processor.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import os
4
+ import random
5
+ import re
6
+ import tempfile
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image
13
+ from sklearn.cluster import KMeans
14
+
15
+ from core.config import MangaTranslatorConfig
16
+ from core.image.image_utils import cv2_to_pil, pil_to_cv2, process_bubble_image_cached
17
+ from core.image.inpainting import FluxKontextInpainter
18
+ from core.image.ocr_detection import OutsideTextDetector, extract_text_with_manga_ocr
19
+ from core.ml.model_manager import get_model_manager
20
+ from utils.logging import log_message
21
+
22
+
23
+ def process_outside_text(
24
+ pil_image: Image.Image,
25
+ config: MangaTranslatorConfig,
26
+ image_path: Union[str, Path],
27
+ image_format: Optional[str],
28
+ verbose: bool = False,
29
+ bubble_data: Optional[List[Dict[str, Any]]] = None,
30
+ text_free_boxes: Optional[List[List[float]]] = None,
31
+ ) -> Tuple[Image.Image, List[Dict[str, Any]]]:
32
+ """
33
+ Process outside text detection, inpainting, and prepare data for translation.
34
+
35
+ This function handles the complete outside text processing pipeline:
36
+ 1. Detects text outside speech bubbles using OCR
37
+ 2. Inpaints the detected text regions using FluxKontext
38
+ 3. Prepares the outside text data for translation API calls
39
+
40
+ Args:
41
+ pil_image: The PIL image to process
42
+ config: MangaTranslatorConfig containing all settings
43
+ image_path: Path to the original image file
44
+ image_format: Original image format (PNG, JPEG, etc.)
45
+ processing_scale: The scale factor for image processing
46
+ verbose: Whether to print detailed logging
47
+
48
+ Returns:
49
+ Tuple containing:
50
+ - processed_pil_image: The image after outside text inpainting
51
+ - outside_text_data: List of dicts with outside text information for translation
52
+ """
53
+ if not config.outside_text.enabled:
54
+ return pil_image, []
55
+
56
+ log_message("Detecting text outside speech bubbles...", verbose=verbose)
57
+
58
+ try:
59
+ outside_detector = OutsideTextDetector(
60
+ device=config.device, hf_token=config.outside_text.huggingface_token
61
+ )
62
+ outside_text_results = outside_detector.detect_outside_text(
63
+ str(image_path),
64
+ yolo_model_path=config.yolo_model_path,
65
+ confidence=config.outside_text.osb_confidence,
66
+ conjoined_confidence=config.detection.conjoined_confidence,
67
+ verbose=verbose,
68
+ image_override=pil_image,
69
+ existing_bubbles=bubble_data,
70
+ text_free_boxes=text_free_boxes,
71
+ )
72
+
73
+ if not outside_text_results:
74
+ log_message("No outside text regions found", verbose=verbose)
75
+ outside_detector.unload_models()
76
+ return pil_image, []
77
+
78
+ img_w, img_h = pil_image.size
79
+
80
+ # Filter out probable page numbers
81
+ # Only run OCR on "suspicious" detections (small & in margin)
82
+ if config.outside_text.enable_page_number_filtering and outside_text_results:
83
+ suspicious_crops = []
84
+ suspicious_indices = []
85
+ safe_results = []
86
+
87
+ margin_threshold = max(
88
+ 0.0, min(0.3, config.outside_text.page_filter_margin_threshold)
89
+ )
90
+ min_area_threshold = max(
91
+ 0.0, min(0.2, config.outside_text.page_filter_min_area_ratio)
92
+ )
93
+
94
+ for i, res in enumerate(outside_text_results):
95
+ bbox, _ = res
96
+ x1, y1, x2, y2 = [int(c) for c in bbox]
97
+ cy = (y1 + y2) / 2
98
+
99
+ is_in_margin = (cy < img_h * margin_threshold) or (
100
+ cy > img_h * (1 - margin_threshold)
101
+ )
102
+
103
+ area = (x2 - x1) * (y2 - y1)
104
+ is_small = area < (img_w * img_h * min_area_threshold)
105
+
106
+ if is_in_margin and is_small:
107
+ suspicious_crops.append(pil_image.crop((x1, y1, x2, y2)))
108
+ suspicious_indices.append(i)
109
+ else:
110
+ safe_results.append(res)
111
+
112
+ if suspicious_crops:
113
+ log_message(
114
+ f"Verifying {len(suspicious_crops)} suspicious OSB regions with OCR...",
115
+ verbose=verbose,
116
+ )
117
+ suspicious_texts = extract_text_with_manga_ocr(
118
+ suspicious_crops, verbose=verbose
119
+ )
120
+
121
+ kept_suspicious_count = 0
122
+ for i, text in enumerate(suspicious_texts):
123
+ # Regex for page numbers: digits, "Page 20", "p. 20", etc.
124
+ is_page_number = bool(
125
+ re.match(
126
+ r"^\s*(?:page\.?|p\.?)?\s*\d+\s*$", text, re.IGNORECASE
127
+ )
128
+ )
129
+
130
+ if not is_page_number:
131
+ safe_results.append(outside_text_results[suspicious_indices[i]])
132
+ kept_suspicious_count += 1
133
+ else:
134
+ log_message(
135
+ f"Filtered out page number: '{text}'", verbose=verbose
136
+ )
137
+
138
+ outside_text_results = safe_results
139
+ log_message(
140
+ f"Remaining OSB regions after filtering: {len(outside_text_results)}",
141
+ verbose=verbose,
142
+ )
143
+
144
+ # Build a mask of all detected speech bubbles to prevent OSB inpainting overlap
145
+ total_bubble_mask = np.zeros((img_h, img_w), dtype=bool)
146
+ if bubble_data:
147
+ for bubble in bubble_data:
148
+ try:
149
+ mask = bubble.get("sam_mask") if isinstance(bubble, dict) else None
150
+ if mask is not None:
151
+ mask_np = np.asarray(mask)
152
+ if mask_np.ndim == 3:
153
+ mask_np = mask_np[..., 0]
154
+ mask_bool = mask_np > 0
155
+ if mask_bool.shape[0] == img_h and mask_bool.shape[1] == img_w:
156
+ total_bubble_mask |= mask_bool
157
+ continue
158
+
159
+ bbox = bubble.get("bbox") if isinstance(bubble, dict) else None
160
+ if bbox and len(bbox) == 4:
161
+ x0, y0, x1, y1 = [int(c) for c in bbox]
162
+ x0 = max(0, min(img_w, x0))
163
+ x1 = max(0, min(img_w, x1))
164
+ y0 = max(0, min(img_h, y0))
165
+ y1 = max(0, min(img_h, y1))
166
+ if x1 > x0 and y1 > y0:
167
+ total_bubble_mask[y0:y1, x0:x1] = True
168
+ except Exception as e:
169
+ log_message(
170
+ f"Warning: Failed to apply bubble mask for OSB exclusion: {e}",
171
+ verbose=verbose,
172
+ )
173
+
174
+ mime_type = (
175
+ "image/png"
176
+ if image_format and image_format.upper() == "PNG"
177
+ else "image/jpeg"
178
+ )
179
+ cv2_ext = ".png" if image_format and image_format.upper() == "PNG" else ".jpg"
180
+
181
+ # Probe original text color for OSB rendering
182
+ original_text_colors = {}
183
+ for ocr_result in outside_text_results:
184
+ bbox_coords, conf = ocr_result
185
+ x1, y1, x2, y2 = [int(c) for c in bbox_coords]
186
+ bbox_tuple = (x1, y1, x2, y2)
187
+
188
+ bbox_area_img = pil_image.crop((x1, y1, x2, y2))
189
+ bbox_array = np.array(bbox_area_img)
190
+
191
+ if bbox_array.shape[-1] == 4:
192
+ bbox_array = bbox_array[..., :3]
193
+
194
+ pixels = bbox_array.reshape(-1, 3)
195
+
196
+ # Use K-Means to find 2 dominant colors
197
+ kmeans = KMeans(n_clusters=2, random_state=42, n_init=10)
198
+ kmeans.fit(pixels)
199
+
200
+ labels = kmeans.labels_
201
+ centers = kmeans.cluster_centers_
202
+
203
+ unique, counts = np.unique(labels, return_counts=True)
204
+ dominant_cluster_idx = unique[np.argmax(counts)]
205
+
206
+ # Dominant cluster is usually the background (text pixels are sparse)
207
+ bg_color_rgb = centers[dominant_cluster_idx]
208
+ # Use proper luminance calculation (ITU-R BT.601)
209
+ bg_brightness = (
210
+ 0.299 * bg_color_rgb[0]
211
+ + 0.587 * bg_color_rgb[1]
212
+ + 0.114 * bg_color_rgb[2]
213
+ )
214
+ is_dark_text = (
215
+ bg_brightness < 128
216
+ ) # passed downstream; renderer inverts for text color
217
+ original_text_colors[bbox_tuple] = is_dark_text
218
+
219
+ log_message(
220
+ f"OSB bbox {bbox_tuple}: "
221
+ f"{'Dark' if is_dark_text else 'Light'} background detected "
222
+ f"(luminance={bg_brightness:.1f})",
223
+ verbose=verbose,
224
+ )
225
+
226
+ log_message("Inpainting outside text regions...", verbose=verbose)
227
+ inpainter = (
228
+ None
229
+ if config.outside_text.force_cv2_inpainting
230
+ else FluxKontextInpainter(
231
+ device=config.device,
232
+ huggingface_token=config.outside_text.huggingface_token,
233
+ num_inference_steps=config.outside_text.flux_num_inference_steps,
234
+ residual_diff_threshold=config.outside_text.flux_residual_diff_threshold,
235
+ )
236
+ )
237
+
238
+ mask_groups, _ = outside_detector.get_text_masks(
239
+ str(image_path),
240
+ bbox_expansion_percent=config.outside_text.bbox_expansion_percent,
241
+ text_box_proximity_ratio=config.outside_text.text_box_proximity_ratio,
242
+ verbose=verbose,
243
+ image_override=pil_image,
244
+ existing_results=outside_text_results,
245
+ )
246
+
247
+ current_image = pil_image
248
+ temp_files = []
249
+ try:
250
+ if mask_groups:
251
+ base_seed = (
252
+ random.randint(1, 999999)
253
+ if config.outside_text.seed == -1
254
+ else config.outside_text.seed
255
+ )
256
+
257
+ flux_inpaints = 0
258
+ cv2_inpaints = 0
259
+ for i, group in enumerate(mask_groups):
260
+ log_message(
261
+ f"Inpainting outside text region {i + 1}/{len(mask_groups)}",
262
+ verbose=verbose,
263
+ )
264
+ combined_mask = group["combined_mask"]
265
+ combined_mask = np.logical_and(
266
+ combined_mask, np.logical_not(total_bubble_mask)
267
+ )
268
+ if not np.any(combined_mask):
269
+ log_message(
270
+ "Skipping outside text region after bubble masking (no remaining area)",
271
+ verbose=verbose,
272
+ )
273
+ continue
274
+ region_seed = base_seed + i if base_seed > 0 else base_seed
275
+
276
+ original_bbox_dict = group.get("original_bbox")
277
+ composite_clip_bbox = None
278
+ fill_color = None
279
+ fallback_fill_color = None
280
+ ox0 = oy0 = ox1 = oy1 = None
281
+ if original_bbox_dict:
282
+ ox = int(original_bbox_dict.get("x", 0))
283
+ oy = int(original_bbox_dict.get("y", 0))
284
+ ow = int(original_bbox_dict.get("width", 0))
285
+ oh = int(original_bbox_dict.get("height", 0))
286
+ if ow > 0 and oh > 0:
287
+ ox0 = max(0, min(img_w, ox))
288
+ oy0 = max(0, min(img_h, oy))
289
+ ox1 = max(0, min(img_w, ox + ow))
290
+ oy1 = max(0, min(img_h, oy + oh))
291
+ composite_clip_bbox = (ox, oy, ox + ow, oy + oh)
292
+
293
+ # Determine detected text color for this region to ensure contrast
294
+ group_bg_is_dark = None
295
+ if original_text_colors:
296
+ votes_dark = 0
297
+ votes_light = 0
298
+ gx1, gy1, gx2, gy2 = ox, oy, ox + ow, oy + oh
299
+
300
+ for (
301
+ bx1,
302
+ by1,
303
+ bx2,
304
+ by2,
305
+ ), t_dark in original_text_colors.items():
306
+ # Check if center of OCR box is inside group box
307
+ bcx = (bx1 + bx2) / 2
308
+ bcy = (by1 + by2) / 2
309
+ if (
310
+ bcx >= gx1
311
+ and bcx <= gx2
312
+ and bcy >= gy1
313
+ and bcy <= gy2
314
+ ):
315
+ if t_dark:
316
+ votes_dark += 1
317
+ else:
318
+ votes_light += 1
319
+ if votes_dark > 0 or votes_light > 0:
320
+ group_bg_is_dark = votes_dark >= votes_light
321
+
322
+ # Detected value represents background brightness
323
+ fallback_fill_color = (
324
+ (0, 0, 0)
325
+ if group_bg_is_dark
326
+ else (255, 255, 255)
327
+ )
328
+
329
+ t_type = "Dark" if group_bg_is_dark else "Light"
330
+ f_col = (
331
+ "White"
332
+ if fallback_fill_color == (255, 255, 255)
333
+ else "Black"
334
+ )
335
+ log_message(
336
+ f"OSB Region {i + 1}: Detected {t_type} background. "
337
+ f"Fallback fill: {f_col}.",
338
+ verbose=verbose,
339
+ )
340
+
341
+ # Expanded sampling around the original bbox to find background color
342
+ expansion_px = 2
343
+ sx1 = max(0, ox - expansion_px)
344
+ sy1 = max(0, oy - expansion_px)
345
+ sx2 = min(img_w, ox + ow + expansion_px)
346
+ sy2 = min(img_h, oy + oh + expansion_px)
347
+
348
+ if sx2 > sx1 and sy2 > sy1:
349
+ mask_h, mask_w = sy2 - sy1, sx2 - sx1
350
+ local_mask = np.ones((mask_h, mask_w), dtype=bool)
351
+
352
+ lx0 = max(0, ox0 - sx1)
353
+ ly0 = max(0, oy0 - sy1)
354
+ lx1 = min(mask_w, ox1 - sx1)
355
+ ly1 = min(mask_h, oy1 - sy1)
356
+
357
+ if lx1 > lx0 and ly1 > ly0:
358
+ local_mask[ly0:ly1, lx0:lx1] = False
359
+
360
+ border_pixels = None
361
+ min_border_pixels = 20
362
+ if np.count_nonzero(local_mask) >= min_border_pixels:
363
+ sampling_crop = current_image.crop(
364
+ (sx1, sy1, sx2, sy2)
365
+ )
366
+ crop_np = np.array(sampling_crop.convert("RGB"))
367
+ border_pixels = crop_np[local_mask]
368
+
369
+ if border_pixels is not None and border_pixels.size > 0:
370
+ white_thresh = 250
371
+ black_thresh = 5
372
+ ratio_threshold = 0.95
373
+
374
+ white_ratio = np.mean(
375
+ np.all(border_pixels >= white_thresh, axis=1)
376
+ )
377
+ black_ratio = np.mean(
378
+ np.all(border_pixels <= black_thresh, axis=1)
379
+ )
380
+
381
+ if fallback_fill_color is None:
382
+ fallback_fill_color = (
383
+ (255, 255, 255)
384
+ if white_ratio >= black_ratio
385
+ else (0, 0, 0)
386
+ )
387
+
388
+ force_fill = (
389
+ config.outside_text.force_cv2_inpainting
390
+ )
391
+ should_simple_fill = (
392
+ white_ratio >= ratio_threshold
393
+ or black_ratio >= ratio_threshold
394
+ or force_fill
395
+ )
396
+
397
+ if should_simple_fill:
398
+ fill_color = fallback_fill_color
399
+
400
+ if force_fill and not (
401
+ white_ratio >= ratio_threshold
402
+ or black_ratio >= ratio_threshold
403
+ ):
404
+ log_message(
405
+ "Forcing CV2 fill: defaulting to "
406
+ f"{'white' if fill_color == (255, 255, 255) else 'black'} background",
407
+ verbose=verbose,
408
+ )
409
+ else:
410
+ log_message(
411
+ "Skipping Flux for OSB region: detected pure "
412
+ f"{'white' if fill_color == (255, 255, 255) else 'black'} background",
413
+ verbose=verbose,
414
+ )
415
+
416
+ def apply_simple_fill(color_to_use):
417
+ new_img = current_image.copy()
418
+
419
+ if (
420
+ original_bbox_dict
421
+ and ox1 is not None
422
+ and ox0 is not None
423
+ and oy1 is not None
424
+ and oy0 is not None
425
+ and ox1 > ox0
426
+ and oy1 > oy0
427
+ ):
428
+ # Restricted fill logic: Clip mask to bbox
429
+ region_mask = combined_mask[oy0:oy1, ox0:ox1]
430
+ if not np.any(region_mask):
431
+ return new_img
432
+
433
+ mask_pil = Image.fromarray(
434
+ (region_mask * 255).astype(np.uint8), mode="L"
435
+ )
436
+ patch = Image.new(
437
+ "RGB", (ox1 - ox0, oy1 - oy0), color_to_use
438
+ )
439
+ new_img.paste(patch, (ox0, oy0), mask=mask_pil)
440
+ else:
441
+ # Full mask fill
442
+ mask_pil = Image.fromarray(
443
+ (combined_mask * 255).astype(np.uint8), mode="L"
444
+ )
445
+ patch = Image.new("RGB", new_img.size, color_to_use)
446
+ new_img.paste(patch, (0, 0), mask=mask_pil)
447
+
448
+ return new_img
449
+
450
+ if fill_color is not None:
451
+ current_image = apply_simple_fill(fill_color)
452
+ cv2_inpaints += 1
453
+ continue
454
+
455
+ flux_failed = False
456
+ flux_fail_reason = None
457
+ inpainted_image = None
458
+
459
+ if inpainter is None:
460
+ flux_failed = True
461
+ flux_fail_reason = "Flux inpainter unavailable"
462
+ else:
463
+ try:
464
+ inpainted_image = inpainter.inpaint_mask(
465
+ current_image,
466
+ combined_mask,
467
+ seed=region_seed,
468
+ verbose=verbose,
469
+ strict_mask_clipping=True,
470
+ composite_clip_bbox=composite_clip_bbox,
471
+ )
472
+ if inpainted_image is current_image:
473
+ flux_failed = True
474
+ flux_fail_reason = (
475
+ "Flux returned original image (no inpaint)"
476
+ )
477
+ except Exception as e:
478
+ flux_failed = True
479
+ flux_fail_reason = f"Flux inpainting error: {e}"
480
+
481
+ if flux_failed:
482
+ fallback_color_to_use = (
483
+ fallback_fill_color
484
+ if fallback_fill_color
485
+ else (255, 255, 255)
486
+ )
487
+ log_message(
488
+ f"Flux failed for OSB region {i + 1}"
489
+ + (f" ({flux_fail_reason})" if flux_fail_reason else "")
490
+ + f"; falling back to CV2 fill ({fallback_color_to_use})",
491
+ always_print=True,
492
+ )
493
+ current_image = apply_simple_fill(fallback_color_to_use)
494
+ cv2_inpaints += 1
495
+ continue
496
+
497
+ flux_inpaints += 1
498
+ # Save to disk if more regions remain to reduce memory usage
499
+ if i < len(mask_groups) - 1:
500
+ temp_file = None
501
+ try:
502
+ temp_fd, temp_file = tempfile.mkstemp(suffix=".png")
503
+ os.close(temp_fd)
504
+ inpainted_image.save(temp_file, format="PNG")
505
+ temp_files.append(temp_file)
506
+
507
+ with Image.open(temp_file) as img_tmp:
508
+ img_tmp.load()
509
+ current_image = img_tmp.copy()
510
+
511
+ del inpainted_image
512
+ gc.collect()
513
+ log_message(
514
+ "Saved intermediate inpainting result to disk",
515
+ verbose=verbose,
516
+ )
517
+ except Exception as e:
518
+ log_message(
519
+ "Warning: Failed to save intermediate image to disk: "
520
+ f"{e}. Continuing with in-memory processing.",
521
+ verbose=verbose,
522
+ )
523
+ # Fallback to in-memory if disk save fails
524
+ current_image = inpainted_image
525
+ if temp_file and temp_file in temp_files:
526
+ temp_files.remove(temp_file)
527
+ else:
528
+ current_image = inpainted_image
529
+
530
+ log_message("Outside text inpainting completed", verbose=verbose)
531
+ log_message(
532
+ f"Inpainted {len(mask_groups)} outside text regions (Flux: {flux_inpaints}, CV2: {cv2_inpaints})",
533
+ always_print=True,
534
+ )
535
+ finally:
536
+ for temp_file in temp_files:
537
+ if temp_file and os.path.exists(temp_file):
538
+ try:
539
+ os.remove(temp_file)
540
+ except Exception:
541
+ pass
542
+
543
+ outside_text_data = []
544
+ original_cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
545
+
546
+ for ocr_result in outside_text_results:
547
+ bbox_coords, conf = ocr_result
548
+ x1, y1, x2, y2 = [int(c) for c in bbox_coords]
549
+ bbox_tuple = (x1, y1, x2, y2)
550
+
551
+ outside_text_image_cv = original_cv_image[y1:y2, x1:x2].copy()
552
+
553
+ outside_text_image_pil = cv2_to_pil(outside_text_image_cv)
554
+
555
+ original_crop_pil = outside_text_image_pil.copy()
556
+
557
+ # Disable upscaling in test_mode
558
+ osb_upscale_method = (
559
+ "none" if config.test_mode else config.translation.upscale_method
560
+ )
561
+
562
+ if osb_upscale_method == "model":
563
+ model_manager = get_model_manager()
564
+ with model_manager.upscale_context() as upscale_model:
565
+ final_text_pil = process_bubble_image_cached(
566
+ outside_text_image_pil,
567
+ upscale_model,
568
+ config.device,
569
+ config.translation.osb_min_side_pixels,
570
+ "min",
571
+ "model",
572
+ verbose,
573
+ )
574
+ elif osb_upscale_method == "model_lite":
575
+ model_manager = get_model_manager()
576
+ with model_manager.upscale_lite_context() as upscale_model:
577
+ final_text_pil = process_bubble_image_cached(
578
+ outside_text_image_pil,
579
+ upscale_model,
580
+ config.device,
581
+ config.translation.osb_min_side_pixels,
582
+ "min",
583
+ "model_lite",
584
+ verbose,
585
+ )
586
+ elif osb_upscale_method == "lanczos":
587
+ w, h = outside_text_image_pil.size
588
+ min_side = min(w, h)
589
+ if min_side < config.translation.osb_min_side_pixels:
590
+ scale_factor = config.translation.osb_min_side_pixels / min_side
591
+ new_w = int(w * scale_factor)
592
+ new_h = int(h * scale_factor)
593
+ resized_text = outside_text_image_pil.resize(
594
+ (new_w, new_h), Image.LANCZOS
595
+ )
596
+ else:
597
+ resized_text = outside_text_image_pil
598
+ final_text_pil = resized_text
599
+ else:
600
+ final_text_pil = outside_text_image_pil
601
+
602
+ outside_text_image_cv = pil_to_cv2(final_text_pil)
603
+
604
+ w = max(1, x2 - x1)
605
+ h = max(1, y2 - y1)
606
+ aspect_ratio = float(h) / float(w)
607
+
608
+ try:
609
+ is_success, buffer = cv2.imencode(cv2_ext, outside_text_image_cv)
610
+ if is_success:
611
+ image_b64 = base64.b64encode(buffer).decode("utf-8")
612
+
613
+ outside_text_data.append(
614
+ {
615
+ "bbox": bbox_tuple,
616
+ "confidence": conf,
617
+ "is_outside_text": True,
618
+ "image_b64": image_b64,
619
+ "mime_type": mime_type,
620
+ "is_dark_text": original_text_colors.get(bbox_tuple, True),
621
+ "aspect_ratio": aspect_ratio,
622
+ "original_crop_pil": original_crop_pil,
623
+ }
624
+ )
625
+ except Exception as e:
626
+ log_message(
627
+ f"Error encoding outside text bbox {(x1, y1, x2, y2)}: {e}",
628
+ verbose=verbose,
629
+ )
630
+
631
+ return current_image, outside_text_data
632
+
633
+ except Exception as e:
634
+ log_message(
635
+ f"Error during outside text detection/inpainting: {e}",
636
+ always_print=True,
637
+ )
638
+ return pil_image, []
core/pipeline.py ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import math
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
7
+
8
+ import cv2
9
+ from PIL import Image
10
+
11
+ from core.caching import get_cache
12
+ from core.config import MangaTranslatorConfig, PreprocessingConfig, RenderingConfig
13
+ from core.scaling import scale_font_size, scale_length, scale_scalar
14
+ from utils.exceptions import (
15
+ CancellationError,
16
+ CleaningError,
17
+ FontError,
18
+ ImageProcessingError,
19
+ RenderingError,
20
+ TranslationError,
21
+ )
22
+ from utils.logging import log_message
23
+
24
+ from .image.cleaning import clean_speech_bubbles, retry_cleaning_with_otsu
25
+ from .image.detection import detect_panels, detect_speech_bubbles
26
+ from .image.image_utils import (
27
+ convert_image_to_target_mode,
28
+ cv2_to_pil,
29
+ pil_to_cv2,
30
+ resize_to_max_side,
31
+ save_image_with_compression,
32
+ upscale_image,
33
+ upscale_image_to_dimension,
34
+ )
35
+ from .image.sorting import sort_bubbles_by_reading_order
36
+ from .ml.model_manager import get_model_manager
37
+ from .outside_text_processor import process_outside_text
38
+ from .services.translation import (
39
+ call_translation_api_batch,
40
+ prepare_bubble_images_for_translation,
41
+ )
42
+ from .text.text_processing import is_latin_style_language
43
+ from .text.text_renderer import render_text_skia
44
+
45
+ if TYPE_CHECKING:
46
+ from ui.cancellation import CancellationManager
47
+
48
+
49
+ def get_image_encoding_params(pil_image_format: Optional[str]) -> Tuple[str, str]:
50
+ """Returns (mime_type, cv2_ext) for a given PIL image format."""
51
+ if pil_image_format and pil_image_format.upper() == "PNG":
52
+ return "image/png", ".png"
53
+ return "image/jpeg", ".jpg"
54
+
55
+
56
+ def _resolve_pre_upscale_factor(
57
+ pre_cfg: Optional[PreprocessingConfig],
58
+ verbose: bool = False,
59
+ ) -> float:
60
+ if pre_cfg is None or not pre_cfg.enabled:
61
+ return 1.0
62
+
63
+ factor = max(1.0, min(float(pre_cfg.factor or 1.0), 8.0))
64
+ if factor <= 1.01:
65
+ return 1.0
66
+
67
+ log_message(f"Initial upscaling enabled: {factor:.2f}x", verbose=verbose)
68
+ return factor
69
+
70
+
71
+ def _apply_pre_upscale_if_needed(
72
+ image: Image.Image,
73
+ config: MangaTranslatorConfig,
74
+ verbose: bool = False,
75
+ ) -> Tuple[Image.Image, float]:
76
+ factor = _resolve_pre_upscale_factor(
77
+ getattr(config, "preprocessing", None), verbose
78
+ )
79
+ if factor == 1.0:
80
+ return image, 1.0
81
+
82
+ # Use the output upscale model setting for initial upscaling as well
83
+ model_type = (
84
+ getattr(config.output, "image_upscale_model", "model_lite")
85
+ if hasattr(config, "output")
86
+ else "model_lite"
87
+ )
88
+ upscaled = upscale_image(image, factor, model_type=model_type, verbose=verbose)
89
+ return upscaled, factor
90
+
91
+
92
+ def translate_and_render(
93
+ image_path: Union[str, Path],
94
+ config: MangaTranslatorConfig,
95
+ output_path: Optional[Union[str, Path]] = None,
96
+ cancellation_manager: Optional["CancellationManager"] = None,
97
+ ):
98
+ """
99
+ Main function to translate manga speech bubbles and render translations using a config object.
100
+
101
+ Args:
102
+ image_path (str or Path): Path to input image
103
+ config (MangaTranslatorConfig): Configuration object containing all settings.
104
+ output_path (str or Path, optional): Path to save the final image. If None, image is not saved.
105
+
106
+ Returns:
107
+ PIL.Image: Final translated image
108
+ """
109
+ start_time = time.time()
110
+ image_path = Path(image_path)
111
+ verbose = config.verbose
112
+ device = config.device
113
+
114
+ log_message(f"Using device: {device}", verbose=verbose)
115
+
116
+ try:
117
+ pil_original = Image.open(image_path)
118
+ image_format = pil_original.format
119
+ mime_type, cv2_ext = get_image_encoding_params(image_format)
120
+ log_message(
121
+ f"Original image format: {image_format} -> MIME: {mime_type}",
122
+ verbose=verbose,
123
+ )
124
+ except FileNotFoundError:
125
+ log_message(f"Error: Input image not found at {image_path}", always_print=True)
126
+ raise
127
+ except Exception as e:
128
+ log_message(f"Error opening image {image_path}: {e}", always_print=True)
129
+ raise
130
+
131
+ if cancellation_manager and cancellation_manager.is_cancelled():
132
+ raise TranslationError("Process cancelled by user.")
133
+
134
+ desired_format = config.output.output_format
135
+ output_ext_for_mode = (
136
+ Path(output_path).suffix.lower() if output_path else image_path.suffix.lower()
137
+ )
138
+
139
+ if desired_format == "jpeg" or (
140
+ desired_format == "auto" and output_ext_for_mode in [".jpg", ".jpeg"]
141
+ ):
142
+ target_mode = "RGB"
143
+ else: # Default to RGBA for PNG, WEBP, or other formats in auto mode
144
+ target_mode = "RGBA"
145
+ log_message(f"Target mode: {target_mode}", verbose=verbose)
146
+
147
+ pil_image_processed = convert_image_to_target_mode(
148
+ pil_original, target_mode, verbose
149
+ )
150
+ pil_image_processed, _ = _apply_pre_upscale_if_needed(
151
+ pil_image_processed, config, verbose
152
+ )
153
+
154
+ # Check for Upscaling Only Mode (skip detection, cleaning, and translation)
155
+ if config.upscaling_only:
156
+ log_message(
157
+ "Upscaling only mode - skipping detection and translation",
158
+ always_print=True,
159
+ )
160
+ final_image_to_save = pil_image_processed
161
+
162
+ if config.output.upscale_final_image:
163
+ log_message("Upscaling final image...", verbose=verbose, always_print=True)
164
+ final_image_to_save = upscale_image(
165
+ final_image_to_save,
166
+ config.output.image_upscale_factor,
167
+ model_type=config.output.image_upscale_model,
168
+ verbose=verbose,
169
+ )
170
+
171
+ if output_path:
172
+ if final_image_to_save.mode != target_mode:
173
+ log_message(f"Converting final image to {target_mode}", verbose=verbose)
174
+ final_image_to_save = final_image_to_save.convert(target_mode)
175
+
176
+ try:
177
+ save_image_with_compression(
178
+ final_image_to_save,
179
+ output_path,
180
+ jpeg_quality=config.output.jpeg_quality,
181
+ png_compression=config.output.png_compression,
182
+ verbose=verbose,
183
+ )
184
+ except ImageProcessingError as e:
185
+ log_message(f"Failed to save image: {e}", always_print=True)
186
+ raise
187
+
188
+ end_time = time.time()
189
+ processing_time = end_time - start_time
190
+ log_message(
191
+ f"Processing completed in {processing_time:.2f}s", always_print=True
192
+ )
193
+
194
+ return final_image_to_save
195
+
196
+ # Calculate dynamic processing scale based on image area relative to 1MP (if enabled)
197
+ if config.preprocessing.auto_scale:
198
+ width, height = pil_image_processed.size
199
+ processing_scale = math.sqrt((width * height) / 1_000_000)
200
+ log_message(
201
+ f"Dynamic processing scale: {processing_scale:.2f}x", verbose=verbose
202
+ )
203
+ else:
204
+ processing_scale = 1.0
205
+
206
+ get_cache().set_current_image(pil_image_processed, verbose)
207
+
208
+ original_cv_image = pil_to_cv2(pil_image_processed)
209
+
210
+ # Detect speech bubbles first so OSB processing can respect bubble regions
211
+ log_message("Detecting speech bubbles...", verbose=verbose)
212
+ try:
213
+ bubble_data, text_free_boxes = detect_speech_bubbles(
214
+ image_path,
215
+ config.yolo_model_path,
216
+ config.detection.confidence,
217
+ verbose=verbose,
218
+ device=device,
219
+ use_sam2=config.detection.use_sam2,
220
+ conjoined_detection=config.detection.conjoined_detection,
221
+ conjoined_confidence=config.detection.conjoined_confidence,
222
+ image_override=pil_image_processed,
223
+ osb_enabled=config.outside_text.enabled,
224
+ osb_text_verification=config.detection.use_osb_text_verification,
225
+ osb_text_hf_token=config.outside_text.huggingface_token,
226
+ )
227
+ except Exception as e:
228
+ log_message(f"Error during detection: {e}", always_print=True)
229
+ bubble_data = []
230
+ text_free_boxes = []
231
+
232
+ # Process outside text detection and inpainting (bubble-aware)
233
+ pil_image_processed, outside_text_data = process_outside_text(
234
+ pil_image_processed,
235
+ config,
236
+ image_path,
237
+ image_format,
238
+ verbose,
239
+ bubble_data=bubble_data,
240
+ text_free_boxes=text_free_boxes,
241
+ )
242
+ original_cv_image = pil_to_cv2(pil_image_processed)
243
+
244
+ full_image_b64 = None
245
+ full_image_mime_type = None
246
+ if config.translation.send_full_page_context:
247
+ try:
248
+ # processing_scale is intentionally not used for context_image_max_side_pixels
249
+ context_image_pil = cv2_to_pil(original_cv_image)
250
+ effective_context_max_side = scale_length(
251
+ config.translation.context_image_max_side_pixels,
252
+ None,
253
+ minimum=512,
254
+ maximum=4096,
255
+ )
256
+
257
+ # Disable upscaling in test_mode
258
+ context_upscale_method = (
259
+ "none" if config.test_mode else config.translation.upscale_method
260
+ )
261
+
262
+ if context_upscale_method == "model":
263
+ # Use upscaling model for full page context
264
+ model_manager = get_model_manager()
265
+ with model_manager.upscale_context() as upscale_model:
266
+ context_image_pil = upscale_image_to_dimension(
267
+ upscale_model,
268
+ context_image_pil,
269
+ effective_context_max_side,
270
+ config.device,
271
+ "max",
272
+ "model",
273
+ verbose,
274
+ )
275
+ # Resize to exact target dimension (downscale if needed)
276
+ context_image_pil = resize_to_max_side(
277
+ context_image_pil,
278
+ effective_context_max_side,
279
+ verbose=verbose,
280
+ )
281
+ log_message(
282
+ "Upscaled full image for context with model", verbose=verbose
283
+ )
284
+ elif context_upscale_method == "model_lite":
285
+ # Use upscaling lite model for full page context
286
+ model_manager = get_model_manager()
287
+ with model_manager.upscale_lite_context() as upscale_model:
288
+ context_image_pil = upscale_image_to_dimension(
289
+ upscale_model,
290
+ context_image_pil,
291
+ effective_context_max_side,
292
+ config.device,
293
+ "max",
294
+ "model_lite",
295
+ verbose,
296
+ )
297
+ # Resize to exact target dimension (downscale if needed)
298
+ context_image_pil = resize_to_max_side(
299
+ context_image_pil,
300
+ effective_context_max_side,
301
+ verbose=verbose,
302
+ )
303
+ log_message(
304
+ "Upscaled full image for context with lite model",
305
+ verbose=verbose,
306
+ )
307
+ elif context_upscale_method == "lanczos":
308
+ # Use LANCZOS resampling
309
+ context_image_pil = resize_to_max_side(
310
+ context_image_pil,
311
+ effective_context_max_side,
312
+ verbose=verbose,
313
+ )
314
+ log_message(
315
+ "Resized full image for context with LANCZOS", verbose=verbose
316
+ )
317
+ else: # upscale_method == "none"
318
+ # No resizing/upscaling
319
+ log_message(
320
+ "Using full image for context without resizing", verbose=verbose
321
+ )
322
+
323
+ context_image_cv = pil_to_cv2(context_image_pil)
324
+ is_success, buffer = cv2.imencode(cv2_ext, context_image_cv)
325
+ if not is_success:
326
+ raise ImageProcessingError(f"Full image encoding to {cv2_ext} failed")
327
+ full_image_b64 = base64.b64encode(buffer).decode("utf-8")
328
+ full_image_mime_type = mime_type
329
+ log_message("Encoded full image for context", verbose=verbose)
330
+ except Exception as e:
331
+ log_message(
332
+ f"Warning: Failed to encode full image context: {e}", always_print=True
333
+ )
334
+
335
+ if cancellation_manager and cancellation_manager.is_cancelled():
336
+ raise CancellationError("Process cancelled by user.")
337
+
338
+ final_image_to_save = pil_image_processed
339
+
340
+ if not bubble_data and not outside_text_data:
341
+ log_message("No speech bubbles or outside text detected", always_print=True)
342
+ else:
343
+ if bubble_data:
344
+ log_message(f"Detected {len(bubble_data)} bubbles", verbose=verbose)
345
+ if outside_text_data:
346
+ log_message(
347
+ f"Detected {len(outside_text_data)} outside text regions",
348
+ verbose=verbose,
349
+ )
350
+
351
+ if cancellation_manager and cancellation_manager.is_cancelled():
352
+ raise CancellationError("Process cancelled by user.")
353
+
354
+ if bubble_data:
355
+ log_message("Cleaning speech bubbles...", verbose=verbose)
356
+ try:
357
+ use_otsu = config.cleaning.use_otsu_threshold
358
+ if config.cleaning.inpaint_colored_bubbles:
359
+ log_message(
360
+ "Flux inpainting enabled for colored bubbles",
361
+ verbose=verbose,
362
+ )
363
+
364
+ cleaned_image_cv, processed_bubbles_info = clean_speech_bubbles(
365
+ pil_image_processed,
366
+ config.yolo_model_path,
367
+ config.detection.confidence,
368
+ pre_computed_detections=bubble_data,
369
+ device=device,
370
+ thresholding_value=config.cleaning.thresholding_value,
371
+ use_otsu_threshold=use_otsu,
372
+ roi_shrink_px=config.cleaning.roi_shrink_px,
373
+ verbose=verbose,
374
+ processing_scale=processing_scale,
375
+ conjoined_confidence=config.detection.conjoined_confidence,
376
+ inpaint_colored_bubbles=config.cleaning.inpaint_colored_bubbles,
377
+ flux_hf_token=config.outside_text.huggingface_token,
378
+ flux_num_inference_steps=config.outside_text.flux_num_inference_steps,
379
+ flux_residual_diff_threshold=config.outside_text.flux_residual_diff_threshold,
380
+ flux_seed=config.outside_text.seed,
381
+ osb_text_verification=config.detection.use_osb_text_verification,
382
+ osb_text_hf_token=config.outside_text.huggingface_token,
383
+ force_cv2_inpainting=config.outside_text.force_cv2_inpainting,
384
+ )
385
+ except CleaningError as e:
386
+ log_message(f"Cleaning failed: {e}", always_print=True)
387
+ cleaned_image_cv = original_cv_image.copy()
388
+ processed_bubbles_info = []
389
+ except Exception as e:
390
+ log_message(f"Error during cleaning: {e}", always_print=True)
391
+ cleaned_image_cv = original_cv_image.copy()
392
+ processed_bubbles_info = []
393
+
394
+ pil_cleaned_image = cv2_to_pil(cleaned_image_cv)
395
+ if pil_cleaned_image.mode != target_mode:
396
+ log_message(
397
+ f"Converting cleaned image to {target_mode}", verbose=verbose
398
+ )
399
+ pil_cleaned_image = pil_cleaned_image.convert(target_mode)
400
+ final_image_to_save = pil_cleaned_image
401
+ else:
402
+ processed_bubbles_info = []
403
+ pil_cleaned_image = pil_image_processed
404
+ if pil_cleaned_image.mode != target_mode:
405
+ log_message(f"Converting image to {target_mode}", verbose=verbose)
406
+ pil_cleaned_image = pil_cleaned_image.convert(target_mode)
407
+ final_image_to_save = pil_cleaned_image
408
+
409
+ # Check for Cleaning Only Mode
410
+ if config.cleaning_only:
411
+ log_message("Cleaning only mode - skipping translation", always_print=True)
412
+ else:
413
+ main_min_font = scale_font_size(
414
+ config.rendering.min_font_size, processing_scale, minimum=4, maximum=256
415
+ )
416
+ main_max_font = scale_font_size(
417
+ config.rendering.max_font_size,
418
+ processing_scale,
419
+ minimum=main_min_font,
420
+ maximum=384,
421
+ )
422
+ padding_pixels = scale_scalar(
423
+ config.rendering.padding_pixels,
424
+ processing_scale,
425
+ minimum=1.0,
426
+ maximum=80.0,
427
+ )
428
+ osb_min_font = scale_font_size(
429
+ config.outside_text.osb_min_font_size,
430
+ processing_scale,
431
+ minimum=4,
432
+ maximum=512,
433
+ )
434
+ osb_max_font = scale_font_size(
435
+ config.outside_text.osb_max_font_size,
436
+ processing_scale,
437
+ minimum=osb_min_font,
438
+ maximum=640,
439
+ )
440
+ osb_outline_width = scale_scalar(
441
+ config.outside_text.osb_outline_width,
442
+ processing_scale,
443
+ minimum=0.0,
444
+ maximum=24.0,
445
+ )
446
+ # Prepare images for Translation
447
+ log_message("Preparing bubble images...", verbose=verbose)
448
+
449
+ # Disable upscaling in test_mode
450
+ bubble_upscale_method = (
451
+ "none" if config.test_mode else config.translation.upscale_method
452
+ )
453
+
454
+ model_manager = get_model_manager()
455
+ # Use appropriate context manager based on upscale_method
456
+ if bubble_upscale_method == "model":
457
+ context_manager = model_manager.upscale_context()
458
+ elif bubble_upscale_method == "model_lite":
459
+ context_manager = model_manager.upscale_lite_context()
460
+ else:
461
+ # For lanczos/none, create a dummy context manager that yields None
462
+ from contextlib import nullcontext
463
+
464
+ context_manager = nullcontext(None)
465
+
466
+ with context_manager as upscale_model:
467
+ bubble_data = prepare_bubble_images_for_translation(
468
+ bubble_data,
469
+ original_cv_image,
470
+ upscale_model,
471
+ config.device,
472
+ mime_type,
473
+ config.translation.bubble_min_side_pixels,
474
+ bubble_upscale_method,
475
+ verbose,
476
+ )
477
+
478
+ if bubble_upscale_method != "none":
479
+ log_message(
480
+ f"Upscaled {len(bubble_data)} bubble images for translation",
481
+ always_print=True,
482
+ )
483
+ else:
484
+ log_message(
485
+ f"Prepared {len(bubble_data)} bubble images for translation",
486
+ always_print=True,
487
+ )
488
+ valid_bubble_data = [b for b in bubble_data if b.get("image_b64")]
489
+ if not valid_bubble_data and not outside_text_data:
490
+ log_message(
491
+ "No valid bubble images or outside text for translation",
492
+ always_print=True,
493
+ )
494
+ else: # Proceed if we have valid bubble data or outside text
495
+ if cancellation_manager and cancellation_manager.is_cancelled():
496
+ raise CancellationError("Process cancelled by user.")
497
+
498
+ # Sort and Translate
499
+ reading_direction = config.translation.reading_direction
500
+ # Merge outside text data with speech bubbles for reading order calculation
501
+ if outside_text_data:
502
+ log_message(
503
+ f"Including {len(outside_text_data)} outside text regions in reading order calculation",
504
+ verbose=verbose,
505
+ )
506
+ # Combine speech bubbles and OSB text for unified reading order sorting
507
+ all_text_data = valid_bubble_data + outside_text_data
508
+ else:
509
+ all_text_data = valid_bubble_data
510
+
511
+ log_message(
512
+ f"Sorting all text elements ({reading_direction.upper()})",
513
+ verbose=verbose,
514
+ )
515
+
516
+ # Detect panels if panel-aware sorting is enabled
517
+ panels = None
518
+ if config.detection.use_panel_sorting:
519
+ try:
520
+ log_message(
521
+ "Detecting panels for panel-aware sorting...",
522
+ verbose=verbose,
523
+ )
524
+ panels = detect_panels(
525
+ image_path,
526
+ confidence=config.detection.panel_confidence,
527
+ device=config.device,
528
+ verbose=verbose,
529
+ )
530
+ if panels:
531
+ log_message(
532
+ f"Detected {len(panels)} panels for sorting",
533
+ always_print=True,
534
+ )
535
+ else:
536
+ log_message(
537
+ "No panels detected, using global sorting",
538
+ verbose=verbose,
539
+ )
540
+ except Exception as e:
541
+ log_message(
542
+ f"Panel detection failed: {e}. Using global sorting.",
543
+ always_print=True,
544
+ )
545
+ panels = None
546
+
547
+ # Sort all text elements (speech bubbles + OSB text) by reading order
548
+ sorted_bubble_data = sort_bubbles_by_reading_order(
549
+ all_text_data, reading_direction, panels=panels
550
+ )
551
+
552
+ bubble_images_b64 = [
553
+ bubble["image_b64"]
554
+ for bubble in sorted_bubble_data
555
+ if "image_b64" in bubble
556
+ ]
557
+ bubble_mime_types = [
558
+ bubble["mime_type"]
559
+ for bubble in sorted_bubble_data
560
+ if "image_b64" in bubble and "mime_type" in bubble
561
+ ]
562
+ translated_texts = []
563
+ if not bubble_images_b64:
564
+ log_message("No valid bubbles after sorting", always_print=True)
565
+ else:
566
+ if getattr(config, "test_mode", False):
567
+ placeholder_long = "Lorem **ipsum** *dolor* sit amet, consectetur adipiscing elit."
568
+ placeholder_short = "Lorem **ipsum** *dolor* sit amet..."
569
+ placeholder_osb = "Lorem"
570
+ log_message(
571
+ f"Test mode: generating placeholders for {len(sorted_bubble_data)} bubbles",
572
+ always_print=True,
573
+ )
574
+ # Map for rendering info used in probe
575
+ bubble_render_info_map_probe = {
576
+ tuple(info["bbox"]): {
577
+ "color": info["color"],
578
+ "mask": info.get("mask"),
579
+ }
580
+ for info in processed_bubbles_info
581
+ if "bbox" in info and "color" in info and "mask" in info
582
+ }
583
+ for i, bubble in enumerate(sorted_bubble_data):
584
+ bbox = bubble["bbox"]
585
+ is_outside_text = bubble.get("is_outside_text", False)
586
+
587
+ # Use simple "Lorem ipsum" for OSB text in test mode
588
+ if is_outside_text:
589
+ translated_texts.append(placeholder_osb)
590
+ continue
591
+
592
+ probe_info = bubble_render_info_map_probe.get(
593
+ tuple(bbox), {}
594
+ )
595
+ bubble_color_bgr = probe_info.get("color", (255, 255, 255))
596
+ cleaned_mask = probe_info.get("mask")
597
+ # Probe fit at max size without mutating the working image
598
+ _probe_canvas = pil_cleaned_image.copy()
599
+ probe_config = RenderingConfig(
600
+ min_font_size=main_max_font,
601
+ max_font_size=main_max_font,
602
+ line_spacing_mult=config.rendering.line_spacing_mult,
603
+ use_subpixel_rendering=config.rendering.use_subpixel_rendering,
604
+ font_hinting=config.rendering.font_hinting,
605
+ use_ligatures=config.rendering.use_ligatures,
606
+ hyphenate_before_scaling=config.rendering.hyphenate_before_scaling,
607
+ hyphen_penalty=config.rendering.hyphen_penalty,
608
+ hyphenation_min_word_length=config.rendering.hyphenation_min_word_length,
609
+ badness_exponent=config.rendering.badness_exponent,
610
+ padding_pixels=padding_pixels,
611
+ supersampling_factor=1, # No supersampling for probe
612
+ )
613
+ try:
614
+ _ = render_text_skia(
615
+ pil_image=_probe_canvas,
616
+ text=placeholder_long,
617
+ bbox=bbox,
618
+ font_dir=config.rendering.font_dir,
619
+ cleaned_mask=cleaned_mask,
620
+ bubble_color_bgr=bubble_color_bgr,
621
+ config=probe_config,
622
+ verbose=verbose,
623
+ bubble_id=str(i + 1),
624
+ )
625
+ fits = True
626
+ except (RenderingError, FontError) as e:
627
+ log_message(
628
+ f"Probe rendering failed: {e}", verbose=verbose
629
+ )
630
+ fits = False
631
+ except Exception as e:
632
+ log_message(
633
+ f"Probe rendering unexpected error: {e}",
634
+ always_print=True,
635
+ )
636
+ fits = False
637
+ translated_texts.append(
638
+ placeholder_long if fits else placeholder_short
639
+ )
640
+ else:
641
+ log_message(
642
+ f"Translating {len(bubble_images_b64)} bubbles: "
643
+ f"{config.translation.input_language} → {config.translation.output_language}",
644
+ always_print=True,
645
+ )
646
+ try:
647
+ translated_texts = call_translation_api_batch(
648
+ config=config.translation,
649
+ images_b64=bubble_images_b64,
650
+ full_image_b64=full_image_b64 or "",
651
+ mime_types=bubble_mime_types,
652
+ full_image_mime_type=full_image_mime_type
653
+ or "image/jpeg",
654
+ bubble_metadata=sorted_bubble_data,
655
+ debug=verbose,
656
+ )
657
+ except TranslationError as e:
658
+ error_str = str(e).lower()
659
+ critical_tokens = (
660
+ "429",
661
+ "rate limit",
662
+ "rate-limit",
663
+ "auth",
664
+ "unauthorized",
665
+ "forbidden",
666
+ "payment",
667
+ "quota",
668
+ "empty response",
669
+ "api failed",
670
+ )
671
+ if any(token in error_str for token in critical_tokens):
672
+ raise
673
+
674
+ log_message(f"Translation failed: {e}", always_print=True)
675
+ translated_texts = [f"[Translation Error: {e}]"] * len(
676
+ bubble_images_b64
677
+ )
678
+ except Exception as e:
679
+ log_message(
680
+ f"Translation API error: {e}", always_print=True
681
+ )
682
+ translated_texts = [
683
+ "[Translation Error: API call raised exception]"
684
+ for _ in sorted_bubble_data
685
+ ]
686
+
687
+ valid_translations = [
688
+ t
689
+ for t in translated_texts
690
+ if t
691
+ and not t.startswith("[Translation Error")
692
+ and not t.startswith("API Error")
693
+ and t.strip()
694
+ not in {
695
+ "[OCR FAILED]",
696
+ "[Empty response / no content]",
697
+ f"[{config.translation.provider}: API call failed/blocked]",
698
+ f"[{config.translation.provider}: OCR call failed/blocked]",
699
+ f"[{config.translation.provider}: Failed to parse response]",
700
+ }
701
+ ]
702
+
703
+ if bubble_images_b64 and not valid_translations:
704
+ raise TranslationError(
705
+ "Total translation failure: All bubbles failed."
706
+ )
707
+
708
+ # Render Translations
709
+ bubble_render_info_map = {
710
+ tuple(info["bbox"]): {
711
+ "color": info["color"],
712
+ "mask": info.get("mask"),
713
+ "base_mask": info.get("base_mask"),
714
+ "is_sam": info.get("is_sam", False),
715
+ "is_colored": info.get("is_colored", False),
716
+ "text_bbox": info.get("text_bbox"),
717
+ }
718
+ for info in processed_bubbles_info
719
+ if "bbox" in info and "color" in info and "mask" in info
720
+ }
721
+ log_message("Rendering translations...", verbose=verbose)
722
+ if len(translated_texts) == len(sorted_bubble_data):
723
+ for i, bubble in enumerate(sorted_bubble_data):
724
+ bubble["translation"] = translated_texts[i]
725
+ bbox = bubble["bbox"]
726
+ text = bubble.get("translation", "")
727
+ is_outside_text = bubble.get("is_outside_text", False)
728
+
729
+ # Convert OSB text to uppercase
730
+ if is_outside_text and text:
731
+ text = text.upper()
732
+ bubble["translation"] = text
733
+
734
+ if (
735
+ not text
736
+ or text.startswith("API Error")
737
+ or text.startswith("[Translation Error]")
738
+ or text.startswith("[Translation Error:")
739
+ or text.strip()
740
+ in {
741
+ "[OCR FAILED]",
742
+ "[Empty response / no content]",
743
+ f"[{config.translation.provider}: API call failed/blocked]",
744
+ f"[{config.translation.provider}: OCR call failed/blocked]",
745
+ f"[{config.translation.provider}: Failed to parse response]",
746
+ }
747
+ ):
748
+ entry_type = "outside text" if is_outside_text else "bubble"
749
+ log_message(
750
+ f"Skipping {entry_type} {bbox} - invalid translation",
751
+ verbose=verbose,
752
+ )
753
+ continue
754
+
755
+ # Use OSB-specific settings for outside text, regular settings for speech bubbles
756
+ if is_outside_text:
757
+ log_message(
758
+ f"Rendering outside text {bbox}: '{text[:30]}...'",
759
+ verbose=verbose,
760
+ )
761
+ font_dir = (
762
+ config.outside_text.osb_font_name
763
+ if config.outside_text.osb_font_name
764
+ else config.rendering.font_dir
765
+ )
766
+ min_font = osb_min_font
767
+ max_font = osb_max_font
768
+ line_spacing = config.outside_text.osb_line_spacing
769
+ use_ligs = config.outside_text.osb_use_ligatures
770
+ # Outside text was inpainted, no mask needed
771
+ cleaned_mask = None
772
+ # Use the detected text color from outside_text_processor
773
+ is_dark_text = bubble.get("is_dark_text", True)
774
+ # Set bubble_color_bgr to mimic the original text color
775
+ # Dark text → dark background value → white rendering
776
+ # Light text → light background value → black rendering
777
+ bubble_color_bgr = (
778
+ (50, 50, 50) if is_dark_text else (255, 255, 255)
779
+ )
780
+ # OSB renders default to horizontal; vertical stacking is fallback-only
781
+ rotation_deg = 0.0
782
+ vertical_stack = False
783
+ else:
784
+ log_message(
785
+ f"Rendering bubble {bbox}: '{text[:30]}...'",
786
+ verbose=verbose,
787
+ )
788
+ font_dir = config.rendering.font_dir
789
+ min_font = main_min_font
790
+ max_font = main_max_font
791
+ line_spacing = config.rendering.line_spacing_mult
792
+ use_ligs = config.rendering.use_ligatures
793
+ render_info = bubble_render_info_map.get(tuple(bbox))
794
+ bubble_color_bgr = (255, 255, 255)
795
+ cleaned_mask = None
796
+ base_mask = None
797
+ is_sam_mask = False
798
+ if render_info:
799
+ bubble_color_bgr = render_info["color"]
800
+ cleaned_mask = render_info.get("mask")
801
+ base_mask = render_info.get("base_mask")
802
+ is_sam_mask = render_info.get("is_sam", False)
803
+ # No rotation/stacking for regular bubbles
804
+ vertical_stack = False
805
+ rotation_deg = 0.0
806
+
807
+ # Only apply hyphenation for Latin-style languages
808
+ should_hyphenate = config.rendering.hyphenate_before_scaling
809
+ if not is_latin_style_language(
810
+ config.translation.output_language
811
+ ):
812
+ should_hyphenate = False
813
+
814
+ render_config = RenderingConfig(
815
+ min_font_size=min_font,
816
+ max_font_size=max_font,
817
+ line_spacing_mult=line_spacing,
818
+ use_subpixel_rendering=(
819
+ config.outside_text.osb_use_subpixel_rendering
820
+ if is_outside_text
821
+ else config.rendering.use_subpixel_rendering
822
+ ),
823
+ font_hinting=(
824
+ config.outside_text.osb_font_hinting
825
+ if is_outside_text
826
+ else config.rendering.font_hinting
827
+ ),
828
+ use_ligatures=use_ligs,
829
+ hyphenate_before_scaling=should_hyphenate,
830
+ hyphen_penalty=config.rendering.hyphen_penalty,
831
+ hyphenation_min_word_length=config.rendering.hyphenation_min_word_length,
832
+ badness_exponent=config.rendering.badness_exponent,
833
+ padding_pixels=padding_pixels,
834
+ outline_width=(
835
+ osb_outline_width if is_outside_text else 0.0
836
+ ),
837
+ supersampling_factor=config.rendering.supersampling_factor,
838
+ )
839
+ success = False
840
+ if is_outside_text:
841
+ try:
842
+ rendered_image = render_text_skia(
843
+ pil_image=pil_cleaned_image,
844
+ text=text,
845
+ bbox=bbox,
846
+ font_dir=font_dir,
847
+ cleaned_mask=cleaned_mask,
848
+ bubble_color_bgr=bubble_color_bgr,
849
+ config=render_config,
850
+ verbose=verbose,
851
+ bubble_id=str(i + 1),
852
+ rotation_deg=rotation_deg,
853
+ vertical_stack=vertical_stack,
854
+ raise_on_safe_error=False,
855
+ )
856
+ success = True
857
+ except Exception as e:
858
+ log_message(
859
+ f"Text rendering failed: {e}", verbose=verbose
860
+ )
861
+ rendered_image = pil_cleaned_image
862
+ success = False
863
+
864
+ # Absolute last-chance fallback: force vertical stacking before giving up
865
+ if not vertical_stack:
866
+ # Fallback uses neutral rotation since we no longer track orientation
867
+ forced_stack_rotation = 0.0
868
+ try:
869
+ log_message(
870
+ "OSB render failed, retrying with vertical-stack fallback",
871
+ verbose=verbose,
872
+ always_print=True,
873
+ )
874
+ rendered_image = render_text_skia(
875
+ pil_image=pil_cleaned_image,
876
+ text=text,
877
+ bbox=bbox,
878
+ font_dir=font_dir,
879
+ cleaned_mask=cleaned_mask,
880
+ bubble_color_bgr=bubble_color_bgr,
881
+ config=render_config,
882
+ verbose=verbose,
883
+ bubble_id=str(i + 1),
884
+ rotation_deg=forced_stack_rotation,
885
+ vertical_stack=True,
886
+ raise_on_safe_error=False,
887
+ )
888
+ log_message(
889
+ "Vertical-stack fallback succeeded",
890
+ verbose=verbose,
891
+ )
892
+ success = True
893
+ except Exception as e2:
894
+ log_message(
895
+ f"Vertical-stack fallback failed: {e2}",
896
+ verbose=verbose,
897
+ )
898
+ # Restore original OSB patch if available
899
+ if "original_crop_pil" in bubble:
900
+ log_message(
901
+ f"Restoring original OSB patch for {bbox}",
902
+ verbose=verbose,
903
+ always_print=True,
904
+ )
905
+ rendered_image = pil_cleaned_image.copy()
906
+ original_patch = bubble["original_crop_pil"]
907
+ rendered_image.paste(
908
+ original_patch, (bbox[0], bbox[1])
909
+ )
910
+ success = True
911
+ else:
912
+ rendered_image = pil_cleaned_image
913
+ success = False
914
+ else:
915
+ if "original_crop_pil" in bubble:
916
+ log_message(
917
+ f"Restoring original OSB patch for {bbox}",
918
+ verbose=verbose,
919
+ always_print=True,
920
+ )
921
+ rendered_image = pil_cleaned_image.copy()
922
+ original_patch = bubble["original_crop_pil"]
923
+ rendered_image.paste(
924
+ original_patch, (bbox[0], bbox[1])
925
+ )
926
+ success = True
927
+ else:
928
+ rendered_image = pil_cleaned_image
929
+ success = False
930
+ else:
931
+ try:
932
+ rendered_image = render_text_skia(
933
+ pil_image=pil_cleaned_image,
934
+ text=text,
935
+ bbox=bbox,
936
+ font_dir=font_dir,
937
+ cleaned_mask=cleaned_mask,
938
+ bubble_color_bgr=bubble_color_bgr,
939
+ config=render_config,
940
+ verbose=verbose,
941
+ bubble_id=str(i + 1),
942
+ rotation_deg=rotation_deg,
943
+ vertical_stack=vertical_stack,
944
+ raise_on_safe_error=True,
945
+ )
946
+ success = True
947
+ except ImageProcessingError as e:
948
+ safe_area_failed = (
949
+ "Safe area calculation failed" in str(e)
950
+ )
951
+ retry_result = None
952
+ if safe_area_failed and base_mask is not None:
953
+ log_message(
954
+ f"Safe area failed for bubble {bbox}, retrying mask with Otsu",
955
+ verbose=verbose,
956
+ always_print=True,
957
+ )
958
+ retry_result = retry_cleaning_with_otsu(
959
+ original_cv_image,
960
+ {
961
+ "base_mask": base_mask,
962
+ "bbox": bbox,
963
+ "is_sam": is_sam_mask,
964
+ "is_colored": (
965
+ render_info.get("is_colored", False)
966
+ if render_info
967
+ else False
968
+ ),
969
+ "text_bbox": (
970
+ render_info.get("text_bbox")
971
+ if render_info
972
+ else None
973
+ ),
974
+ },
975
+ config.cleaning.thresholding_value,
976
+ config.cleaning.roi_shrink_px,
977
+ processing_scale,
978
+ verbose=verbose,
979
+ classify_colored=(
980
+ config.cleaning.inpaint_colored_bubbles
981
+ ),
982
+ )
983
+
984
+ if (
985
+ retry_result
986
+ and retry_result.get("mask") is not None
987
+ ):
988
+ cleaned_mask = retry_result["mask"]
989
+ bubble_color_bgr = retry_result.get(
990
+ "color", bubble_color_bgr
991
+ )
992
+ base_mask = retry_result.get("base_mask", base_mask)
993
+ if render_info is not None:
994
+ render_info.update(
995
+ {
996
+ "mask": cleaned_mask,
997
+ "color": bubble_color_bgr,
998
+ "base_mask": base_mask,
999
+ "is_colored": retry_result.get(
1000
+ "is_colored",
1001
+ render_info.get(
1002
+ "is_colored", False
1003
+ ),
1004
+ ),
1005
+ "text_bbox": retry_result.get(
1006
+ "text_bbox",
1007
+ render_info.get("text_bbox"),
1008
+ ),
1009
+ }
1010
+ )
1011
+
1012
+ try:
1013
+ rendered_image = render_text_skia(
1014
+ pil_image=pil_cleaned_image,
1015
+ text=text,
1016
+ bbox=bbox,
1017
+ font_dir=font_dir,
1018
+ cleaned_mask=cleaned_mask,
1019
+ bubble_color_bgr=bubble_color_bgr,
1020
+ config=render_config,
1021
+ verbose=verbose,
1022
+ bubble_id=str(i + 1),
1023
+ rotation_deg=rotation_deg,
1024
+ vertical_stack=vertical_stack,
1025
+ raise_on_safe_error=False,
1026
+ )
1027
+ success = True
1028
+ except (
1029
+ RenderingError,
1030
+ FontError,
1031
+ ImageProcessingError,
1032
+ ) as e2:
1033
+ log_message(
1034
+ f"Text rendering failed after Otsu retry: {e2}",
1035
+ verbose=verbose,
1036
+ )
1037
+ rendered_image = pil_cleaned_image
1038
+ success = False
1039
+ if not success:
1040
+ # Final fallback to padded bbox path
1041
+ fallback_msg = (
1042
+ f"Safe area calculation failed for {bbox}, using padded bbox fallback"
1043
+ if safe_area_failed
1044
+ else f"Rendering retry fallback for {bbox}, using padded bbox method"
1045
+ )
1046
+ log_message(
1047
+ fallback_msg,
1048
+ verbose=verbose,
1049
+ )
1050
+ try:
1051
+ rendered_image = render_text_skia(
1052
+ pil_image=pil_cleaned_image,
1053
+ text=text,
1054
+ bbox=bbox,
1055
+ font_dir=font_dir,
1056
+ cleaned_mask=cleaned_mask,
1057
+ bubble_color_bgr=bubble_color_bgr,
1058
+ config=render_config,
1059
+ verbose=verbose,
1060
+ bubble_id=str(i + 1),
1061
+ rotation_deg=rotation_deg,
1062
+ vertical_stack=vertical_stack,
1063
+ raise_on_safe_error=False,
1064
+ )
1065
+ success = True
1066
+ except (RenderingError, FontError) as e2:
1067
+ log_message(
1068
+ f"Text rendering failed: {e2}",
1069
+ verbose=verbose,
1070
+ )
1071
+ rendered_image = pil_cleaned_image
1072
+ success = False
1073
+ except (RenderingError, FontError) as e:
1074
+ log_message(
1075
+ f"Text rendering failed: {e}", verbose=verbose
1076
+ )
1077
+ rendered_image = pil_cleaned_image
1078
+ success = False
1079
+
1080
+ if success:
1081
+ pil_cleaned_image = rendered_image
1082
+ final_image_to_save = pil_cleaned_image
1083
+ else:
1084
+ log_message(
1085
+ f"Failed to render bubble {bbox}", verbose=verbose
1086
+ )
1087
+ else:
1088
+ log_message(
1089
+ f"Warning: Bubble/translation count mismatch "
1090
+ f"({len(sorted_bubble_data)}/{len(translated_texts)})",
1091
+ always_print=True,
1092
+ )
1093
+
1094
+ # Final Image Upscaling (optional)
1095
+ if config.output.upscale_final_image:
1096
+ log_message("Upscaling final image...", verbose=verbose, always_print=True)
1097
+ final_image_to_save = upscale_image(
1098
+ final_image_to_save,
1099
+ config.output.image_upscale_factor,
1100
+ model_type=config.output.image_upscale_model,
1101
+ verbose=verbose,
1102
+ )
1103
+
1104
+ # Save Output
1105
+ if output_path:
1106
+ if final_image_to_save.mode != target_mode:
1107
+ log_message(f"Converting final image to {target_mode}", verbose=verbose)
1108
+ final_image_to_save = final_image_to_save.convert(target_mode)
1109
+
1110
+ try:
1111
+ save_image_with_compression(
1112
+ final_image_to_save,
1113
+ output_path,
1114
+ jpeg_quality=config.output.jpeg_quality,
1115
+ png_compression=config.output.png_compression,
1116
+ verbose=verbose,
1117
+ )
1118
+ except ImageProcessingError as e:
1119
+ log_message(f"Failed to save image: {e}", always_print=True)
1120
+ raise
1121
+
1122
+ end_time = time.time()
1123
+ processing_time = end_time - start_time
1124
+ log_message(f"Processing completed in {processing_time:.2f}s", always_print=True)
1125
+
1126
+ return final_image_to_save
1127
+
1128
+
1129
+ def batch_translate_images(
1130
+ input_dir: Union[str, Path],
1131
+ config: MangaTranslatorConfig,
1132
+ output_dir: Optional[Union[str, Path]] = None,
1133
+ progress_callback: Optional[Callable[[float, str], None]] = None,
1134
+ preserve_structure: bool = False,
1135
+ cancellation_manager: Optional["CancellationManager"] = None,
1136
+ ) -> Dict[str, Any]:
1137
+ """
1138
+ Process all images in a directory using a configuration object.
1139
+
1140
+ Args:
1141
+ input_dir (str or Path): Directory containing images to process
1142
+ config (MangaTranslatorConfig): Configuration object containing all settings.
1143
+ output_dir (str or Path, optional): Directory to save translated images.
1144
+ If None, uses input_dir / "output_translated".
1145
+ progress_callback (callable, optional): Function to call with progress updates (0.0-1.0, message).
1146
+ preserve_structure (bool): If True, recursively process subdirectories and preserve folder structure
1147
+ in the output. If False, only processes files in the root directory.
1148
+
1149
+ Returns:
1150
+ dict: Processing results with keys:
1151
+ - "success_count": Number of successfully processed images
1152
+ - "error_count": Number of images that failed to process
1153
+ - "errors": Dictionary mapping filenames to error messages
1154
+ """
1155
+ input_dir = Path(input_dir)
1156
+ if not input_dir.is_dir():
1157
+ log_message(f"Input path '{input_dir}' is not a directory", always_print=True)
1158
+ return {"success_count": 0, "error_count": 0, "errors": {}}
1159
+
1160
+ if output_dir:
1161
+ output_dir = Path(output_dir)
1162
+ else:
1163
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
1164
+ output_dir = Path("./output") / timestamp
1165
+
1166
+ os.makedirs(output_dir, exist_ok=True)
1167
+
1168
+ image_extensions = [".jpg", ".jpeg", ".png", ".webp"]
1169
+
1170
+ if preserve_structure:
1171
+ # Recursively find all image files preserving directory structure
1172
+ image_files = []
1173
+ for root, dirs, files in os.walk(input_dir):
1174
+ for file in files:
1175
+ file_path = Path(root) / file
1176
+ if file_path.suffix.lower() in image_extensions:
1177
+ image_files.append(file_path)
1178
+ else:
1179
+ image_files = [
1180
+ f
1181
+ for f in input_dir.iterdir()
1182
+ if f.is_file() and f.suffix.lower() in image_extensions
1183
+ ]
1184
+
1185
+ if not image_files:
1186
+ log_message(f"No image files found in '{input_dir}'", always_print=True)
1187
+ return {"success_count": 0, "error_count": 0, "errors": {}}
1188
+
1189
+ results = {"success_count": 0, "error_count": 0, "errors": {}}
1190
+
1191
+ total_images = len(image_files)
1192
+ start_batch_time = time.time()
1193
+
1194
+ log_message(f"Starting batch processing: {total_images} images", always_print=True)
1195
+
1196
+ if progress_callback:
1197
+ progress_callback(0.0, f"Starting batch processing of {total_images} images...")
1198
+
1199
+ for i, img_path in enumerate(image_files):
1200
+ try:
1201
+ # Calculate relative path from input directory for structure preservation
1202
+ if preserve_structure:
1203
+ relative_path = img_path.relative_to(input_dir)
1204
+ # Create output subdirectory structure
1205
+ output_subdir = output_dir / relative_path.parent
1206
+ os.makedirs(output_subdir, exist_ok=True)
1207
+ # Use relative path for output filename
1208
+ output_filename = f"{relative_path.stem}_translated"
1209
+ display_path = str(relative_path)
1210
+ error_key = str(relative_path)
1211
+ else:
1212
+ output_subdir = output_dir
1213
+ output_filename = f"{img_path.stem}_translated"
1214
+ display_path = img_path.name
1215
+ error_key = img_path.name
1216
+
1217
+ if cancellation_manager and cancellation_manager.is_cancelled():
1218
+ raise CancellationError("Batch process cancelled by user.")
1219
+
1220
+ if progress_callback:
1221
+ current_progress = i / total_images
1222
+ progress_callback(
1223
+ current_progress,
1224
+ f"Processing image {i + 1}/{total_images}: {display_path}",
1225
+ )
1226
+
1227
+ original_ext = img_path.suffix.lower()
1228
+ desired_format = config.output.output_format
1229
+ if desired_format == "jpeg":
1230
+ output_ext = ".jpg"
1231
+ elif desired_format == "png":
1232
+ output_ext = ".png"
1233
+ elif desired_format == "auto":
1234
+ output_ext = original_ext
1235
+ else:
1236
+ output_ext = original_ext
1237
+ log_message(
1238
+ f"Warning: Invalid output_format '{desired_format}' in config. "
1239
+ f"Using original extension '{original_ext}'.",
1240
+ always_print=True,
1241
+ )
1242
+
1243
+ output_path = output_subdir / f"{output_filename}{output_ext}"
1244
+ log_message(
1245
+ f"Processing {i + 1}/{total_images}: {display_path}", always_print=True
1246
+ )
1247
+
1248
+ translate_and_render(
1249
+ img_path, config, output_path, cancellation_manager=cancellation_manager
1250
+ )
1251
+
1252
+ results["success_count"] += 1
1253
+
1254
+ if progress_callback:
1255
+ completed_progress = (i + 1) / total_images
1256
+ progress_callback(
1257
+ completed_progress, f"Completed {i + 1}/{total_images} images"
1258
+ )
1259
+
1260
+ except CancellationError:
1261
+ log_message(
1262
+ f"Batch cancelled during processing of {display_path}",
1263
+ verbose=config.verbose,
1264
+ )
1265
+ raise
1266
+ except Exception as e:
1267
+ log_message(f"Error processing {display_path}: {str(e)}", always_print=True)
1268
+ results["error_count"] += 1
1269
+ results["errors"][error_key] = str(e)
1270
+
1271
+ if progress_callback:
1272
+ completed_progress = (i + 1) / total_images
1273
+ progress_callback(
1274
+ completed_progress,
1275
+ f"Completed {i + 1}/{total_images} images (with errors)",
1276
+ )
1277
+
1278
+ if progress_callback:
1279
+ progress_callback(1.0, "Processing complete")
1280
+
1281
+ end_batch_time = time.time()
1282
+ total_batch_time = end_batch_time - start_batch_time
1283
+ seconds_per_image = total_batch_time / total_images if total_images > 0 else 0
1284
+
1285
+ log_message(
1286
+ f"Batch complete: {results['success_count']}/{total_images} images in "
1287
+ f"{total_batch_time:.2f}s ({seconds_per_image:.2f}s/image)",
1288
+ always_print=True,
1289
+ )
1290
+ if results["error_count"] > 0:
1291
+ log_message(f"Failed: {results['error_count']} images", always_print=True)
1292
+ for filename, error_msg in results["errors"].items():
1293
+ log_message(f" - {filename}: {error_msg}", always_print=True)
1294
+
1295
+ return results
core/scaling.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+
4
+ def _normalize_scale(scale: Optional[float]) -> float:
5
+ if scale is None or scale <= 0:
6
+ return 1.0
7
+ return float(scale)
8
+
9
+
10
+ def _clamp(value: float, minimum: Optional[float], maximum: Optional[float]) -> float:
11
+ if minimum is not None:
12
+ value = max(minimum, value)
13
+ if maximum is not None:
14
+ value = min(maximum, value)
15
+ return value
16
+
17
+
18
+ def scale_scalar(
19
+ value: float,
20
+ scale: Optional[float],
21
+ *,
22
+ minimum: Optional[float] = None,
23
+ maximum: Optional[float] = None,
24
+ ) -> float:
25
+ """
26
+ Scale an arbitrary scalar (float) value by the processing scale.
27
+ """
28
+ effective_scale = _normalize_scale(scale)
29
+ scaled = value * effective_scale
30
+ return _clamp(scaled, minimum, maximum)
31
+
32
+
33
+ def scale_length(
34
+ value: float,
35
+ scale: Optional[float],
36
+ *,
37
+ minimum: Optional[float] = 1.0,
38
+ maximum: Optional[float] = None,
39
+ ) -> int:
40
+ """
41
+ Scale a pixel length and return an int with rounding and clamping.
42
+ """
43
+ scaled = scale_scalar(value, scale, minimum=minimum, maximum=maximum)
44
+ # Round to nearest integer for pixel units
45
+ return max(1, int(round(scaled)))
46
+
47
+
48
+ def scale_area(
49
+ value: float,
50
+ scale: Optional[float],
51
+ *,
52
+ minimum: Optional[float] = 1.0,
53
+ maximum: Optional[float] = None,
54
+ ) -> int:
55
+ """
56
+ Scale an area-like value (square pixels). Uses scale^2.
57
+ """
58
+ effective_scale = _normalize_scale(scale)
59
+ scaled = value * (effective_scale * effective_scale)
60
+ scaled = _clamp(scaled, minimum, maximum)
61
+ return max(1, int(round(scaled)))
62
+
63
+
64
+ def scale_kernel(
65
+ kernel: Tuple[int, int],
66
+ scale: Optional[float],
67
+ *,
68
+ minimum: int = 1,
69
+ maximum: int = 63,
70
+ ) -> Tuple[int, int]:
71
+ """
72
+ Scale a 2D kernel size while ensuring odd dimensions (required for many morphology ops).
73
+ """
74
+ width, height = kernel
75
+ effective_scale = _normalize_scale(scale)
76
+
77
+ def _scale_dimension(base: int) -> int:
78
+ dimension = scale_scalar(
79
+ base,
80
+ effective_scale,
81
+ minimum=float(minimum),
82
+ maximum=float(maximum),
83
+ )
84
+ dim_int = max(minimum, int(round(dimension)))
85
+ # Ensure result stays within bounds
86
+ dim_int = min(maximum, dim_int)
87
+ if dim_int % 2 == 0:
88
+ # Prefer rounding up to keep padding generous, but clamp again
89
+ dim_int = min(maximum, dim_int + 1)
90
+ if dim_int % 2 == 0:
91
+ dim_int = max(minimum, dim_int - 1)
92
+ if dim_int % 2 == 0:
93
+ dim_int = max(minimum, dim_int + 1)
94
+ return max(minimum, dim_int)
95
+
96
+ return (_scale_dimension(width), _scale_dimension(height))
97
+
98
+
99
+ def scale_font_size(
100
+ value: float,
101
+ scale: Optional[float],
102
+ *,
103
+ minimum: int = 4,
104
+ maximum: int = 256,
105
+ ) -> int:
106
+ """
107
+ Scale a font size (int) using linear scaling with clamping.
108
+ """
109
+ return scale_length(value, scale, minimum=minimum, maximum=maximum)
core/services/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ External service integration modules for MangaTranslator.
3
+
4
+ This subpackage contains modules for:
5
+ - Translation API calls to various LLM providers
6
+ - External service communication
7
+ """
8
+
9
+ from core.image.sorting import sort_bubbles_by_reading_order
10
+
11
+ from .translation import (
12
+ call_translation_api_batch,
13
+ prepare_bubble_images_for_translation,
14
+ )
15
+
16
+ __all__ = [
17
+ "call_translation_api_batch",
18
+ "prepare_bubble_images_for_translation",
19
+ "sort_bubbles_by_reading_order",
20
+ ]
core/services/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (686 Bytes). View file
 
core/services/__pycache__/translation.cpython-311.pyc ADDED
Binary file (49.8 kB). View file
 
core/services/translation.py ADDED
@@ -0,0 +1,1385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ from io import BytesIO
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from core.caching import get_cache
11
+ from core.config import TranslationConfig, calculate_reasoning_budget
12
+ from core.image.image_utils import cv2_to_pil, pil_to_cv2, process_bubble_image_cached
13
+ from core.image.ocr_detection import extract_text_with_manga_ocr
14
+ from utils.endpoints import (
15
+ call_anthropic_endpoint,
16
+ call_deepseek_endpoint,
17
+ call_gemini_endpoint,
18
+ call_moonshot_endpoint,
19
+ call_openai_compatible_endpoint,
20
+ call_openai_endpoint,
21
+ call_openrouter_endpoint,
22
+ call_xai_endpoint,
23
+ call_zai_endpoint,
24
+ openrouter_is_reasoning_model,
25
+ )
26
+ from utils.exceptions import TranslationError
27
+ from utils.logging import log_message
28
+ from utils.model_metadata import (
29
+ get_max_tokens_cap,
30
+ is_deepseek_reasoning_model,
31
+ is_openai_compatible_reasoning_model,
32
+ is_opus_45_model,
33
+ is_xai_reasoning_model,
34
+ is_zai_reasoning_model,
35
+ )
36
+
37
+ TRANSLATION_PATTERN = re.compile(
38
+ r'^\s*(\d+)\s*:\s*"?\s*(.*?)\s*"?\s*(?=\s*\n\s*\d+\s*:|\s*$)',
39
+ re.MULTILINE | re.DOTALL,
40
+ )
41
+
42
+
43
+ def _build_system_prompt_ocr(
44
+ input_language: Optional[str],
45
+ reading_direction: str,
46
+ ) -> str:
47
+ lang_label = f"{input_language} " if input_language else ""
48
+ direction = (
49
+ "right-to-left"
50
+ if (reading_direction or "rtl").lower() == "rtl"
51
+ else "left-to-right"
52
+ )
53
+
54
+ return f"""
55
+ ## ROLE
56
+ You are an expert manga OCR transcriber.
57
+
58
+ ## OBJECTIVE
59
+ Your sole purpose is to accurately transcribe the original text from a series of provided images. You must not translate, interpret, or add commentary.
60
+
61
+ ## CORE RULES
62
+ - **Reading Context:** The image crops are presented in a {direction} reading order. Do not reorder them.
63
+ - **Transcription Policy:** Preserve all original punctuation, ellipses, and casing. Collapse multi-line text into a single line, separated by a single space.
64
+ - **Ignore Policy:** You must ignore image borders, speech bubble tails, watermarks, page numbers, and any decorative elements outside the text itself.
65
+ - **Language Focus:** Transcribe only the original {lang_label}text.
66
+ - **Ruby/Furigana Policy:** If small phonetic characters (ruby/furigana) are present, you must ignore them and transcribe only the main, larger base text.
67
+ - **Visual Emphasis Policy:** If the source text is visually emphasized (bold, slanted, etc.), you must mirror that emphasis in your transcription using markdown-style markers: `*italic*` for slanted text, `**bold**` for bold text, `***bold-italic***` for both.
68
+ - **Edge Cases:**
69
+ - If an image contains standalone periods/ellipses, you must return it exactly as it appears.
70
+ - If text is indecipherable, you must return the exact token: `[OCR FAILED]`.
71
+
72
+ ## OUTPUT SCHEMA
73
+ - You must return your response as a single numbered list with exactly one line per input image.
74
+ - The numbering must correspond to the input image order (1, 2, 3...).
75
+ - The format must be `i: <transcribed {lang_label}text>` where `i` is the input image number.
76
+ - Do not include section headers, explanations, or formatting outside of this list.
77
+ """ # noqa
78
+
79
+
80
+ def _build_system_prompt_translation(
81
+ output_language: str,
82
+ mode: str,
83
+ reading_direction: str,
84
+ full_page_context: bool = False,
85
+ ) -> str:
86
+ direction = (
87
+ "right-to-left"
88
+ if (reading_direction or "rtl").lower() == "rtl"
89
+ else "left-to-right"
90
+ )
91
+ input_type = "transcriptions" if mode == "two-step" else "image crops"
92
+
93
+ cohesion_visual = (
94
+ " Refer to the full-page image to resolve ambiguous context."
95
+ if full_page_context
96
+ else ""
97
+ )
98
+
99
+ if mode == "two-step":
100
+ edge_cases = """- **Edge Cases:**
101
+ - If an input line contains standalone periods/ellipses, you must return it exactly as it appears.
102
+ - If an input line is the exact token `[OCR FAILED]`, you must output it unchanged."""
103
+ else:
104
+ edge_cases = """- **Edge Cases:**
105
+ - If an image contains standalone periods/ellipses, you must return it exactly as it appears.
106
+ - If text is indecipherable, you must return the exact token: `[OCR FAILED]`."""
107
+
108
+ core_rules = f"""
109
+ ## CORE RULES
110
+ - **Reading Context:** The {input_type} are presented in a {direction} reading order. Do not reorder them.
111
+ - **Cohesion:** Treat the input lines as a continuous narrative. Ensure the translation flows logically and naturally as a cohesive whole.{cohesion_visual}
112
+ - **Fidelity:** Focus on intent; translate functionally rather than literally.
113
+ - **Conciseness:** Keep translations idiomatic and concise.
114
+ - **Emphasis:** If the source text is visually emphasized (bold, slanted, etc.), mirror that emphasis using the STYLING GUIDE.
115
+ - **Punctuation:** Replace ellipses (e.g., "…") with consecutive periods (e.g., "...").
116
+ - **Text Types:**
117
+ - **Spoken Dialogue/Internal Monologue:** Translate naturally, matching the character's personality.
118
+ - **Narration:** Translate neutrally without special styling.
119
+ - **Audible SFX:** Translate physical sounds (Giongo) as standard onomatopoeia.
120
+ - **Mimetic FX:** Translate atmospheric text (Gitaigo) or silent actions as descriptive verbs or adjectives. Do not add a period at the end of the word.
121
+ {edge_cases}
122
+ """ # noqa
123
+
124
+ shared_components = f"""
125
+ ## ROLE
126
+ You are a professional manga localization translator and editor.
127
+
128
+ ## OBJECTIVE
129
+ Your goal is to produce natural-sounding, high-quality translations in {output_language} that are faithful to the original source's meaning, tone, and visual emphasis.
130
+
131
+ ## STYLING GUIDE
132
+ You must use the following markdown-style markers to convey emphasis:
133
+ - `*italic*`: Used for onomatopoeias, thoughts, flashbacks, distant sounds, or dialogue mediated by a device (e.g., phone, radio).
134
+ - `**bold**`: Used for sound effects (SFX), shouting, timestamps, or individual emphatic words.
135
+ - `***bold-italic***`: Used for extremely loud sounds or dialogue that also meets the criteria for italics (e.g., shouting over a radio).
136
+
137
+ {core_rules}
138
+ """ # noqa
139
+
140
+ if mode == "one-step":
141
+ output_schema = f"""
142
+ ## OUTPUT SCHEMA
143
+ - You must return your response as a single numbered list with exactly one line per input image.
144
+ - The numbering must correspond to the input image order (1, 2, 3...).
145
+ - For each item, provide both transcription and translation in the format:
146
+ `i: <transcribed text> || <translated {output_language} text>` where `i` is the input image number.
147
+ - Do not include section headers, explanations, or formatting outside of this list.
148
+ """
149
+ elif mode == "two-step":
150
+ output_schema = f"""
151
+ ## OUTPUT SCHEMA
152
+ - You must return your response as a single numbered list with exactly one line per input text.
153
+ - The numbering must correspond to the input order (1, 2, 3...).
154
+ - The format must be `i: <translated {output_language} text>` where `i` is the input text number.
155
+ - Do not include section headers, explanations, or formatting outside of this list.
156
+ """ # noqa
157
+ else:
158
+ raise ValueError(
159
+ f"Invalid mode '{mode}' specified for translation system prompt."
160
+ )
161
+
162
+ return shared_components + output_schema
163
+
164
+
165
+ def _is_reasoning_model_google(model_name: str) -> bool:
166
+ """Check if a Google model is reasoning-capable."""
167
+ name = model_name or ""
168
+ return (
169
+ name.startswith("gemini-2.5")
170
+ or "gemini-2.5" in name
171
+ or "gemini-3" in name.lower()
172
+ )
173
+
174
+
175
+ def _is_reasoning_model_openai(model_name: str) -> bool:
176
+ """Check if an OpenAI model is reasoning-capable."""
177
+ lm = (model_name or "").lower()
178
+ return (
179
+ lm.startswith("gpt-5")
180
+ or lm.startswith("o1")
181
+ or lm.startswith("o3")
182
+ or lm.startswith("o4-mini")
183
+ )
184
+
185
+
186
+ def _is_reasoning_model_anthropic(model_name: str) -> bool:
187
+ """Check if an Anthropic model is reasoning-capable."""
188
+ lm = (model_name or "").lower()
189
+ reasoning_prefixes = [
190
+ "claude-opus-4",
191
+ "claude-sonnet-4",
192
+ "claude-haiku-4-5",
193
+ "claude-3-7-sonnet",
194
+ ]
195
+ return any(lm.startswith(p) for p in reasoning_prefixes)
196
+
197
+
198
+ def _add_media_resolution_to_part(
199
+ part: Dict[str, Any],
200
+ media_resolution_ui: str,
201
+ is_gemini_3: bool,
202
+ ) -> Dict[str, Any]:
203
+ """
204
+ Add media_resolution to an inline_data part for Gemini 3 models.
205
+
206
+ Args:
207
+ part: Part dictionary with inline_data
208
+ media_resolution_ui: UI format media resolution ("auto"/"high"/"medium"/"low")
209
+ is_gemini_3: Whether the model is Gemini 3
210
+
211
+ Returns:
212
+ Part dictionary with media_resolution added if Gemini 3, otherwise unchanged
213
+ """
214
+ if not is_gemini_3 or "inline_data" not in part:
215
+ return part
216
+
217
+ media_resolution_mapping = {
218
+ "auto": "MEDIA_RESOLUTION_UNSPECIFIED",
219
+ "high": "MEDIA_RESOLUTION_HIGH",
220
+ "medium": "MEDIA_RESOLUTION_MEDIUM",
221
+ "low": "MEDIA_RESOLUTION_LOW",
222
+ }
223
+ backend_media_resolution = media_resolution_mapping.get(
224
+ media_resolution_ui.lower(), "MEDIA_RESOLUTION_UNSPECIFIED"
225
+ )
226
+
227
+ result = part.copy()
228
+ result["media_resolution"] = {"level": backend_media_resolution}
229
+ return result
230
+
231
+
232
+ def _build_generation_config(
233
+ provider: str,
234
+ model_name: str,
235
+ config: TranslationConfig,
236
+ debug: bool = False,
237
+ ) -> Dict[str, Any]:
238
+ """
239
+ Build provider-specific generation config dictionary.
240
+
241
+ Centralizes logic for:
242
+ - Base parameters (temperature, top_p, top_k)
243
+ - Provider-specific parameter names and constraints
244
+ - Reasoning model detection and token limits
245
+ - Special features (thinking, reasoning_effort, etc.)
246
+
247
+ Args:
248
+ provider: Provider name (Google, OpenAI, Anthropic, xAI, OpenRouter, OpenAI-Compatible)
249
+ model_name: Model identifier
250
+ config: TranslationConfig with all settings
251
+ debug: Whether to log debug messages
252
+
253
+ Returns:
254
+ Dictionary with generation config parameters for the specific provider
255
+ """
256
+ temperature = config.temperature
257
+ top_p = config.top_p
258
+ top_k = config.top_k
259
+
260
+ if config.max_tokens is not None:
261
+ max_tokens_value = config.max_tokens
262
+ else:
263
+ is_reasoning = False
264
+ if provider == "Google":
265
+ is_reasoning = _is_reasoning_model_google(model_name)
266
+ elif provider == "OpenAI":
267
+ is_reasoning = _is_reasoning_model_openai(model_name)
268
+ elif provider == "Anthropic":
269
+ is_reasoning = _is_reasoning_model_anthropic(model_name)
270
+ elif provider == "xAI":
271
+ is_reasoning = is_xai_reasoning_model(model_name)
272
+ elif provider == "OpenRouter":
273
+ is_reasoning = openrouter_is_reasoning_model(model_name, debug)
274
+ elif provider == "OpenAI-Compatible":
275
+ is_reasoning = is_openai_compatible_reasoning_model(model_name)
276
+ elif provider == "DeepSeek":
277
+ is_reasoning = is_deepseek_reasoning_model(model_name)
278
+ elif provider == "Z.ai":
279
+ is_reasoning = is_zai_reasoning_model(model_name)
280
+ max_tokens_value = 16384 if is_reasoning else 4096
281
+
282
+ max_tokens_cap = get_max_tokens_cap(provider, model_name)
283
+ if max_tokens_cap is not None and max_tokens_value > max_tokens_cap:
284
+ max_tokens_value = max_tokens_cap
285
+
286
+ if provider == "Google":
287
+ is_gemini_3 = "gemini-3" in model_name.lower()
288
+ generation_config = {
289
+ "temperature": temperature,
290
+ "topP": top_p,
291
+ "topK": top_k,
292
+ "maxOutputTokens": max_tokens_value,
293
+ }
294
+ if not is_gemini_3:
295
+ media_resolution_mapping = {
296
+ "auto": "MEDIA_RESOLUTION_UNSPECIFIED",
297
+ "high": "MEDIA_RESOLUTION_HIGH",
298
+ "medium": "MEDIA_RESOLUTION_MEDIUM",
299
+ "low": "MEDIA_RESOLUTION_LOW",
300
+ }
301
+ backend_media_resolution = media_resolution_mapping.get(
302
+ config.media_resolution.lower(), "MEDIA_RESOLUTION_UNSPECIFIED"
303
+ )
304
+ generation_config["media_resolution"] = backend_media_resolution
305
+ if is_gemini_3:
306
+ reasoning_effort = config.reasoning_effort or "high"
307
+ generation_config["thinkingConfig"] = {"thinkingLevel": reasoning_effort}
308
+ log_message(
309
+ f"Using reasoning effort '{reasoning_effort}' for {model_name}",
310
+ verbose=debug,
311
+ )
312
+ elif _is_reasoning_model_google(model_name) and not is_gemini_3:
313
+ reasoning_effort = config.reasoning_effort or "auto"
314
+ is_flash = "gemini-2.5-flash" in model_name.lower()
315
+ is_pro = "gemini-2.5-pro" in model_name.lower()
316
+ if reasoning_effort == "none":
317
+ if is_flash:
318
+ generation_config["thinkingConfig"] = {"thinkingBudget": 0}
319
+ log_message(f"Disabled reasoning for {model_name}", verbose=debug)
320
+ elif is_pro:
321
+ generation_config["thinkingConfig"] = {"thinkingBudget": 128}
322
+ log_message(
323
+ f"Using 'none' reasoning effort (thinkingBudget: 128) for {model_name}",
324
+ verbose=debug,
325
+ )
326
+ else:
327
+ log_message(
328
+ f"Warning: 'none' not supported for {model_name}, using 'auto'",
329
+ verbose=debug,
330
+ )
331
+ elif reasoning_effort == "auto":
332
+ log_message(
333
+ f"Using auto reasoning allocation for {model_name}", verbose=debug
334
+ )
335
+ else:
336
+ thinking_budget = calculate_reasoning_budget(
337
+ max_tokens_value, reasoning_effort
338
+ )
339
+ generation_config["thinkingConfig"] = {
340
+ "thinkingBudget": thinking_budget
341
+ }
342
+ log_message(
343
+ f"Using reasoning effort '{reasoning_effort}' (budget: {thinking_budget} tokens) for {model_name}",
344
+ verbose=debug,
345
+ )
346
+ return generation_config
347
+
348
+ elif provider == "OpenAI":
349
+ generation_config = {
350
+ "temperature": temperature,
351
+ "top_p": top_p,
352
+ "max_output_tokens": max_tokens_value,
353
+ } # top_k not supported by OpenAI
354
+ if config.reasoning_effort:
355
+ lm = (model_name or "").lower()
356
+ is_chat_variant = "chat" in lm
357
+ is_gpt5_1 = lm.startswith("gpt-5.1")
358
+ is_gpt5_2 = lm.startswith("gpt-5.2")
359
+ effort = config.reasoning_effort
360
+ if effort == "xhigh" and not is_gpt5_2:
361
+ effort = "high"
362
+ if not is_chat_variant and (is_gpt5_1 or is_gpt5_2 or effort != "none"):
363
+ generation_config["reasoning_effort"] = effort
364
+ return generation_config
365
+
366
+ elif provider == "Anthropic":
367
+ is_reasoning = _is_reasoning_model_anthropic(model_name)
368
+ is_opus_45 = is_opus_45_model(model_name)
369
+ clamped_temp = min(temperature, 1.0) # Anthropic caps at 1.0
370
+ generation_config = {
371
+ "temperature": clamped_temp,
372
+ "top_p": top_p,
373
+ "top_k": top_k,
374
+ "max_tokens": max_tokens_value,
375
+ }
376
+ if is_reasoning:
377
+ generation_config["reasoning_effort"] = config.reasoning_effort or "none"
378
+ if is_opus_45 and config.effort:
379
+ generation_config["effort"] = config.effort
380
+ return generation_config
381
+
382
+ elif provider == "xAI":
383
+ is_reasoning = is_xai_reasoning_model(model_name)
384
+ generation_config = {
385
+ "temperature": temperature,
386
+ "top_p": top_p,
387
+ "max_tokens": max_tokens_value,
388
+ }
389
+ if is_reasoning:
390
+ generation_config["reasoning_effort"] = config.reasoning_effort or "high"
391
+ return generation_config
392
+
393
+ elif provider == "DeepSeek":
394
+ is_reasoning = is_deepseek_reasoning_model(model_name)
395
+ generation_config = {
396
+ "temperature": temperature,
397
+ "top_p": top_p,
398
+ "max_tokens": max_tokens_value,
399
+ }
400
+ return generation_config
401
+
402
+ elif provider == "Z.ai":
403
+ is_reasoning = is_zai_reasoning_model(model_name)
404
+ generation_config = {
405
+ "temperature": temperature,
406
+ "top_p": top_p,
407
+ "top_k": top_k,
408
+ "max_tokens": max_tokens_value,
409
+ }
410
+ if is_reasoning:
411
+ # Z.ai uses thinking parameter with {"type": "enabled"} or {"type": "disabled"}
412
+ # Map reasoning_effort: "high" -> enabled, "none" -> disabled
413
+ reasoning_effort = config.reasoning_effort or "high"
414
+ thinking_type = "enabled" if reasoning_effort == "high" else "disabled"
415
+ generation_config["thinking"] = {"type": thinking_type}
416
+ return generation_config
417
+
418
+ elif provider == "Moonshot AI":
419
+ # Moonshot AI is text-only, reasoning models have always-on reasoning
420
+ generation_config = {
421
+ "temperature": min(temperature, 1.0), # Moonshot caps at 1.0
422
+ "top_p": top_p,
423
+ "max_tokens": max_tokens_value,
424
+ }
425
+ return generation_config
426
+
427
+ elif provider == "OpenRouter":
428
+ model_lower = (model_name or "").lower()
429
+ is_openai_model = "openai/" in model_lower or model_lower.startswith("gpt-")
430
+ is_anthropic_model = "anthropic/" in model_lower or model_lower.startswith(
431
+ "claude-"
432
+ )
433
+ is_grok_model = "grok-4" in model_lower
434
+ is_gemini_3 = "gemini-3" in model_lower
435
+
436
+ generation_config = {
437
+ "temperature": temperature,
438
+ "top_p": top_p,
439
+ "top_k": top_k,
440
+ "max_tokens": max_tokens_value,
441
+ }
442
+
443
+ is_openai_reasoning = is_openai_model and (
444
+ "gpt-5" in model_lower
445
+ or "o1" in model_lower
446
+ or "o3" in model_lower
447
+ or "o4-mini" in model_lower
448
+ )
449
+ is_gpt5_1 = is_openai_model and "gpt-5.1" in model_lower
450
+ is_gpt5 = is_openai_model and "gpt-5" in model_lower and not is_gpt5_1
451
+ # For OpenRouter, Anthropic models use dots (4.5) not hyphens (4-5)
452
+ # Claude 3.7 Sonnet :thinking variant is reasoning-capable, non-thinking is not
453
+ is_claude_37_sonnet_thinking = (
454
+ is_anthropic_model
455
+ and "claude-3.7-sonnet" in model_lower
456
+ and ":thinking" in model_lower
457
+ )
458
+ is_anthropic_reasoning = is_anthropic_model and (
459
+ "claude-opus-4" in model_lower
460
+ or "claude-sonnet-4" in model_lower
461
+ or "claude-haiku-4.5" in model_lower
462
+ or is_claude_37_sonnet_thinking
463
+ )
464
+ # For OpenRouter, Grok models don't have "reasoning" in the name (e.g., "grok-4.1-fast")
465
+ is_grok_reasoning = is_grok_model and "non-reasoning" not in model_lower
466
+
467
+ # Add metadata flags for OpenRouter endpoint to avoid re-parsing model names
468
+ generation_config["_metadata"] = {
469
+ "is_openai_model": is_openai_model,
470
+ "is_anthropic_model": is_anthropic_model,
471
+ "is_grok_model": is_grok_model,
472
+ "is_gemini_3": is_gemini_3,
473
+ "is_google_model": "google/" in model_lower or "gemini" in model_lower,
474
+ "is_openai_reasoning": is_openai_reasoning,
475
+ "is_anthropic_reasoning": is_anthropic_reasoning,
476
+ "is_grok_reasoning": is_grok_reasoning,
477
+ "is_claude_37_sonnet_thinking": is_claude_37_sonnet_thinking,
478
+ "is_gpt5_1": is_gpt5_1,
479
+ "is_gpt5": is_gpt5,
480
+ }
481
+
482
+ if is_openai_reasoning or is_anthropic_reasoning or is_grok_reasoning:
483
+ if is_anthropic_reasoning:
484
+ reasoning_effort = config.reasoning_effort or "none"
485
+ generation_config["reasoning_effort"] = reasoning_effort
486
+ elif is_gpt5_1:
487
+ generation_config["reasoning_effort"] = config.reasoning_effort
488
+ elif config.reasoning_effort and config.reasoning_effort != "none":
489
+ generation_config["reasoning_effort"] = config.reasoning_effort
490
+ elif "gemini" in model_lower or "google/" in model_lower:
491
+ if config.reasoning_effort:
492
+ generation_config["reasoning_effort"] = config.reasoning_effort
493
+
494
+ return generation_config
495
+
496
+ elif provider == "OpenAI-Compatible":
497
+ return {
498
+ "temperature": temperature,
499
+ "top_p": top_p,
500
+ "top_k": top_k,
501
+ "max_tokens": max_tokens_value,
502
+ }
503
+
504
+ else:
505
+ raise TranslationError(f"Unknown provider for generation config: {provider}")
506
+
507
+
508
+ def _call_llm_endpoint(
509
+ config: TranslationConfig,
510
+ parts: List[Dict[str, Any]],
511
+ prompt_text: str,
512
+ debug: bool = False,
513
+ system_prompt: Optional[str] = None,
514
+ ) -> Optional[str]:
515
+ """Internal helper to dispatch API calls based on provider."""
516
+ provider = config.provider
517
+ model_name = config.model_name
518
+ api_parts = parts + [{"text": prompt_text}]
519
+
520
+ try:
521
+ if provider == "Google":
522
+ api_key = config.google_api_key
523
+ if not api_key:
524
+ raise TranslationError("Google API key is missing.")
525
+ generation_config = _build_generation_config(
526
+ provider, model_name, config, debug
527
+ )
528
+ return call_gemini_endpoint(
529
+ api_key=api_key,
530
+ model_name=model_name,
531
+ parts=api_parts,
532
+ generation_config=generation_config,
533
+ system_prompt=system_prompt,
534
+ debug=debug,
535
+ enable_web_search=config.enable_web_search,
536
+ )
537
+ elif provider == "OpenAI":
538
+ api_key = config.openai_api_key
539
+ if not api_key:
540
+ raise TranslationError("OpenAI API key is missing.")
541
+ generation_config = _build_generation_config(
542
+ provider, model_name, config, debug
543
+ )
544
+ return call_openai_endpoint(
545
+ api_key=api_key,
546
+ model_name=model_name,
547
+ parts=api_parts,
548
+ generation_config=generation_config,
549
+ system_prompt=system_prompt,
550
+ debug=debug,
551
+ enable_web_search=config.enable_web_search,
552
+ )
553
+ elif provider == "Anthropic":
554
+ api_key = config.anthropic_api_key
555
+ if not api_key:
556
+ raise TranslationError("Anthropic API key is missing.")
557
+ generation_config = _build_generation_config(
558
+ provider, model_name, config, debug
559
+ )
560
+ return call_anthropic_endpoint(
561
+ api_key=api_key,
562
+ model_name=model_name,
563
+ parts=api_parts,
564
+ generation_config=generation_config,
565
+ system_prompt=system_prompt,
566
+ debug=debug,
567
+ enable_web_search=config.enable_web_search,
568
+ )
569
+ elif provider == "xAI":
570
+ api_key = config.xai_api_key
571
+ if not api_key:
572
+ raise TranslationError("xAI API key is missing.")
573
+ generation_config = _build_generation_config(
574
+ provider, model_name, config, debug
575
+ )
576
+ return call_xai_endpoint(
577
+ api_key=api_key,
578
+ model_name=model_name,
579
+ parts=api_parts,
580
+ generation_config=generation_config,
581
+ system_prompt=system_prompt,
582
+ debug=debug,
583
+ enable_web_search=config.enable_web_search,
584
+ )
585
+ elif provider == "DeepSeek":
586
+ api_key = config.deepseek_api_key
587
+ if not api_key:
588
+ raise TranslationError("DeepSeek API key is missing.")
589
+ generation_config = _build_generation_config(
590
+ provider, model_name, config, debug
591
+ )
592
+ return call_deepseek_endpoint(
593
+ api_key=api_key,
594
+ model_name=model_name,
595
+ parts=api_parts,
596
+ generation_config=generation_config,
597
+ system_prompt=system_prompt,
598
+ debug=debug,
599
+ )
600
+ elif provider == "Z.ai":
601
+ api_key = config.zai_api_key
602
+ if not api_key:
603
+ raise TranslationError("Z.ai API key is missing.")
604
+ generation_config = _build_generation_config(
605
+ provider, model_name, config, debug
606
+ )
607
+ return call_zai_endpoint(
608
+ api_key=api_key,
609
+ model_name=model_name,
610
+ parts=api_parts,
611
+ generation_config=generation_config,
612
+ system_prompt=system_prompt,
613
+ debug=debug,
614
+ enable_web_search=config.enable_web_search,
615
+ )
616
+ elif provider == "Moonshot AI":
617
+ api_key = config.moonshot_api_key
618
+ if not api_key:
619
+ raise TranslationError("Moonshot API key is missing.")
620
+ generation_config = _build_generation_config(
621
+ provider, model_name, config, debug
622
+ )
623
+ return call_moonshot_endpoint(
624
+ api_key=api_key,
625
+ model_name=model_name,
626
+ parts=api_parts,
627
+ generation_config=generation_config,
628
+ system_prompt=system_prompt,
629
+ debug=debug,
630
+ enable_web_search=config.enable_web_search,
631
+ )
632
+ elif provider == "OpenRouter":
633
+ api_key = config.openrouter_api_key
634
+ if not api_key:
635
+ raise TranslationError("OpenRouter API key is missing.")
636
+ generation_config = _build_generation_config(
637
+ provider, model_name, config, debug
638
+ )
639
+ return call_openrouter_endpoint(
640
+ api_key=api_key,
641
+ model_name=model_name,
642
+ parts=api_parts,
643
+ generation_config=generation_config,
644
+ system_prompt=system_prompt,
645
+ debug=debug,
646
+ enable_web_search=config.enable_web_search,
647
+ )
648
+ elif provider == "OpenAI-Compatible":
649
+ base_url = config.openai_compatible_url
650
+ api_key = config.openai_compatible_api_key # Optional
651
+ if not base_url:
652
+ raise TranslationError("OpenAI-Compatible URL is missing.")
653
+ generation_config = _build_generation_config(
654
+ provider, model_name, config, debug
655
+ )
656
+ return call_openai_compatible_endpoint(
657
+ base_url=base_url,
658
+ api_key=api_key,
659
+ model_name=model_name,
660
+ parts=api_parts,
661
+ generation_config=generation_config,
662
+ system_prompt=system_prompt,
663
+ debug=debug,
664
+ )
665
+ else:
666
+ raise TranslationError(
667
+ f"Unknown translation provider specified: {provider}"
668
+ )
669
+
670
+ except (ValueError, RuntimeError):
671
+ raise
672
+
673
+
674
+ def _parse_llm_response_unified(
675
+ response_text: Optional[str],
676
+ total_elements: int,
677
+ provider: str,
678
+ debug: bool = False,
679
+ ) -> List[str]:
680
+ """Parse LLM response with a single numbered list."""
681
+ if response_text is None:
682
+ log_message(f"API call failed: {provider} returned None", always_print=True)
683
+ raise TranslationError(f"{provider}: API failed (returned None)")
684
+ elif response_text == "":
685
+ log_message(f"API call returned empty response: {provider}", always_print=True)
686
+ raise TranslationError(f"{provider}: Empty response")
687
+
688
+ try:
689
+ log_message(
690
+ f"Parsing {provider} unified response: {len(response_text)} chars",
691
+ verbose=debug,
692
+ )
693
+ log_message(f"Raw response:\n---\n{response_text}\n---", always_print=True)
694
+
695
+ # Pattern matches "1: text" or "1. text" or "1 text" etc.
696
+ pattern = re.compile(
697
+ r'^\s*(\d+)\s*[:.]\s*"?\s*(.*?)\s*"?\s*(?=\s*\n\s*\d+\s*[:.]|\s*$)',
698
+ re.MULTILINE | re.DOTALL,
699
+ )
700
+
701
+ matches = pattern.findall(response_text)
702
+ result_dict = {}
703
+
704
+ for num_str, text in matches:
705
+ try:
706
+ num = int(num_str)
707
+ if 1 <= num <= total_elements:
708
+ result_dict[num] = text.strip()
709
+ except ValueError:
710
+ continue
711
+
712
+ final_list = []
713
+ for i in range(1, total_elements + 1):
714
+ if i in result_dict:
715
+ final_list.append(result_dict[i])
716
+ else:
717
+ final_list.append(f"[{provider}: Missing item {i}]")
718
+
719
+ log_message(
720
+ f"Parsed {len(result_dict)} items from unified response (expected {total_elements})",
721
+ verbose=debug,
722
+ )
723
+ return final_list
724
+
725
+ except Exception as e:
726
+ log_message(
727
+ f"Failed to parse {provider} unified response: {str(e)}",
728
+ always_print=True,
729
+ )
730
+ return [f"[{provider}: Parse error]"] * total_elements
731
+
732
+
733
+ def _prepare_images_for_ocr(
734
+ images_b64: List[str], verbose: bool = False
735
+ ) -> List[Optional[Image.Image]]:
736
+ """Prepare base64-encoded images for OCR by decoding and converting to RGB.
737
+
738
+ Args:
739
+ images_b64: List of base64-encoded image strings
740
+ verbose: Whether to print verbose logging
741
+
742
+ Returns:
743
+ List of PIL Images (or None for decode failures), all in RGB mode
744
+ """
745
+ pil_images = []
746
+ for img_b64 in images_b64:
747
+ try:
748
+ image_data = base64.b64decode(img_b64)
749
+ pil_img = Image.open(BytesIO(image_data))
750
+ if pil_img.mode != "RGB":
751
+ pil_img = pil_img.convert("RGB")
752
+ pil_images.append(pil_img)
753
+ except Exception as e:
754
+ log_message(
755
+ f"Failed to decode image for manga-ocr: {e}",
756
+ always_print=True,
757
+ )
758
+ pil_images.append(None)
759
+ return pil_images
760
+
761
+
762
+ def _format_ocr_results(
763
+ extracted_texts: List[str],
764
+ bubble_metadata: List[Dict[str, Any]],
765
+ ) -> None:
766
+ """Format and log OCR results.
767
+
768
+ Args:
769
+ extracted_texts: List of extracted text strings
770
+ bubble_metadata: List of metadata dicts for text elements
771
+ verbose: Whether to print verbose logging
772
+ """
773
+ log_lines = []
774
+
775
+ for i, text in enumerate(extracted_texts):
776
+ metadata = bubble_metadata[i] if i < len(bubble_metadata) else {}
777
+ is_osb = metadata.get("is_outside_text", False)
778
+ prefix = f"{i + 1}"
779
+ type_label = "[OSB]" if is_osb else "[Bubble]"
780
+
781
+ log_lines.append(f"{prefix}: {type_label} {text}")
782
+
783
+ if log_lines:
784
+ log_message(
785
+ f"Raw OCR output:\n---\n{chr(10).join(log_lines)}\n---",
786
+ always_print=True,
787
+ )
788
+
789
+
790
+ def _check_ocr_failure(texts: List[str], provider: Optional[str] = None) -> bool:
791
+ """Check if all OCR results indicate failure.
792
+
793
+ Args:
794
+ texts: List of extracted text strings
795
+ provider: Optional provider name for LLM OCR failure detection
796
+
797
+ Returns:
798
+ True if all texts indicate failure, False otherwise
799
+ """
800
+ if not texts:
801
+ return True
802
+
803
+ if provider:
804
+ for text in texts:
805
+ if f"[{provider}-OCR:" not in text:
806
+ return False
807
+ return True
808
+ else:
809
+ return all(text == "[OCR FAILED]" for text in texts)
810
+
811
+
812
+ def _format_special_instructions(config: TranslationConfig) -> str:
813
+ """Format user's special instructions section for prompts.
814
+
815
+ Args:
816
+ config: TranslationConfig with special_instructions
817
+
818
+ Returns:
819
+ Formatted special instructions string (empty if none)
820
+ """
821
+ if config.special_instructions and config.special_instructions.strip():
822
+ return f"""
823
+
824
+ ## SPECIAL INSTRUCTIONS
825
+ {config.special_instructions.strip()}
826
+ """
827
+ return ""
828
+
829
+
830
+ def _perform_manga_ocr(
831
+ images_b64: List[str],
832
+ bubble_metadata: List[Dict[str, Any]],
833
+ debug: bool = False,
834
+ ) -> List[str]:
835
+ """Perform OCR using manga-ocr model.
836
+
837
+ Args:
838
+ images_b64: List of base64-encoded images
839
+ bubble_metadata: List of metadata dicts for text elements
840
+ debug: Whether to print verbose logging
841
+
842
+ Returns:
843
+ List of extracted text strings, or early return with failure list
844
+ """
845
+ total_elements = len(images_b64)
846
+ log_message("Using manga-ocr for text extraction", verbose=debug)
847
+
848
+ cache = get_cache()
849
+ cache_key = cache.get_manga_ocr_cache_key(images_b64, total_elements)
850
+ cached_ocr = cache.get_manga_ocr_result(cache_key)
851
+ if cached_ocr is not None:
852
+ if len(cached_ocr) == total_elements:
853
+ log_message("Using cached manga-ocr results", verbose=debug)
854
+ return cached_ocr
855
+ log_message("Discarding manga-ocr cache due to length mismatch", verbose=debug)
856
+
857
+ pil_images = _prepare_images_for_ocr(images_b64, verbose=debug)
858
+ extracted_texts = extract_text_with_manga_ocr(pil_images, verbose=debug)
859
+
860
+ formatted_texts = []
861
+ for i, text in enumerate(extracted_texts):
862
+ if text == "[OCR FAILED]" or not text:
863
+ formatted_texts.append(text if text else "[OCR FAILED]")
864
+ else:
865
+ formatted_texts.append(text)
866
+
867
+ extracted_texts = formatted_texts
868
+
869
+ _format_ocr_results(extracted_texts, bubble_metadata)
870
+
871
+ if len(extracted_texts) != total_elements:
872
+ msg = (
873
+ f"Warning: extracted_texts length ({len(extracted_texts)}) "
874
+ f"doesn't match total_elements ({total_elements})"
875
+ )
876
+ log_message(msg, always_print=True)
877
+ while len(extracted_texts) < total_elements:
878
+ extracted_texts.append("[OCR FAILED]")
879
+ extracted_texts = extracted_texts[:total_elements]
880
+
881
+ if not extracted_texts:
882
+ log_message("manga-ocr returned empty results", verbose=debug)
883
+ failure_results = ["[OCR FAILED]"] * total_elements
884
+ cache.set_manga_ocr_result(cache_key, failure_results, debug)
885
+ return failure_results
886
+
887
+ if _check_ocr_failure(extracted_texts):
888
+ log_message("manga-ocr returned only failures", verbose=debug)
889
+ cache.set_manga_ocr_result(cache_key, extracted_texts, debug)
890
+ return extracted_texts
891
+
892
+ cache.set_manga_ocr_result(cache_key, extracted_texts, debug)
893
+ return extracted_texts
894
+
895
+
896
+ def _perform_llm_ocr(
897
+ config: TranslationConfig,
898
+ images_b64: List[str],
899
+ mime_types: List[str],
900
+ ocr_prompt: str,
901
+ is_gemini_3: bool,
902
+ provider: str,
903
+ input_language: Optional[str],
904
+ reading_direction: str,
905
+ debug: bool = False,
906
+ ) -> List[str]:
907
+ """Perform OCR using vision LLM.
908
+
909
+ Args:
910
+ config: TranslationConfig
911
+ images_b64: List of base64-encoded images
912
+ mime_types: List of MIME types for each image
913
+ ocr_prompt: OCR prompt text
914
+ is_gemini_3: Whether model is Gemini 3
915
+ provider: Provider name
916
+ input_language: Input language
917
+ reading_direction: Reading direction
918
+ debug: Whether to print verbose logging
919
+
920
+ Returns:
921
+ List of extracted text strings, or early return with failure list
922
+ """
923
+ total_elements = len(images_b64)
924
+ ocr_parts = []
925
+ for i, img_b64 in enumerate(images_b64):
926
+ mime_type = mime_types[i] if i < len(mime_types) else "image/jpeg"
927
+ bubble_part = {"inline_data": {"mime_type": mime_type, "data": img_b64}}
928
+ if is_gemini_3:
929
+ bubble_part = _add_media_resolution_to_part(
930
+ bubble_part, config.media_resolution_bubbles, is_gemini_3
931
+ )
932
+ ocr_parts.append(bubble_part)
933
+
934
+ ocr_system = _build_system_prompt_ocr(input_language, reading_direction)
935
+ ocr_response_text = _call_llm_endpoint(
936
+ config,
937
+ ocr_parts,
938
+ ocr_prompt,
939
+ debug,
940
+ system_prompt=ocr_system,
941
+ )
942
+ extracted_texts = _parse_llm_response_unified(
943
+ ocr_response_text,
944
+ total_elements,
945
+ provider + "-OCR",
946
+ debug,
947
+ )
948
+
949
+ if extracted_texts is None:
950
+ log_message("OCR API call failed", always_print=True)
951
+ return [f"[{provider}: OCR failed]"] * total_elements
952
+
953
+ if _check_ocr_failure(extracted_texts, provider):
954
+ log_message("OCR returned only placeholders", verbose=debug)
955
+ return extracted_texts
956
+
957
+ return extracted_texts
958
+
959
+
960
+ def call_translation_api_batch(
961
+ config: TranslationConfig,
962
+ images_b64: List[str],
963
+ full_image_b64: str,
964
+ mime_types: List[str],
965
+ full_image_mime_type: str,
966
+ bubble_metadata: List[Dict[str, Any]],
967
+ debug: bool = False,
968
+ ) -> List[str]:
969
+ """
970
+ Generates prompts and calls the appropriate LLM API endpoint based on the provider and mode
971
+ specified in the configuration, translating text from speech bubbles and outside-bubble text.
972
+
973
+ Supports "one-step" (OCR+Translate+Style) and "two-step" (OCR then Translate+Style) modes.
974
+
975
+ Args:
976
+ config (TranslationConfig): Configuration object.
977
+ images_b64 (list): List of base64 encoded images of all text elements, in reading order.
978
+ full_image_b64 (str): Base64 encoded image of the full manga page.
979
+ mime_types (List[str]): List of MIME types for each text element image.
980
+ full_image_mime_type (str): MIME type of the full page image.
981
+ bubble_metadata (List[Dict]): List of metadata dicts with 'is_outside_text' flags for each image.
982
+ debug (bool): Whether to print debugging information.
983
+
984
+ Returns:
985
+ list: List of translated strings (potentially with style markers), one for each input text element.
986
+ Returns placeholder messages on errors or empty responses.
987
+
988
+ Raises:
989
+ ValueError: If required config (API key, provider, URL) is missing or invalid.
990
+ RuntimeError: If an API call fails irrecoverably after retries (raised by endpoint functions).
991
+ """
992
+ provider = config.provider
993
+ input_language = config.input_language
994
+ output_language = config.output_language
995
+ reading_direction = config.reading_direction
996
+ translation_mode = config.translation_mode
997
+
998
+ # Include conditional bubble hints
999
+ total_elements = len(images_b64)
1000
+ dialogue_indices = [
1001
+ i + 1
1002
+ for i, meta in enumerate(bubble_metadata)
1003
+ if not meta.get("is_outside_text", False)
1004
+ ]
1005
+ osb_indices = [
1006
+ i + 1
1007
+ for i, meta in enumerate(bubble_metadata)
1008
+ if meta.get("is_outside_text", False)
1009
+ ]
1010
+
1011
+ hints = []
1012
+ if dialogue_indices:
1013
+ dialogue_list_str = ", ".join(map(str, dialogue_indices))
1014
+ hints.append(f"Items [{dialogue_list_str}] contain spoken dialogue.")
1015
+ if osb_indices:
1016
+ osb_list_str = ", ".join(map(str, osb_indices))
1017
+ hints.append(
1018
+ f"Items [{osb_list_str}] contain sound effects, mimetic effects, narration, or internal monologues."
1019
+ )
1020
+
1021
+ context_hints = ""
1022
+ if hints:
1023
+ context_hints = "\nNote: " + " ".join(hints) + " Translate them accordingly."
1024
+
1025
+ reading_order_desc = (
1026
+ "right-to-left, top-to-bottom"
1027
+ if reading_direction == "rtl"
1028
+ else "left-to-right, top-to-bottom"
1029
+ )
1030
+
1031
+ cache = get_cache()
1032
+ cache_key = cache.get_translation_cache_key(images_b64, full_image_b64, config)
1033
+ cached_translation = cache.get_translation(cache_key)
1034
+ if cached_translation is not None:
1035
+ log_message(" - Using cached translation", verbose=debug)
1036
+ return cached_translation
1037
+
1038
+ model_name = config.model_name
1039
+ is_gemini_3 = provider == "Google" and "gemini-3" in model_name.lower()
1040
+
1041
+ base_parts = []
1042
+ for i, img_b64 in enumerate(images_b64):
1043
+ mime_type = mime_types[i] if i < len(mime_types) else "image/jpeg"
1044
+ bubble_part = {"inline_data": {"mime_type": mime_type, "data": img_b64}}
1045
+ if is_gemini_3:
1046
+ bubble_part = _add_media_resolution_to_part(
1047
+ bubble_part, config.media_resolution_bubbles, is_gemini_3
1048
+ )
1049
+ base_parts.append(bubble_part)
1050
+
1051
+ if config.send_full_page_context and full_image_b64:
1052
+ context_part = {
1053
+ "inline_data": {
1054
+ "mime_type": full_image_mime_type,
1055
+ "data": full_image_b64,
1056
+ }
1057
+ }
1058
+ if is_gemini_3:
1059
+ context_part = _add_media_resolution_to_part(
1060
+ context_part, config.media_resolution_context, is_gemini_3
1061
+ )
1062
+ base_parts.append(context_part)
1063
+
1064
+ try:
1065
+ if translation_mode == "two-step":
1066
+ special_instructions_section = _format_special_instructions(config)
1067
+
1068
+ ocr_prompt = f"""
1069
+ ## CONTEXT
1070
+ You have been provided with {total_elements} individual text images from a manga page. They are presented in their natural reading order ({reading_order_desc}).
1071
+
1072
+ ## TASK
1073
+ Apply your OCR transcription rules to each image provided.{special_instructions_section}
1074
+ """ # noqa
1075
+
1076
+ log_message("Starting OCR step", verbose=debug)
1077
+
1078
+ if config.ocr_method == "manga-ocr":
1079
+ extracted_texts = _perform_manga_ocr(
1080
+ images_b64,
1081
+ bubble_metadata,
1082
+ debug,
1083
+ )
1084
+ else:
1085
+ extracted_texts = _perform_llm_ocr(
1086
+ config,
1087
+ images_b64,
1088
+ mime_types,
1089
+ ocr_prompt,
1090
+ is_gemini_3,
1091
+ provider,
1092
+ input_language,
1093
+ reading_direction,
1094
+ debug,
1095
+ )
1096
+
1097
+ log_message("Starting translation step", verbose=debug)
1098
+
1099
+ formatted_texts = []
1100
+ ocr_failed_indices = set()
1101
+ for i, text in enumerate(extracted_texts):
1102
+ if f"[{provider}-OCR:" in text or text == "[OCR FAILED]":
1103
+ formatted_texts.append("[OCR FAILED]")
1104
+ ocr_failed_indices.add(i)
1105
+ else:
1106
+ formatted_texts.append(text)
1107
+
1108
+ ocr_input_section = """
1109
+ ## INPUT DATA
1110
+ """
1111
+ for i, text in enumerate(formatted_texts):
1112
+ ocr_input_section += f"{i + 1}: {text}\n"
1113
+
1114
+ full_page_context = (
1115
+ "A full-page image is also provided for visual and narrative context."
1116
+ if (
1117
+ config.ocr_method != "manga-ocr"
1118
+ and config.send_full_page_context
1119
+ and full_image_b64
1120
+ )
1121
+ else ""
1122
+ )
1123
+
1124
+ special_instructions_section = _format_special_instructions(config)
1125
+
1126
+ translation_prompt = f"""
1127
+ ## CONTEXT
1128
+ You have been provided with a list of {total_elements} transcribed text segments from a manga page. {full_page_context}
1129
+ {context_hints}
1130
+
1131
+ {ocr_input_section}
1132
+
1133
+ ## TASK
1134
+ Apply your translation and styling rules to the text in the `## INPUT DATA` section.
1135
+ The target language is {output_language}. Use the appropriate translation approach for each text type.{special_instructions_section}
1136
+ """ # noqa
1137
+
1138
+ translation_parts = []
1139
+ if (
1140
+ config.ocr_method != "manga-ocr"
1141
+ and config.send_full_page_context
1142
+ and full_image_b64
1143
+ ):
1144
+ context_part = {
1145
+ "inline_data": {
1146
+ "mime_type": full_image_mime_type,
1147
+ "data": full_image_b64,
1148
+ }
1149
+ }
1150
+ if is_gemini_3:
1151
+ context_part = _add_media_resolution_to_part(
1152
+ context_part, config.media_resolution_context, is_gemini_3
1153
+ )
1154
+ translation_parts.append(context_part)
1155
+
1156
+ translation_system = _build_system_prompt_translation(
1157
+ output_language,
1158
+ mode="two-step",
1159
+ reading_direction=reading_direction,
1160
+ full_page_context=(
1161
+ config.send_full_page_context and bool(full_image_b64)
1162
+ ),
1163
+ )
1164
+ translation_response_text = _call_llm_endpoint(
1165
+ config,
1166
+ translation_parts,
1167
+ translation_prompt,
1168
+ debug,
1169
+ system_prompt=translation_system,
1170
+ )
1171
+ final_translations = _parse_llm_response_unified(
1172
+ translation_response_text,
1173
+ total_elements,
1174
+ provider + "-Translate",
1175
+ debug,
1176
+ )
1177
+
1178
+ if final_translations is None:
1179
+ log_message("Translation API call failed", always_print=True)
1180
+ combined_results = []
1181
+ for i in range(total_elements):
1182
+ if i in ocr_failed_indices:
1183
+ combined_results.append(f"[{provider}: OCR Failed]")
1184
+ else:
1185
+ combined_results.append(f"[{provider}: Translation failed]")
1186
+ return combined_results
1187
+
1188
+ combined_results = []
1189
+ for i in range(total_elements):
1190
+ if i in ocr_failed_indices:
1191
+ if final_translations[i] == "[OCR FAILED]":
1192
+ combined_results.append("[OCR FAILED]")
1193
+ else:
1194
+ log_message(
1195
+ f"Element {i + 1}: LLM ignored OCR failure instruction",
1196
+ verbose=debug,
1197
+ )
1198
+ combined_results.append("[OCR FAILED]")
1199
+ else:
1200
+ combined_results.append(final_translations[i])
1201
+
1202
+ cache.set_translation(cache_key, combined_results)
1203
+ return combined_results
1204
+
1205
+ elif translation_mode == "one-step":
1206
+ log_message("Starting one-step translation", verbose=debug)
1207
+
1208
+ full_page_context = (
1209
+ "A full-page image is also provided for visual and narrative context."
1210
+ if config.send_full_page_context
1211
+ else ""
1212
+ )
1213
+
1214
+ special_instructions_section = _format_special_instructions(config)
1215
+
1216
+ one_step_prompt = f"""
1217
+ ## CONTEXT
1218
+ You have been provided with {total_elements} individual text images from a manga page. {full_page_context}
1219
+ {context_hints}
1220
+
1221
+ ## TASK
1222
+ For each image, you must perform two steps:
1223
+ 1. **Transcribe:** Extract the original text exactly as it appears.
1224
+ 2. **Translate:** Translate the text you just transcribed into {output_language}, applying your translation and styling rules.{special_instructions_section}
1225
+
1226
+ ## OUTPUT FORMAT
1227
+ You must return your response as a single numbered list with exactly one line per input image.
1228
+ The numbering must correspond to the input image order (1, 2, 3...).
1229
+ Format: `i: <transcribed text> || <translated {output_language} text>`
1230
+ """ # noqa
1231
+
1232
+ one_step_system = _build_system_prompt_translation(
1233
+ output_language,
1234
+ mode="one-step",
1235
+ reading_direction=reading_direction,
1236
+ full_page_context=(
1237
+ config.send_full_page_context and bool(full_image_b64)
1238
+ ),
1239
+ )
1240
+ response_text = _call_llm_endpoint(
1241
+ config,
1242
+ base_parts,
1243
+ one_step_prompt,
1244
+ debug,
1245
+ system_prompt=one_step_system,
1246
+ )
1247
+
1248
+ # Parse one-step format ("Original || Translated")
1249
+ raw_lines = _parse_llm_response_unified(
1250
+ response_text, total_elements, provider, debug
1251
+ )
1252
+
1253
+ translations = []
1254
+ for line in raw_lines:
1255
+ if "||" in line:
1256
+ parts = line.split("||", 1)
1257
+ translations.append(parts[1].strip())
1258
+ else:
1259
+ translations.append(line)
1260
+
1261
+ cache.set_translation(cache_key, translations)
1262
+ return translations
1263
+ else:
1264
+ raise TranslationError(
1265
+ f"Unknown translation_mode specified in config: {translation_mode}"
1266
+ )
1267
+ except TranslationError:
1268
+ raise
1269
+ except (ValueError, RuntimeError) as e:
1270
+ log_message(f"Translation error: {e}", always_print=True)
1271
+ return [f"[Translation Error: {e}]"] * total_elements
1272
+
1273
+
1274
+ def prepare_bubble_images_for_translation(
1275
+ bubble_data: List[Dict[str, Any]],
1276
+ original_cv_image: np.ndarray,
1277
+ upscale_model: Any,
1278
+ device: Any,
1279
+ mime_type: str,
1280
+ bubble_min_side_pixels: int,
1281
+ upscale_method: str = "model_lite",
1282
+ verbose: bool = False,
1283
+ ) -> List[Dict[str, Any]]:
1284
+ """
1285
+ Prepare bubble images for translation by cropping, upscaling, color matching, and encoding.
1286
+
1287
+ This function processes each speech bubble to prepare it for the translation API:
1288
+ 1. Crops the bubble from the original image
1289
+ 2. Upscales the bubble to meet minimum size requirements (based on upscale_method)
1290
+ 3. Matches colors to preserve visual consistency (only for model upscaling)
1291
+ 4. Encodes the processed bubble as base64 for API transmission
1292
+
1293
+ Args:
1294
+ bubble_data: List of bubble detection dicts with 'bbox' keys
1295
+ original_cv_image: OpenCV image array of the original image
1296
+ upscale_model: Loaded upscaling model
1297
+ device: PyTorch device for model inference
1298
+ mime_type: MIME type for image encoding
1299
+ upscale_method: Method for upscaling - "model", "lanczos", or "none"
1300
+ verbose: Whether to print detailed logging
1301
+
1302
+ Returns:
1303
+ List of bubble dicts with added 'image_b64' and 'mime_type' keys
1304
+ (immutable approach - returns new list without mutating input)
1305
+ """
1306
+ cv2_ext = ".png" if mime_type == "image/png" else ".jpg"
1307
+
1308
+ prepared_bubbles = []
1309
+
1310
+ if upscale_method == "model":
1311
+ log_message(
1312
+ f"Upscaling {len(bubble_data)} bubble images with 2x-AnimeSharpV4_RCAN",
1313
+ always_print=True,
1314
+ )
1315
+ elif upscale_method == "model_lite":
1316
+ log_message(
1317
+ f"Upscaling {len(bubble_data)} bubble images with 2x-AnimeSharpV4_Fast_RCAN_PU (Lite)",
1318
+ always_print=True,
1319
+ )
1320
+ elif upscale_method == "lanczos":
1321
+ log_message(
1322
+ f"Upscaling {len(bubble_data)} bubble images with LANCZOS",
1323
+ always_print=True,
1324
+ )
1325
+ else: # upscale_method == "none"
1326
+ log_message(
1327
+ f"Processing {len(bubble_data)} bubble images without upscaling",
1328
+ always_print=True,
1329
+ )
1330
+
1331
+ for bubble in bubble_data:
1332
+ prepared_bubble = bubble.copy()
1333
+ x1, y1, x2, y2 = bubble["bbox"]
1334
+
1335
+ bubble_image_cv = original_cv_image[y1:y2, x1:x2].copy()
1336
+ bubble_image_pil = cv2_to_pil(bubble_image_cv)
1337
+
1338
+ if upscale_method == "model" or upscale_method == "model_lite":
1339
+ final_bubble_pil = process_bubble_image_cached(
1340
+ bubble_image_pil,
1341
+ upscale_model,
1342
+ device,
1343
+ bubble_min_side_pixels,
1344
+ "min",
1345
+ upscale_method,
1346
+ verbose,
1347
+ )
1348
+ elif upscale_method == "lanczos":
1349
+ w, h = bubble_image_pil.size
1350
+ min_side = min(w, h)
1351
+ if min_side < bubble_min_side_pixels:
1352
+ scale_factor = bubble_min_side_pixels / min_side
1353
+ new_w = int(w * scale_factor)
1354
+ new_h = int(h * scale_factor)
1355
+ resized_bubble = bubble_image_pil.resize((new_w, new_h), Image.LANCZOS)
1356
+ else:
1357
+ resized_bubble = bubble_image_pil
1358
+ final_bubble_pil = resized_bubble
1359
+ else: # upscale_method == "none"
1360
+ final_bubble_pil = bubble_image_pil
1361
+
1362
+ final_bubble_cv = pil_to_cv2(final_bubble_pil)
1363
+
1364
+ try:
1365
+ is_success, buffer = cv2.imencode(cv2_ext, final_bubble_cv)
1366
+ if is_success:
1367
+ image_b64 = base64.b64encode(buffer).decode("utf-8")
1368
+ prepared_bubble["image_b64"] = image_b64
1369
+ prepared_bubble["mime_type"] = mime_type
1370
+ log_message(
1371
+ f"Bubble {x1},{y1} ({final_bubble_pil.size[0]}x{final_bubble_pil.size[1]})",
1372
+ verbose=verbose,
1373
+ )
1374
+ else:
1375
+ log_message(
1376
+ f"Failed to encode bubble {bubble['bbox']}", verbose=verbose
1377
+ )
1378
+ prepared_bubble["image_b64"] = None
1379
+ except Exception as e:
1380
+ log_message(f"Error encoding bubble {bubble['bbox']}: {e}", verbose=verbose)
1381
+ prepared_bubble["image_b64"] = None
1382
+
1383
+ prepared_bubbles.append(prepared_bubble)
1384
+
1385
+ return prepared_bubbles
core/text/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text processing and rendering modules for MangaTranslator.
3
+
4
+ This subpackage contains modules for:
5
+ - Text processing and tokenization
6
+ - Font management and loading
7
+ - Layout engine for optimal text placement
8
+ - Drawing engine using Skia
9
+ - High-level text rendering orchestration
10
+ """
11
+
12
+ from .drawing_engine import (
13
+ draw_layout,
14
+ load_font_resources,
15
+ pil_to_skia_surface,
16
+ skia_surface_to_pil,
17
+ )
18
+ from .font_manager import (
19
+ LRUCache,
20
+ find_font_variants,
21
+ get_font_features,
22
+ load_font_data,
23
+ )
24
+ from .layout_engine import find_optimal_layout, shape_line
25
+ from .text_processing import (
26
+ find_optimal_breaks_dp,
27
+ parse_styled_segments,
28
+ tokenize_styled_text,
29
+ try_hyphenate_word,
30
+ )
31
+ from .text_renderer import render_text_skia
32
+
33
+ __all__ = [
34
+ "draw_layout",
35
+ "load_font_resources",
36
+ "pil_to_skia_surface",
37
+ "skia_surface_to_pil",
38
+ "find_font_variants",
39
+ "get_font_features",
40
+ "LRUCache",
41
+ "load_font_data",
42
+ "find_optimal_layout",
43
+ "shape_line",
44
+ "find_optimal_breaks_dp",
45
+ "parse_styled_segments",
46
+ "tokenize_styled_text",
47
+ "try_hyphenate_word",
48
+ "render_text_skia",
49
+ ]
core/text/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.37 kB). View file