rural-voltage-demo / tabs /tab4_detection.py
Sheldon123z's picture
Initial deployment of voltage anomaly detection demo
5d08a33 verified
"""
自定义检测标签页
农村低压配电网电压异常检测项目
本模块实现用户自定义数据的异常检测功能:
- CSV 文件上传
- 数据预览
- 模型选择和阈值调节
- 实时推理 (CPU)
- 检测结果可视化
Author: Rural Voltage Detection Project
Date: 2026
"""
import sys
from pathlib import Path
from typing import Optional, Tuple, Dict, Any
import gradio as gr
import numpy as np
import pandas as pd
# Add parent directories to path for imports
DEMO_DIR = Path(__file__).parent.parent
CODE_DIR = DEMO_DIR
if str(CODE_DIR) not in sys.path:
sys.path.insert(0, str(CODE_DIR))
if str(DEMO_DIR) not in sys.path:
sys.path.insert(0, str(DEMO_DIR))
# Import core modules
from core.inference import VoltageAnomalyDetector
from core.data_processor import DataProcessor
from visualization.detection_plots import (
create_detection_timeline,
create_score_distribution,
)
from config import (
MODEL_DIR,
INFERENCE_CONFIG,
DEMO_DATA_CONFIG,
)
# ============================================================================
# 全局变量和配置
# ============================================================================
# 可用模型列表
AVAILABLE_MODELS = ["VoltageTimesNet_v2", "TimesNet", "DLinear"]
# 模型描述
MODEL_DESCRIPTIONS = {
"VoltageTimesNet_v2": "推荐模型: 基于 TimesNet 的改进版本,针对召回率进行优化,适合农村电压异常检测",
"TimesNet": "基线模型: 使用 FFT 发现周期性,结合 2D 卷积捕获时序模式",
"DLinear": "轻量级模型: 基于线性分解的简单模型,速度快但精度略低",
}
# 检测器缓存
_detector_cache: Dict[str, VoltageAnomalyDetector] = {}
# ============================================================================
# 辅助函数
# ============================================================================
def get_detector(model_name: str) -> VoltageAnomalyDetector:
"""
获取或创建检测器实例(带缓存)
Args:
model_name: 模型名称
Returns:
VoltageAnomalyDetector 实例
"""
if model_name not in _detector_cache:
# 查找模型检查点
checkpoint_path = None
model_file = MODEL_DIR / f"best_{model_name.lower()}.pth"
if model_file.exists():
checkpoint_path = str(model_file)
# 创建检测器
detector = VoltageAnomalyDetector(
model_name=model_name,
checkpoint_path=checkpoint_path,
device=INFERENCE_CONFIG.get("device", "cpu"),
)
_detector_cache[model_name] = detector
return _detector_cache[model_name]
def validate_csv_file(file_path: str) -> Tuple[bool, str, Optional[pd.DataFrame]]:
"""
验证 CSV 文件
Args:
file_path: CSV 文件路径
Returns:
(是否有效, 消息, DataFrame 或 None)
"""
try:
if file_path is None:
return False, "请先上传 CSV 文件", None
# 读取文件
df = pd.read_csv(file_path)
if df.empty:
return False, "CSV 文件为空", None
# 检查数据长度
min_length = DEMO_DATA_CONFIG.get("window_size", 100) + 10
if len(df) < min_length:
return False, f"数据长度不足,至少需要 {min_length} 行数据(当前: {len(df)} 行)", None
# 检查是否有数值列
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
if len(numeric_cols) == 0:
return False, "CSV 文件中没有找到数值列", None
return True, f"文件验证成功: {len(df)} 行, {len(numeric_cols)} 个数值特征", df
except pd.errors.EmptyDataError:
return False, "CSV 文件为空或格式错误", None
except pd.errors.ParserError as e:
return False, f"CSV 解析错误: {str(e)}", None
except Exception as e:
return False, f"读取文件失败: {str(e)}", None
def format_preview_df(df: pd.DataFrame, max_rows: int = 10) -> pd.DataFrame:
"""
格式化预览 DataFrame
Args:
df: 原始 DataFrame
max_rows: 最大显示行数
Returns:
格式化后的 DataFrame
"""
# 选择数值列
exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"]
feature_cols = [c for c in df.columns if c not in exclude_cols]
# 截取前 N 行
preview_df = df[feature_cols].head(max_rows).copy()
# 四舍五入数值
for col in preview_df.columns:
if preview_df[col].dtype in [np.float64, np.float32]:
preview_df[col] = preview_df[col].round(4)
return preview_df
def generate_detection_stats(
labels: np.ndarray,
scores: np.ndarray,
threshold: float,
total_samples: int,
) -> str:
"""
生成检测统计信息
Args:
labels: 预测标签
scores: 异常分数
threshold: 阈值
total_samples: 原始数据样本数
Returns:
Markdown 格式的统计信息
"""
n_anomaly = int(np.sum(labels))
n_normal = len(labels) - n_anomaly
anomaly_ratio = n_anomaly / len(labels) * 100 if len(labels) > 0 else 0
# 分数统计
score_mean = float(np.mean(scores))
score_std = float(np.std(scores))
score_max = float(np.max(scores))
score_min = float(np.min(scores))
stats_md = f"""
### 检测统计
| 指标 | 数值 |
|------|------|
| 原始数据样本数 | {total_samples} |
| 检测窗口数 | {len(labels)} |
| 检测到的异常 | {n_anomaly} ({anomaly_ratio:.2f}%) |
| 正常样本 | {n_normal} ({100-anomaly_ratio:.2f}%) |
| 使用阈值 | {threshold:.4f} |
### 异常分数统计
| 指标 | 数值 |
|------|------|
| 平均分数 | {score_mean:.4f} |
| 标准差 | {score_std:.4f} |
| 最大分数 | {score_max:.4f} |
| 最小分数 | {score_min:.4f} |
> 注: 异常分数表示重构误差,分数越高表示越可能是异常。
"""
return stats_md
# ============================================================================
# 事件处理函数
# ============================================================================
def handle_file_upload(file) -> Tuple[str, pd.DataFrame, str]:
"""
处理文件上传事件
Args:
file: 上传的文件对象
Returns:
(状态消息, 预览 DataFrame, 详细信息)
"""
if file is None:
return (
"等待上传文件...",
pd.DataFrame(),
"请上传 CSV 格式的时序数据文件",
)
# 验证文件
is_valid, message, df = validate_csv_file(file.name)
if not is_valid:
return (
f"文件验证失败: {message}",
pd.DataFrame(),
message,
)
# 生成预览
preview_df = format_preview_df(df, max_rows=10)
# 生成详细信息
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"]
feature_cols = [c for c in numeric_cols if c not in exclude_cols]
detail_info = f"""
**文件信息:**
- 文件名: {Path(file.name).name}
- 数据行数: {len(df)}
- 特征数量: {len(feature_cols)}
- 特征列表: {', '.join(feature_cols[:8])}{'...' if len(feature_cols) > 8 else ''}
"""
return (
f"文件验证成功: {len(df)} 行数据",
preview_df,
detail_info,
)
def handle_model_change(model_name: str) -> str:
"""
处理模型选择变更
Args:
model_name: 选择的模型名称
Returns:
模型描述信息
"""
desc = MODEL_DESCRIPTIONS.get(model_name, "未知模型")
return f"**{model_name}**: {desc}"
def run_detection(
file,
model_name: str,
threshold: float,
progress=gr.Progress(),
) -> Tuple[Any, Any, str]:
"""
执行异常检测
Args:
file: 上传的文件对象
model_name: 模型名称
threshold: 异常阈值
progress: Gradio 进度条
Returns:
(时间线图, 分布图, 统计信息)
"""
# 验证输入
if file is None:
return None, None, "请先上传 CSV 文件"
try:
progress(0, desc="验证文件...")
is_valid, message, df = validate_csv_file(file.name)
if not is_valid:
return None, None, f"文件验证失败: {message}"
progress(0.1, desc="加载数据...")
# 提取特征数据
exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"]
feature_cols = [c for c in df.columns if c not in exclude_cols
and df[c].dtype in [np.float64, np.float32, np.int64, np.int32]]
if len(feature_cols) == 0:
return None, None, "未找到有效的数值特征列"
data = df[feature_cols].values.astype(np.float32)
data = np.nan_to_num(data, nan=0.0)
total_samples = len(data)
progress(0.2, desc="创建数据窗口...")
# 创建数据处理器
window_size = DEMO_DATA_CONFIG.get("window_size", 100)
step_size = DEMO_DATA_CONFIG.get("step_size", 1)
processor = DataProcessor(
seq_len=window_size,
step=step_size,
normalize=True,
)
# 处理数据
windows = processor.fit_transform(data)
n_windows = windows.shape[0]
progress(0.3, desc=f"加载模型 {model_name}...")
# 获取检测器
detector = get_detector(model_name)
# 加载模型(如果尚未加载)
if not detector.is_loaded:
detector.load_model()
progress(0.5, desc="执行推理...")
# 执行检测
results = detector.predict(windows, threshold=threshold)
scores = results["scores"]
labels = results["labels"]
progress(0.8, desc="生成可视化...")
# 准备可视化数据
# 使用原始数据的前 N 个点(与 scores 长度对齐)
vis_data = data[:len(scores), :min(3, data.shape[1])]
vis_feature_names = feature_cols[:min(3, len(feature_cols))]
# 创建时间线图
timeline_fig = create_detection_timeline(
data=vis_data,
scores=scores,
labels=labels,
threshold=threshold,
feature_names=vis_feature_names,
max_features=3,
title="异常检测结果时间线",
)
progress(0.9, desc="生成统计信息...")
# 创建分数分布图(使用预测标签)
dist_fig = create_score_distribution(
scores=scores,
labels=labels,
threshold=threshold,
bins=50,
title="异常分数分布",
)
# 生成统计信息
stats_md = generate_detection_stats(
labels=labels,
scores=scores,
threshold=threshold,
total_samples=total_samples,
)
progress(1.0, desc="完成!")
return timeline_fig, dist_fig, stats_md
except ValueError as e:
return None, None, f"数据处理错误: {str(e)}"
except RuntimeError as e:
return None, None, f"模型推理错误: {str(e)}"
except Exception as e:
import traceback
error_detail = traceback.format_exc()
return None, None, f"检测失败: {str(e)}\n\n详细信息:\n```\n{error_detail}\n```"
# ============================================================================
# 创建标签页
# ============================================================================
def create_detection_tab():
"""
创建自定义检测标签页
Returns:
Gradio Tab 组件
"""
with gr.Tab("自定义检测") as tab:
gr.Markdown("""
## 自定义数据异常检测
上传您自己的时序数据文件,使用训练好的模型进行异常检测。
**使用说明:**
1. 上传 CSV 格式的时序数据文件
2. 选择检测模型
3. 调整异常阈值
4. 点击「开始检测」按钮
> 提示: 数据文件应包含数值型特征列,至少需要 100+ 行数据。
""")
with gr.Row():
# 左侧: 输入区域
with gr.Column(scale=1):
gr.Markdown("### 数据上传")
# 文件上传
file_input = gr.File(
label="上传 CSV 文件",
file_types=[".csv"],
file_count="single",
)
# 文件状态
file_status = gr.Textbox(
label="文件状态",
value="等待上传文件...",
interactive=False,
lines=1,
)
# 文件详情
file_detail = gr.Markdown(
value="请上传 CSV 格式的时序数据文件"
)
gr.Markdown("### 模型配置")
# 模型选择
model_selector = gr.Radio(
choices=AVAILABLE_MODELS,
value="VoltageTimesNet_v2",
label="选择模型",
)
# 模型描述
model_desc = gr.Markdown(
value=f"**VoltageTimesNet_v2**: {MODEL_DESCRIPTIONS['VoltageTimesNet_v2']}"
)
# 阈值滑块
threshold_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.01,
label="异常阈值",
info="阈值越低,检测越敏感(检出更多异常);阈值越高,检测越保守",
)
# 检测按钮
detect_btn = gr.Button(
"开始检测",
variant="primary",
size="lg",
)
# 右侧: 数据预览
with gr.Column(scale=1):
gr.Markdown("### 数据预览 (前 10 行)")
preview_table = gr.Dataframe(
label="数据预览",
headers=None,
interactive=False,
wrap=True,
)
gr.Markdown("---")
gr.Markdown("### 检测结果")
# 结果区域
with gr.Row():
# 时间线图
with gr.Column(scale=2):
timeline_plot = gr.Plot(
label="异常检测时间线",
)
with gr.Row():
# 分布图
with gr.Column(scale=1):
dist_plot = gr.Plot(
label="异常分数分布",
)
# 统计信息
with gr.Column(scale=1):
stats_output = gr.Markdown(
value="等待检测...",
label="检测统计",
)
# ====================================================================
# 事件绑定
# ====================================================================
# 文件上传事件
file_input.change(
fn=handle_file_upload,
inputs=[file_input],
outputs=[file_status, preview_table, file_detail],
)
# 模型选择事件
model_selector.change(
fn=handle_model_change,
inputs=[model_selector],
outputs=[model_desc],
)
# 检测按钮事件
detect_btn.click(
fn=run_detection,
inputs=[file_input, model_selector, threshold_slider],
outputs=[timeline_plot, dist_plot, stats_output],
)
return tab
# ============================================================================
# 测试代码
# ============================================================================
if __name__ == "__main__":
# 创建测试应用
with gr.Blocks(title="自定义检测测试", theme="soft") as demo:
create_detection_tab()
demo.launch(
server_name="0.0.0.0",
server_port=7862,
share=False,
)