import numpy as np import random from pathlib import Path from typing import List, Tuple, Dict, Optional, Union import cv2 import json import argparse class MaskCorruptor: """ 通过随机丢弃掩码实例或分配错误标签来破坏分割掩码。 使用示例: from corrupt_mask import MaskCorruptor mask = np.zeros((600,800)) # 虚拟掩码 corruptor = MaskCorruptor(drop_probability=0.2, mislabel_probability=0.1) # 初始化破坏器 corr_mask = corruptor.corrupt_single_mask(mask) # 调用 """ def __init__(self, drop_probability: float = 0.1, mislabel_probability: float = 0.1, boundary_blur_level = 8, seed: Optional[int] = None): """ 使用破坏参数初始化破坏器。 参数: drop_probability: 完全丢弃一个掩码实例的概率 mislabel_probability: 为掩码实例分配错误标签的概率 seed: 用于重现性的随机种子 """ self.drop_probability = drop_probability self.mislabel_probability = mislabel_probability self.boundary_blur_level = boundary_blur_level if seed is not None: np.random.seed(seed) random.seed(seed) def corrupt_single_mask(self, mask: np.ndarray, max_label: Optional[int] = None, preserve_background: bool = True) -> np.ndarray: """ 破坏单个掩码图像。 参数: mask: 输入掩码数组(2D或3D) max_label: 考虑的最大标签值(如果为None,则使用掩码中的最大值) preserve_background: 如果为True,保持标签0作为背景不变 返回: 破坏后的掩码数组 """ mask = mask.copy() if max_label is None: max_label = np.max(mask) # Get unique labels (excluding background if preserve_background is True) unique_labels = np.unique(mask) if preserve_background and 0 in unique_labels: unique_labels = unique_labels[unique_labels != 0] if len(unique_labels) == 0: return mask # No instances to corrupt corrupted_mask = np.zeros_like(mask) # Preserve background in output if preserve_background: corrupted_mask[mask == 0] = 0 for label in unique_labels: # Skip if this is background and we're preserving it if preserve_background and label == 0: continue # Randomly decide whether to drop this mask if np.random.random() < self.drop_probability: continue # Skip this mask entirely # Extract the current mask instance instance_mask = (mask == label) # Randomly decide whether to mislabel if np.random.random() < self.mislabel_probability: # Generate a wrong label if preserve_background: # Generate label from 1 to max_label possible_labels = [l for l in range(1, max_label + 1) if l != label] else: possible_labels = [l for l in range(0, max_label + 1) if l != label] if possible_labels: new_label = np.random.choice(possible_labels) else: new_label = label # No alternative labels available else: new_label = label # Apply the (possibly modified) label corrupted_mask[instance_mask] = new_label # add blurred boundaries h, w = corrupted_mask.shape[:2] dsize = (int(w // (self.boundary_blur_level+1)), int(h // (self.boundary_blur_level+1))) corrupted_mask = cv2.resize(corrupted_mask, dsize, interpolation=cv2.INTER_NEAREST) corrupted_mask = cv2.resize(corrupted_mask, (w,h), interpolation=cv2.INTER_NEAREST) return corrupted_mask def corrupt_masks_from_directory(self, input_dir: Union[str, Path], output_dir: Union[str, Path], file_pattern: str = "*.png", max_label: Optional[int] = None, preserve_background: bool = True): """ 破坏目录中的所有掩码并保存到输出目录。 参数: input_dir: 包含输入掩码的目录 output_dir: 保存破坏后掩码的目录 file_pattern: 匹配掩码文件的模式 max_label: 最大标签值 preserve_background: 是否保留背景(标签0) """ input_path = Path(input_dir) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) mask_files = list(input_path.glob(file_pattern)) print(f"Found {len(mask_files)} mask files to corrupt") for i, mask_file in enumerate(mask_files, 1): # Read mask (supporting different formats) if mask_file.suffix.lower() in ['.png', '.jpg', '.jpeg']: mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE) else: print(f"Unsupported file format: {mask_file.suffix}") continue # Corrupt the mask corrupted_mask = self.corrupt_single_mask( mask, max_label, preserve_background ) # Save corrupted mask output_file = output_path / mask_file.name cv2.imwrite(str(output_file), corrupted_mask) if i % 10 == 0 or i == len(mask_files): print(f"Processed {i}/{len(mask_files)} files") # Save corruption parameters as metadata self.save_parameters(output_path) def save_parameters(self, output_dir: Union[str, Path]): """将破坏参数保存为JSON文件。""" params = { 'drop_probability': self.drop_probability, 'mislabel_probability': self.mislabel_probability, 'corruption_type': 'random_drop_and_mislabel' } output_path = Path(output_dir) with open(output_path / 'corruption_parameters.json', 'w') as f: json.dump(params, f, indent=2) def main(): parser = argparse.ArgumentParser(description='Corrupt segmentation masks') parser.add_argument('--input_dir', type=str, required=True, help='Input directory containing masks') parser.add_argument('--output_dir', type=str, required=True, help='Output directory for corrupted masks') parser.add_argument('--drop_prob', type=float, default=0.1, help='Probability of dropping a mask (0-1)') parser.add_argument('--mislabel_prob', type=float, default=0.1, help='Probability of mislabeling a mask (0-1)') parser.add_argument('--max_label', type=int, default=None, help='Maximum label value (if not specified, use max from data)') parser.add_argument('--file_pattern', type=str, default='*', help='File pattern for mask files') parser.add_argument('--preserve_background', action='store_true', help='Preserve label 0 as background (unchanged)') parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility') args = parser.parse_args() # Create corruptor instance corruptor = MaskCorruptor( drop_probability=args.drop_prob, mislabel_probability=args.mislabel_prob, seed=args.seed ) # Corrupt masks corruptor.corrupt_masks_from_directory( input_dir=args.input_dir, output_dir=args.output_dir, file_pattern=args.file_pattern, max_label=args.max_label, preserve_background=args.preserve_background ) print(f"Corruption complete! Masks saved to {args.output_dir}") if __name__ == "__main__": main()