""" 自定义检测标签页 农村低压配电网电压异常检测项目 本模块实现用户自定义数据的异常检测功能: - CSV 文件上传 - 数据预览 - 模型选择和阈值调节 - 实时推理 (CPU) - 检测结果可视化 Author: Rural Voltage Detection Project Date: 2026 """ import sys from pathlib import Path from typing import Optional, Tuple, Dict, Any import gradio as gr import numpy as np import pandas as pd # Add parent directories to path for imports DEMO_DIR = Path(__file__).parent.parent CODE_DIR = DEMO_DIR if str(CODE_DIR) not in sys.path: sys.path.insert(0, str(CODE_DIR)) if str(DEMO_DIR) not in sys.path: sys.path.insert(0, str(DEMO_DIR)) # Import core modules from core.inference import VoltageAnomalyDetector from core.data_processor import DataProcessor from visualization.detection_plots import ( create_detection_timeline, create_score_distribution, ) from config import ( MODEL_DIR, INFERENCE_CONFIG, DEMO_DATA_CONFIG, ) # ============================================================================ # 全局变量和配置 # ============================================================================ # 可用模型列表 AVAILABLE_MODELS = ["VoltageTimesNet_v2", "TimesNet", "DLinear"] # 模型描述 MODEL_DESCRIPTIONS = { "VoltageTimesNet_v2": "推荐模型: 基于 TimesNet 的改进版本,针对召回率进行优化,适合农村电压异常检测", "TimesNet": "基线模型: 使用 FFT 发现周期性,结合 2D 卷积捕获时序模式", "DLinear": "轻量级模型: 基于线性分解的简单模型,速度快但精度略低", } # 检测器缓存 _detector_cache: Dict[str, VoltageAnomalyDetector] = {} # ============================================================================ # 辅助函数 # ============================================================================ def get_detector(model_name: str) -> VoltageAnomalyDetector: """ 获取或创建检测器实例(带缓存) Args: model_name: 模型名称 Returns: VoltageAnomalyDetector 实例 """ if model_name not in _detector_cache: # 查找模型检查点 checkpoint_path = None model_file = MODEL_DIR / f"best_{model_name.lower()}.pth" if model_file.exists(): checkpoint_path = str(model_file) # 创建检测器 detector = VoltageAnomalyDetector( model_name=model_name, checkpoint_path=checkpoint_path, device=INFERENCE_CONFIG.get("device", "cpu"), ) _detector_cache[model_name] = detector return _detector_cache[model_name] def validate_csv_file(file_path: str) -> Tuple[bool, str, Optional[pd.DataFrame]]: """ 验证 CSV 文件 Args: file_path: CSV 文件路径 Returns: (是否有效, 消息, DataFrame 或 None) """ try: if file_path is None: return False, "请先上传 CSV 文件", None # 读取文件 df = pd.read_csv(file_path) if df.empty: return False, "CSV 文件为空", None # 检查数据长度 min_length = DEMO_DATA_CONFIG.get("window_size", 100) + 10 if len(df) < min_length: return False, f"数据长度不足,至少需要 {min_length} 行数据(当前: {len(df)} 行)", None # 检查是否有数值列 numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() if len(numeric_cols) == 0: return False, "CSV 文件中没有找到数值列", None return True, f"文件验证成功: {len(df)} 行, {len(numeric_cols)} 个数值特征", df except pd.errors.EmptyDataError: return False, "CSV 文件为空或格式错误", None except pd.errors.ParserError as e: return False, f"CSV 解析错误: {str(e)}", None except Exception as e: return False, f"读取文件失败: {str(e)}", None def format_preview_df(df: pd.DataFrame, max_rows: int = 10) -> pd.DataFrame: """ 格式化预览 DataFrame Args: df: 原始 DataFrame max_rows: 最大显示行数 Returns: 格式化后的 DataFrame """ # 选择数值列 exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"] feature_cols = [c for c in df.columns if c not in exclude_cols] # 截取前 N 行 preview_df = df[feature_cols].head(max_rows).copy() # 四舍五入数值 for col in preview_df.columns: if preview_df[col].dtype in [np.float64, np.float32]: preview_df[col] = preview_df[col].round(4) return preview_df def generate_detection_stats( labels: np.ndarray, scores: np.ndarray, threshold: float, total_samples: int, ) -> str: """ 生成检测统计信息 Args: labels: 预测标签 scores: 异常分数 threshold: 阈值 total_samples: 原始数据样本数 Returns: Markdown 格式的统计信息 """ n_anomaly = int(np.sum(labels)) n_normal = len(labels) - n_anomaly anomaly_ratio = n_anomaly / len(labels) * 100 if len(labels) > 0 else 0 # 分数统计 score_mean = float(np.mean(scores)) score_std = float(np.std(scores)) score_max = float(np.max(scores)) score_min = float(np.min(scores)) stats_md = f""" ### 检测统计 | 指标 | 数值 | |------|------| | 原始数据样本数 | {total_samples} | | 检测窗口数 | {len(labels)} | | 检测到的异常 | {n_anomaly} ({anomaly_ratio:.2f}%) | | 正常样本 | {n_normal} ({100-anomaly_ratio:.2f}%) | | 使用阈值 | {threshold:.4f} | ### 异常分数统计 | 指标 | 数值 | |------|------| | 平均分数 | {score_mean:.4f} | | 标准差 | {score_std:.4f} | | 最大分数 | {score_max:.4f} | | 最小分数 | {score_min:.4f} | > 注: 异常分数表示重构误差,分数越高表示越可能是异常。 """ return stats_md # ============================================================================ # 事件处理函数 # ============================================================================ def handle_file_upload(file) -> Tuple[str, pd.DataFrame, str]: """ 处理文件上传事件 Args: file: 上传的文件对象 Returns: (状态消息, 预览 DataFrame, 详细信息) """ if file is None: return ( "等待上传文件...", pd.DataFrame(), "请上传 CSV 格式的时序数据文件", ) # 验证文件 is_valid, message, df = validate_csv_file(file.name) if not is_valid: return ( f"文件验证失败: {message}", pd.DataFrame(), message, ) # 生成预览 preview_df = format_preview_df(df, max_rows=10) # 生成详细信息 numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"] feature_cols = [c for c in numeric_cols if c not in exclude_cols] detail_info = f""" **文件信息:** - 文件名: {Path(file.name).name} - 数据行数: {len(df)} - 特征数量: {len(feature_cols)} - 特征列表: {', '.join(feature_cols[:8])}{'...' if len(feature_cols) > 8 else ''} """ return ( f"文件验证成功: {len(df)} 行数据", preview_df, detail_info, ) def handle_model_change(model_name: str) -> str: """ 处理模型选择变更 Args: model_name: 选择的模型名称 Returns: 模型描述信息 """ desc = MODEL_DESCRIPTIONS.get(model_name, "未知模型") return f"**{model_name}**: {desc}" def run_detection( file, model_name: str, threshold: float, progress=gr.Progress(), ) -> Tuple[Any, Any, str]: """ 执行异常检测 Args: file: 上传的文件对象 model_name: 模型名称 threshold: 异常阈值 progress: Gradio 进度条 Returns: (时间线图, 分布图, 统计信息) """ # 验证输入 if file is None: return None, None, "请先上传 CSV 文件" try: progress(0, desc="验证文件...") is_valid, message, df = validate_csv_file(file.name) if not is_valid: return None, None, f"文件验证失败: {message}" progress(0.1, desc="加载数据...") # 提取特征数据 exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"] feature_cols = [c for c in df.columns if c not in exclude_cols and df[c].dtype in [np.float64, np.float32, np.int64, np.int32]] if len(feature_cols) == 0: return None, None, "未找到有效的数值特征列" data = df[feature_cols].values.astype(np.float32) data = np.nan_to_num(data, nan=0.0) total_samples = len(data) progress(0.2, desc="创建数据窗口...") # 创建数据处理器 window_size = DEMO_DATA_CONFIG.get("window_size", 100) step_size = DEMO_DATA_CONFIG.get("step_size", 1) processor = DataProcessor( seq_len=window_size, step=step_size, normalize=True, ) # 处理数据 windows = processor.fit_transform(data) n_windows = windows.shape[0] progress(0.3, desc=f"加载模型 {model_name}...") # 获取检测器 detector = get_detector(model_name) # 加载模型(如果尚未加载) if not detector.is_loaded: detector.load_model() progress(0.5, desc="执行推理...") # 执行检测 results = detector.predict(windows, threshold=threshold) scores = results["scores"] labels = results["labels"] progress(0.8, desc="生成可视化...") # 准备可视化数据 # 使用原始数据的前 N 个点(与 scores 长度对齐) vis_data = data[:len(scores), :min(3, data.shape[1])] vis_feature_names = feature_cols[:min(3, len(feature_cols))] # 创建时间线图 timeline_fig = create_detection_timeline( data=vis_data, scores=scores, labels=labels, threshold=threshold, feature_names=vis_feature_names, max_features=3, title="异常检测结果时间线", ) progress(0.9, desc="生成统计信息...") # 创建分数分布图(使用预测标签) dist_fig = create_score_distribution( scores=scores, labels=labels, threshold=threshold, bins=50, title="异常分数分布", ) # 生成统计信息 stats_md = generate_detection_stats( labels=labels, scores=scores, threshold=threshold, total_samples=total_samples, ) progress(1.0, desc="完成!") return timeline_fig, dist_fig, stats_md except ValueError as e: return None, None, f"数据处理错误: {str(e)}" except RuntimeError as e: return None, None, f"模型推理错误: {str(e)}" except Exception as e: import traceback error_detail = traceback.format_exc() return None, None, f"检测失败: {str(e)}\n\n详细信息:\n```\n{error_detail}\n```" # ============================================================================ # 创建标签页 # ============================================================================ def create_detection_tab(): """ 创建自定义检测标签页 Returns: Gradio Tab 组件 """ with gr.Tab("自定义检测") as tab: gr.Markdown(""" ## 自定义数据异常检测 上传您自己的时序数据文件,使用训练好的模型进行异常检测。 **使用说明:** 1. 上传 CSV 格式的时序数据文件 2. 选择检测模型 3. 调整异常阈值 4. 点击「开始检测」按钮 > 提示: 数据文件应包含数值型特征列,至少需要 100+ 行数据。 """) with gr.Row(): # 左侧: 输入区域 with gr.Column(scale=1): gr.Markdown("### 数据上传") # 文件上传 file_input = gr.File( label="上传 CSV 文件", file_types=[".csv"], file_count="single", ) # 文件状态 file_status = gr.Textbox( label="文件状态", value="等待上传文件...", interactive=False, lines=1, ) # 文件详情 file_detail = gr.Markdown( value="请上传 CSV 格式的时序数据文件" ) gr.Markdown("### 模型配置") # 模型选择 model_selector = gr.Radio( choices=AVAILABLE_MODELS, value="VoltageTimesNet_v2", label="选择模型", ) # 模型描述 model_desc = gr.Markdown( value=f"**VoltageTimesNet_v2**: {MODEL_DESCRIPTIONS['VoltageTimesNet_v2']}" ) # 阈值滑块 threshold_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.5, step=0.01, label="异常阈值", info="阈值越低,检测越敏感(检出更多异常);阈值越高,检测越保守", ) # 检测按钮 detect_btn = gr.Button( "开始检测", variant="primary", size="lg", ) # 右侧: 数据预览 with gr.Column(scale=1): gr.Markdown("### 数据预览 (前 10 行)") preview_table = gr.Dataframe( label="数据预览", headers=None, interactive=False, wrap=True, ) gr.Markdown("---") gr.Markdown("### 检测结果") # 结果区域 with gr.Row(): # 时间线图 with gr.Column(scale=2): timeline_plot = gr.Plot( label="异常检测时间线", ) with gr.Row(): # 分布图 with gr.Column(scale=1): dist_plot = gr.Plot( label="异常分数分布", ) # 统计信息 with gr.Column(scale=1): stats_output = gr.Markdown( value="等待检测...", label="检测统计", ) # ==================================================================== # 事件绑定 # ==================================================================== # 文件上传事件 file_input.change( fn=handle_file_upload, inputs=[file_input], outputs=[file_status, preview_table, file_detail], ) # 模型选择事件 model_selector.change( fn=handle_model_change, inputs=[model_selector], outputs=[model_desc], ) # 检测按钮事件 detect_btn.click( fn=run_detection, inputs=[file_input, model_selector, threshold_slider], outputs=[timeline_plot, dist_plot, stats_output], ) return tab # ============================================================================ # 测试代码 # ============================================================================ if __name__ == "__main__": # 创建测试应用 with gr.Blocks(title="自定义检测测试", theme="soft") as demo: create_detection_tab() demo.launch( server_name="0.0.0.0", server_port=7862, share=False, )