from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks import numpy as np import cv2 import uvicorn from PIL import Image import io from typing import List, Dict, Any, Optional, Tuple from pydantic import BaseModel import logging from pathlib import Path import time import hashlib from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from concurrent.futures import ThreadPoolExecutor from collections import defaultdict from dataclasses import dataclass, field import warnings import torch from torchvision import transforms import onnxruntime as ort from sklearn.cluster import KMeans import os os.environ["OMP_NUM_THREADS"] = "1" warnings.filterwarnings("ignore") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="Seat Extraction API v9.0 (No OCR)", description="BG removal → Section detection )", version="9.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) CACHE_DIR = Path("cache") CACHE_DIR.mkdir(exist_ok=True) RESULTS_CACHE = {} MAX_CACHE_SIZE = 100 extractor = None class PolygonResponse(BaseModel): polygons: List[List[List[float]]] confidence_scores: List[float] areas: List[float] bounding_boxes: List[List[float]] labels: List[str] colors: List[str] seat_groups: Dict[str, List[int]] processing_info: Dict[str, Any] cache_hit: bool = False detected_text: List[Dict[str, Any]] = [] geojson: Optional[Dict[str, Any]] = None @dataclass class OptimizationConfig: """Configuration for seat detection (OCR removed)""" use_background_removal: bool = True # Color detection exclude_pure_black: bool = True exclude_pure_white: bool = True use_color_clustering: bool = True n_color_clusters: int = 20 # Detection thresholds min_section_area: int = 500 max_section_area: int = 50000 min_solidity: float = 0.3 # Morphology morphology_kernel_size: int = 3 class BackgroundRemover: """Background removal using BiRefNet ONNX""" def __init__(self): self.session = None self.input_name = None self.output_name = None self.transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def load_model(self): if self.session is None: try: providers = [] if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers(): providers.append('CUDAExecutionProvider') providers.append('CPUExecutionProvider') model_path = "models/BiRefNet.onnx" self.session = ort.InferenceSession(model_path, providers=providers) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name logger.info(f"BiRefNet loaded: {self.session.get_providers()}") except Exception as e: logger.error(f"BiRefNet load failed: {e}") self.session = None def remove_background(self, image: Image.Image) -> Tuple[Image.Image, np.ndarray]: if self.session is None: if image.mode != 'RGB': image = image.convert('RGB') return image, None if image.mode != 'RGB': image = image.convert('RGB') image_size = image.size input_tensor = self.transform(image).unsqueeze(0) input_numpy = input_tensor.numpy() try: outputs = self.session.run([self.output_name], {self.input_name: input_numpy}) pred_numpy = outputs[0][0] pred_numpy = 1 / (1 + np.exp(-pred_numpy)) if len(pred_numpy.shape) == 3: pred_numpy = pred_numpy[0] pred_numpy = (pred_numpy * 255).astype(np.uint8) pred_pil = Image.fromarray(pred_numpy, mode='L') mask = pred_pil.resize(image_size) except Exception as e: logger.error(f"ONNX inference failed: {e}") return image, None mask_np = np.array(mask) if len(mask_np.shape) == 3: mask_np = mask_np[:, :, 0] image_array = np.array(image) if len(image_array.shape) == 2: image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB) elif image_array.shape[2] == 4: image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB) masked_array = np.zeros_like(image_array) mask_normalized = mask_np.astype(np.float32) / 255.0 for c in range(3): masked_array[:, :, c] = (image_array[:, :, c] * mask_normalized).astype(np.uint8) processed_image = Image.fromarray(masked_array) return processed_image, mask_np class SmartColorDetector: """Detect all colors except pure black/white""" def __init__(self, config: OptimizationConfig): self.config = config def create_valid_color_mask(self, image: np.ndarray) -> np.ndarray: """Create mask for all colored pixels (not pure black/white)""" hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) h, s, v = cv2.split(hsv) valid_mask = np.ones(image.shape[:2], dtype=np.uint8) * 255 if self.config.exclude_pure_black: black_mask = v < 20 valid_mask[black_mask] = 0 logger.info(f"Excluded {np.sum(black_mask)} pure black pixels") if self.config.exclude_pure_white: white_mask = (v > 235) & (s < 25) valid_mask[white_mask] = 0 logger.info(f"Excluded {np.sum(white_mask)} pure white pixels") logger.info(f"Valid colored pixels: {np.sum(valid_mask > 0)}") return valid_mask def cluster_colors(self, image: np.ndarray, valid_mask: np.ndarray) -> List[np.ndarray]: """Group similar colors using K-means clustering""" masks = [] valid_pixels = image[valid_mask > 0] if len(valid_pixels) < 100: logger.warning("Not enough valid pixels for clustering") return [valid_mask] pixels_flat = valid_pixels.reshape(-1, 3).astype(np.float32) n_clusters = min(self.config.n_color_clusters, len(pixels_flat) // 100) if n_clusters < 2: return [valid_mask] logger.info(f"Clustering into {n_clusters} color groups...") try: kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) labels = kmeans.fit_predict(pixels_flat) centers = kmeans.cluster_centers_.astype(np.uint8) pixel_coords = np.argwhere(valid_mask > 0) for cluster_id in range(n_clusters): cluster_mask = np.zeros(image.shape[:2], dtype=np.uint8) cluster_pixels = pixel_coords[labels == cluster_id] if len(cluster_pixels) < 50: continue for coord in cluster_pixels: cluster_mask[coord[0], coord[1]] = 255 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_CLOSE, kernel, iterations=2) cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_OPEN, kernel, iterations=1) if np.sum(cluster_mask) > 100: masks.append(cluster_mask) logger.info(f" Cluster {cluster_id}: {np.sum(cluster_mask)} pixels, " f"center color: {centers[cluster_id]}") except Exception as e: logger.error(f"Clustering failed: {e}") return [valid_mask] return masks class EnhancedSeatExtractor: def __init__(self, config: OptimizationConfig = OptimizationConfig()): self.config = config self.executor = ThreadPoolExecutor(max_workers=4) self.bg_remover = BackgroundRemover() self.color_detector = SmartColorDetector(config) logger.info("Enhanced Extractor initialized") def compute_image_hash(self, image: np.ndarray) -> str: return hashlib.md5(image.tobytes()).hexdigest() def extract_dominant_color(self, image: np.ndarray, contour: np.ndarray) -> str: """ Trích xuất màu chủ đạo từ contour và convert sang HEX """ # Tạo mask cho contour mask = np.zeros(image.shape[:2], dtype=np.uint8) cv2.drawContours(mask, [contour], 0, 255, -1) # Lấy pixels trong vùng contour pixels = image[mask > 0] if len(pixels) == 0: return "#808080" # Gray mặc định # Tính màu trung bình mean_color = np.mean(pixels, axis=0).astype(int) # Convert RGB to HEX hex_color = "#{:02x}{:02x}{:02x}".format( int(mean_color[0]), int(mean_color[1]), int(mean_color[2]) ) return hex_color def detect_sections_in_mask(self, mask: np.ndarray, image: np.ndarray) -> List[Dict]: """ Detect sections from a color mask và extract màu """ sections = [] if np.sum(mask) < self.config.min_section_area: return sections kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (self.config.morphology_kernel_size, self.config.morphology_kernel_size) ) cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: area = cv2.contourArea(contour) if area < self.config.min_section_area or area > self.config.max_section_area: continue hull = cv2.convexHull(contour) hull_area = cv2.contourArea(hull) solidity = area / hull_area if hull_area > 0 else 0 if solidity < self.config.min_solidity: continue epsilon = 0.01 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) if len(approx) >= 3: x, y, w, h = cv2.boundingRect(contour) # Extract màu chủ đạo dominant_color = self.extract_dominant_color(image, approx) sections.append({ 'contour': approx, 'bbox': [x, y, x + w, y + h], 'area': area, 'confidence': min(1.0, solidity), 'center': (x + w // 2, y + h // 2), 'solidity': solidity, 'color': dominant_color }) return sections def extract_polygons_enhanced(self, image: np.ndarray) -> PolygonResponse: """ PIPELINE: 1. Background removal for section detection 2. Color detection & clustering 3. Section detection + Color extraction """ start_time = time.time() image_hash = self.compute_image_hash(image) if image_hash in RESULTS_CACHE: logger.info("Returning cached results") cached_result = RESULTS_CACHE[image_hash] cached_result.cache_hit = True return cached_result # Ensure RGB if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif len(image.shape) == 3: if image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # Step 1: Background Removal processed_image = image if self.config.use_background_removal: logger.info("Removing background for section detection...") pil_image = Image.fromarray(image).convert('RGB') bg_removed, bg_mask = self.bg_remover.remove_background(pil_image) processed_image = np.array(bg_removed) if len(processed_image.shape) != 3 or processed_image.shape[2] != 3: if len(processed_image.shape) == 2: processed_image = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB) # Step 2: Smart Color Detection logger.info("Detecting all colors (excluding black/white)...") valid_color_mask = self.color_detector.create_valid_color_mask(processed_image) # Step 3: Cluster Colors & Detect Sections all_sections = [] if self.config.use_color_clustering: logger.info("Clustering colors...") color_masks = self.color_detector.cluster_colors(processed_image, valid_color_mask) logger.info(f"Found {len(color_masks)} color groups") for i, mask in enumerate(color_masks): logger.info(f"Processing color group {i + 1}/{len(color_masks)}...") sections = self.detect_sections_in_mask(mask, processed_image) for section in sections: section['color_group'] = i all_sections.extend(sections) logger.info(f" Found {len(sections)} sections in group {i}") else: all_sections = self.detect_sections_in_mask(valid_color_mask, processed_image) # Step 4: Remove overlapping sections filtered_sections = self.remove_overlapping_sections(all_sections) # Convert to response format polygons = [] confidence_scores = [] areas = [] bounding_boxes = [] labels = [] colors = [] for i, section in enumerate(filtered_sections): contour = section['contour'] polygon = contour.reshape(-1, 2).tolist() polygons.append(polygon) confidence_scores.append(section['confidence']) areas.append(section['area']) bounding_boxes.append(section['bbox']) labels.append(f"Section_{i + 1}") colors.append(section['color']) seat_groups = self.group_sections(filtered_sections) processing_time = time.time() - start_time geojson_output = self.to_geojson(filtered_sections) response = PolygonResponse( polygons=polygons, confidence_scores=confidence_scores, areas=areas, bounding_boxes=bounding_boxes, labels=labels, colors=colors, seat_groups=seat_groups, detected_text=[], processing_info={ "total_sections": len(polygons), "total_text_regions": 0, "sections_with_text": 0, "vietnamese_text": 0, "english_text": 0, "processing_time": processing_time, "ocr_engine": "None (Disabled for performance)", "pipeline": "BG Removal → Color Detection → Section Detection + Color Extraction", "techniques": [ "BiRefNet BG removal for section detection", "Smart color detection (exclude black/white)", "K-means color clustering", "Morphological cleaning", "Dominant color extraction (HEX format)" ] }, cache_hit=False, geojson=geojson_output ) if len(RESULTS_CACHE) >= MAX_CACHE_SIZE: RESULTS_CACHE.pop(next(iter(RESULTS_CACHE))) RESULTS_CACHE[image_hash] = response return response def remove_overlapping_sections(self, sections: List[Dict]) -> List[Dict]: if not sections: return sections sorted_sections = sorted(sections, key=lambda x: x['confidence'], reverse=True) filtered = [] for section in sorted_sections: overlap = False for accepted in filtered: if self.calculate_overlap(section['bbox'], accepted['bbox']) > 0.5: overlap = True break if not overlap: filtered.append(section) return filtered def calculate_overlap(self, bbox1: List, bbox2: List) -> float: x1_1, y1_1, x2_1, y2_1 = bbox1 x1_2, y1_2, x2_2, y2_2 = bbox2 x1_int = max(x1_1, x1_2) y1_int = max(y1_1, y1_2) x2_int = min(x2_1, x2_2) y2_int = min(y2_1, y2_2) if x2_int <= x1_int or y2_int <= y1_int: return 0.0 intersection = (x2_int - x1_int) * (y2_int - y1_int) area1 = (x2_1 - x1_1) * (y2_1 - y1_1) area2 = (x2_2 - x1_2) * (y2_2 - y1_2) union = area1 + area2 - intersection return intersection / union if union > 0 else 0.0 def group_sections(self, sections: List[Dict]) -> Dict[str, List[int]]: groups = defaultdict(list) for idx, section in enumerate(sections): group_id = section.get('color_group', 0) groups[f"ColorGroup_{group_id}"].append(idx) return dict(groups) def to_geojson(self, sections: List[Dict]) -> Dict[str, Any]: features = [] for section in sections: contour = section['contour'].reshape(-1, 2).tolist() features.append({ "type": "Feature", "properties": { "confidence": section.get("confidence"), "area": section.get("area"), "color_group": section.get("color_group"), "color": section.get("color") }, "geometry": { "type": "Polygon", "coordinates": [[list(map(float, p)) for p in contour]] } }) return { "type": "FeatureCollection", "features": features } @app.on_event("startup") async def startup_event(): global extractor try: config = OptimizationConfig( use_background_removal=True, exclude_pure_black=True, exclude_pure_white=True, use_color_clustering=True, n_color_clusters=20, min_section_area=500, max_section_area=50000 ) extractor = EnhancedSeatExtractor(config) logger.info("Loading BiRefNet...") extractor.bg_remover.load_model() logger.info("System initialized successfully") except Exception as e: logger.error(f"Initialization failed: {e}") import traceback traceback.print_exc() @app.post("/extract-seats/", response_model=PolygonResponse) async def extract_seats_endpoint( file: UploadFile = File(...), use_background_removal: bool = Query(True), use_clustering: bool = Query(True), n_clusters: int = Query(20, ge=2, le=50) ): """ Extract sections with color information PIPELINE: 1. Background removal for section detection 2. Color detection & clustering 3. Section detection + Color extraction Response includes: - colors: List of HEX color strings for each section """ if extractor is None: raise HTTPException(status_code=503, detail="System not initialized") if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Must be an image") try: contents = await file.read() image = Image.open(io.BytesIO(contents)) image_array = np.array(image) extractor.config.use_background_removal = use_background_removal extractor.config.use_color_clustering = use_clustering extractor.config.n_color_clusters = n_clusters result = extractor.extract_polygons_enhanced(image_array) return result except Exception as e: logger.error(f"Processing failed: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Failed: {str(e)}") if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), reload=False, log_level="info" )