import gradio as gr import numpy as np import tempfile import os from pathlib import Path import cv2 from PIL import Image import matplotlib.pyplot as plt import io from typing import Optional from corrupt_mask import MaskCorruptor class GradioMaskCorruptor: """Wrapper for MaskCorruptor with Gradio-specific functionality.""" def __init__(self): self.corruptor = None self.original_mask = None def visualize_masks(self, original_mask, corrupted_mask, colormap='viridis'): """Create a visualization comparing original and corrupted masks.""" # Create figure with subplots fig, axes = plt.subplots(1, 3, figsize=(12, 4)) # Plot original mask im1 = axes[0].imshow(original_mask, cmap=colormap) axes[0].set_title('Original Mask') axes[0].axis('off') plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04) # Plot corrupted mask im2 = axes[1].imshow(corrupted_mask, cmap=colormap) axes[1].set_title('Corrupted Mask (Output)') axes[1].axis('off') plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04) # Plot difference diff = np.where(corrupted_mask != original_mask, 1, 0) im3 = axes[2].imshow(diff, cmap='Reds', vmin=0, vmax=1) axes[2].set_title('Diff (Red)') axes[2].axis('off') plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04) plt.tight_layout() # Convert to PIL Image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close(fig) return img def process_single_mask(self, input_image, drop_probability: float, mislabel_probability: float, blur_level: float, max_label: int, preserve_background: bool, seed: Optional[int], colormap: str): """Process a single uploaded mask image.""" try: # Convert PIL Image to numpy array if isinstance(input_image, np.ndarray): mask = input_image else: mask = np.array(input_image) # If RGB, convert to grayscale if len(mask.shape) == 3: mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) # Store original mask self.original_mask = mask.copy() # Initialize corruptor with parameters self.corruptor = MaskCorruptor( drop_probability=drop_probability, mislabel_probability=mislabel_probability, boundary_blur_level=blur_level, seed=seed if seed != 0 else None ) # Corrupt the mask corrupted_mask = self.corruptor.corrupt_single_mask( mask=mask, max_label=max_label if max_label > 0 else None, preserve_background=preserve_background ) # Get statistics original_labels = np.unique(mask) corrupted_labels = np.unique(corrupted_mask) dropped_instances = len(original_labels) - len(corrupted_labels) if preserve_background and 0 in original_labels: dropped_instances -= 1 # Don't count background # Create visualization viz_image = self.visualize_masks(mask, corrupted_mask, colormap) # Create statistics text stats_text = f""" ## 📊 Corruption Statistics ### Original Mask: - Unique labels: {len(original_labels)} - Label values: {original_labels.tolist()} - Shape: {mask.shape} ### Corrupted Mask: - Unique labels: {len(corrupted_labels)} - Label values: {corrupted_labels.tolist()} - Dropped instances: {max(0, dropped_instances)} - Corruption probability: {drop_probability * 100:.1f}% - Mislabel probability: {mislabel_probability * 100:.1f}% - Boundary blur factor: {blur_level * 100:.1f}% ### Parameters: - Preserve background: {preserve_background} - Max label: {max_label if max_label > 0 else 'Auto'} - Random seed: {seed if seed != 0 else 'Random'} """ return viz_image, stats_text # ✅ Only return image and stats except Exception as e: return None, f"❌ Error processing image: {str(e)}" def create_example_mask(self): """Create an example synthetic mask for demonstration.""" # Create a synthetic mask with 5 instances mask = np.zeros((256, 256), dtype=np.uint8) for i in range(1, 6): mask[30 * i:30 * i + 20, 30 * i:30 * i + 20] = i # Add some non-rectangular shapes cv2.circle(mask, (100, 100), 15, 6, -1) cv2.ellipse(mask, (200, 150), (20, 10), 0, 0, 360, 7, -1) # Convert to PIL Image mask_img = Image.fromarray(mask.astype(np.uint8)) return mask_img def create_gradio_app(): """Create and configure the Gradio interface.""" corruptor = GradioMaskCorruptor() # Create example mask example_mask = Image.open('in/semantic_class_0.png') # Define CSS for better styling css = """ .gradio-container { max-width: 1200px !important; } .output-image { border: 2px solid #4CAF50; border-radius: 10px; } .stats-box { background-color: #f0f8ff; padding: 15px; border-radius: 10px; border-left: 5px solid #2196F3; } """ # Define the Gradio interface with gr.Blocks(title="Mask Corruption Tool", css=css) as app: gr.Markdown(""" # 🎭 Mask Corruption Tool Upload a segmentation mask and artificially corrupt it by: 1. **Randomly dropping** mask instances 2. **Assigning wrong labels** to mask instances 3. **Blurring** mask boundaries """) with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("## 📤 Input Settings") input_image = gr.Image( label="Upload Mask Image", type="pil", format='png', height=300, image_mode='L', elem_classes=["input-image"] ) with gr.Row(): use_example = gr.Button("📋 Load Example Mask", variant="secondary") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Parameters section gr.Markdown("## ⚙️ Corruption Parameters") drop_prob = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Drop Probability", info="Probability of completely removing each mask instance" ) mislabel_prob = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Mislabel Probability", info="Probability of assigning wrong label to each instance" ) blur_level = gr.Slider( minimum=0.0, maximum=16.0, value=2, step=0.5, label="Blur Factor", info="Makes mask boundaries less sharp" ) with gr.Row(): max_label = gr.Number( value=10, label="Max Label Value", info="Set to 0 for auto-detect", precision=0 ) seed = gr.Number( value=42, label="Random Seed", info="Set to 0 for random", precision=0 ) preserve_bg = gr.Checkbox( value=True, label="Preserve Background (label 0)", info="Keep background label unchanged" ) colormap = gr.Dropdown( choices=['viridis', 'plasma', 'inferno', 'magma', 'cividis', 'tab20', 'Set3', 'Set2', 'tab20c'], value='viridis', label="Colormap for Visualization" ) process_btn = gr.Button("✨ Corrupt Mask!", variant="primary") with gr.Column(scale=2): # Output section gr.Markdown("## 📊 Results") output_image = gr.Image( label="Visualization Comparison", type="pil", height=400, elem_classes=["output-image"] ) # Statistics section stats_output = gr.Markdown( label="Statistics", elem_classes=["stats-box"] ) # Example callbacks use_example.click( fn=lambda: example_mask, outputs=[input_image] ) clear_btn.click( fn=lambda: (None, None, "## 📊 Corruption Statistics:\n\n*Upload a mask to see results here...*"), outputs=[input_image, output_image, stats_output] ) # Main processing callback process_btn.click( fn=corruptor.process_single_mask, inputs=[input_image, drop_prob, mislabel_prob, blur_level, max_label, preserve_bg, seed, colormap], outputs=[output_image, stats_output] ) # Examples section gr.Markdown("## 🚀 Quick Examples") gr.Examples( examples=[ [example_mask, 0.2, 0.3, 1, 17, True, 42, 'viridis'], [example_mask, 0.5, 0.1, 2, 17, True, 123, 'plasma'], [example_mask, 0.1, 0.5, 8, 17, False, 42, 'inferno'], ], inputs=[input_image, drop_prob, mislabel_prob, blur_level, max_label, preserve_bg, seed, colormap], outputs=[output_image, stats_output], fn=corruptor.process_single_mask, cache_examples=True ) # Footer gr.Markdown(""" --- ### 📝 How to use: 1. Upload a mask image (grayscale, each instance with unique integer label) 2. Adjust corruption parameters using the sliders 3. Click "Corrupt Mask!" to process 4. View the comparison visualization and statistics ### 💡 Tips: - Use the example mask to get started quickly - Set Random Seed to 0 for different results each time - Higher drop/mislabel probabilities = more corruption - Preserve background keeps label 0 unchanged (recommended for most cases) """) return app # For HuggingFace Spaces deployment app = create_gradio_app() if __name__ == "__main__": # For local testing app.launch( debug=True, css="css", theme=gr.themes.Soft(), show_error = True, )