Spaces:
Runtime error
Runtime error
Initial deployment of voltage anomaly detection demo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +30 -6
- app.py +122 -0
- assets/custom.css +96 -0
- config.py +101 -0
- core/__init__.py +14 -0
- core/__pycache__/__init__.cpython-310.pyc +0 -0
- core/__pycache__/__init__.cpython-311.pyc +0 -0
- core/__pycache__/data_processor.cpython-310.pyc +0 -0
- core/__pycache__/data_processor.cpython-311.pyc +0 -0
- core/__pycache__/inference.cpython-310.pyc +0 -0
- core/__pycache__/inference.cpython-311.pyc +0 -0
- core/__pycache__/model_loader.cpython-310.pyc +0 -0
- core/__pycache__/model_loader.cpython-311.pyc +0 -0
- core/data_processor.py +417 -0
- core/inference.py +426 -0
- core/model_loader.py +308 -0
- docs/model_architectures/svg/01_TimesNet.svg +150 -0
- docs/model_architectures/svg/02_VoltageTimesNet.svg +149 -0
- docs/model_architectures/svg/03_VoltageTimesNet_v2.svg +197 -0
- docs/model_architectures/svg/04_TPATimesNet.svg +175 -0
- docs/model_architectures/svg/05_MTSTimesNet.svg +201 -0
- docs/model_architectures/svg/06_DLinear.svg +149 -0
- layers/AutoCorrelation.py +186 -0
- layers/Autoformer_EncDec.py +219 -0
- layers/Conv_Blocks.py +85 -0
- layers/Embed.py +229 -0
- layers/SelfAttention_Family.py +378 -0
- layers/StandardNorm.py +85 -0
- layers/ThreePhaseAttention.py +504 -0
- layers/Transformer_EncDec.py +159 -0
- layers/VoltageEmbed.py +465 -0
- layers/__init__.py +27 -0
- layers/__pycache__/AutoCorrelation.cpython-310.pyc +0 -0
- layers/__pycache__/AutoCorrelation.cpython-311.pyc +0 -0
- layers/__pycache__/Autoformer_EncDec.cpython-310.pyc +0 -0
- layers/__pycache__/Autoformer_EncDec.cpython-311.pyc +0 -0
- layers/__pycache__/Conv_Blocks.cpython-310.pyc +0 -0
- layers/__pycache__/Conv_Blocks.cpython-311.pyc +0 -0
- layers/__pycache__/Embed.cpython-310.pyc +0 -0
- layers/__pycache__/Embed.cpython-311.pyc +0 -0
- layers/__pycache__/SelfAttention_Family.cpython-310.pyc +0 -0
- layers/__pycache__/SelfAttention_Family.cpython-311.pyc +0 -0
- layers/__pycache__/StandardNorm.cpython-310.pyc +0 -0
- layers/__pycache__/StandardNorm.cpython-311.pyc +0 -0
- layers/__pycache__/Transformer_EncDec.cpython-310.pyc +0 -0
- layers/__pycache__/Transformer_EncDec.cpython-311.pyc +0 -0
- layers/__pycache__/__init__.cpython-310.pyc +0 -0
- layers/__pycache__/__init__.cpython-311.pyc +0 -0
- models/DLinear.py +127 -0
- models/MTSTimesNet.py +418 -0
README.md
CHANGED
|
@@ -1,12 +1,36 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: 农村低压配电网电压异常检测
|
| 3 |
+
emoji: ⚡
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# 农村低压配电网电压异常检测系统
|
| 14 |
+
|
| 15 |
+
基于 TimesNet 的时间序列异常检测方法研究与应用
|
| 16 |
+
|
| 17 |
+
## 功能
|
| 18 |
+
|
| 19 |
+
- **原理演示**: FFT 周期发现可视化
|
| 20 |
+
- **创新对比**: VoltageTimesNet vs TimesNet
|
| 21 |
+
- **模型竞技场**: 6 模型性能对比
|
| 22 |
+
- **模型结构**: 网络架构图展示
|
| 23 |
+
- **自定义检测**: 上传 CSV 进行异常检测
|
| 24 |
+
|
| 25 |
+
## 模型
|
| 26 |
+
|
| 27 |
+
| 模型 | F1 | Recall | 说明 |
|
| 28 |
+
|:-----|:--:|:------:|:-----|
|
| 29 |
+
| VoltageTimesNet_v2 | 0.6622 | 0.5858 | 最优模型 |
|
| 30 |
+
| TimesNet | 0.6520 | 0.5705 | 基线 |
|
| 31 |
+
|
| 32 |
+
## 链接
|
| 33 |
+
|
| 34 |
+
- [GitHub](https://github.com/sheldon123z/Rural-Low-Voltage-Detection)
|
| 35 |
+
- [数据集](https://huggingface.co/datasets/Sheldon123z/rural-voltage-datasets)
|
| 36 |
+
- [模型权重](https://huggingface.co/Sheldon123z/rural-voltage-detection-models)
|
app.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
农村低压配电网电压异常检测 - Gradio 交互式演示
|
| 3 |
+
Rural Low-Voltage Distribution Network Voltage Anomaly Detection Demo
|
| 4 |
+
|
| 5 |
+
用于论文答辩演示,展示 TimesNet 周期建模原理、VoltageTimesNet 创新点、多模型性能对比
|
| 6 |
+
HuggingFace Spaces 版本
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# 添加项目路径 (HuggingFace Spaces 兼容)
|
| 14 |
+
DEMO_DIR = Path(__file__).parent
|
| 15 |
+
sys.path.insert(0, str(DEMO_DIR))
|
| 16 |
+
|
| 17 |
+
# 导入标签页
|
| 18 |
+
from tabs.tab1_principle import create_principle_tab
|
| 19 |
+
from tabs.tab2_innovation import create_innovation_tab
|
| 20 |
+
from tabs.tab3_arena import create_arena_tab
|
| 21 |
+
from tabs.tab4_detection import create_detection_tab
|
| 22 |
+
from tabs.tab5_architecture import create_architecture_tab
|
| 23 |
+
|
| 24 |
+
# 导入配置
|
| 25 |
+
from config import GRADIO_THEME, THESIS_COLORS
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_header():
|
| 29 |
+
"""创建页面头部"""
|
| 30 |
+
return gr.Markdown(
|
| 31 |
+
"""
|
| 32 |
+
# ⚡ 农村低压配电网电压异常检测系统
|
| 33 |
+
|
| 34 |
+
**基于 TimesNet 的时间序列异常检测方法研究与应用**
|
| 35 |
+
|
| 36 |
+
本演示系统用于论文答辩,展示研究成果和模型性能。
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
"""
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_footer():
|
| 44 |
+
"""创建页面底部"""
|
| 45 |
+
return gr.Markdown(
|
| 46 |
+
"""
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
### 📚 系统说明
|
| 50 |
+
|
| 51 |
+
| 标签页 | 功能 | 说明 |
|
| 52 |
+
|--------|------|------|
|
| 53 |
+
| 原理演示 | FFT 周期发现 | 展示 TimesNet 核心算法原理 |
|
| 54 |
+
| 创新对比 | 模型改进 | VoltageTimesNet 与 TimesNet 的差异对比 |
|
| 55 |
+
| 模型竞技场 | 性能对比 | 6 个模型的多维度性能对比 |
|
| 56 |
+
| 模型结构 | 架构展示 | 6 个模型的网络结构图和详细说明 |
|
| 57 |
+
| 自定义检测 | 实时推理 | 上传 CSV 进行异常检测 |
|
| 58 |
+
|
| 59 |
+
**技术栈**: PyTorch + Gradio + Plotly
|
| 60 |
+
|
| 61 |
+
**模型**: VoltageTimesNet_v2 (最优) | VoltageTimesNet | TimesNet | TPATimesNet | MTSTimesNet | DLinear
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
<center>
|
| 66 |
+
|
| 67 |
+
📧 联系作者 | 📖 [GitHub](https://github.com/sheldon123z/Rural-Low-Voltage-Detection) | 🤗 [HuggingFace](https://huggingface.co/Sheldon123z)
|
| 68 |
+
|
| 69 |
+
</center>
|
| 70 |
+
"""
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_app():
|
| 75 |
+
"""创建 Gradio 应用"""
|
| 76 |
+
|
| 77 |
+
# 加载自定义 CSS
|
| 78 |
+
css_path = DEMO_DIR / "assets" / "custom.css"
|
| 79 |
+
custom_css = ""
|
| 80 |
+
if css_path.exists():
|
| 81 |
+
with open(css_path, "r", encoding="utf-8") as f:
|
| 82 |
+
custom_css = f.read()
|
| 83 |
+
|
| 84 |
+
# 创建应用
|
| 85 |
+
with gr.Blocks(title="农村低压配电网电压异常检测系统") as app:
|
| 86 |
+
|
| 87 |
+
# 页面头部
|
| 88 |
+
create_header()
|
| 89 |
+
|
| 90 |
+
# 标签页
|
| 91 |
+
with gr.Tabs():
|
| 92 |
+
# Tab 1: 原理演示
|
| 93 |
+
create_principle_tab()
|
| 94 |
+
|
| 95 |
+
# Tab 2: 创新对比
|
| 96 |
+
create_innovation_tab()
|
| 97 |
+
|
| 98 |
+
# Tab 3: 模型竞技场
|
| 99 |
+
create_arena_tab()
|
| 100 |
+
|
| 101 |
+
# Tab 4: 模型结构
|
| 102 |
+
create_architecture_tab()
|
| 103 |
+
|
| 104 |
+
# Tab 5: 自定义检测
|
| 105 |
+
create_detection_tab()
|
| 106 |
+
|
| 107 |
+
# 页面底部
|
| 108 |
+
create_footer()
|
| 109 |
+
|
| 110 |
+
return app
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# 创建应用实例
|
| 114 |
+
app = create_app()
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
app.launch(
|
| 118 |
+
server_name="0.0.0.0",
|
| 119 |
+
server_port=7860,
|
| 120 |
+
share=False,
|
| 121 |
+
show_error=True,
|
| 122 |
+
)
|
assets/custom.css
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Gradio Demo 自定义样式 */
|
| 2 |
+
|
| 3 |
+
/* 主题颜色 */
|
| 4 |
+
:root {
|
| 5 |
+
--primary-color: #4878A8;
|
| 6 |
+
--secondary-color: #72A86D;
|
| 7 |
+
--accent-color: #C4785C;
|
| 8 |
+
--warning-color: #D4A84C;
|
| 9 |
+
--neutral-color: #808080;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
/* 标签页样式 */
|
| 13 |
+
.tab-nav button {
|
| 14 |
+
font-size: 16px !important;
|
| 15 |
+
font-weight: 500 !important;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.tab-nav button.selected {
|
| 19 |
+
background-color: var(--primary-color) !important;
|
| 20 |
+
color: white !important;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
/* 标题样式 */
|
| 24 |
+
h1, h2, h3 {
|
| 25 |
+
color: #2c3e50 !important;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
/* 卡片样式 */
|
| 29 |
+
.gr-box {
|
| 30 |
+
border-radius: 8px !important;
|
| 31 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
/* 按钮样式 */
|
| 35 |
+
.gr-button-primary {
|
| 36 |
+
background-color: var(--primary-color) !important;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
.gr-button-primary:hover {
|
| 40 |
+
background-color: #3a6a9a !important;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/* 滑块样式 */
|
| 44 |
+
.gr-slider input[type="range"] {
|
| 45 |
+
accent-color: var(--primary-color);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/* 表格样式 */
|
| 49 |
+
.gr-dataframe {
|
| 50 |
+
font-size: 14px !important;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.gr-dataframe th {
|
| 54 |
+
background-color: var(--primary-color) !important;
|
| 55 |
+
color: white !important;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/* 图表容器 */
|
| 59 |
+
.plotly-container {
|
| 60 |
+
border-radius: 8px;
|
| 61 |
+
overflow: hidden;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/* 说明文字 */
|
| 65 |
+
.description-text {
|
| 66 |
+
color: #666;
|
| 67 |
+
font-size: 14px;
|
| 68 |
+
line-height: 1.6;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/* 指标卡片 */
|
| 72 |
+
.metric-card {
|
| 73 |
+
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
| 74 |
+
border-radius: 12px;
|
| 75 |
+
padding: 20px;
|
| 76 |
+
text-align: center;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.metric-value {
|
| 80 |
+
font-size: 32px;
|
| 81 |
+
font-weight: bold;
|
| 82 |
+
color: var(--primary-color);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.metric-label {
|
| 86 |
+
font-size: 14px;
|
| 87 |
+
color: #666;
|
| 88 |
+
margin-top: 8px;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/* 响应式布局 */
|
| 92 |
+
@media (max-width: 768px) {
|
| 93 |
+
.gr-row {
|
| 94 |
+
flex-direction: column !important;
|
| 95 |
+
}
|
| 96 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Demo 配置文件
|
| 3 |
+
农村低压配电网电压异常检测项目 - HuggingFace Spaces 版本
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# 路径配置 (HuggingFace Spaces 兼容)
|
| 10 |
+
DEMO_DIR = Path(__file__).parent
|
| 11 |
+
CODE_DIR = DEMO_DIR # 在 Spaces 中 demo 就是根目录
|
| 12 |
+
PROJECT_DIR = DEMO_DIR
|
| 13 |
+
|
| 14 |
+
# 模型路径 (使用 HuggingFace Hub)
|
| 15 |
+
MODEL_DIR = DEMO_DIR / "models"
|
| 16 |
+
BEST_MODEL_PATH = None # 将从 HuggingFace Hub 下载
|
| 17 |
+
MODEL_CONFIG_PATH = None
|
| 18 |
+
|
| 19 |
+
# 数据路径
|
| 20 |
+
DATASET_DIR = DEMO_DIR / "dataset"
|
| 21 |
+
RURAL_VOLTAGE_DIR = DATASET_DIR / "RuralVoltage" / "realistic_v2"
|
| 22 |
+
PSM_DIR = DATASET_DIR / "PSM"
|
| 23 |
+
|
| 24 |
+
# 预计算数据路径
|
| 25 |
+
PRECOMPUTED_DIR = DEMO_DIR / "precomputed"
|
| 26 |
+
|
| 27 |
+
# SVG 文件路径
|
| 28 |
+
SVG_DIR = DEMO_DIR / "docs" / "model_architectures" / "svg"
|
| 29 |
+
|
| 30 |
+
# 模型配置
|
| 31 |
+
MODEL_CONFIGS = {
|
| 32 |
+
"VoltageTimesNet_v2": {
|
| 33 |
+
"enc_in": 16,
|
| 34 |
+
"c_out": 16,
|
| 35 |
+
"seq_len": 100,
|
| 36 |
+
"d_model": 64,
|
| 37 |
+
"d_ff": 64,
|
| 38 |
+
"e_layers": 2,
|
| 39 |
+
"top_k": 5,
|
| 40 |
+
"num_kernels": 6,
|
| 41 |
+
},
|
| 42 |
+
"TimesNet": {
|
| 43 |
+
"enc_in": 16,
|
| 44 |
+
"c_out": 16,
|
| 45 |
+
"seq_len": 100,
|
| 46 |
+
"d_model": 64,
|
| 47 |
+
"d_ff": 64,
|
| 48 |
+
"e_layers": 2,
|
| 49 |
+
"top_k": 5,
|
| 50 |
+
"num_kernels": 6,
|
| 51 |
+
},
|
| 52 |
+
"DLinear": {
|
| 53 |
+
"enc_in": 16,
|
| 54 |
+
"seq_len": 100,
|
| 55 |
+
"pred_len": 100,
|
| 56 |
+
"individual": False,
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# 可视化配色方案 (柔和科研风格)
|
| 61 |
+
THESIS_COLORS = {
|
| 62 |
+
"primary": "#4878A8",
|
| 63 |
+
"secondary": "#72A86D",
|
| 64 |
+
"accent": "#C4785C",
|
| 65 |
+
"warning": "#D4A84C",
|
| 66 |
+
"neutral": "#808080",
|
| 67 |
+
"light_gray": "#B0B0B0",
|
| 68 |
+
"anomaly": "#E74C3C",
|
| 69 |
+
"normal": "#2ECC71",
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# 模型对比颜色
|
| 73 |
+
MODEL_COLORS = {
|
| 74 |
+
"VoltageTimesNet_v2": "#4878A8",
|
| 75 |
+
"VoltageTimesNet": "#72A86D",
|
| 76 |
+
"TimesNet": "#C4785C",
|
| 77 |
+
"TPATimesNet": "#D4A84C",
|
| 78 |
+
"MTSTimesNet": "#9B59B6",
|
| 79 |
+
"DLinear": "#808080",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# Gradio 主题配置
|
| 83 |
+
GRADIO_THEME = "soft"
|
| 84 |
+
|
| 85 |
+
# 推理配置
|
| 86 |
+
INFERENCE_CONFIG = {
|
| 87 |
+
"batch_size": 32,
|
| 88 |
+
"device": "cpu",
|
| 89 |
+
"default_threshold": 0.5,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# 演示数据配置
|
| 93 |
+
DEMO_DATA_CONFIG = {
|
| 94 |
+
"sample_length": 1000,
|
| 95 |
+
"window_size": 100,
|
| 96 |
+
"step_size": 1,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# HuggingFace Hub 配置
|
| 100 |
+
HF_MODEL_REPO = "Sheldon123z/rural-voltage-detection-models"
|
| 101 |
+
HF_DATASET_REPO = "Sheldon123z/rural-voltage-datasets"
|
core/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Demo 核心模块
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .model_loader import load_model, get_available_models
|
| 6 |
+
from .data_processor import DataProcessor
|
| 7 |
+
from .inference import VoltageAnomalyDetector
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"load_model",
|
| 11 |
+
"get_available_models",
|
| 12 |
+
"DataProcessor",
|
| 13 |
+
"VoltageAnomalyDetector",
|
| 14 |
+
]
|
core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (447 Bytes). View file
|
|
|
core/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (536 Bytes). View file
|
|
|
core/__pycache__/data_processor.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
core/__pycache__/data_processor.cpython-311.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
core/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
core/__pycache__/inference.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
core/__pycache__/model_loader.cpython-310.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|
core/__pycache__/model_loader.cpython-311.pyc
ADDED
|
Binary file (8.56 kB). View file
|
|
|
core/data_processor.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Processor Module for Gradio Demo
|
| 3 |
+
农村低压配电网电压异常检测项目
|
| 4 |
+
|
| 5 |
+
Provides:
|
| 6 |
+
- DataProcessor: Class for data preprocessing, normalization, and windowing
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional, Tuple, List, Union, Dict, Any
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from sklearn.preprocessing import StandardScaler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataProcessor:
|
| 19 |
+
"""
|
| 20 |
+
Data preprocessing module for voltage anomaly detection.
|
| 21 |
+
|
| 22 |
+
Supports:
|
| 23 |
+
- CSV data loading
|
| 24 |
+
- StandardScaler normalization
|
| 25 |
+
- Sliding window segmentation
|
| 26 |
+
- Feature extraction
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> processor = DataProcessor(seq_len=100)
|
| 30 |
+
>>> processor.fit(train_data)
|
| 31 |
+
>>> windows = processor.transform(test_data)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Default feature columns for RuralVoltage dataset
|
| 35 |
+
DEFAULT_FEATURE_COLS = [
|
| 36 |
+
"Va", "Vb", "Vc", # Three-phase voltages
|
| 37 |
+
"Ia", "Ib", "Ic", # Three-phase currents
|
| 38 |
+
"P", "Q", "S", "PF", # Power metrics
|
| 39 |
+
"THD_Va", "THD_Vb", "THD_Vc", # Harmonic distortion
|
| 40 |
+
"Freq", # Frequency
|
| 41 |
+
"V_unbalance", "I_unbalance" # Unbalance ratios
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
seq_len: int = 100,
|
| 47 |
+
step: int = 1,
|
| 48 |
+
feature_cols: Optional[List[str]] = None,
|
| 49 |
+
normalize: bool = True
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Initialize DataProcessor.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
seq_len: Length of sliding window (default: 100)
|
| 56 |
+
step: Step size for sliding window (default: 1)
|
| 57 |
+
feature_cols: List of feature column names to use
|
| 58 |
+
normalize: Whether to apply StandardScaler normalization
|
| 59 |
+
"""
|
| 60 |
+
self.seq_len = seq_len
|
| 61 |
+
self.step = step
|
| 62 |
+
self.feature_cols = feature_cols
|
| 63 |
+
self.normalize = normalize
|
| 64 |
+
|
| 65 |
+
self.scaler = StandardScaler() if normalize else None
|
| 66 |
+
self._is_fitted = False
|
| 67 |
+
self._n_features = None
|
| 68 |
+
|
| 69 |
+
def fit(self, data: Union[np.ndarray, pd.DataFrame]) -> "DataProcessor":
|
| 70 |
+
"""
|
| 71 |
+
Fit the scaler on training data.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
data: Training data as numpy array [N, C] or DataFrame
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
self for chaining
|
| 78 |
+
"""
|
| 79 |
+
data_array = self._to_numpy(data)
|
| 80 |
+
|
| 81 |
+
if self.normalize:
|
| 82 |
+
self.scaler.fit(data_array)
|
| 83 |
+
|
| 84 |
+
self._n_features = data_array.shape[1]
|
| 85 |
+
self._is_fitted = True
|
| 86 |
+
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def transform(
|
| 90 |
+
self,
|
| 91 |
+
data: Union[np.ndarray, pd.DataFrame],
|
| 92 |
+
return_windows: bool = True
|
| 93 |
+
) -> np.ndarray:
|
| 94 |
+
"""
|
| 95 |
+
Transform data using fitted scaler and optionally create windows.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
data: Data to transform [N, C] or DataFrame
|
| 99 |
+
return_windows: If True, return sliding windows [num_windows, seq_len, C]
|
| 100 |
+
If False, return normalized data [N, C]
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Transformed data
|
| 104 |
+
"""
|
| 105 |
+
data_array = self._to_numpy(data)
|
| 106 |
+
|
| 107 |
+
# Normalize
|
| 108 |
+
if self.normalize and self._is_fitted:
|
| 109 |
+
data_array = self.scaler.transform(data_array)
|
| 110 |
+
|
| 111 |
+
# Create windows
|
| 112 |
+
if return_windows:
|
| 113 |
+
return self.create_windows(data_array)
|
| 114 |
+
|
| 115 |
+
return data_array
|
| 116 |
+
|
| 117 |
+
def fit_transform(
|
| 118 |
+
self,
|
| 119 |
+
data: Union[np.ndarray, pd.DataFrame],
|
| 120 |
+
return_windows: bool = True
|
| 121 |
+
) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Fit scaler and transform data.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
data: Training data
|
| 127 |
+
return_windows: Whether to return sliding windows
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Transformed data
|
| 131 |
+
"""
|
| 132 |
+
self.fit(data)
|
| 133 |
+
return self.transform(data, return_windows=return_windows)
|
| 134 |
+
|
| 135 |
+
def create_windows(
|
| 136 |
+
self,
|
| 137 |
+
data: np.ndarray,
|
| 138 |
+
step: Optional[int] = None
|
| 139 |
+
) -> np.ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Create sliding windows from sequential data.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data: Input data [N, C]
|
| 145 |
+
step: Optional override for step size
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Windows array [num_windows, seq_len, C]
|
| 149 |
+
"""
|
| 150 |
+
if step is None:
|
| 151 |
+
step = self.step
|
| 152 |
+
|
| 153 |
+
n_samples, n_features = data.shape
|
| 154 |
+
|
| 155 |
+
# Calculate number of windows
|
| 156 |
+
n_windows = (n_samples - self.seq_len) // step + 1
|
| 157 |
+
|
| 158 |
+
if n_windows <= 0:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
f"Data length {n_samples} is too short for "
|
| 161 |
+
f"seq_len={self.seq_len}, step={step}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Create windows using stride tricks for efficiency
|
| 165 |
+
windows = np.zeros((n_windows, self.seq_len, n_features), dtype=np.float32)
|
| 166 |
+
for i in range(n_windows):
|
| 167 |
+
start_idx = i * step
|
| 168 |
+
windows[i] = data[start_idx:start_idx + self.seq_len]
|
| 169 |
+
|
| 170 |
+
return windows
|
| 171 |
+
|
| 172 |
+
def inverse_transform(self, data: np.ndarray) -> np.ndarray:
|
| 173 |
+
"""
|
| 174 |
+
Inverse transform normalized data back to original scale.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
data: Normalized data [N, C] or [B, T, C]
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Data in original scale
|
| 181 |
+
"""
|
| 182 |
+
if not self.normalize or self.scaler is None:
|
| 183 |
+
return data
|
| 184 |
+
|
| 185 |
+
original_shape = data.shape
|
| 186 |
+
|
| 187 |
+
# Handle 3D input
|
| 188 |
+
if len(original_shape) == 3:
|
| 189 |
+
B, T, C = original_shape
|
| 190 |
+
data = data.reshape(-1, C)
|
| 191 |
+
data = self.scaler.inverse_transform(data)
|
| 192 |
+
data = data.reshape(B, T, C)
|
| 193 |
+
else:
|
| 194 |
+
data = self.scaler.inverse_transform(data)
|
| 195 |
+
|
| 196 |
+
return data
|
| 197 |
+
|
| 198 |
+
def _to_numpy(self, data: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
|
| 199 |
+
"""Convert input to numpy array, selecting feature columns if needed."""
|
| 200 |
+
if isinstance(data, pd.DataFrame):
|
| 201 |
+
# Select feature columns
|
| 202 |
+
if self.feature_cols:
|
| 203 |
+
available_cols = [c for c in self.feature_cols if c in data.columns]
|
| 204 |
+
if available_cols:
|
| 205 |
+
data = data[available_cols]
|
| 206 |
+
else:
|
| 207 |
+
# Exclude common non-feature columns
|
| 208 |
+
exclude_cols = ["timestamp", "date", "time", "label", "index"]
|
| 209 |
+
feature_cols = [c for c in data.columns if c not in exclude_cols]
|
| 210 |
+
data = data[feature_cols]
|
| 211 |
+
else:
|
| 212 |
+
# Use all numeric columns
|
| 213 |
+
exclude_cols = ["timestamp", "date", "time", "label", "index"]
|
| 214 |
+
feature_cols = [c for c in data.columns if c not in exclude_cols]
|
| 215 |
+
data = data[feature_cols]
|
| 216 |
+
|
| 217 |
+
data = data.values
|
| 218 |
+
|
| 219 |
+
# Handle NaN values
|
| 220 |
+
data = np.nan_to_num(data, nan=0.0)
|
| 221 |
+
|
| 222 |
+
return data.astype(np.float32)
|
| 223 |
+
|
| 224 |
+
@classmethod
|
| 225 |
+
def load_csv(
|
| 226 |
+
cls,
|
| 227 |
+
file_path: Union[str, Path],
|
| 228 |
+
feature_cols: Optional[List[str]] = None
|
| 229 |
+
) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]:
|
| 230 |
+
"""
|
| 231 |
+
Load data from CSV file.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
file_path: Path to CSV file
|
| 235 |
+
feature_cols: Optional list of feature columns to use
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Tuple of (data, labels, feature_names)
|
| 239 |
+
- data: Feature values [N, C]
|
| 240 |
+
- labels: Label values [N] if 'label' column exists, else None
|
| 241 |
+
- feature_names: List of feature column names
|
| 242 |
+
"""
|
| 243 |
+
file_path = Path(file_path)
|
| 244 |
+
if not file_path.exists():
|
| 245 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 246 |
+
|
| 247 |
+
df = pd.read_csv(file_path)
|
| 248 |
+
|
| 249 |
+
# Extract labels if present
|
| 250 |
+
labels = None
|
| 251 |
+
if "label" in df.columns:
|
| 252 |
+
labels = df["label"].values
|
| 253 |
+
|
| 254 |
+
# Determine feature columns
|
| 255 |
+
exclude_cols = ["timestamp", "date", "time", "label", "index", "Unnamed: 0"]
|
| 256 |
+
if feature_cols:
|
| 257 |
+
available_cols = [c for c in feature_cols if c in df.columns]
|
| 258 |
+
if available_cols:
|
| 259 |
+
use_cols = available_cols
|
| 260 |
+
else:
|
| 261 |
+
use_cols = [c for c in df.columns if c not in exclude_cols]
|
| 262 |
+
else:
|
| 263 |
+
use_cols = [c for c in df.columns if c not in exclude_cols]
|
| 264 |
+
|
| 265 |
+
data = df[use_cols].values
|
| 266 |
+
data = np.nan_to_num(data, nan=0.0).astype(np.float32)
|
| 267 |
+
|
| 268 |
+
return data, labels, use_cols
|
| 269 |
+
|
| 270 |
+
@classmethod
|
| 271 |
+
def load_dataset(
|
| 272 |
+
cls,
|
| 273 |
+
root_path: Union[str, Path],
|
| 274 |
+
dataset_type: str = "RuralVoltage"
|
| 275 |
+
) -> Dict[str, Any]:
|
| 276 |
+
"""
|
| 277 |
+
Load a complete dataset (train, test, labels).
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
root_path: Path to dataset directory
|
| 281 |
+
dataset_type: Type of dataset ("RuralVoltage", "PSM", etc.)
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Dict with train_data, test_data, test_labels, feature_names
|
| 285 |
+
"""
|
| 286 |
+
root_path = Path(root_path)
|
| 287 |
+
|
| 288 |
+
if dataset_type == "RuralVoltage":
|
| 289 |
+
train_path = root_path / "train.csv"
|
| 290 |
+
test_path = root_path / "test.csv"
|
| 291 |
+
label_path = root_path / "test_label.csv"
|
| 292 |
+
|
| 293 |
+
train_data, _, feature_names = cls.load_csv(train_path)
|
| 294 |
+
test_data, _, _ = cls.load_csv(test_path, feature_cols=feature_names)
|
| 295 |
+
_, test_labels, _ = cls.load_csv(label_path)
|
| 296 |
+
|
| 297 |
+
elif dataset_type == "PSM":
|
| 298 |
+
train_path = root_path / "train.csv"
|
| 299 |
+
test_path = root_path / "test.csv"
|
| 300 |
+
label_path = root_path / "test_label.csv"
|
| 301 |
+
|
| 302 |
+
train_df = pd.read_csv(train_path)
|
| 303 |
+
test_df = pd.read_csv(test_path)
|
| 304 |
+
label_df = pd.read_csv(label_path)
|
| 305 |
+
|
| 306 |
+
train_data = train_df.values[:, 1:] # Skip first column
|
| 307 |
+
test_data = test_df.values[:, 1:]
|
| 308 |
+
test_labels = label_df.values[:, 1:]
|
| 309 |
+
feature_names = list(train_df.columns[1:])
|
| 310 |
+
|
| 311 |
+
train_data = np.nan_to_num(train_data, nan=0.0).astype(np.float32)
|
| 312 |
+
test_data = np.nan_to_num(test_data, nan=0.0).astype(np.float32)
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
raise ValueError(f"Unknown dataset type: {dataset_type}")
|
| 316 |
+
|
| 317 |
+
return {
|
| 318 |
+
"train_data": train_data,
|
| 319 |
+
"test_data": test_data,
|
| 320 |
+
"test_labels": test_labels,
|
| 321 |
+
"feature_names": feature_names
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
def get_scaler_params(self) -> Optional[Dict[str, np.ndarray]]:
|
| 325 |
+
"""
|
| 326 |
+
Get scaler parameters (mean and scale).
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dict with 'mean' and 'scale' arrays, or None if not fitted
|
| 330 |
+
"""
|
| 331 |
+
if not self._is_fitted or self.scaler is None:
|
| 332 |
+
return None
|
| 333 |
+
|
| 334 |
+
return {
|
| 335 |
+
"mean": self.scaler.mean_,
|
| 336 |
+
"scale": self.scaler.scale_
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def set_scaler_params(self, mean: np.ndarray, scale: np.ndarray) -> None:
|
| 340 |
+
"""
|
| 341 |
+
Set scaler parameters directly (useful for loading saved parameters).
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
mean: Mean values for each feature
|
| 345 |
+
scale: Scale (std) values for each feature
|
| 346 |
+
"""
|
| 347 |
+
if self.scaler is None:
|
| 348 |
+
self.scaler = StandardScaler()
|
| 349 |
+
|
| 350 |
+
self.scaler.mean_ = mean
|
| 351 |
+
self.scaler.scale_ = scale
|
| 352 |
+
self.scaler.var_ = scale ** 2
|
| 353 |
+
self.scaler.n_features_in_ = len(mean)
|
| 354 |
+
self._n_features = len(mean)
|
| 355 |
+
self._is_fitted = True
|
| 356 |
+
|
| 357 |
+
@property
|
| 358 |
+
def n_features(self) -> Optional[int]:
|
| 359 |
+
"""Number of features."""
|
| 360 |
+
return self._n_features
|
| 361 |
+
|
| 362 |
+
@property
|
| 363 |
+
def is_fitted(self) -> bool:
|
| 364 |
+
"""Whether the processor has been fitted."""
|
| 365 |
+
return self._is_fitted
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def preprocess_for_inference(
|
| 369 |
+
data: Union[np.ndarray, pd.DataFrame],
|
| 370 |
+
scaler_mean: Optional[np.ndarray] = None,
|
| 371 |
+
scaler_scale: Optional[np.ndarray] = None,
|
| 372 |
+
seq_len: int = 100,
|
| 373 |
+
step: int = 1
|
| 374 |
+
) -> np.ndarray:
|
| 375 |
+
"""
|
| 376 |
+
Convenience function to preprocess data for model inference.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
data: Raw input data
|
| 380 |
+
scaler_mean: Optional pre-computed scaler mean
|
| 381 |
+
scaler_scale: Optional pre-computed scaler scale
|
| 382 |
+
seq_len: Window length
|
| 383 |
+
step: Step size
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Preprocessed windows ready for model input
|
| 387 |
+
"""
|
| 388 |
+
processor = DataProcessor(seq_len=seq_len, step=step)
|
| 389 |
+
|
| 390 |
+
if scaler_mean is not None and scaler_scale is not None:
|
| 391 |
+
processor.set_scaler_params(scaler_mean, scaler_scale)
|
| 392 |
+
return processor.transform(data, return_windows=True)
|
| 393 |
+
else:
|
| 394 |
+
return processor.fit_transform(data, return_windows=True)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
# Test module
|
| 399 |
+
print("Testing DataProcessor...")
|
| 400 |
+
|
| 401 |
+
# Create sample data
|
| 402 |
+
np.random.seed(42)
|
| 403 |
+
sample_data = np.random.randn(1000, 16).astype(np.float32)
|
| 404 |
+
|
| 405 |
+
# Test processor
|
| 406 |
+
processor = DataProcessor(seq_len=100, step=1)
|
| 407 |
+
windows = processor.fit_transform(sample_data)
|
| 408 |
+
|
| 409 |
+
print(f"Input shape: {sample_data.shape}")
|
| 410 |
+
print(f"Windows shape: {windows.shape}")
|
| 411 |
+
print(f"Expected windows: {(1000 - 100) // 1 + 1}")
|
| 412 |
+
print(f"Is fitted: {processor.is_fitted}")
|
| 413 |
+
print(f"N features: {processor.n_features}")
|
| 414 |
+
|
| 415 |
+
# Test inverse transform
|
| 416 |
+
original = processor.inverse_transform(windows[:5])
|
| 417 |
+
print(f"Inverse transform shape: {original.shape}")
|
core/inference.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voltage Anomaly Detection Inference Module
|
| 3 |
+
|
| 4 |
+
This module provides a high-level API for voltage anomaly detection inference.
|
| 5 |
+
Supports CPU inference with torch.no_grad() optimization.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import sys
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Optional, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
# Add code directory to path for model imports
|
| 19 |
+
CODE_DIR = Path(__file__).parent.parent.parent
|
| 20 |
+
if str(CODE_DIR) not in sys.path:
|
| 21 |
+
sys.path.insert(0, str(CODE_DIR))
|
| 22 |
+
|
| 23 |
+
from models import model_dict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VoltageAnomalyDetector:
|
| 27 |
+
"""
|
| 28 |
+
High-level API for voltage anomaly detection inference.
|
| 29 |
+
|
| 30 |
+
This class wraps the model loading, preprocessing, and inference logic
|
| 31 |
+
for easy use in applications like Gradio demos.
|
| 32 |
+
|
| 33 |
+
Example:
|
| 34 |
+
>>> detector = VoltageAnomalyDetector("VoltageTimesNet_v2", checkpoint_path)
|
| 35 |
+
>>> detector.load_model()
|
| 36 |
+
>>> results = detector.predict(data, threshold=0.5)
|
| 37 |
+
>>> print(results["labels"]) # Anomaly labels (0 or 1)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Default model configurations
|
| 41 |
+
DEFAULT_CONFIGS = {
|
| 42 |
+
"VoltageTimesNet_v2": {
|
| 43 |
+
"enc_in": 16,
|
| 44 |
+
"c_out": 16,
|
| 45 |
+
"seq_len": 100,
|
| 46 |
+
"d_model": 64,
|
| 47 |
+
"d_ff": 64,
|
| 48 |
+
"e_layers": 2,
|
| 49 |
+
"top_k": 5,
|
| 50 |
+
"num_kernels": 6,
|
| 51 |
+
"dropout": 0.1,
|
| 52 |
+
"embed": "fixed",
|
| 53 |
+
"freq": "h",
|
| 54 |
+
"task_name": "anomaly_detection",
|
| 55 |
+
"pred_len": 0,
|
| 56 |
+
"label_len": 0,
|
| 57 |
+
},
|
| 58 |
+
"VoltageTimesNet": {
|
| 59 |
+
"enc_in": 16,
|
| 60 |
+
"c_out": 16,
|
| 61 |
+
"seq_len": 100,
|
| 62 |
+
"d_model": 64,
|
| 63 |
+
"d_ff": 64,
|
| 64 |
+
"e_layers": 2,
|
| 65 |
+
"top_k": 5,
|
| 66 |
+
"num_kernels": 6,
|
| 67 |
+
"dropout": 0.1,
|
| 68 |
+
"embed": "fixed",
|
| 69 |
+
"freq": "h",
|
| 70 |
+
"task_name": "anomaly_detection",
|
| 71 |
+
"pred_len": 0,
|
| 72 |
+
"label_len": 0,
|
| 73 |
+
},
|
| 74 |
+
"TimesNet": {
|
| 75 |
+
"enc_in": 16,
|
| 76 |
+
"c_out": 16,
|
| 77 |
+
"seq_len": 100,
|
| 78 |
+
"d_model": 64,
|
| 79 |
+
"d_ff": 64,
|
| 80 |
+
"e_layers": 2,
|
| 81 |
+
"top_k": 5,
|
| 82 |
+
"num_kernels": 6,
|
| 83 |
+
"dropout": 0.1,
|
| 84 |
+
"embed": "fixed",
|
| 85 |
+
"freq": "h",
|
| 86 |
+
"task_name": "anomaly_detection",
|
| 87 |
+
"pred_len": 0,
|
| 88 |
+
"label_len": 0,
|
| 89 |
+
},
|
| 90 |
+
"DLinear": {
|
| 91 |
+
"enc_in": 16,
|
| 92 |
+
"c_out": 16,
|
| 93 |
+
"seq_len": 100,
|
| 94 |
+
"pred_len": 100,
|
| 95 |
+
"individual": False,
|
| 96 |
+
"task_name": "anomaly_detection",
|
| 97 |
+
"label_len": 0,
|
| 98 |
+
},
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
model_name: str,
|
| 104 |
+
checkpoint_path: Optional[str] = None,
|
| 105 |
+
device: str = "cpu",
|
| 106 |
+
config_path: Optional[str] = None,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
Initialize the VoltageAnomalyDetector.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
model_name: Name of the model (e.g., "VoltageTimesNet_v2", "TimesNet")
|
| 113 |
+
checkpoint_path: Path to the model checkpoint file (.pth)
|
| 114 |
+
device: Device to run inference on ("cpu" or "cuda")
|
| 115 |
+
config_path: Path to model config JSON file (optional)
|
| 116 |
+
"""
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self.checkpoint_path = checkpoint_path
|
| 119 |
+
self.device = torch.device(device)
|
| 120 |
+
self.config_path = config_path
|
| 121 |
+
|
| 122 |
+
self.model: Optional[nn.Module] = None
|
| 123 |
+
self.config: Dict = {}
|
| 124 |
+
self._is_loaded = False
|
| 125 |
+
|
| 126 |
+
# MSE criterion for reconstruction error
|
| 127 |
+
self.anomaly_criterion = nn.MSELoss(reduction='none')
|
| 128 |
+
|
| 129 |
+
def _load_config(self) -> Dict:
|
| 130 |
+
"""
|
| 131 |
+
Load model configuration from file or use defaults.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Dictionary containing model configuration
|
| 135 |
+
"""
|
| 136 |
+
config = {}
|
| 137 |
+
|
| 138 |
+
# Try loading from config file first
|
| 139 |
+
if self.config_path is not None:
|
| 140 |
+
config_file = Path(self.config_path)
|
| 141 |
+
if config_file.exists():
|
| 142 |
+
with open(config_file, "r") as f:
|
| 143 |
+
config = json.load(f)
|
| 144 |
+
|
| 145 |
+
# Fall back to default config
|
| 146 |
+
if not config and self.model_name in self.DEFAULT_CONFIGS:
|
| 147 |
+
config = self.DEFAULT_CONFIGS[self.model_name].copy()
|
| 148 |
+
|
| 149 |
+
# Ensure model name is set
|
| 150 |
+
config["model"] = self.model_name
|
| 151 |
+
|
| 152 |
+
return config
|
| 153 |
+
|
| 154 |
+
def _config_to_args(self, config: Dict) -> Namespace:
|
| 155 |
+
"""
|
| 156 |
+
Convert config dictionary to argparse Namespace.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
config: Configuration dictionary
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Namespace object with configuration attributes
|
| 163 |
+
"""
|
| 164 |
+
# Merge with defaults
|
| 165 |
+
default_config = self.DEFAULT_CONFIGS.get(self.model_name, {})
|
| 166 |
+
merged_config = {**default_config, **config}
|
| 167 |
+
|
| 168 |
+
# Ensure required fields
|
| 169 |
+
merged_config.setdefault("task_name", "anomaly_detection")
|
| 170 |
+
merged_config.setdefault("embed", "fixed")
|
| 171 |
+
merged_config.setdefault("freq", "h")
|
| 172 |
+
merged_config.setdefault("dropout", 0.1)
|
| 173 |
+
merged_config.setdefault("pred_len", 0)
|
| 174 |
+
merged_config.setdefault("label_len", 0)
|
| 175 |
+
|
| 176 |
+
return Namespace(**merged_config)
|
| 177 |
+
|
| 178 |
+
def load_model(self, strict: bool = False) -> None:
|
| 179 |
+
"""
|
| 180 |
+
Load the model and weights.
|
| 181 |
+
|
| 182 |
+
This method initializes the model architecture and loads
|
| 183 |
+
pretrained weights if a checkpoint path is provided.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
strict: Whether to strictly enforce that the keys in state_dict
|
| 187 |
+
match the keys returned by the model's state_dict function.
|
| 188 |
+
If False, allows loading checkpoints with minor mismatches.
|
| 189 |
+
Default is False for better compatibility.
|
| 190 |
+
|
| 191 |
+
Raises:
|
| 192 |
+
ValueError: If model name is not found in model registry
|
| 193 |
+
FileNotFoundError: If checkpoint file does not exist
|
| 194 |
+
"""
|
| 195 |
+
# Load configuration
|
| 196 |
+
self.config = self._load_config()
|
| 197 |
+
args = self._config_to_args(self.config)
|
| 198 |
+
|
| 199 |
+
# Check model exists
|
| 200 |
+
if self.model_name not in model_dict:
|
| 201 |
+
available = list(model_dict.keys())
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Model '{self.model_name}' not found. "
|
| 204 |
+
f"Available models: {available}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Build model
|
| 208 |
+
Model = model_dict[self.model_name]
|
| 209 |
+
self.model = Model(args)
|
| 210 |
+
|
| 211 |
+
# Load checkpoint if provided
|
| 212 |
+
if self.checkpoint_path is not None:
|
| 213 |
+
checkpoint_file = Path(self.checkpoint_path)
|
| 214 |
+
if not checkpoint_file.exists():
|
| 215 |
+
raise FileNotFoundError(
|
| 216 |
+
f"Checkpoint file not found: {self.checkpoint_path}"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Load weights
|
| 220 |
+
state_dict = torch.load(
|
| 221 |
+
self.checkpoint_path,
|
| 222 |
+
map_location=self.device,
|
| 223 |
+
weights_only=True
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Load state dict with optional strict mode
|
| 227 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(
|
| 228 |
+
state_dict, strict=strict
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if missing_keys:
|
| 232 |
+
print(f"Warning: Missing keys in checkpoint: {missing_keys}")
|
| 233 |
+
if unexpected_keys:
|
| 234 |
+
print(f"Warning: Unexpected keys in checkpoint: {unexpected_keys}")
|
| 235 |
+
|
| 236 |
+
print(f"Loaded checkpoint from: {self.checkpoint_path}")
|
| 237 |
+
|
| 238 |
+
# Move to device and set eval mode
|
| 239 |
+
self.model = self.model.to(self.device)
|
| 240 |
+
self.model.eval()
|
| 241 |
+
self._is_loaded = True
|
| 242 |
+
|
| 243 |
+
print(f"Model '{self.model_name}' loaded successfully on {self.device}")
|
| 244 |
+
|
| 245 |
+
def preprocess(self, data: np.ndarray) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
Preprocess input data for model inference.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
data: Input numpy array with shape:
|
| 251 |
+
- (seq_len, n_features) for single sample
|
| 252 |
+
- (batch_size, seq_len, n_features) for batch
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Preprocessed tensor with shape (batch_size, seq_len, n_features)
|
| 256 |
+
"""
|
| 257 |
+
# Ensure numpy array
|
| 258 |
+
if not isinstance(data, np.ndarray):
|
| 259 |
+
data = np.array(data)
|
| 260 |
+
|
| 261 |
+
# Convert to float32
|
| 262 |
+
data = data.astype(np.float32)
|
| 263 |
+
|
| 264 |
+
# Add batch dimension if needed
|
| 265 |
+
if data.ndim == 2:
|
| 266 |
+
data = data[np.newaxis, ...] # (1, seq_len, n_features)
|
| 267 |
+
|
| 268 |
+
# Validate shape
|
| 269 |
+
if data.ndim != 3:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Expected 2D or 3D array, got shape {data.shape}"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Convert to tensor
|
| 275 |
+
tensor = torch.from_numpy(data).to(self.device)
|
| 276 |
+
|
| 277 |
+
return tensor
|
| 278 |
+
|
| 279 |
+
def get_reconstruction_error(self, data: np.ndarray) -> np.ndarray:
|
| 280 |
+
"""
|
| 281 |
+
Compute reconstruction error for input data.
|
| 282 |
+
|
| 283 |
+
The reconstruction error is computed as the mean squared error
|
| 284 |
+
between the input and the model's reconstruction.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
data: Input numpy array with shape (batch_size, seq_len, n_features)
|
| 288 |
+
or (seq_len, n_features) for single sample
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Reconstruction error array with shape (n_samples,)
|
| 292 |
+
where n_samples = batch_size * seq_len
|
| 293 |
+
"""
|
| 294 |
+
if not self._is_loaded:
|
| 295 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 296 |
+
|
| 297 |
+
# Preprocess data
|
| 298 |
+
batch_x = self.preprocess(data)
|
| 299 |
+
|
| 300 |
+
# Inference with no gradient computation
|
| 301 |
+
with torch.no_grad():
|
| 302 |
+
# Forward pass - reconstruct input
|
| 303 |
+
outputs = self.model(batch_x, None, None, None)
|
| 304 |
+
|
| 305 |
+
# Compute reconstruction error per sample
|
| 306 |
+
# Shape: (batch, seq_len, features) -> (batch, seq_len)
|
| 307 |
+
error = torch.mean(
|
| 308 |
+
self.anomaly_criterion(batch_x, outputs),
|
| 309 |
+
dim=-1
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Flatten to (n_samples,)
|
| 313 |
+
error = error.reshape(-1)
|
| 314 |
+
|
| 315 |
+
# Convert to numpy
|
| 316 |
+
error_np = error.cpu().numpy()
|
| 317 |
+
|
| 318 |
+
return error_np
|
| 319 |
+
|
| 320 |
+
def predict(
|
| 321 |
+
self,
|
| 322 |
+
data: np.ndarray,
|
| 323 |
+
threshold: float = 0.5,
|
| 324 |
+
return_scores: bool = True,
|
| 325 |
+
) -> Dict[str, Union[np.ndarray, float]]:
|
| 326 |
+
"""
|
| 327 |
+
Perform anomaly detection inference.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
data: Input numpy array with shape (batch_size, seq_len, n_features)
|
| 331 |
+
or (seq_len, n_features) for single sample
|
| 332 |
+
threshold: Anomaly score threshold for binary classification.
|
| 333 |
+
Samples with scores above threshold are labeled as anomalies.
|
| 334 |
+
return_scores: Whether to return raw anomaly scores
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Dictionary containing:
|
| 338 |
+
- "scores": np.ndarray of anomaly scores (if return_scores=True)
|
| 339 |
+
- "labels": np.ndarray of binary labels (0=normal, 1=anomaly)
|
| 340 |
+
- "threshold": float threshold used for classification
|
| 341 |
+
"""
|
| 342 |
+
if not self._is_loaded:
|
| 343 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 344 |
+
|
| 345 |
+
# Get reconstruction errors as anomaly scores
|
| 346 |
+
scores = self.get_reconstruction_error(data)
|
| 347 |
+
|
| 348 |
+
# Apply threshold to get binary labels
|
| 349 |
+
labels = (scores > threshold).astype(np.int32)
|
| 350 |
+
|
| 351 |
+
# Build result dictionary
|
| 352 |
+
result = {
|
| 353 |
+
"labels": labels,
|
| 354 |
+
"threshold": threshold,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
if return_scores:
|
| 358 |
+
result["scores"] = scores
|
| 359 |
+
|
| 360 |
+
return result
|
| 361 |
+
|
| 362 |
+
def predict_with_percentile_threshold(
|
| 363 |
+
self,
|
| 364 |
+
data: np.ndarray,
|
| 365 |
+
anomaly_ratio: float = 1.0,
|
| 366 |
+
return_scores: bool = True,
|
| 367 |
+
) -> Dict[str, Union[np.ndarray, float]]:
|
| 368 |
+
"""
|
| 369 |
+
Perform anomaly detection using percentile-based threshold.
|
| 370 |
+
|
| 371 |
+
This method computes the threshold based on the anomaly ratio,
|
| 372 |
+
similar to the training evaluation approach.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
data: Input numpy array
|
| 376 |
+
anomaly_ratio: Expected percentage of anomalies (e.g., 1.0 means 1%)
|
| 377 |
+
return_scores: Whether to return raw anomaly scores
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Dictionary containing scores, labels, and computed threshold
|
| 381 |
+
"""
|
| 382 |
+
if not self._is_loaded:
|
| 383 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 384 |
+
|
| 385 |
+
# Get reconstruction errors
|
| 386 |
+
scores = self.get_reconstruction_error(data)
|
| 387 |
+
|
| 388 |
+
# Compute threshold using percentile
|
| 389 |
+
threshold = np.percentile(scores, 100 - anomaly_ratio)
|
| 390 |
+
|
| 391 |
+
# Apply threshold
|
| 392 |
+
labels = (scores > threshold).astype(np.int32)
|
| 393 |
+
|
| 394 |
+
result = {
|
| 395 |
+
"labels": labels,
|
| 396 |
+
"threshold": float(threshold),
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
if return_scores:
|
| 400 |
+
result["scores"] = scores
|
| 401 |
+
|
| 402 |
+
return result
|
| 403 |
+
|
| 404 |
+
@property
|
| 405 |
+
def seq_len(self) -> int:
|
| 406 |
+
"""Get the expected input sequence length."""
|
| 407 |
+
return self.config.get("seq_len", 100)
|
| 408 |
+
|
| 409 |
+
@property
|
| 410 |
+
def n_features(self) -> int:
|
| 411 |
+
"""Get the expected number of input features."""
|
| 412 |
+
return self.config.get("enc_in", 16)
|
| 413 |
+
|
| 414 |
+
@property
|
| 415 |
+
def is_loaded(self) -> bool:
|
| 416 |
+
"""Check if the model is loaded."""
|
| 417 |
+
return self._is_loaded
|
| 418 |
+
|
| 419 |
+
def __repr__(self) -> str:
|
| 420 |
+
status = "loaded" if self._is_loaded else "not loaded"
|
| 421 |
+
return (
|
| 422 |
+
f"VoltageAnomalyDetector("
|
| 423 |
+
f"model={self.model_name}, "
|
| 424 |
+
f"device={self.device}, "
|
| 425 |
+
f"status={status})"
|
| 426 |
+
)
|
core/model_loader.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Loader Module for Gradio Demo
|
| 3 |
+
农村低压配电网电压异常检测项目
|
| 4 |
+
|
| 5 |
+
Provides:
|
| 6 |
+
- load_model(): Load a model with optional checkpoint
|
| 7 |
+
- get_available_models(): List available models
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Dict, List, Any
|
| 13 |
+
from argparse import Namespace
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
# Add code directory to path for importing models
|
| 18 |
+
CODE_DIR = Path(__file__).parent.parent.parent
|
| 19 |
+
if str(CODE_DIR) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(CODE_DIR))
|
| 21 |
+
|
| 22 |
+
from models import model_dict
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Default model configurations for anomaly detection
|
| 26 |
+
DEFAULT_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
| 27 |
+
"VoltageTimesNet_v2": {
|
| 28 |
+
"task_name": "anomaly_detection",
|
| 29 |
+
"enc_in": 16,
|
| 30 |
+
"c_out": 16,
|
| 31 |
+
"seq_len": 100,
|
| 32 |
+
"pred_len": 0,
|
| 33 |
+
"label_len": 0,
|
| 34 |
+
"d_model": 64,
|
| 35 |
+
"d_ff": 64,
|
| 36 |
+
"e_layers": 2,
|
| 37 |
+
"top_k": 5,
|
| 38 |
+
"num_kernels": 6,
|
| 39 |
+
"embed": "timeF",
|
| 40 |
+
"freq": "h",
|
| 41 |
+
"dropout": 0.1,
|
| 42 |
+
},
|
| 43 |
+
"VoltageTimesNet": {
|
| 44 |
+
"task_name": "anomaly_detection",
|
| 45 |
+
"enc_in": 16,
|
| 46 |
+
"c_out": 16,
|
| 47 |
+
"seq_len": 100,
|
| 48 |
+
"pred_len": 0,
|
| 49 |
+
"label_len": 0,
|
| 50 |
+
"d_model": 64,
|
| 51 |
+
"d_ff": 64,
|
| 52 |
+
"e_layers": 2,
|
| 53 |
+
"top_k": 5,
|
| 54 |
+
"num_kernels": 6,
|
| 55 |
+
"embed": "timeF",
|
| 56 |
+
"freq": "h",
|
| 57 |
+
"dropout": 0.1,
|
| 58 |
+
},
|
| 59 |
+
"TimesNet": {
|
| 60 |
+
"task_name": "anomaly_detection",
|
| 61 |
+
"enc_in": 16,
|
| 62 |
+
"c_out": 16,
|
| 63 |
+
"seq_len": 100,
|
| 64 |
+
"pred_len": 0,
|
| 65 |
+
"label_len": 0,
|
| 66 |
+
"d_model": 64,
|
| 67 |
+
"d_ff": 64,
|
| 68 |
+
"e_layers": 2,
|
| 69 |
+
"top_k": 5,
|
| 70 |
+
"num_kernels": 6,
|
| 71 |
+
"embed": "timeF",
|
| 72 |
+
"freq": "h",
|
| 73 |
+
"dropout": 0.1,
|
| 74 |
+
},
|
| 75 |
+
"TPATimesNet": {
|
| 76 |
+
"task_name": "anomaly_detection",
|
| 77 |
+
"enc_in": 16,
|
| 78 |
+
"c_out": 16,
|
| 79 |
+
"seq_len": 100,
|
| 80 |
+
"pred_len": 0,
|
| 81 |
+
"label_len": 0,
|
| 82 |
+
"d_model": 64,
|
| 83 |
+
"d_ff": 64,
|
| 84 |
+
"e_layers": 2,
|
| 85 |
+
"top_k": 5,
|
| 86 |
+
"num_kernels": 6,
|
| 87 |
+
"embed": "timeF",
|
| 88 |
+
"freq": "h",
|
| 89 |
+
"dropout": 0.1,
|
| 90 |
+
},
|
| 91 |
+
"MTSTimesNet": {
|
| 92 |
+
"task_name": "anomaly_detection",
|
| 93 |
+
"enc_in": 16,
|
| 94 |
+
"c_out": 16,
|
| 95 |
+
"seq_len": 100,
|
| 96 |
+
"pred_len": 0,
|
| 97 |
+
"label_len": 0,
|
| 98 |
+
"d_model": 64,
|
| 99 |
+
"d_ff": 64,
|
| 100 |
+
"e_layers": 2,
|
| 101 |
+
"top_k": 5,
|
| 102 |
+
"num_kernels": 6,
|
| 103 |
+
"embed": "timeF",
|
| 104 |
+
"freq": "h",
|
| 105 |
+
"dropout": 0.1,
|
| 106 |
+
},
|
| 107 |
+
"DLinear": {
|
| 108 |
+
"task_name": "anomaly_detection",
|
| 109 |
+
"enc_in": 16,
|
| 110 |
+
"c_out": 16,
|
| 111 |
+
"seq_len": 100,
|
| 112 |
+
"pred_len": 100, # DLinear requires pred_len = seq_len for anomaly detection
|
| 113 |
+
"individual": False,
|
| 114 |
+
"moving_avg": 25,
|
| 115 |
+
},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Models suitable for voltage anomaly detection demo
|
| 119 |
+
DEMO_MODELS = [
|
| 120 |
+
"VoltageTimesNet_v2",
|
| 121 |
+
"VoltageTimesNet",
|
| 122 |
+
"TimesNet",
|
| 123 |
+
"TPATimesNet",
|
| 124 |
+
"MTSTimesNet",
|
| 125 |
+
"DLinear",
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_available_models() -> List[str]:
|
| 130 |
+
"""
|
| 131 |
+
Get list of available models for the demo.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List of model names that can be loaded
|
| 135 |
+
"""
|
| 136 |
+
return [m for m in DEMO_MODELS if m in model_dict]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_all_models() -> List[str]:
|
| 140 |
+
"""
|
| 141 |
+
Get list of all registered models.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of all model names in model_dict
|
| 145 |
+
"""
|
| 146 |
+
return list(model_dict.keys())
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def create_model_config(
|
| 150 |
+
model_name: str,
|
| 151 |
+
config_override: Optional[Dict[str, Any]] = None
|
| 152 |
+
) -> Namespace:
|
| 153 |
+
"""
|
| 154 |
+
Create model configuration namespace.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
model_name: Name of the model
|
| 158 |
+
config_override: Optional dict to override default config
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Namespace object with model configuration
|
| 162 |
+
"""
|
| 163 |
+
# Get default config or create minimal config
|
| 164 |
+
if model_name in DEFAULT_MODEL_CONFIGS:
|
| 165 |
+
config = DEFAULT_MODEL_CONFIGS[model_name].copy()
|
| 166 |
+
else:
|
| 167 |
+
# Minimal config for unknown models
|
| 168 |
+
config = {
|
| 169 |
+
"task_name": "anomaly_detection",
|
| 170 |
+
"enc_in": 16,
|
| 171 |
+
"c_out": 16,
|
| 172 |
+
"seq_len": 100,
|
| 173 |
+
"pred_len": 0,
|
| 174 |
+
"label_len": 0,
|
| 175 |
+
"d_model": 64,
|
| 176 |
+
"d_ff": 64,
|
| 177 |
+
"e_layers": 2,
|
| 178 |
+
"top_k": 5,
|
| 179 |
+
"num_kernels": 6,
|
| 180 |
+
"embed": "timeF",
|
| 181 |
+
"freq": "h",
|
| 182 |
+
"dropout": 0.1,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
# Apply overrides
|
| 186 |
+
if config_override:
|
| 187 |
+
config.update(config_override)
|
| 188 |
+
|
| 189 |
+
return Namespace(**config)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def load_model(
|
| 193 |
+
model_name: str,
|
| 194 |
+
checkpoint_path: Optional[str] = None,
|
| 195 |
+
config_override: Optional[Dict[str, Any]] = None,
|
| 196 |
+
device: str = "cpu"
|
| 197 |
+
) -> torch.nn.Module:
|
| 198 |
+
"""
|
| 199 |
+
Load a model with optional checkpoint.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
model_name: Name of the model (e.g., "VoltageTimesNet_v2", "TimesNet")
|
| 203 |
+
checkpoint_path: Optional path to model checkpoint (.pth file)
|
| 204 |
+
config_override: Optional dict to override default model config
|
| 205 |
+
device: Device to load model on ("cpu" or "cuda")
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Loaded model in eval mode
|
| 209 |
+
|
| 210 |
+
Raises:
|
| 211 |
+
ValueError: If model_name is not found
|
| 212 |
+
FileNotFoundError: If checkpoint_path doesn't exist
|
| 213 |
+
|
| 214 |
+
Example:
|
| 215 |
+
>>> model = load_model("VoltageTimesNet_v2")
|
| 216 |
+
>>> model = load_model("TimesNet", checkpoint_path="./best_model.pth")
|
| 217 |
+
>>> model = load_model("TimesNet", config_override={"seq_len": 50})
|
| 218 |
+
"""
|
| 219 |
+
# Validate model name
|
| 220 |
+
if model_name not in model_dict:
|
| 221 |
+
available = get_available_models()
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Model '{model_name}' not found. "
|
| 224 |
+
f"Available models: {available}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Create config
|
| 228 |
+
config = create_model_config(model_name, config_override)
|
| 229 |
+
|
| 230 |
+
# Build model
|
| 231 |
+
Model = model_dict[model_name]
|
| 232 |
+
model = Model(config)
|
| 233 |
+
|
| 234 |
+
# Load checkpoint if provided
|
| 235 |
+
if checkpoint_path:
|
| 236 |
+
checkpoint_path = Path(checkpoint_path)
|
| 237 |
+
if not checkpoint_path.exists():
|
| 238 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 239 |
+
|
| 240 |
+
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
| 241 |
+
|
| 242 |
+
# Handle different checkpoint formats
|
| 243 |
+
if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
|
| 244 |
+
state_dict = state_dict["model_state_dict"]
|
| 245 |
+
|
| 246 |
+
# Load state dict with strict=False to handle minor mismatches
|
| 247 |
+
model.load_state_dict(state_dict, strict=False)
|
| 248 |
+
print(f"Loaded checkpoint from: {checkpoint_path}")
|
| 249 |
+
|
| 250 |
+
# Move to device and set to eval mode
|
| 251 |
+
model = model.to(device)
|
| 252 |
+
model.eval()
|
| 253 |
+
|
| 254 |
+
return model
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_model_info(model_name: str) -> Dict[str, Any]:
|
| 258 |
+
"""
|
| 259 |
+
Get information about a model.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
model_name: Name of the model
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Dict with model information
|
| 266 |
+
"""
|
| 267 |
+
if model_name not in model_dict:
|
| 268 |
+
return {"error": f"Model '{model_name}' not found"}
|
| 269 |
+
|
| 270 |
+
config = DEFAULT_MODEL_CONFIGS.get(model_name, {})
|
| 271 |
+
|
| 272 |
+
info = {
|
| 273 |
+
"name": model_name,
|
| 274 |
+
"available": True,
|
| 275 |
+
"config": config,
|
| 276 |
+
"description": _get_model_description(model_name),
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
return info
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _get_model_description(model_name: str) -> str:
|
| 283 |
+
"""Get model description."""
|
| 284 |
+
descriptions = {
|
| 285 |
+
"VoltageTimesNet_v2": "Enhanced TimesNet with recall optimization for voltage anomaly detection",
|
| 286 |
+
"VoltageTimesNet": "TimesNet variant with preset periods for voltage patterns",
|
| 287 |
+
"TimesNet": "FFT-based period discovery with 2D convolution for temporal patterns",
|
| 288 |
+
"TPATimesNet": "Three-Phase Attention TimesNet for multi-phase voltage analysis",
|
| 289 |
+
"MTSTimesNet": "Multi-scale Temporal TimesNet for multi-resolution patterns",
|
| 290 |
+
"DLinear": "Lightweight linear model with trend-seasonal decomposition",
|
| 291 |
+
}
|
| 292 |
+
return descriptions.get(model_name, "No description available")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
# Test module
|
| 297 |
+
print("Available models:", get_available_models())
|
| 298 |
+
|
| 299 |
+
for model_name in get_available_models():
|
| 300 |
+
print(f"\nLoading {model_name}...")
|
| 301 |
+
try:
|
| 302 |
+
model = load_model(model_name)
|
| 303 |
+
# Count parameters
|
| 304 |
+
params = sum(p.numel() for p in model.parameters())
|
| 305 |
+
print(f" - Parameters: {params:,}")
|
| 306 |
+
print(f" - Config: {get_model_info(model_name)['config']}")
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f" - Error: {e}")
|
docs/model_architectures/svg/01_TimesNet.svg
ADDED
|
|
docs/model_architectures/svg/02_VoltageTimesNet.svg
ADDED
|
|
docs/model_architectures/svg/03_VoltageTimesNet_v2.svg
ADDED
|
|
docs/model_architectures/svg/04_TPATimesNet.svg
ADDED
|
|
docs/model_architectures/svg/05_MTSTimesNet.svg
ADDED
|
|
docs/model_architectures/svg/06_DLinear.svg
ADDED
|
|
layers/AutoCorrelation.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AutoCorrelation Mechanism for Autoformer.
|
| 3 |
+
|
| 4 |
+
Period-based dependencies discovery with time delay aggregation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AutoCorrelation(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
AutoCorrelation Mechanism with:
|
| 16 |
+
(1) period-based dependencies discovery
|
| 17 |
+
(2) time delay aggregation
|
| 18 |
+
|
| 19 |
+
This block can replace the self-attention family mechanism seamlessly.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
mask_flag=True,
|
| 25 |
+
factor=1,
|
| 26 |
+
scale=None,
|
| 27 |
+
attention_dropout=0.1,
|
| 28 |
+
output_attention=False,
|
| 29 |
+
):
|
| 30 |
+
super(AutoCorrelation, self).__init__()
|
| 31 |
+
self.factor = factor
|
| 32 |
+
self.scale = scale
|
| 33 |
+
self.mask_flag = mask_flag
|
| 34 |
+
self.output_attention = output_attention
|
| 35 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 36 |
+
|
| 37 |
+
def time_delay_agg_training(self, values, corr):
|
| 38 |
+
"""SpeedUp version for training phase."""
|
| 39 |
+
head = values.shape[1]
|
| 40 |
+
channel = values.shape[2]
|
| 41 |
+
length = values.shape[3]
|
| 42 |
+
|
| 43 |
+
top_k = int(self.factor * math.log(length))
|
| 44 |
+
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
| 45 |
+
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
| 46 |
+
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
| 47 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 48 |
+
|
| 49 |
+
tmp_values = values
|
| 50 |
+
delays_agg = torch.zeros_like(values).float()
|
| 51 |
+
for i in range(top_k):
|
| 52 |
+
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
| 53 |
+
delays_agg = delays_agg + pattern * (
|
| 54 |
+
tmp_corr[:, i]
|
| 55 |
+
.unsqueeze(1)
|
| 56 |
+
.unsqueeze(1)
|
| 57 |
+
.unsqueeze(1)
|
| 58 |
+
.repeat(1, head, channel, length)
|
| 59 |
+
)
|
| 60 |
+
return delays_agg
|
| 61 |
+
|
| 62 |
+
def time_delay_agg_inference(self, values, corr):
|
| 63 |
+
"""SpeedUp version for inference phase."""
|
| 64 |
+
batch = values.shape[0]
|
| 65 |
+
head = values.shape[1]
|
| 66 |
+
channel = values.shape[2]
|
| 67 |
+
length = values.shape[3]
|
| 68 |
+
|
| 69 |
+
init_index = (
|
| 70 |
+
torch.arange(length)
|
| 71 |
+
.unsqueeze(0)
|
| 72 |
+
.unsqueeze(0)
|
| 73 |
+
.unsqueeze(0)
|
| 74 |
+
.repeat(batch, head, channel, 1)
|
| 75 |
+
.to(values.device)
|
| 76 |
+
)
|
| 77 |
+
top_k = int(self.factor * math.log(length))
|
| 78 |
+
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
| 79 |
+
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
| 80 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 81 |
+
|
| 82 |
+
tmp_values = values.repeat(1, 1, 1, 2)
|
| 83 |
+
delays_agg = torch.zeros_like(values).float()
|
| 84 |
+
for i in range(top_k):
|
| 85 |
+
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(
|
| 86 |
+
1
|
| 87 |
+
).repeat(1, head, channel, length)
|
| 88 |
+
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
| 89 |
+
delays_agg = delays_agg + pattern * (
|
| 90 |
+
tmp_corr[:, i]
|
| 91 |
+
.unsqueeze(1)
|
| 92 |
+
.unsqueeze(1)
|
| 93 |
+
.unsqueeze(1)
|
| 94 |
+
.repeat(1, head, channel, length)
|
| 95 |
+
)
|
| 96 |
+
return delays_agg
|
| 97 |
+
|
| 98 |
+
def time_delay_agg_full(self, values, corr):
|
| 99 |
+
"""Standard version of Autocorrelation."""
|
| 100 |
+
batch = values.shape[0]
|
| 101 |
+
head = values.shape[1]
|
| 102 |
+
channel = values.shape[2]
|
| 103 |
+
length = values.shape[3]
|
| 104 |
+
|
| 105 |
+
init_index = (
|
| 106 |
+
torch.arange(length)
|
| 107 |
+
.unsqueeze(0)
|
| 108 |
+
.unsqueeze(0)
|
| 109 |
+
.unsqueeze(0)
|
| 110 |
+
.repeat(batch, head, channel, 1)
|
| 111 |
+
.to(values.device)
|
| 112 |
+
)
|
| 113 |
+
top_k = int(self.factor * math.log(length))
|
| 114 |
+
weights, delay = torch.topk(corr, top_k, dim=-1)
|
| 115 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 116 |
+
|
| 117 |
+
tmp_values = values.repeat(1, 1, 1, 2)
|
| 118 |
+
delays_agg = torch.zeros_like(values).float()
|
| 119 |
+
for i in range(top_k):
|
| 120 |
+
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
| 121 |
+
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
| 122 |
+
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
| 123 |
+
return delays_agg
|
| 124 |
+
|
| 125 |
+
def forward(self, queries, keys, values, attn_mask):
|
| 126 |
+
B, L, H, E = queries.shape
|
| 127 |
+
_, S, _, D = values.shape
|
| 128 |
+
if L > S:
|
| 129 |
+
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
|
| 130 |
+
values = torch.cat([values, zeros], dim=1)
|
| 131 |
+
keys = torch.cat([keys, zeros], dim=1)
|
| 132 |
+
else:
|
| 133 |
+
values = values[:, :L, :, :]
|
| 134 |
+
keys = keys[:, :L, :, :]
|
| 135 |
+
|
| 136 |
+
# period-based dependencies
|
| 137 |
+
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
| 138 |
+
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
| 139 |
+
res = q_fft * torch.conj(k_fft)
|
| 140 |
+
corr = torch.fft.irfft(res, dim=-1)
|
| 141 |
+
|
| 142 |
+
# time delay agg
|
| 143 |
+
if self.training:
|
| 144 |
+
V = self.time_delay_agg_training(
|
| 145 |
+
values.permute(0, 2, 3, 1).contiguous(), corr
|
| 146 |
+
).permute(0, 3, 1, 2)
|
| 147 |
+
else:
|
| 148 |
+
V = self.time_delay_agg_inference(
|
| 149 |
+
values.permute(0, 2, 3, 1).contiguous(), corr
|
| 150 |
+
).permute(0, 3, 1, 2)
|
| 151 |
+
|
| 152 |
+
if self.output_attention:
|
| 153 |
+
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
| 154 |
+
else:
|
| 155 |
+
return (V.contiguous(), None)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class AutoCorrelationLayer(nn.Module):
|
| 159 |
+
"""AutoCorrelation layer with projections."""
|
| 160 |
+
|
| 161 |
+
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
|
| 162 |
+
super(AutoCorrelationLayer, self).__init__()
|
| 163 |
+
|
| 164 |
+
d_keys = d_keys or (d_model // n_heads)
|
| 165 |
+
d_values = d_values or (d_model // n_heads)
|
| 166 |
+
|
| 167 |
+
self.inner_correlation = correlation
|
| 168 |
+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 169 |
+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 170 |
+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
| 171 |
+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
| 172 |
+
self.n_heads = n_heads
|
| 173 |
+
|
| 174 |
+
def forward(self, queries, keys, values, attn_mask):
|
| 175 |
+
B, L, _ = queries.shape
|
| 176 |
+
_, S, _ = keys.shape
|
| 177 |
+
H = self.n_heads
|
| 178 |
+
|
| 179 |
+
queries = self.query_projection(queries).view(B, L, H, -1)
|
| 180 |
+
keys = self.key_projection(keys).view(B, S, H, -1)
|
| 181 |
+
values = self.value_projection(values).view(B, S, H, -1)
|
| 182 |
+
|
| 183 |
+
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
|
| 184 |
+
out = out.view(B, L, -1)
|
| 185 |
+
|
| 186 |
+
return self.out_projection(out), attn
|
layers/Autoformer_EncDec.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Autoformer Encoder-Decoder with Series Decomposition.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class my_Layernorm(nn.Module):
|
| 11 |
+
"""Special designed layernorm for the seasonal part."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, channels):
|
| 14 |
+
super(my_Layernorm, self).__init__()
|
| 15 |
+
self.layernorm = nn.LayerNorm(channels)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x_hat = self.layernorm(x)
|
| 19 |
+
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
| 20 |
+
return x_hat - bias
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class moving_avg(nn.Module):
|
| 24 |
+
"""Moving average block to highlight the trend of time series."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, kernel_size, stride):
|
| 27 |
+
super(moving_avg, self).__init__()
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 33 |
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 34 |
+
x = torch.cat([front, x, end], dim=1)
|
| 35 |
+
x = self.avg(x.permute(0, 2, 1))
|
| 36 |
+
x = x.permute(0, 2, 1)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class series_decomp(nn.Module):
|
| 41 |
+
"""Series decomposition block."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, kernel_size):
|
| 44 |
+
super(series_decomp, self).__init__()
|
| 45 |
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
moving_mean = self.moving_avg(x)
|
| 49 |
+
res = x - moving_mean
|
| 50 |
+
return res, moving_mean
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class series_decomp_multi(nn.Module):
|
| 54 |
+
"""Multiple Series decomposition block from FEDformer."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, kernel_size):
|
| 57 |
+
super(series_decomp_multi, self).__init__()
|
| 58 |
+
self.kernel_size = kernel_size
|
| 59 |
+
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
moving_mean = []
|
| 63 |
+
res = []
|
| 64 |
+
for func in self.series_decomp:
|
| 65 |
+
sea, moving_avg = func(x)
|
| 66 |
+
moving_mean.append(moving_avg)
|
| 67 |
+
res.append(sea)
|
| 68 |
+
|
| 69 |
+
sea = sum(res) / len(res)
|
| 70 |
+
moving_mean = sum(moving_mean) / len(moving_mean)
|
| 71 |
+
return sea, moving_mean
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class EncoderLayer(nn.Module):
|
| 75 |
+
"""Autoformer encoder layer with progressive decomposition."""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
attention,
|
| 80 |
+
d_model,
|
| 81 |
+
d_ff=None,
|
| 82 |
+
moving_avg=25,
|
| 83 |
+
dropout=0.1,
|
| 84 |
+
activation="relu",
|
| 85 |
+
):
|
| 86 |
+
super(EncoderLayer, self).__init__()
|
| 87 |
+
d_ff = d_ff or 4 * d_model
|
| 88 |
+
self.attention = attention
|
| 89 |
+
self.conv1 = nn.Conv1d(
|
| 90 |
+
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
| 91 |
+
)
|
| 92 |
+
self.conv2 = nn.Conv1d(
|
| 93 |
+
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
| 94 |
+
)
|
| 95 |
+
self.decomp1 = series_decomp(moving_avg)
|
| 96 |
+
self.decomp2 = series_decomp(moving_avg)
|
| 97 |
+
self.dropout = nn.Dropout(dropout)
|
| 98 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 99 |
+
|
| 100 |
+
def forward(self, x, attn_mask=None):
|
| 101 |
+
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
|
| 102 |
+
x = x + self.dropout(new_x)
|
| 103 |
+
x, _ = self.decomp1(x)
|
| 104 |
+
y = x
|
| 105 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 106 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 107 |
+
res, _ = self.decomp2(x + y)
|
| 108 |
+
return res, attn
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Encoder(nn.Module):
|
| 112 |
+
"""Autoformer encoder."""
|
| 113 |
+
|
| 114 |
+
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
| 115 |
+
super(Encoder, self).__init__()
|
| 116 |
+
self.attn_layers = nn.ModuleList(attn_layers)
|
| 117 |
+
self.conv_layers = (
|
| 118 |
+
nn.ModuleList(conv_layers) if conv_layers is not None else None
|
| 119 |
+
)
|
| 120 |
+
self.norm = norm_layer
|
| 121 |
+
|
| 122 |
+
def forward(self, x, attn_mask=None):
|
| 123 |
+
attns = []
|
| 124 |
+
if self.conv_layers is not None:
|
| 125 |
+
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
| 126 |
+
x, attn = attn_layer(x, attn_mask=attn_mask)
|
| 127 |
+
x = conv_layer(x)
|
| 128 |
+
attns.append(attn)
|
| 129 |
+
x, attn = self.attn_layers[-1](x)
|
| 130 |
+
attns.append(attn)
|
| 131 |
+
else:
|
| 132 |
+
for attn_layer in self.attn_layers:
|
| 133 |
+
x, attn = attn_layer(x, attn_mask=attn_mask)
|
| 134 |
+
attns.append(attn)
|
| 135 |
+
|
| 136 |
+
if self.norm is not None:
|
| 137 |
+
x = self.norm(x)
|
| 138 |
+
|
| 139 |
+
return x, attns
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DecoderLayer(nn.Module):
|
| 143 |
+
"""Autoformer decoder layer with progressive decomposition."""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
self_attention,
|
| 148 |
+
cross_attention,
|
| 149 |
+
d_model,
|
| 150 |
+
c_out,
|
| 151 |
+
d_ff=None,
|
| 152 |
+
moving_avg=25,
|
| 153 |
+
dropout=0.1,
|
| 154 |
+
activation="relu",
|
| 155 |
+
):
|
| 156 |
+
super(DecoderLayer, self).__init__()
|
| 157 |
+
d_ff = d_ff or 4 * d_model
|
| 158 |
+
self.self_attention = self_attention
|
| 159 |
+
self.cross_attention = cross_attention
|
| 160 |
+
self.conv1 = nn.Conv1d(
|
| 161 |
+
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
| 162 |
+
)
|
| 163 |
+
self.conv2 = nn.Conv1d(
|
| 164 |
+
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
| 165 |
+
)
|
| 166 |
+
self.decomp1 = series_decomp(moving_avg)
|
| 167 |
+
self.decomp2 = series_decomp(moving_avg)
|
| 168 |
+
self.decomp3 = series_decomp(moving_avg)
|
| 169 |
+
self.dropout = nn.Dropout(dropout)
|
| 170 |
+
self.projection = nn.Conv1d(
|
| 171 |
+
in_channels=d_model,
|
| 172 |
+
out_channels=c_out,
|
| 173 |
+
kernel_size=3,
|
| 174 |
+
stride=1,
|
| 175 |
+
padding=1,
|
| 176 |
+
padding_mode="circular",
|
| 177 |
+
bias=False,
|
| 178 |
+
)
|
| 179 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 180 |
+
|
| 181 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
| 182 |
+
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
| 183 |
+
x, trend1 = self.decomp1(x)
|
| 184 |
+
x = x + self.dropout(
|
| 185 |
+
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
|
| 186 |
+
)
|
| 187 |
+
x, trend2 = self.decomp2(x)
|
| 188 |
+
y = x
|
| 189 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 190 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 191 |
+
x, trend3 = self.decomp3(x + y)
|
| 192 |
+
|
| 193 |
+
residual_trend = trend1 + trend2 + trend3
|
| 194 |
+
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
|
| 195 |
+
1, 2
|
| 196 |
+
)
|
| 197 |
+
return x, residual_trend
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Decoder(nn.Module):
|
| 201 |
+
"""Autoformer decoder."""
|
| 202 |
+
|
| 203 |
+
def __init__(self, layers, norm_layer=None, projection=None):
|
| 204 |
+
super(Decoder, self).__init__()
|
| 205 |
+
self.layers = nn.ModuleList(layers)
|
| 206 |
+
self.norm = norm_layer
|
| 207 |
+
self.projection = projection
|
| 208 |
+
|
| 209 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
| 210 |
+
for layer in self.layers:
|
| 211 |
+
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
| 212 |
+
trend = trend + residual_trend
|
| 213 |
+
|
| 214 |
+
if self.norm is not None:
|
| 215 |
+
x = self.norm(x)
|
| 216 |
+
|
| 217 |
+
if self.projection is not None:
|
| 218 |
+
x = self.projection(x)
|
| 219 |
+
return x, trend
|
layers/Conv_Blocks.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convolution Blocks for Voltage Anomaly Detection
|
| 3 |
+
Standalone version - independent from main TSLib
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Inception_Block_V1(nn.Module):
|
| 11 |
+
"""Inception block with multiple kernel sizes for TimesNet."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
| 14 |
+
super(Inception_Block_V1, self).__init__()
|
| 15 |
+
self.in_channels = in_channels
|
| 16 |
+
self.out_channels = out_channels
|
| 17 |
+
self.num_kernels = num_kernels
|
| 18 |
+
kernels = []
|
| 19 |
+
for i in range(self.num_kernels):
|
| 20 |
+
kernels.append(
|
| 21 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)
|
| 22 |
+
)
|
| 23 |
+
self.kernels = nn.ModuleList(kernels)
|
| 24 |
+
if init_weight:
|
| 25 |
+
self._initialize_weights()
|
| 26 |
+
|
| 27 |
+
def _initialize_weights(self):
|
| 28 |
+
for m in self.modules():
|
| 29 |
+
if isinstance(m, nn.Conv2d):
|
| 30 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 31 |
+
if m.bias is not None:
|
| 32 |
+
nn.init.constant_(m.bias, 0)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
res_list = []
|
| 36 |
+
for i in range(self.num_kernels):
|
| 37 |
+
res_list.append(self.kernels[i](x))
|
| 38 |
+
res = torch.stack(res_list, dim=-1).mean(-1)
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Inception_Block_V2(nn.Module):
|
| 43 |
+
"""Inception block V2 with separable convolutions."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
| 46 |
+
super(Inception_Block_V2, self).__init__()
|
| 47 |
+
self.in_channels = in_channels
|
| 48 |
+
self.out_channels = out_channels
|
| 49 |
+
self.num_kernels = num_kernels
|
| 50 |
+
kernels = []
|
| 51 |
+
for i in range(self.num_kernels // 2):
|
| 52 |
+
kernels.append(
|
| 53 |
+
nn.Conv2d(
|
| 54 |
+
in_channels,
|
| 55 |
+
out_channels,
|
| 56 |
+
kernel_size=[1, 2 * i + 3],
|
| 57 |
+
padding=[0, i + 1],
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
kernels.append(
|
| 61 |
+
nn.Conv2d(
|
| 62 |
+
in_channels,
|
| 63 |
+
out_channels,
|
| 64 |
+
kernel_size=[2 * i + 3, 1],
|
| 65 |
+
padding=[i + 1, 0],
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
| 69 |
+
self.kernels = nn.ModuleList(kernels)
|
| 70 |
+
if init_weight:
|
| 71 |
+
self._initialize_weights()
|
| 72 |
+
|
| 73 |
+
def _initialize_weights(self):
|
| 74 |
+
for m in self.modules():
|
| 75 |
+
if isinstance(m, nn.Conv2d):
|
| 76 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 77 |
+
if m.bias is not None:
|
| 78 |
+
nn.init.constant_(m.bias, 0)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
res_list = []
|
| 82 |
+
for i in range(self.num_kernels // 2 * 2 + 1):
|
| 83 |
+
res_list.append(self.kernels[i](x))
|
| 84 |
+
res = torch.stack(res_list, dim=-1).mean(-1)
|
| 85 |
+
return res
|
layers/Embed.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding Layers for Voltage Anomaly Detection
|
| 3 |
+
Standalone version - independent from main TSLib
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PositionalEmbedding(nn.Module):
|
| 14 |
+
"""Positional encoding using sinusoidal functions."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, d_model, max_len=5000):
|
| 17 |
+
super(PositionalEmbedding, self).__init__()
|
| 18 |
+
# Compute the positional encodings once in log space
|
| 19 |
+
pe = torch.zeros(max_len, d_model).float()
|
| 20 |
+
pe.require_grad = False
|
| 21 |
+
|
| 22 |
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
| 23 |
+
div_term = (
|
| 24 |
+
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
| 25 |
+
).exp()
|
| 26 |
+
|
| 27 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 29 |
+
|
| 30 |
+
pe = pe.unsqueeze(0)
|
| 31 |
+
self.register_buffer("pe", pe)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return self.pe[:, : x.size(1)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TokenEmbedding(nn.Module):
|
| 38 |
+
"""Token embedding using 1D convolution."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, c_in, d_model):
|
| 41 |
+
super(TokenEmbedding, self).__init__()
|
| 42 |
+
padding = 1 if torch.__version__ >= "1.5.0" else 2
|
| 43 |
+
self.tokenConv = nn.Conv1d(
|
| 44 |
+
in_channels=c_in,
|
| 45 |
+
out_channels=d_model,
|
| 46 |
+
kernel_size=3,
|
| 47 |
+
padding=padding,
|
| 48 |
+
padding_mode="circular",
|
| 49 |
+
bias=False,
|
| 50 |
+
)
|
| 51 |
+
for m in self.modules():
|
| 52 |
+
if isinstance(m, nn.Conv1d):
|
| 53 |
+
nn.init.kaiming_normal_(
|
| 54 |
+
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FixedEmbedding(nn.Module):
|
| 63 |
+
"""Fixed embedding (non-trainable) using sinusoidal functions."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, c_in, d_model):
|
| 66 |
+
super(FixedEmbedding, self).__init__()
|
| 67 |
+
|
| 68 |
+
w = torch.zeros(c_in, d_model).float()
|
| 69 |
+
w.require_grad = False
|
| 70 |
+
|
| 71 |
+
position = torch.arange(0, c_in).float().unsqueeze(1)
|
| 72 |
+
div_term = (
|
| 73 |
+
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
|
| 74 |
+
).exp()
|
| 75 |
+
|
| 76 |
+
w[:, 0::2] = torch.sin(position * div_term)
|
| 77 |
+
w[:, 1::2] = torch.cos(position * div_term)
|
| 78 |
+
|
| 79 |
+
self.emb = nn.Embedding(c_in, d_model)
|
| 80 |
+
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
return self.emb(x).detach()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TemporalEmbedding(nn.Module):
|
| 87 |
+
"""Temporal embedding for time features."""
|
| 88 |
+
|
| 89 |
+
def __init__(self, d_model, embed_type="fixed", freq="h"):
|
| 90 |
+
super(TemporalEmbedding, self).__init__()
|
| 91 |
+
|
| 92 |
+
minute_size = 4
|
| 93 |
+
hour_size = 24
|
| 94 |
+
weekday_size = 7
|
| 95 |
+
day_size = 32
|
| 96 |
+
month_size = 13
|
| 97 |
+
|
| 98 |
+
Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
|
| 99 |
+
if freq == "t":
|
| 100 |
+
self.minute_embed = Embed(minute_size, d_model)
|
| 101 |
+
self.hour_embed = Embed(hour_size, d_model)
|
| 102 |
+
self.weekday_embed = Embed(weekday_size, d_model)
|
| 103 |
+
self.day_embed = Embed(day_size, d_model)
|
| 104 |
+
self.month_embed = Embed(month_size, d_model)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = x.long()
|
| 108 |
+
minute_x = (
|
| 109 |
+
self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
|
| 110 |
+
)
|
| 111 |
+
hour_x = self.hour_embed(x[:, :, 3])
|
| 112 |
+
weekday_x = self.weekday_embed(x[:, :, 2])
|
| 113 |
+
day_x = self.day_embed(x[:, :, 1])
|
| 114 |
+
month_x = self.month_embed(x[:, :, 0])
|
| 115 |
+
|
| 116 |
+
return hour_x + weekday_x + day_x + month_x + minute_x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TimeFeatureEmbedding(nn.Module):
|
| 120 |
+
"""Time feature embedding using linear projection."""
|
| 121 |
+
|
| 122 |
+
def __init__(self, d_model, embed_type="timeF", freq="h"):
|
| 123 |
+
super(TimeFeatureEmbedding, self).__init__()
|
| 124 |
+
|
| 125 |
+
freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
|
| 126 |
+
d_inp = freq_map[freq]
|
| 127 |
+
self.embed = nn.Linear(d_inp, d_model, bias=False)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
return self.embed(x)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class DataEmbedding(nn.Module):
|
| 134 |
+
"""Complete data embedding with value, position, and temporal components."""
|
| 135 |
+
|
| 136 |
+
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
| 137 |
+
super(DataEmbedding, self).__init__()
|
| 138 |
+
|
| 139 |
+
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
| 140 |
+
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
| 141 |
+
self.temporal_embedding = (
|
| 142 |
+
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
| 143 |
+
if embed_type != "timeF"
|
| 144 |
+
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
| 145 |
+
)
|
| 146 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, x_mark):
|
| 149 |
+
if x_mark is None:
|
| 150 |
+
x = self.value_embedding(x) + self.position_embedding(x)
|
| 151 |
+
else:
|
| 152 |
+
x = (
|
| 153 |
+
self.value_embedding(x)
|
| 154 |
+
+ self.temporal_embedding(x_mark)
|
| 155 |
+
+ self.position_embedding(x)
|
| 156 |
+
)
|
| 157 |
+
return self.dropout(x)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class DataEmbedding_inverted(nn.Module):
|
| 161 |
+
"""Inverted data embedding (channel-independent)."""
|
| 162 |
+
|
| 163 |
+
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
| 164 |
+
super(DataEmbedding_inverted, self).__init__()
|
| 165 |
+
self.value_embedding = nn.Linear(c_in, d_model)
|
| 166 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 167 |
+
|
| 168 |
+
def forward(self, x, x_mark):
|
| 169 |
+
x = x.permute(0, 2, 1)
|
| 170 |
+
# x: [Batch, Variate, Time]
|
| 171 |
+
if x_mark is None:
|
| 172 |
+
x = self.value_embedding(x)
|
| 173 |
+
else:
|
| 174 |
+
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
|
| 175 |
+
# x: [Batch, Variate, d_model]
|
| 176 |
+
return self.dropout(x)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class DataEmbedding_wo_pos(nn.Module):
|
| 180 |
+
"""Data embedding without positional encoding."""
|
| 181 |
+
|
| 182 |
+
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
|
| 183 |
+
super(DataEmbedding_wo_pos, self).__init__()
|
| 184 |
+
|
| 185 |
+
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
| 186 |
+
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
| 187 |
+
self.temporal_embedding = (
|
| 188 |
+
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
| 189 |
+
if embed_type != "timeF"
|
| 190 |
+
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
| 191 |
+
)
|
| 192 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 193 |
+
|
| 194 |
+
def forward(self, x, x_mark):
|
| 195 |
+
if x_mark is None:
|
| 196 |
+
x = self.value_embedding(x)
|
| 197 |
+
else:
|
| 198 |
+
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
| 199 |
+
return self.dropout(x)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class PatchEmbedding(nn.Module):
|
| 203 |
+
"""Patch-based embedding for time series."""
|
| 204 |
+
|
| 205 |
+
def __init__(self, d_model, patch_len, stride, padding, dropout):
|
| 206 |
+
super(PatchEmbedding, self).__init__()
|
| 207 |
+
# Patching
|
| 208 |
+
self.patch_len = patch_len
|
| 209 |
+
self.stride = stride
|
| 210 |
+
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
|
| 211 |
+
|
| 212 |
+
# Input encoding: projection of feature vectors onto a d-dim vector space
|
| 213 |
+
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
| 214 |
+
|
| 215 |
+
# Positional embedding
|
| 216 |
+
self.position_embedding = PositionalEmbedding(d_model)
|
| 217 |
+
|
| 218 |
+
# Residual dropout
|
| 219 |
+
self.dropout = nn.Dropout(dropout)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
# do patching
|
| 223 |
+
n_vars = x.shape[1]
|
| 224 |
+
x = self.padding_patch_layer(x)
|
| 225 |
+
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
| 226 |
+
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
| 227 |
+
# Input encoding
|
| 228 |
+
x = self.value_embedding(x) + self.position_embedding(x)
|
| 229 |
+
return self.dropout(x), n_vars
|
layers/SelfAttention_Family.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Self-Attention Family for Time Series Models.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- FullAttention: Standard scaled dot-product attention
|
| 6 |
+
- ProbAttention: Informer's ProbSparse attention
|
| 7 |
+
- DSAttention: De-stationary attention
|
| 8 |
+
- AttentionLayer: Attention wrapper with projections
|
| 9 |
+
- ReformerLayer: LSH attention (requires reformer_pytorch)
|
| 10 |
+
- TwoStageAttentionLayer: Crossformer's two-stage attention
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from math import sqrt
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from utils.masking import ProbMask, TriangularCausalMask
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DSAttention(nn.Module):
|
| 23 |
+
"""De-stationary Attention"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
mask_flag=True,
|
| 28 |
+
factor=5,
|
| 29 |
+
scale=None,
|
| 30 |
+
attention_dropout=0.1,
|
| 31 |
+
output_attention=False,
|
| 32 |
+
):
|
| 33 |
+
super(DSAttention, self).__init__()
|
| 34 |
+
self.scale = scale
|
| 35 |
+
self.mask_flag = mask_flag
|
| 36 |
+
self.output_attention = output_attention
|
| 37 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 38 |
+
|
| 39 |
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
| 40 |
+
B, L, H, E = queries.shape
|
| 41 |
+
_, S, _, D = values.shape
|
| 42 |
+
scale = self.scale or 1.0 / sqrt(E)
|
| 43 |
+
|
| 44 |
+
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1)
|
| 45 |
+
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1)
|
| 46 |
+
|
| 47 |
+
# De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors
|
| 48 |
+
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
|
| 49 |
+
|
| 50 |
+
if self.mask_flag:
|
| 51 |
+
if attn_mask is None:
|
| 52 |
+
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
| 53 |
+
scores.masked_fill_(attn_mask.mask, -np.inf)
|
| 54 |
+
|
| 55 |
+
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
| 56 |
+
V = torch.einsum("bhls,bshd->blhd", A, values)
|
| 57 |
+
|
| 58 |
+
if self.output_attention:
|
| 59 |
+
return V.contiguous(), A
|
| 60 |
+
else:
|
| 61 |
+
return V.contiguous(), None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class FullAttention(nn.Module):
|
| 65 |
+
"""Standard scaled dot-product attention."""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
mask_flag=True,
|
| 70 |
+
factor=5,
|
| 71 |
+
scale=None,
|
| 72 |
+
attention_dropout=0.1,
|
| 73 |
+
output_attention=False,
|
| 74 |
+
):
|
| 75 |
+
super(FullAttention, self).__init__()
|
| 76 |
+
self.scale = scale
|
| 77 |
+
self.mask_flag = mask_flag
|
| 78 |
+
self.output_attention = output_attention
|
| 79 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 80 |
+
|
| 81 |
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
| 82 |
+
B, L, H, E = queries.shape
|
| 83 |
+
_, S, _, D = values.shape
|
| 84 |
+
scale = self.scale or 1.0 / sqrt(E)
|
| 85 |
+
|
| 86 |
+
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
| 87 |
+
|
| 88 |
+
if self.mask_flag:
|
| 89 |
+
if attn_mask is None:
|
| 90 |
+
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
| 91 |
+
scores.masked_fill_(attn_mask.mask, -np.inf)
|
| 92 |
+
|
| 93 |
+
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
| 94 |
+
V = torch.einsum("bhls,bshd->blhd", A, values)
|
| 95 |
+
|
| 96 |
+
if self.output_attention:
|
| 97 |
+
return V.contiguous(), A
|
| 98 |
+
else:
|
| 99 |
+
return V.contiguous(), None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ProbAttention(nn.Module):
|
| 103 |
+
"""Informer's ProbSparse self-attention mechanism."""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
mask_flag=True,
|
| 108 |
+
factor=5,
|
| 109 |
+
scale=None,
|
| 110 |
+
attention_dropout=0.1,
|
| 111 |
+
output_attention=False,
|
| 112 |
+
):
|
| 113 |
+
super(ProbAttention, self).__init__()
|
| 114 |
+
self.factor = factor
|
| 115 |
+
self.scale = scale
|
| 116 |
+
self.mask_flag = mask_flag
|
| 117 |
+
self.output_attention = output_attention
|
| 118 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 119 |
+
|
| 120 |
+
def _prob_QK(self, Q, K, sample_k, n_top):
|
| 121 |
+
B, H, L_K, E = K.shape
|
| 122 |
+
_, _, L_Q, _ = Q.shape
|
| 123 |
+
|
| 124 |
+
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
| 125 |
+
index_sample = torch.randint(L_K, (L_Q, sample_k))
|
| 126 |
+
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
|
| 127 |
+
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
|
| 128 |
+
|
| 129 |
+
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
|
| 130 |
+
M_top = M.topk(n_top, sorted=False)[1]
|
| 131 |
+
|
| 132 |
+
Q_reduce = Q[
|
| 133 |
+
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
|
| 134 |
+
]
|
| 135 |
+
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))
|
| 136 |
+
|
| 137 |
+
return Q_K, M_top
|
| 138 |
+
|
| 139 |
+
def _get_initial_context(self, V, L_Q):
|
| 140 |
+
B, H, L_V, D = V.shape
|
| 141 |
+
if not self.mask_flag:
|
| 142 |
+
V_sum = V.mean(dim=-2)
|
| 143 |
+
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
|
| 144 |
+
else:
|
| 145 |
+
assert L_Q == L_V
|
| 146 |
+
contex = V.cumsum(dim=-2)
|
| 147 |
+
return contex
|
| 148 |
+
|
| 149 |
+
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
| 150 |
+
B, H, L_V, D = V.shape
|
| 151 |
+
|
| 152 |
+
if self.mask_flag:
|
| 153 |
+
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
| 154 |
+
scores.masked_fill_(attn_mask.mask, -np.inf)
|
| 155 |
+
|
| 156 |
+
attn = torch.softmax(scores, dim=-1)
|
| 157 |
+
|
| 158 |
+
context_in[
|
| 159 |
+
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
| 160 |
+
] = torch.matmul(attn, V).type_as(context_in)
|
| 161 |
+
|
| 162 |
+
if self.output_attention:
|
| 163 |
+
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
|
| 164 |
+
attns[
|
| 165 |
+
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
| 166 |
+
] = attn
|
| 167 |
+
return context_in, attns
|
| 168 |
+
else:
|
| 169 |
+
return context_in, None
|
| 170 |
+
|
| 171 |
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
| 172 |
+
B, L_Q, H, D = queries.shape
|
| 173 |
+
_, L_K, _, _ = keys.shape
|
| 174 |
+
|
| 175 |
+
queries = queries.transpose(2, 1)
|
| 176 |
+
keys = keys.transpose(2, 1)
|
| 177 |
+
values = values.transpose(2, 1)
|
| 178 |
+
|
| 179 |
+
U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item()
|
| 180 |
+
u = self.factor * np.ceil(np.log(L_Q)).astype("int").item()
|
| 181 |
+
|
| 182 |
+
U_part = U_part if U_part < L_K else L_K
|
| 183 |
+
u = u if u < L_Q else L_Q
|
| 184 |
+
|
| 185 |
+
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
|
| 186 |
+
|
| 187 |
+
scale = self.scale or 1.0 / sqrt(D)
|
| 188 |
+
if scale is not None:
|
| 189 |
+
scores_top = scores_top * scale
|
| 190 |
+
|
| 191 |
+
context = self._get_initial_context(values, L_Q)
|
| 192 |
+
context, attn = self._update_context(
|
| 193 |
+
context, values, scores_top, index, L_Q, attn_mask
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return context.contiguous(), attn
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class AttentionLayer(nn.Module):
|
| 200 |
+
"""Attention wrapper with input/output projections."""
|
| 201 |
+
|
| 202 |
+
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
|
| 203 |
+
super(AttentionLayer, self).__init__()
|
| 204 |
+
|
| 205 |
+
d_keys = d_keys or (d_model // n_heads)
|
| 206 |
+
d_values = d_values or (d_model // n_heads)
|
| 207 |
+
|
| 208 |
+
self.inner_attention = attention
|
| 209 |
+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 210 |
+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 211 |
+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
| 212 |
+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
| 213 |
+
self.n_heads = n_heads
|
| 214 |
+
|
| 215 |
+
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
| 216 |
+
B, L, _ = queries.shape
|
| 217 |
+
_, S, _ = keys.shape
|
| 218 |
+
H = self.n_heads
|
| 219 |
+
|
| 220 |
+
queries = self.query_projection(queries).view(B, L, H, -1)
|
| 221 |
+
keys = self.key_projection(keys).view(B, S, H, -1)
|
| 222 |
+
values = self.value_projection(values).view(B, S, H, -1)
|
| 223 |
+
|
| 224 |
+
out, attn = self.inner_attention(
|
| 225 |
+
queries, keys, values, attn_mask, tau=tau, delta=delta
|
| 226 |
+
)
|
| 227 |
+
out = out.view(B, L, -1)
|
| 228 |
+
|
| 229 |
+
return self.out_projection(out), attn
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class ReformerLayer(nn.Module):
|
| 233 |
+
"""LSH Self-Attention layer (Reformer)."""
|
| 234 |
+
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
attention,
|
| 238 |
+
d_model,
|
| 239 |
+
n_heads,
|
| 240 |
+
d_keys=None,
|
| 241 |
+
d_values=None,
|
| 242 |
+
causal=False,
|
| 243 |
+
bucket_size=4,
|
| 244 |
+
n_hashes=4,
|
| 245 |
+
):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.bucket_size = bucket_size
|
| 248 |
+
# Note: requires reformer_pytorch package
|
| 249 |
+
try:
|
| 250 |
+
from reformer_pytorch import LSHSelfAttention
|
| 251 |
+
|
| 252 |
+
self.attn = LSHSelfAttention(
|
| 253 |
+
dim=d_model,
|
| 254 |
+
heads=n_heads,
|
| 255 |
+
bucket_size=bucket_size,
|
| 256 |
+
n_hashes=n_hashes,
|
| 257 |
+
causal=causal,
|
| 258 |
+
)
|
| 259 |
+
except ImportError:
|
| 260 |
+
raise ImportError(
|
| 261 |
+
"ReformerLayer requires reformer_pytorch. Install with: pip install reformer_pytorch"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
def fit_length(self, queries):
|
| 265 |
+
B, N, C = queries.shape
|
| 266 |
+
if N % (self.bucket_size * 2) == 0:
|
| 267 |
+
return queries
|
| 268 |
+
else:
|
| 269 |
+
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
|
| 270 |
+
return torch.cat(
|
| 271 |
+
[queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def forward(self, queries, keys, values, attn_mask, tau, delta):
|
| 275 |
+
B, N, C = queries.shape
|
| 276 |
+
queries = self.attn(self.fit_length(queries))[:, :N, :]
|
| 277 |
+
return queries, None
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class TwoStageAttentionLayer(nn.Module):
|
| 281 |
+
"""
|
| 282 |
+
Two Stage Attention (TSA) Layer for Crossformer.
|
| 283 |
+
input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self, configs, seg_num, factor, d_model, n_heads, d_ff=None, dropout=0.1
|
| 288 |
+
):
|
| 289 |
+
super(TwoStageAttentionLayer, self).__init__()
|
| 290 |
+
d_ff = d_ff or 4 * d_model
|
| 291 |
+
self.time_attention = AttentionLayer(
|
| 292 |
+
FullAttention(
|
| 293 |
+
False,
|
| 294 |
+
configs.factor,
|
| 295 |
+
attention_dropout=configs.dropout,
|
| 296 |
+
output_attention=False,
|
| 297 |
+
),
|
| 298 |
+
d_model,
|
| 299 |
+
n_heads,
|
| 300 |
+
)
|
| 301 |
+
self.dim_sender = AttentionLayer(
|
| 302 |
+
FullAttention(
|
| 303 |
+
False,
|
| 304 |
+
configs.factor,
|
| 305 |
+
attention_dropout=configs.dropout,
|
| 306 |
+
output_attention=False,
|
| 307 |
+
),
|
| 308 |
+
d_model,
|
| 309 |
+
n_heads,
|
| 310 |
+
)
|
| 311 |
+
self.dim_receiver = AttentionLayer(
|
| 312 |
+
FullAttention(
|
| 313 |
+
False,
|
| 314 |
+
configs.factor,
|
| 315 |
+
attention_dropout=configs.dropout,
|
| 316 |
+
output_attention=False,
|
| 317 |
+
),
|
| 318 |
+
d_model,
|
| 319 |
+
n_heads,
|
| 320 |
+
)
|
| 321 |
+
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
|
| 322 |
+
|
| 323 |
+
self.dropout = nn.Dropout(dropout)
|
| 324 |
+
|
| 325 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 326 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 327 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 328 |
+
self.norm4 = nn.LayerNorm(d_model)
|
| 329 |
+
|
| 330 |
+
self.MLP1 = nn.Sequential(
|
| 331 |
+
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
|
| 332 |
+
)
|
| 333 |
+
self.MLP2 = nn.Sequential(
|
| 334 |
+
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
| 338 |
+
try:
|
| 339 |
+
from einops import rearrange, repeat
|
| 340 |
+
except ImportError:
|
| 341 |
+
raise ImportError(
|
| 342 |
+
"TwoStageAttentionLayer requires einops. Install with: pip install einops"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
batch = x.shape[0]
|
| 346 |
+
time_in = rearrange(x, "b ts_d seg_num d_model -> (b ts_d) seg_num d_model")
|
| 347 |
+
time_enc, attn = self.time_attention(
|
| 348 |
+
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
|
| 349 |
+
)
|
| 350 |
+
dim_in = time_in + self.dropout(time_enc)
|
| 351 |
+
dim_in = self.norm1(dim_in)
|
| 352 |
+
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
| 353 |
+
dim_in = self.norm2(dim_in)
|
| 354 |
+
|
| 355 |
+
dim_send = rearrange(
|
| 356 |
+
dim_in, "(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model", b=batch
|
| 357 |
+
)
|
| 358 |
+
batch_router = repeat(
|
| 359 |
+
self.router,
|
| 360 |
+
"seg_num factor d_model -> (repeat seg_num) factor d_model",
|
| 361 |
+
repeat=batch,
|
| 362 |
+
)
|
| 363 |
+
dim_buffer, attn = self.dim_sender(
|
| 364 |
+
batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None
|
| 365 |
+
)
|
| 366 |
+
dim_receive, attn = self.dim_receiver(
|
| 367 |
+
dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None
|
| 368 |
+
)
|
| 369 |
+
dim_enc = dim_send + self.dropout(dim_receive)
|
| 370 |
+
dim_enc = self.norm3(dim_enc)
|
| 371 |
+
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
| 372 |
+
dim_enc = self.norm4(dim_enc)
|
| 373 |
+
|
| 374 |
+
final_out = rearrange(
|
| 375 |
+
dim_enc, "(b seg_num) ts_d d_model -> b ts_d seg_num d_model", b=batch
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return final_out
|
layers/StandardNorm.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
StandardNorm / Normalize layer (RevIN-style normalization)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Normalize(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Reversible Instance Normalization for time series.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
num_features: int,
|
| 17 |
+
eps=1e-5,
|
| 18 |
+
affine=False,
|
| 19 |
+
subtract_last=False,
|
| 20 |
+
non_norm=False,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
:param num_features: the number of features or channels
|
| 24 |
+
:param eps: a value added for numerical stability
|
| 25 |
+
:param affine: if True, RevIN has learnable affine parameters
|
| 26 |
+
"""
|
| 27 |
+
super(Normalize, self).__init__()
|
| 28 |
+
self.num_features = num_features
|
| 29 |
+
self.eps = eps
|
| 30 |
+
self.affine = affine
|
| 31 |
+
self.subtract_last = subtract_last
|
| 32 |
+
self.non_norm = non_norm
|
| 33 |
+
if self.affine:
|
| 34 |
+
self._init_params()
|
| 35 |
+
|
| 36 |
+
def forward(self, x, mode: str):
|
| 37 |
+
if mode == "norm":
|
| 38 |
+
self._get_statistics(x)
|
| 39 |
+
x = self._normalize(x)
|
| 40 |
+
elif mode == "denorm":
|
| 41 |
+
x = self._denormalize(x)
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
def _init_params(self):
|
| 47 |
+
# initialize RevIN params: (C,)
|
| 48 |
+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
| 49 |
+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
| 50 |
+
|
| 51 |
+
def _get_statistics(self, x):
|
| 52 |
+
dim2reduce = tuple(range(1, x.ndim - 1))
|
| 53 |
+
if self.subtract_last:
|
| 54 |
+
self.last = x[:, -1, :].unsqueeze(1)
|
| 55 |
+
else:
|
| 56 |
+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
| 57 |
+
self.stdev = torch.sqrt(
|
| 58 |
+
torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
|
| 59 |
+
).detach()
|
| 60 |
+
|
| 61 |
+
def _normalize(self, x):
|
| 62 |
+
if self.non_norm:
|
| 63 |
+
return x
|
| 64 |
+
if self.subtract_last:
|
| 65 |
+
x = x - self.last
|
| 66 |
+
else:
|
| 67 |
+
x = x - self.mean
|
| 68 |
+
x = x / self.stdev
|
| 69 |
+
if self.affine:
|
| 70 |
+
x = x * self.affine_weight
|
| 71 |
+
x = x + self.affine_bias
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
def _denormalize(self, x):
|
| 75 |
+
if self.non_norm:
|
| 76 |
+
return x
|
| 77 |
+
if self.affine:
|
| 78 |
+
x = x - self.affine_bias
|
| 79 |
+
x = x / (self.affine_weight + self.eps * self.eps)
|
| 80 |
+
x = x * self.stdev
|
| 81 |
+
if self.subtract_last:
|
| 82 |
+
x = x + self.last
|
| 83 |
+
else:
|
| 84 |
+
x = x + self.mean
|
| 85 |
+
return x
|
layers/ThreePhaseAttention.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ThreePhaseAttention: Attention Mechanisms for Three-Phase Power Systems
|
| 3 |
+
|
| 4 |
+
This module provides attention mechanisms specifically designed for analyzing
|
| 5 |
+
three-phase voltage signals in power grid anomaly detection:
|
| 6 |
+
|
| 7 |
+
1. InterPhaseAttention: Captures relationships between Va, Vb, Vc phases
|
| 8 |
+
2. SymmetricalComponentAttention: Analyzes positive/negative/zero sequences
|
| 9 |
+
3. TransientAttention: Multi-scale attention for transient event detection
|
| 10 |
+
4. VoltageChannelAttention: Channel-wise attention for voltage features
|
| 11 |
+
|
| 12 |
+
Key concepts:
|
| 13 |
+
- In balanced systems, Va, Vb, Vc have 120° phase difference
|
| 14 |
+
- Unbalance indicates issues (asymmetric loads, faults)
|
| 15 |
+
- Symmetrical components help diagnose fault types
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class InterPhaseAttention(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Attention mechanism that captures inter-phase relationships.
|
| 29 |
+
|
| 30 |
+
In three-phase systems:
|
| 31 |
+
- Normal: Va, Vb, Vc are balanced with 120° phase shifts
|
| 32 |
+
- Fault: Phase relationships deviate from normal patterns
|
| 33 |
+
|
| 34 |
+
This attention helps detect anomalies by modeling phase interactions.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, d_model, n_heads=4, dropout=0.1, num_phases=3):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
d_model: Model dimension
|
| 41 |
+
n_heads: Number of attention heads
|
| 42 |
+
dropout: Dropout rate
|
| 43 |
+
num_phases: Number of phases (default 3 for three-phase)
|
| 44 |
+
"""
|
| 45 |
+
super(InterPhaseAttention, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
self.n_heads = n_heads
|
| 49 |
+
self.num_phases = num_phases
|
| 50 |
+
self.d_k = d_model // n_heads
|
| 51 |
+
|
| 52 |
+
# Query, Key, Value projections for each phase
|
| 53 |
+
self.W_q = nn.Linear(d_model, d_model)
|
| 54 |
+
self.W_k = nn.Linear(d_model, d_model)
|
| 55 |
+
self.W_v = nn.Linear(d_model, d_model)
|
| 56 |
+
|
| 57 |
+
# Phase-specific transformations
|
| 58 |
+
self.phase_transforms = nn.ModuleList(
|
| 59 |
+
[nn.Linear(d_model, d_model) for _ in range(num_phases)]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Output projection
|
| 63 |
+
self.W_o = nn.Linear(d_model, d_model)
|
| 64 |
+
|
| 65 |
+
# Phase relationship encoding (learnable)
|
| 66 |
+
# Initialize with 120° phase shifts
|
| 67 |
+
phase_angles = torch.tensor([0, 2 * math.pi / 3, 4 * math.pi / 3])
|
| 68 |
+
self.register_buffer("phase_angles", phase_angles)
|
| 69 |
+
self.phase_bias = nn.Parameter(torch.zeros(num_phases, num_phases))
|
| 70 |
+
|
| 71 |
+
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, phase_mask=None):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
x: Input tensor [B, T, C] where C = num_phases * features_per_phase
|
| 78 |
+
phase_mask: Optional mask for phase interactions [num_phases, num_phases]
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Attention output [B, T, d_model]
|
| 82 |
+
"""
|
| 83 |
+
B, T, C = x.size()
|
| 84 |
+
|
| 85 |
+
# Project inputs
|
| 86 |
+
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k)
|
| 87 |
+
K = self.W_k(x).view(B, T, self.n_heads, self.d_k)
|
| 88 |
+
V = self.W_v(x).view(B, T, self.n_heads, self.d_k)
|
| 89 |
+
|
| 90 |
+
# Transpose for attention: [B, n_heads, T, d_k]
|
| 91 |
+
Q = Q.transpose(1, 2)
|
| 92 |
+
K = K.transpose(1, 2)
|
| 93 |
+
V = V.transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
# Scaled dot-product attention
|
| 96 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 97 |
+
|
| 98 |
+
# Add phase relationship bias
|
| 99 |
+
# This encourages the model to learn phase-specific interactions
|
| 100 |
+
if self.num_phases <= T:
|
| 101 |
+
phase_bias_expanded = self.phase_bias.unsqueeze(0).unsqueeze(0)
|
| 102 |
+
# Tile to match temporal dimension
|
| 103 |
+
n_tiles = (T + self.num_phases - 1) // self.num_phases
|
| 104 |
+
phase_bias_tiled = phase_bias_expanded.repeat(1, 1, n_tiles, n_tiles)
|
| 105 |
+
phase_bias_tiled = phase_bias_tiled[:, :, :T, :T]
|
| 106 |
+
scores = scores + phase_bias_tiled
|
| 107 |
+
|
| 108 |
+
if phase_mask is not None:
|
| 109 |
+
scores = scores.masked_fill(phase_mask == 0, -1e9)
|
| 110 |
+
|
| 111 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 112 |
+
attn_weights = self.dropout(attn_weights)
|
| 113 |
+
|
| 114 |
+
# Apply attention to values
|
| 115 |
+
context = torch.matmul(attn_weights, V)
|
| 116 |
+
|
| 117 |
+
# Reshape and project output
|
| 118 |
+
context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
|
| 119 |
+
output = self.W_o(context)
|
| 120 |
+
|
| 121 |
+
# Residual connection and layer norm
|
| 122 |
+
output = self.layer_norm(output + x)
|
| 123 |
+
|
| 124 |
+
return output
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class SymmetricalComponentAttention(nn.Module):
|
| 128 |
+
"""
|
| 129 |
+
Attention based on symmetrical component analysis.
|
| 130 |
+
|
| 131 |
+
Symmetrical components decompose unbalanced three-phase systems into:
|
| 132 |
+
- Positive sequence (balanced, normal operation)
|
| 133 |
+
- Negative sequence (indicates unbalance)
|
| 134 |
+
- Zero sequence (indicates ground faults)
|
| 135 |
+
|
| 136 |
+
This attention helps identify different types of power system faults.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, d_model, dropout=0.1):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
d_model: Model dimension
|
| 143 |
+
dropout: Dropout rate
|
| 144 |
+
"""
|
| 145 |
+
super(SymmetricalComponentAttention, self).__init__()
|
| 146 |
+
|
| 147 |
+
self.d_model = d_model
|
| 148 |
+
|
| 149 |
+
# Fortescue transformation matrix for symmetrical components
|
| 150 |
+
# a = exp(j*2π/3) = -0.5 + j*0.866
|
| 151 |
+
a_real = -0.5
|
| 152 |
+
a_imag = math.sqrt(3) / 2
|
| 153 |
+
|
| 154 |
+
# Transformation matrix (real and imaginary parts)
|
| 155 |
+
# [1, 1, 1; 1, a², a; 1, a, a²] for positive, negative, zero
|
| 156 |
+
self.register_buffer(
|
| 157 |
+
"fortescue_real",
|
| 158 |
+
torch.tensor(
|
| 159 |
+
[
|
| 160 |
+
[1, 1, 1],
|
| 161 |
+
[1, a_real**2 - a_imag**2, a_real],
|
| 162 |
+
[1, a_real, a_real**2 - a_imag**2],
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
/ 3.0,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.register_buffer(
|
| 169 |
+
"fortescue_imag",
|
| 170 |
+
torch.tensor(
|
| 171 |
+
[
|
| 172 |
+
[0, 0, 0],
|
| 173 |
+
[0, 2 * a_real * a_imag, a_imag],
|
| 174 |
+
[0, a_imag, 2 * a_real * a_imag],
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
/ 3.0,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Sequence-specific attention
|
| 181 |
+
self.pos_seq_attn = nn.MultiheadAttention(
|
| 182 |
+
d_model, num_heads=4, dropout=dropout, batch_first=True
|
| 183 |
+
)
|
| 184 |
+
self.neg_seq_attn = nn.MultiheadAttention(
|
| 185 |
+
d_model, num_heads=4, dropout=dropout, batch_first=True
|
| 186 |
+
)
|
| 187 |
+
self.zero_seq_attn = nn.MultiheadAttention(
|
| 188 |
+
d_model, num_heads=4, dropout=dropout, batch_first=True
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Sequence weighting (learnable importance of each sequence)
|
| 192 |
+
self.sequence_weights = nn.Parameter(torch.tensor([1.0, 0.5, 0.3]))
|
| 193 |
+
|
| 194 |
+
# Output projection
|
| 195 |
+
self.output_proj = nn.Linear(d_model * 3, d_model)
|
| 196 |
+
|
| 197 |
+
self.dropout = nn.Dropout(dropout)
|
| 198 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 199 |
+
|
| 200 |
+
def compute_symmetrical_components(self, x):
|
| 201 |
+
"""
|
| 202 |
+
Compute symmetrical components from three-phase signals.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
x: Three-phase signals [B, T, 3]
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Tuple of (positive, negative, zero) sequences, each [B, T, 1]
|
| 209 |
+
"""
|
| 210 |
+
B, T, _ = x.size()
|
| 211 |
+
|
| 212 |
+
# Apply Fortescue transformation (simplified real-valued version)
|
| 213 |
+
# For real signals, we approximate the transformation
|
| 214 |
+
x_transformed = torch.matmul(x, self.fortescue_real.T)
|
| 215 |
+
|
| 216 |
+
pos_seq = x_transformed[:, :, 0:1] # Positive sequence
|
| 217 |
+
neg_seq = x_transformed[:, :, 1:2] # Negative sequence
|
| 218 |
+
zero_seq = x_transformed[:, :, 2:3] # Zero sequence
|
| 219 |
+
|
| 220 |
+
return pos_seq, neg_seq, zero_seq
|
| 221 |
+
|
| 222 |
+
def forward(self, x, return_sequences=False):
|
| 223 |
+
"""
|
| 224 |
+
Args:
|
| 225 |
+
x: Input tensor [B, T, d_model]
|
| 226 |
+
return_sequences: Whether to return individual sequence outputs
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Attention output [B, T, d_model]
|
| 230 |
+
Optionally: (pos_out, neg_out, zero_out) if return_sequences=True
|
| 231 |
+
"""
|
| 232 |
+
B, T, _ = x.size()
|
| 233 |
+
|
| 234 |
+
# Apply attention for each sequence type
|
| 235 |
+
pos_out, _ = self.pos_seq_attn(x, x, x)
|
| 236 |
+
neg_out, _ = self.neg_seq_attn(x, x, x)
|
| 237 |
+
zero_out, _ = self.zero_seq_attn(x, x, x)
|
| 238 |
+
|
| 239 |
+
# Weighted combination based on sequence importance
|
| 240 |
+
weights = F.softmax(self.sequence_weights, dim=0)
|
| 241 |
+
combined = torch.cat(
|
| 242 |
+
[weights[0] * pos_out, weights[1] * neg_out, weights[2] * zero_out], dim=-1
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Project back to d_model
|
| 246 |
+
output = self.output_proj(combined)
|
| 247 |
+
output = self.dropout(output)
|
| 248 |
+
|
| 249 |
+
# Residual connection and layer norm
|
| 250 |
+
output = self.layer_norm(output + x)
|
| 251 |
+
|
| 252 |
+
if return_sequences:
|
| 253 |
+
return output, (pos_out, neg_out, zero_out)
|
| 254 |
+
return output
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class TransientAttention(nn.Module):
|
| 258 |
+
"""
|
| 259 |
+
Multi-scale attention for detecting transient events in voltage signals.
|
| 260 |
+
|
| 261 |
+
Transient events in power systems include:
|
| 262 |
+
- Voltage sags (ms to seconds)
|
| 263 |
+
- Voltage swells
|
| 264 |
+
- Momentary interruptions
|
| 265 |
+
- Switching transients (μs to ms)
|
| 266 |
+
|
| 267 |
+
This attention uses multiple time scales to capture different transient types.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, d_model, n_heads=4, dropout=0.1, scales=[1, 3, 5, 10]):
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
d_model: Model dimension
|
| 274 |
+
n_heads: Number of attention heads
|
| 275 |
+
dropout: Dropout rate
|
| 276 |
+
scales: List of temporal scales for multi-scale attention
|
| 277 |
+
"""
|
| 278 |
+
super(TransientAttention, self).__init__()
|
| 279 |
+
|
| 280 |
+
self.d_model = d_model
|
| 281 |
+
self.n_heads = n_heads
|
| 282 |
+
self.scales = scales
|
| 283 |
+
|
| 284 |
+
# Multi-scale convolutions for different transient durations
|
| 285 |
+
self.scale_convs = nn.ModuleList(
|
| 286 |
+
[
|
| 287 |
+
nn.Conv1d(
|
| 288 |
+
d_model, d_model, kernel_size=s, padding=s // 2, groups=d_model
|
| 289 |
+
)
|
| 290 |
+
for s in scales
|
| 291 |
+
]
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Scale-specific attention
|
| 295 |
+
self.scale_attentions = nn.ModuleList(
|
| 296 |
+
[
|
| 297 |
+
nn.MultiheadAttention(
|
| 298 |
+
d_model, num_heads=n_heads, dropout=dropout, batch_first=True
|
| 299 |
+
)
|
| 300 |
+
for _ in scales
|
| 301 |
+
]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Scale importance weights
|
| 305 |
+
self.scale_weights = nn.Parameter(torch.ones(len(scales)) / len(scales))
|
| 306 |
+
|
| 307 |
+
# Output projection
|
| 308 |
+
self.output_proj = nn.Linear(d_model * len(scales), d_model)
|
| 309 |
+
|
| 310 |
+
self.dropout = nn.Dropout(dropout)
|
| 311 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 312 |
+
|
| 313 |
+
def forward(self, x):
|
| 314 |
+
"""
|
| 315 |
+
Args:
|
| 316 |
+
x: Input tensor [B, T, d_model]
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Multi-scale attention output [B, T, d_model]
|
| 320 |
+
"""
|
| 321 |
+
B, T, D = x.size()
|
| 322 |
+
|
| 323 |
+
scale_outputs = []
|
| 324 |
+
|
| 325 |
+
for i, (conv, attn) in enumerate(zip(self.scale_convs, self.scale_attentions)):
|
| 326 |
+
# Apply scale-specific convolution
|
| 327 |
+
x_conv = conv(x.permute(0, 2, 1)).permute(0, 2, 1)
|
| 328 |
+
|
| 329 |
+
# Ensure same length
|
| 330 |
+
if x_conv.size(1) != T:
|
| 331 |
+
x_conv = x_conv[:, :T, :]
|
| 332 |
+
|
| 333 |
+
# Apply attention at this scale
|
| 334 |
+
scale_out, _ = attn(x_conv, x_conv, x_conv)
|
| 335 |
+
scale_outputs.append(scale_out)
|
| 336 |
+
|
| 337 |
+
# Weighted combination of scales
|
| 338 |
+
weights = F.softmax(self.scale_weights, dim=0)
|
| 339 |
+
combined = torch.cat(
|
| 340 |
+
[w * out for w, out in zip(weights, scale_outputs)], dim=-1
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Project to output dimension
|
| 344 |
+
output = self.output_proj(combined)
|
| 345 |
+
output = self.dropout(output)
|
| 346 |
+
|
| 347 |
+
# Residual connection and layer norm
|
| 348 |
+
output = self.layer_norm(output + x)
|
| 349 |
+
|
| 350 |
+
return output
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class VoltageChannelAttention(nn.Module):
|
| 354 |
+
"""
|
| 355 |
+
Channel-wise attention for voltage feature selection.
|
| 356 |
+
|
| 357 |
+
Different voltage features have varying importance for anomaly detection:
|
| 358 |
+
- Va, Vb, Vc: Direct voltage measurements
|
| 359 |
+
- Ia, Ib, Ic: Current measurements
|
| 360 |
+
- P, Q, S: Power metrics
|
| 361 |
+
- THD: Harmonic distortion
|
| 362 |
+
- Unbalance: Phase unbalance factor
|
| 363 |
+
|
| 364 |
+
This attention learns to weight different features based on their
|
| 365 |
+
relevance to anomaly detection.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
def __init__(self, num_channels, reduction_ratio=4):
|
| 369 |
+
"""
|
| 370 |
+
Args:
|
| 371 |
+
num_channels: Number of input channels/features
|
| 372 |
+
reduction_ratio: Reduction ratio for the bottleneck
|
| 373 |
+
"""
|
| 374 |
+
super(VoltageChannelAttention, self).__init__()
|
| 375 |
+
|
| 376 |
+
self.num_channels = num_channels
|
| 377 |
+
reduced_channels = max(1, num_channels // reduction_ratio)
|
| 378 |
+
|
| 379 |
+
# Channel attention with squeeze-and-excitation style
|
| 380 |
+
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
| 381 |
+
self.max_pool = nn.AdaptiveMaxPool1d(1)
|
| 382 |
+
|
| 383 |
+
self.fc = nn.Sequential(
|
| 384 |
+
nn.Linear(num_channels, reduced_channels, bias=False),
|
| 385 |
+
nn.ReLU(inplace=True),
|
| 386 |
+
nn.Linear(reduced_channels, num_channels, bias=False),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Feature group weighting (voltage, current, power, quality)
|
| 390 |
+
# Learnable importance for different feature groups
|
| 391 |
+
self.group_weights = nn.Parameter(torch.ones(4))
|
| 392 |
+
|
| 393 |
+
def forward(self, x):
|
| 394 |
+
"""
|
| 395 |
+
Args:
|
| 396 |
+
x: Input tensor [B, T, C]
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
Channel-attended output [B, T, C]
|
| 400 |
+
"""
|
| 401 |
+
B, T, C = x.size()
|
| 402 |
+
|
| 403 |
+
# Global average and max pooling
|
| 404 |
+
x_permuted = x.permute(0, 2, 1) # [B, C, T]
|
| 405 |
+
avg_out = self.avg_pool(x_permuted).squeeze(-1) # [B, C]
|
| 406 |
+
max_out = self.max_pool(x_permuted).squeeze(-1) # [B, C]
|
| 407 |
+
|
| 408 |
+
# Channel attention weights
|
| 409 |
+
avg_attn = self.fc(avg_out)
|
| 410 |
+
max_attn = self.fc(max_out)
|
| 411 |
+
attn = torch.sigmoid(avg_attn + max_attn) # [B, C]
|
| 412 |
+
|
| 413 |
+
# Apply attention
|
| 414 |
+
output = x * attn.unsqueeze(1)
|
| 415 |
+
|
| 416 |
+
return output
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class VoltageAttentionBlock(nn.Module):
|
| 420 |
+
"""
|
| 421 |
+
Combined attention block for voltage anomaly detection.
|
| 422 |
+
|
| 423 |
+
Integrates multiple attention mechanisms:
|
| 424 |
+
1. Inter-phase attention for phase relationships
|
| 425 |
+
2. Transient attention for multi-scale temporal patterns
|
| 426 |
+
3. Channel attention for feature importance
|
| 427 |
+
|
| 428 |
+
This provides comprehensive attention for power grid signals.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
d_model,
|
| 434 |
+
num_channels,
|
| 435 |
+
n_heads=4,
|
| 436 |
+
dropout=0.1,
|
| 437 |
+
use_inter_phase=True,
|
| 438 |
+
use_transient=True,
|
| 439 |
+
use_channel=True,
|
| 440 |
+
):
|
| 441 |
+
"""
|
| 442 |
+
Args:
|
| 443 |
+
d_model: Model dimension
|
| 444 |
+
num_channels: Number of input channels
|
| 445 |
+
n_heads: Number of attention heads
|
| 446 |
+
dropout: Dropout rate
|
| 447 |
+
use_inter_phase: Whether to use inter-phase attention
|
| 448 |
+
use_transient: Whether to use transient attention
|
| 449 |
+
use_channel: Whether to use channel attention
|
| 450 |
+
"""
|
| 451 |
+
super(VoltageAttentionBlock, self).__init__()
|
| 452 |
+
|
| 453 |
+
self.use_inter_phase = use_inter_phase
|
| 454 |
+
self.use_transient = use_transient
|
| 455 |
+
self.use_channel = use_channel
|
| 456 |
+
|
| 457 |
+
if use_inter_phase:
|
| 458 |
+
self.inter_phase_attn = InterPhaseAttention(d_model, n_heads, dropout)
|
| 459 |
+
|
| 460 |
+
if use_transient:
|
| 461 |
+
self.transient_attn = TransientAttention(d_model, n_heads, dropout)
|
| 462 |
+
|
| 463 |
+
if use_channel:
|
| 464 |
+
self.channel_attn = VoltageChannelAttention(num_channels)
|
| 465 |
+
|
| 466 |
+
# Final projection
|
| 467 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
| 468 |
+
self.dropout = nn.Dropout(dropout)
|
| 469 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 470 |
+
|
| 471 |
+
def forward(self, x, x_raw=None):
|
| 472 |
+
"""
|
| 473 |
+
Args:
|
| 474 |
+
x: Embedded input [B, T, d_model]
|
| 475 |
+
x_raw: Raw input for channel attention [B, T, num_channels] (optional)
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
Attention output [B, T, d_model]
|
| 479 |
+
"""
|
| 480 |
+
output = x
|
| 481 |
+
|
| 482 |
+
# Apply inter-phase attention
|
| 483 |
+
if self.use_inter_phase:
|
| 484 |
+
output = self.inter_phase_attn(output)
|
| 485 |
+
|
| 486 |
+
# Apply transient attention
|
| 487 |
+
if self.use_transient:
|
| 488 |
+
output = self.transient_attn(output)
|
| 489 |
+
|
| 490 |
+
# Apply channel attention on raw features if provided
|
| 491 |
+
if self.use_channel and x_raw is not None:
|
| 492 |
+
channel_weights = self.channel_attn(x_raw)
|
| 493 |
+
# Broadcast channel weights to embedded dimension
|
| 494 |
+
# This is a simplified application; in practice, might need adaptation
|
| 495 |
+
output = output * channel_weights.mean(dim=-1, keepdim=True).expand_as(
|
| 496 |
+
output
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# Final projection with residual
|
| 500 |
+
output = self.output_proj(output)
|
| 501 |
+
output = self.dropout(output)
|
| 502 |
+
output = self.layer_norm(output + x)
|
| 503 |
+
|
| 504 |
+
return output
|
layers/Transformer_EncDec.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer Encoder-Decoder Layers.
|
| 3 |
+
|
| 4 |
+
Standard Transformer architecture components for time series.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConvLayer(nn.Module):
|
| 13 |
+
"""Convolutional layer with downsampling for distilling."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, c_in):
|
| 16 |
+
super(ConvLayer, self).__init__()
|
| 17 |
+
self.downConv = nn.Conv1d(
|
| 18 |
+
in_channels=c_in,
|
| 19 |
+
out_channels=c_in,
|
| 20 |
+
kernel_size=3,
|
| 21 |
+
padding=2,
|
| 22 |
+
padding_mode="circular",
|
| 23 |
+
)
|
| 24 |
+
self.norm = nn.BatchNorm1d(c_in)
|
| 25 |
+
self.activation = nn.ELU()
|
| 26 |
+
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
x = self.downConv(x.permute(0, 2, 1))
|
| 30 |
+
x = self.norm(x)
|
| 31 |
+
x = self.activation(x)
|
| 32 |
+
x = self.maxPool(x)
|
| 33 |
+
x = x.transpose(1, 2)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EncoderLayer(nn.Module):
|
| 38 |
+
"""Standard Transformer encoder layer."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
| 41 |
+
super(EncoderLayer, self).__init__()
|
| 42 |
+
d_ff = d_ff or 4 * d_model
|
| 43 |
+
self.attention = attention
|
| 44 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
| 45 |
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
| 46 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 47 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 48 |
+
self.dropout = nn.Dropout(dropout)
|
| 49 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 50 |
+
|
| 51 |
+
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
| 52 |
+
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
|
| 53 |
+
x = x + self.dropout(new_x)
|
| 54 |
+
|
| 55 |
+
y = x = self.norm1(x)
|
| 56 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 57 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 58 |
+
|
| 59 |
+
return self.norm2(x + y), attn
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Encoder(nn.Module):
|
| 63 |
+
"""Transformer encoder with optional convolutional layers."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
| 66 |
+
super(Encoder, self).__init__()
|
| 67 |
+
self.attn_layers = nn.ModuleList(attn_layers)
|
| 68 |
+
self.conv_layers = (
|
| 69 |
+
nn.ModuleList(conv_layers) if conv_layers is not None else None
|
| 70 |
+
)
|
| 71 |
+
self.norm = norm_layer
|
| 72 |
+
|
| 73 |
+
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
| 74 |
+
attns = []
|
| 75 |
+
if self.conv_layers is not None:
|
| 76 |
+
for i, (attn_layer, conv_layer) in enumerate(
|
| 77 |
+
zip(self.attn_layers, self.conv_layers)
|
| 78 |
+
):
|
| 79 |
+
delta = delta if i == 0 else None
|
| 80 |
+
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
| 81 |
+
x = conv_layer(x)
|
| 82 |
+
attns.append(attn)
|
| 83 |
+
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
| 84 |
+
attns.append(attn)
|
| 85 |
+
else:
|
| 86 |
+
for attn_layer in self.attn_layers:
|
| 87 |
+
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
| 88 |
+
attns.append(attn)
|
| 89 |
+
|
| 90 |
+
if self.norm is not None:
|
| 91 |
+
x = self.norm(x)
|
| 92 |
+
|
| 93 |
+
return x, attns
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DecoderLayer(nn.Module):
|
| 97 |
+
"""Standard Transformer decoder layer with self and cross attention."""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
self_attention,
|
| 102 |
+
cross_attention,
|
| 103 |
+
d_model,
|
| 104 |
+
d_ff=None,
|
| 105 |
+
dropout=0.1,
|
| 106 |
+
activation="relu",
|
| 107 |
+
):
|
| 108 |
+
super(DecoderLayer, self).__init__()
|
| 109 |
+
d_ff = d_ff or 4 * d_model
|
| 110 |
+
self.self_attention = self_attention
|
| 111 |
+
self.cross_attention = cross_attention
|
| 112 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
| 113 |
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
| 114 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 115 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 116 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 117 |
+
self.dropout = nn.Dropout(dropout)
|
| 118 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 119 |
+
|
| 120 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
| 121 |
+
x = x + self.dropout(
|
| 122 |
+
self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
|
| 123 |
+
)
|
| 124 |
+
x = self.norm1(x)
|
| 125 |
+
|
| 126 |
+
x = x + self.dropout(
|
| 127 |
+
self.cross_attention(
|
| 128 |
+
x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
|
| 129 |
+
)[0]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
y = x = self.norm2(x)
|
| 133 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 134 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 135 |
+
|
| 136 |
+
return self.norm3(x + y)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Decoder(nn.Module):
|
| 140 |
+
"""Transformer decoder."""
|
| 141 |
+
|
| 142 |
+
def __init__(self, layers, norm_layer=None, projection=None):
|
| 143 |
+
super(Decoder, self).__init__()
|
| 144 |
+
self.layers = nn.ModuleList(layers)
|
| 145 |
+
self.norm = norm_layer
|
| 146 |
+
self.projection = projection
|
| 147 |
+
|
| 148 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
|
| 149 |
+
for layer in self.layers:
|
| 150 |
+
x = layer(
|
| 151 |
+
x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if self.norm is not None:
|
| 155 |
+
x = self.norm(x)
|
| 156 |
+
|
| 157 |
+
if self.projection is not None:
|
| 158 |
+
x = self.projection(x)
|
| 159 |
+
return x
|
layers/VoltageEmbed.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoltageEmbed: Specialized Embedding Layers for Power Grid Voltage Signals
|
| 3 |
+
|
| 4 |
+
This module provides domain-specific embeddings designed for rural power grid
|
| 5 |
+
voltage anomaly detection:
|
| 6 |
+
|
| 7 |
+
1. PowerFrequencyEmbedding: Encodes 50Hz power frequency cycles
|
| 8 |
+
2. DailyLoadEmbedding: Captures daily load patterns (24-hour cycles)
|
| 9 |
+
3. ThreePhaseEmbedding: Encodes Va-Vb-Vc phase relationships (120-degree shift)
|
| 10 |
+
4. VoltageDataEmbedding: Combined embedding for voltage signals
|
| 11 |
+
|
| 12 |
+
Key innovations:
|
| 13 |
+
- Exploit known periodicities in power systems (50Hz, daily, weekly)
|
| 14 |
+
- Encode three-phase relationships (phase angle differences)
|
| 15 |
+
- Integrate voltage quality features (THD, unbalance) into embeddings
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PowerFrequencyEmbedding(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Embedding that encodes power frequency cycle information.
|
| 28 |
+
|
| 29 |
+
In 50Hz power systems, each cycle is 20ms. For data sampled at different
|
| 30 |
+
rates, this embedding helps the model understand where each sample falls
|
| 31 |
+
within the power frequency cycle.
|
| 32 |
+
|
| 33 |
+
For anomaly detection with second-level sampling, this embedding encodes
|
| 34 |
+
the phase relationship with respect to longer-term harmonics.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, d_model, max_len=5000, power_freq=50.0, sample_rate=1.0):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
d_model: Embedding dimension
|
| 41 |
+
max_len: Maximum sequence length
|
| 42 |
+
power_freq: Power system frequency (50Hz or 60Hz)
|
| 43 |
+
sample_rate: Sampling rate in Hz
|
| 44 |
+
"""
|
| 45 |
+
super(PowerFrequencyEmbedding, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
self.power_freq = power_freq
|
| 49 |
+
self.sample_rate = sample_rate
|
| 50 |
+
|
| 51 |
+
# Pre-compute power frequency cycle embeddings
|
| 52 |
+
pe = torch.zeros(max_len, d_model).float()
|
| 53 |
+
pe.requires_grad = False
|
| 54 |
+
|
| 55 |
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
| 56 |
+
|
| 57 |
+
# Multiple harmonics of power frequency
|
| 58 |
+
harmonics = [1, 2, 3, 5, 7] # Fundamental + common harmonics
|
| 59 |
+
harmonic_dim = d_model // (len(harmonics) * 2)
|
| 60 |
+
|
| 61 |
+
for h_idx, harmonic in enumerate(harmonics):
|
| 62 |
+
freq = power_freq * harmonic
|
| 63 |
+
# Angular frequency considering sample rate
|
| 64 |
+
omega = 2.0 * math.pi * freq / sample_rate
|
| 65 |
+
|
| 66 |
+
start_idx = h_idx * harmonic_dim * 2
|
| 67 |
+
end_idx = min(start_idx + harmonic_dim * 2, d_model)
|
| 68 |
+
|
| 69 |
+
# Use alternating sin/cos for each harmonic
|
| 70 |
+
for i in range(0, end_idx - start_idx, 2):
|
| 71 |
+
phase_shift = i * math.pi / (end_idx - start_idx)
|
| 72 |
+
if start_idx + i < d_model:
|
| 73 |
+
pe[:, start_idx + i] = torch.sin(
|
| 74 |
+
position.squeeze() * omega + phase_shift
|
| 75 |
+
)
|
| 76 |
+
if start_idx + i + 1 < d_model:
|
| 77 |
+
pe[:, start_idx + i + 1] = torch.cos(
|
| 78 |
+
position.squeeze() * omega + phase_shift
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
pe = pe.unsqueeze(0)
|
| 82 |
+
self.register_buffer("pe", pe)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
x: Input tensor [B, T, C]
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Power frequency positional encoding [B, T, d_model]
|
| 91 |
+
"""
|
| 92 |
+
return self.pe[:, : x.size(1)]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DailyLoadEmbedding(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Embedding that captures daily load patterns in power systems.
|
| 98 |
+
|
| 99 |
+
Power consumption follows predictable daily patterns:
|
| 100 |
+
- Morning peak (7-9 AM)
|
| 101 |
+
- Midday trough (12-2 PM)
|
| 102 |
+
- Evening peak (6-9 PM)
|
| 103 |
+
- Night low (12-6 AM)
|
| 104 |
+
|
| 105 |
+
This embedding helps the model understand time-of-day context.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, d_model, samples_per_day=86400):
|
| 109 |
+
"""
|
| 110 |
+
Args:
|
| 111 |
+
d_model: Embedding dimension
|
| 112 |
+
samples_per_day: Number of samples per day (86400 for 1Hz sampling)
|
| 113 |
+
"""
|
| 114 |
+
super(DailyLoadEmbedding, self).__init__()
|
| 115 |
+
|
| 116 |
+
self.d_model = d_model
|
| 117 |
+
self.samples_per_day = samples_per_day
|
| 118 |
+
|
| 119 |
+
# Pre-compute daily cycle embeddings
|
| 120 |
+
# Use multiple frequencies to capture different daily patterns
|
| 121 |
+
daily_periods = [
|
| 122 |
+
samples_per_day, # Full day cycle
|
| 123 |
+
samples_per_day // 2, # Half-day (AM/PM)
|
| 124 |
+
samples_per_day // 3, # 8-hour cycles
|
| 125 |
+
samples_per_day // 4, # 6-hour cycles
|
| 126 |
+
samples_per_day // 6, # 4-hour cycles
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
self.period_embeddings = nn.ModuleList(
|
| 130 |
+
[
|
| 131 |
+
nn.Embedding(max(period, 1), d_model // len(daily_periods))
|
| 132 |
+
for period in daily_periods
|
| 133 |
+
]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Projection to combine period embeddings
|
| 137 |
+
self.projection = nn.Linear(
|
| 138 |
+
d_model // len(daily_periods) * len(daily_periods), d_model
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x, time_indices=None):
|
| 142 |
+
"""
|
| 143 |
+
Args:
|
| 144 |
+
x: Input tensor [B, T, C]
|
| 145 |
+
time_indices: Optional time indices within day [B, T]
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Daily load pattern embedding [B, T, d_model]
|
| 149 |
+
"""
|
| 150 |
+
B, T, _ = x.size()
|
| 151 |
+
|
| 152 |
+
if time_indices is None:
|
| 153 |
+
# Assume sequential time indices
|
| 154 |
+
time_indices = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1)
|
| 155 |
+
|
| 156 |
+
embeddings = []
|
| 157 |
+
for i, emb in enumerate(self.period_embeddings):
|
| 158 |
+
period = emb.num_embeddings
|
| 159 |
+
period_idx = (time_indices % period).long()
|
| 160 |
+
embeddings.append(emb(period_idx))
|
| 161 |
+
|
| 162 |
+
combined = torch.cat(embeddings, dim=-1)
|
| 163 |
+
return self.projection(combined)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ThreePhaseEmbedding(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Embedding that encodes three-phase relationships for Va, Vb, Vc.
|
| 169 |
+
|
| 170 |
+
In balanced three-phase systems:
|
| 171 |
+
- Va, Vb, Vc are 120 degrees (2π/3) apart
|
| 172 |
+
- Positive sequence: Va leads Vb leads Vc
|
| 173 |
+
- Negative sequence: Va leads Vc leads Vb (indicates unbalance)
|
| 174 |
+
|
| 175 |
+
This embedding helps the model understand inter-phase relationships.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, d_model, num_phases=3):
|
| 179 |
+
"""
|
| 180 |
+
Args:
|
| 181 |
+
d_model: Embedding dimension
|
| 182 |
+
num_phases: Number of phases (3 for three-phase systems)
|
| 183 |
+
"""
|
| 184 |
+
super(ThreePhaseEmbedding, self).__init__()
|
| 185 |
+
|
| 186 |
+
self.d_model = d_model
|
| 187 |
+
self.num_phases = num_phases
|
| 188 |
+
|
| 189 |
+
# Phase angle embeddings (0, 120, 240 degrees)
|
| 190 |
+
phase_angles = torch.tensor([0, 2 * math.pi / 3, 4 * math.pi / 3])
|
| 191 |
+
self.register_buffer("phase_angles", phase_angles)
|
| 192 |
+
|
| 193 |
+
# Learnable phase embedding
|
| 194 |
+
self.phase_embed = nn.Embedding(num_phases, d_model)
|
| 195 |
+
|
| 196 |
+
# Positive and negative sequence embeddings
|
| 197 |
+
self.pos_seq_embed = nn.Linear(num_phases, d_model)
|
| 198 |
+
self.neg_seq_embed = nn.Linear(num_phases, d_model)
|
| 199 |
+
|
| 200 |
+
def forward(self, x, channel_ids=None):
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
x: Input tensor [B, T, C] where C includes three-phase channels
|
| 204 |
+
channel_ids: Optional tensor indicating which channels are Va, Vb, Vc
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Three-phase embedding [B, T, d_model]
|
| 208 |
+
"""
|
| 209 |
+
B, T, C = x.size()
|
| 210 |
+
|
| 211 |
+
# Assume first 3 channels are Va, Vb, Vc if not specified
|
| 212 |
+
if channel_ids is None:
|
| 213 |
+
voltage_channels = min(3, C)
|
| 214 |
+
else:
|
| 215 |
+
voltage_channels = len(channel_ids)
|
| 216 |
+
|
| 217 |
+
# Get phase embeddings
|
| 218 |
+
phase_ids = torch.arange(voltage_channels, device=x.device)
|
| 219 |
+
phase_emb = self.phase_embed(phase_ids) # [num_phases, d_model]
|
| 220 |
+
|
| 221 |
+
# Calculate symmetrical components (simplified)
|
| 222 |
+
# Positive sequence: Va + a*Vb + a²*Vc where a = exp(j*2π/3)
|
| 223 |
+
if voltage_channels >= 3:
|
| 224 |
+
v_abc = x[:, :, :3] # [B, T, 3]
|
| 225 |
+
|
| 226 |
+
# Positive sequence embedding
|
| 227 |
+
pos_emb = self.pos_seq_embed(v_abc) # [B, T, d_model]
|
| 228 |
+
|
| 229 |
+
# Negative sequence: Va + a²*Vb + a*Vc
|
| 230 |
+
v_acb = torch.stack([x[:, :, 0], x[:, :, 2], x[:, :, 1]], dim=-1)
|
| 231 |
+
neg_emb = self.neg_seq_embed(v_acb) # [B, T, d_model]
|
| 232 |
+
|
| 233 |
+
# Combine with phase embeddings
|
| 234 |
+
combined = (
|
| 235 |
+
pos_emb + 0.1 * neg_emb + phase_emb.mean(0).unsqueeze(0).unsqueeze(0)
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
combined = phase_emb[:voltage_channels].mean(0).unsqueeze(0).unsqueeze(0)
|
| 239 |
+
combined = combined.expand(B, T, -1)
|
| 240 |
+
|
| 241 |
+
return combined
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class VoltageQualityEmbedding(nn.Module):
|
| 245 |
+
"""
|
| 246 |
+
Embedding that encodes voltage quality indicators.
|
| 247 |
+
|
| 248 |
+
Key voltage quality metrics:
|
| 249 |
+
- Voltage deviation from nominal
|
| 250 |
+
- Total Harmonic Distortion (THD)
|
| 251 |
+
- Voltage unbalance factor
|
| 252 |
+
- Frequency deviation
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(self, d_model, nominal_voltage=220.0, nominal_freq=50.0):
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
d_model: Embedding dimension
|
| 259 |
+
nominal_voltage: Nominal voltage value (V)
|
| 260 |
+
nominal_freq: Nominal frequency (Hz)
|
| 261 |
+
"""
|
| 262 |
+
super(VoltageQualityEmbedding, self).__init__()
|
| 263 |
+
|
| 264 |
+
self.d_model = d_model
|
| 265 |
+
self.nominal_voltage = nominal_voltage
|
| 266 |
+
self.nominal_freq = nominal_freq
|
| 267 |
+
|
| 268 |
+
# Quality indicator projections
|
| 269 |
+
self.voltage_deviation_proj = nn.Linear(1, d_model // 4)
|
| 270 |
+
self.thd_proj = nn.Linear(1, d_model // 4)
|
| 271 |
+
self.unbalance_proj = nn.Linear(1, d_model // 4)
|
| 272 |
+
self.freq_deviation_proj = nn.Linear(1, d_model // 4)
|
| 273 |
+
|
| 274 |
+
# Combination projection
|
| 275 |
+
self.combine_proj = nn.Linear(d_model, d_model)
|
| 276 |
+
|
| 277 |
+
def forward(self, voltage, thd=None, unbalance=None, freq=None):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
voltage: Voltage values [B, T, num_phases]
|
| 281 |
+
thd: Total harmonic distortion [B, T, 1] or None
|
| 282 |
+
unbalance: Voltage unbalance factor [B, T, 1] or None
|
| 283 |
+
freq: System frequency [B, T, 1] or None
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Voltage quality embedding [B, T, d_model]
|
| 287 |
+
"""
|
| 288 |
+
B, T, _ = voltage.size()
|
| 289 |
+
|
| 290 |
+
# Calculate voltage deviation
|
| 291 |
+
mean_voltage = voltage.mean(dim=-1, keepdim=True)
|
| 292 |
+
voltage_dev = (mean_voltage - self.nominal_voltage) / self.nominal_voltage
|
| 293 |
+
volt_emb = self.voltage_deviation_proj(voltage_dev)
|
| 294 |
+
|
| 295 |
+
# THD embedding
|
| 296 |
+
if thd is not None:
|
| 297 |
+
thd_emb = self.thd_proj(thd)
|
| 298 |
+
else:
|
| 299 |
+
thd_emb = torch.zeros(B, T, self.d_model // 4, device=voltage.device)
|
| 300 |
+
|
| 301 |
+
# Unbalance embedding
|
| 302 |
+
if unbalance is not None:
|
| 303 |
+
unb_emb = self.unbalance_proj(unbalance)
|
| 304 |
+
else:
|
| 305 |
+
# Calculate simple unbalance from voltage
|
| 306 |
+
if voltage.size(-1) >= 3:
|
| 307 |
+
v_mean = voltage[:, :, :3].mean(dim=-1, keepdim=True)
|
| 308 |
+
v_max_dev = (
|
| 309 |
+
(voltage[:, :, :3] - v_mean).abs().max(dim=-1, keepdim=True)[0]
|
| 310 |
+
)
|
| 311 |
+
unb = v_max_dev / (v_mean + 1e-8) * 100
|
| 312 |
+
unb_emb = self.unbalance_proj(unb)
|
| 313 |
+
else:
|
| 314 |
+
unb_emb = torch.zeros(B, T, self.d_model // 4, device=voltage.device)
|
| 315 |
+
|
| 316 |
+
# Frequency deviation embedding
|
| 317 |
+
if freq is not None:
|
| 318 |
+
freq_dev = (freq - self.nominal_freq) / self.nominal_freq
|
| 319 |
+
freq_emb = self.freq_deviation_proj(freq_dev)
|
| 320 |
+
else:
|
| 321 |
+
freq_emb = torch.zeros(B, T, self.d_model // 4, device=voltage.device)
|
| 322 |
+
|
| 323 |
+
# Combine all quality embeddings
|
| 324 |
+
combined = torch.cat([volt_emb, thd_emb, unb_emb, freq_emb], dim=-1)
|
| 325 |
+
return self.combine_proj(combined)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class VoltageDataEmbedding(nn.Module):
|
| 329 |
+
"""
|
| 330 |
+
Complete data embedding for voltage anomaly detection.
|
| 331 |
+
|
| 332 |
+
Combines:
|
| 333 |
+
1. Token embedding (from raw values)
|
| 334 |
+
2. Power frequency embedding
|
| 335 |
+
3. Daily load embedding
|
| 336 |
+
4. Three-phase embedding
|
| 337 |
+
5. Voltage quality embedding
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
c_in,
|
| 343 |
+
d_model,
|
| 344 |
+
embed_type="fixed",
|
| 345 |
+
freq="h",
|
| 346 |
+
dropout=0.1,
|
| 347 |
+
use_power_freq=True,
|
| 348 |
+
use_daily=True,
|
| 349 |
+
use_three_phase=True,
|
| 350 |
+
use_quality=True,
|
| 351 |
+
sample_rate=1.0,
|
| 352 |
+
):
|
| 353 |
+
"""
|
| 354 |
+
Args:
|
| 355 |
+
c_in: Number of input channels
|
| 356 |
+
d_model: Embedding dimension
|
| 357 |
+
embed_type: Type of temporal embedding
|
| 358 |
+
freq: Frequency of data
|
| 359 |
+
dropout: Dropout rate
|
| 360 |
+
use_power_freq: Whether to use power frequency embedding
|
| 361 |
+
use_daily: Whether to use daily load embedding
|
| 362 |
+
use_three_phase: Whether to use three-phase embedding
|
| 363 |
+
use_quality: Whether to use voltage quality embedding
|
| 364 |
+
sample_rate: Sampling rate in Hz
|
| 365 |
+
"""
|
| 366 |
+
super(VoltageDataEmbedding, self).__init__()
|
| 367 |
+
|
| 368 |
+
self.d_model = d_model
|
| 369 |
+
self.use_power_freq = use_power_freq
|
| 370 |
+
self.use_daily = use_daily
|
| 371 |
+
self.use_three_phase = use_three_phase
|
| 372 |
+
self.use_quality = use_quality
|
| 373 |
+
|
| 374 |
+
# Token embedding (1D convolution)
|
| 375 |
+
padding = 1
|
| 376 |
+
self.token_conv = nn.Conv1d(
|
| 377 |
+
in_channels=c_in,
|
| 378 |
+
out_channels=d_model,
|
| 379 |
+
kernel_size=3,
|
| 380 |
+
padding=padding,
|
| 381 |
+
padding_mode="circular",
|
| 382 |
+
bias=False,
|
| 383 |
+
)
|
| 384 |
+
nn.init.kaiming_normal_(
|
| 385 |
+
self.token_conv.weight, mode="fan_in", nonlinearity="leaky_relu"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Optional embeddings
|
| 389 |
+
if use_power_freq:
|
| 390 |
+
self.power_freq_embedding = PowerFrequencyEmbedding(
|
| 391 |
+
d_model, sample_rate=sample_rate
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if use_daily:
|
| 395 |
+
# Samples per day depends on sampling rate
|
| 396 |
+
samples_per_day = int(86400 * sample_rate)
|
| 397 |
+
self.daily_embedding = DailyLoadEmbedding(
|
| 398 |
+
d_model, samples_per_day=samples_per_day
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if use_three_phase:
|
| 402 |
+
self.three_phase_embedding = ThreePhaseEmbedding(d_model)
|
| 403 |
+
|
| 404 |
+
if use_quality:
|
| 405 |
+
self.quality_embedding = VoltageQualityEmbedding(d_model)
|
| 406 |
+
|
| 407 |
+
# Combination weights
|
| 408 |
+
num_embeddings = 1 + use_power_freq + use_daily + use_three_phase + use_quality
|
| 409 |
+
self.combination_weights = nn.Parameter(
|
| 410 |
+
torch.ones(num_embeddings) / num_embeddings
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 414 |
+
|
| 415 |
+
def forward(self, x, x_mark=None):
|
| 416 |
+
"""
|
| 417 |
+
Args:
|
| 418 |
+
x: Input tensor [B, T, C]
|
| 419 |
+
x_mark: Optional temporal marks [B, T, mark_dim]
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
Embedded tensor [B, T, d_model]
|
| 423 |
+
"""
|
| 424 |
+
# Token embedding
|
| 425 |
+
token_emb = self.token_conv(x.permute(0, 2, 1)).transpose(1, 2)
|
| 426 |
+
|
| 427 |
+
embeddings = [token_emb]
|
| 428 |
+
weights = [self.combination_weights[0]]
|
| 429 |
+
|
| 430 |
+
idx = 1
|
| 431 |
+
|
| 432 |
+
# Power frequency embedding
|
| 433 |
+
if self.use_power_freq:
|
| 434 |
+
pf_emb = self.power_freq_embedding(x)
|
| 435 |
+
embeddings.append(pf_emb)
|
| 436 |
+
weights.append(self.combination_weights[idx])
|
| 437 |
+
idx += 1
|
| 438 |
+
|
| 439 |
+
# Daily embedding
|
| 440 |
+
if self.use_daily:
|
| 441 |
+
daily_emb = self.daily_embedding(x)
|
| 442 |
+
embeddings.append(daily_emb)
|
| 443 |
+
weights.append(self.combination_weights[idx])
|
| 444 |
+
idx += 1
|
| 445 |
+
|
| 446 |
+
# Three-phase embedding
|
| 447 |
+
if self.use_three_phase:
|
| 448 |
+
tp_emb = self.three_phase_embedding(x)
|
| 449 |
+
embeddings.append(tp_emb)
|
| 450 |
+
weights.append(self.combination_weights[idx])
|
| 451 |
+
idx += 1
|
| 452 |
+
|
| 453 |
+
# Quality embedding
|
| 454 |
+
if self.use_quality:
|
| 455 |
+
# Extract voltage channels (assume first 3)
|
| 456 |
+
voltage = x[:, :, : min(3, x.size(-1))]
|
| 457 |
+
qual_emb = self.quality_embedding(voltage)
|
| 458 |
+
embeddings.append(qual_emb)
|
| 459 |
+
weights.append(self.combination_weights[idx])
|
| 460 |
+
|
| 461 |
+
# Weighted combination
|
| 462 |
+
weights = F.softmax(torch.stack(weights), dim=0)
|
| 463 |
+
combined = sum(w * e for w, e in zip(weights, embeddings))
|
| 464 |
+
|
| 465 |
+
return self.dropout(combined)
|
layers/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Layers module for Voltage Anomaly Detection
|
| 2 |
+
from .Conv_Blocks import Inception_Block_V1, Inception_Block_V2
|
| 3 |
+
from .Embed import (
|
| 4 |
+
DataEmbedding,
|
| 5 |
+
DataEmbedding_inverted,
|
| 6 |
+
DataEmbedding_wo_pos,
|
| 7 |
+
FixedEmbedding,
|
| 8 |
+
PatchEmbedding,
|
| 9 |
+
PositionalEmbedding,
|
| 10 |
+
TemporalEmbedding,
|
| 11 |
+
TimeFeatureEmbedding,
|
| 12 |
+
TokenEmbedding,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"PositionalEmbedding",
|
| 17 |
+
"TokenEmbedding",
|
| 18 |
+
"FixedEmbedding",
|
| 19 |
+
"TemporalEmbedding",
|
| 20 |
+
"TimeFeatureEmbedding",
|
| 21 |
+
"DataEmbedding",
|
| 22 |
+
"DataEmbedding_inverted",
|
| 23 |
+
"DataEmbedding_wo_pos",
|
| 24 |
+
"PatchEmbedding",
|
| 25 |
+
"Inception_Block_V1",
|
| 26 |
+
"Inception_Block_V2",
|
| 27 |
+
]
|
layers/__pycache__/AutoCorrelation.cpython-310.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
layers/__pycache__/AutoCorrelation.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
layers/__pycache__/Autoformer_EncDec.cpython-310.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
layers/__pycache__/Autoformer_EncDec.cpython-311.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
layers/__pycache__/Conv_Blocks.cpython-310.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
layers/__pycache__/Conv_Blocks.cpython-311.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
layers/__pycache__/Embed.cpython-310.pyc
ADDED
|
Binary file (7.66 kB). View file
|
|
|
layers/__pycache__/Embed.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
layers/__pycache__/SelfAttention_Family.cpython-310.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
layers/__pycache__/SelfAttention_Family.cpython-311.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
layers/__pycache__/StandardNorm.cpython-310.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
layers/__pycache__/StandardNorm.cpython-311.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
layers/__pycache__/Transformer_EncDec.cpython-310.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
layers/__pycache__/Transformer_EncDec.cpython-311.pyc
ADDED
|
Binary file (9.95 kB). View file
|
|
|
layers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (599 Bytes). View file
|
|
|
layers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (756 Bytes). View file
|
|
|
models/DLinear.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DLinear: Decomposition Linear Model for Time Series.
|
| 3 |
+
|
| 4 |
+
Paper: Are Transformers Effective for Time Series Forecasting?
|
| 5 |
+
Link: https://arxiv.org/pdf/2205.13504.pdf
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from layers.Autoformer_EncDec import series_decomp
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Model(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
DLinear: Simple linear model with series decomposition.
|
| 17 |
+
Decomposes time series into trend and seasonal, applies linear layers separately.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, configs, individual=False):
|
| 21 |
+
"""
|
| 22 |
+
individual: Bool, whether shared model among different variates.
|
| 23 |
+
"""
|
| 24 |
+
super(Model, self).__init__()
|
| 25 |
+
self.task_name = configs.task_name
|
| 26 |
+
self.seq_len = configs.seq_len
|
| 27 |
+
|
| 28 |
+
if (
|
| 29 |
+
self.task_name == "classification"
|
| 30 |
+
or self.task_name == "anomaly_detection"
|
| 31 |
+
or self.task_name == "imputation"
|
| 32 |
+
):
|
| 33 |
+
self.pred_len = configs.seq_len
|
| 34 |
+
else:
|
| 35 |
+
self.pred_len = configs.pred_len
|
| 36 |
+
|
| 37 |
+
# Series decomposition block from Autoformer
|
| 38 |
+
self.decompsition = series_decomp(configs.moving_avg)
|
| 39 |
+
self.individual = individual
|
| 40 |
+
self.channels = configs.enc_in
|
| 41 |
+
|
| 42 |
+
if self.individual:
|
| 43 |
+
self.Linear_Seasonal = nn.ModuleList()
|
| 44 |
+
self.Linear_Trend = nn.ModuleList()
|
| 45 |
+
for i in range(self.channels):
|
| 46 |
+
self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len))
|
| 47 |
+
self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len))
|
| 48 |
+
self.Linear_Seasonal[i].weight = nn.Parameter(
|
| 49 |
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
|
| 50 |
+
)
|
| 51 |
+
self.Linear_Trend[i].weight = nn.Parameter(
|
| 52 |
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
|
| 56 |
+
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
|
| 57 |
+
self.Linear_Seasonal.weight = nn.Parameter(
|
| 58 |
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
|
| 59 |
+
)
|
| 60 |
+
self.Linear_Trend.weight = nn.Parameter(
|
| 61 |
+
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if self.task_name == "classification":
|
| 65 |
+
self.projection = nn.Linear(
|
| 66 |
+
configs.enc_in * configs.seq_len, configs.num_class
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def encoder(self, x):
|
| 70 |
+
seasonal_init, trend_init = self.decompsition(x)
|
| 71 |
+
seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(
|
| 72 |
+
0, 2, 1
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if self.individual:
|
| 76 |
+
seasonal_output = torch.zeros(
|
| 77 |
+
[seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
|
| 78 |
+
dtype=seasonal_init.dtype,
|
| 79 |
+
).to(seasonal_init.device)
|
| 80 |
+
trend_output = torch.zeros(
|
| 81 |
+
[trend_init.size(0), trend_init.size(1), self.pred_len],
|
| 82 |
+
dtype=trend_init.dtype,
|
| 83 |
+
).to(trend_init.device)
|
| 84 |
+
for i in range(self.channels):
|
| 85 |
+
seasonal_output[:, i, :] = self.Linear_Seasonal[i](
|
| 86 |
+
seasonal_init[:, i, :]
|
| 87 |
+
)
|
| 88 |
+
trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :])
|
| 89 |
+
else:
|
| 90 |
+
seasonal_output = self.Linear_Seasonal(seasonal_init)
|
| 91 |
+
trend_output = self.Linear_Trend(trend_init)
|
| 92 |
+
|
| 93 |
+
x = seasonal_output + trend_output
|
| 94 |
+
return x.permute(0, 2, 1)
|
| 95 |
+
|
| 96 |
+
def forecast(self, x_enc):
|
| 97 |
+
return self.encoder(x_enc)
|
| 98 |
+
|
| 99 |
+
def imputation(self, x_enc):
|
| 100 |
+
return self.encoder(x_enc)
|
| 101 |
+
|
| 102 |
+
def anomaly_detection(self, x_enc):
|
| 103 |
+
return self.encoder(x_enc)
|
| 104 |
+
|
| 105 |
+
def classification(self, x_enc):
|
| 106 |
+
enc_out = self.encoder(x_enc)
|
| 107 |
+
output = enc_out.reshape(enc_out.shape[0], -1)
|
| 108 |
+
output = self.projection(output)
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 112 |
+
if (
|
| 113 |
+
self.task_name == "long_term_forecast"
|
| 114 |
+
or self.task_name == "short_term_forecast"
|
| 115 |
+
):
|
| 116 |
+
dec_out = self.forecast(x_enc)
|
| 117 |
+
return dec_out[:, -self.pred_len :, :]
|
| 118 |
+
if self.task_name == "imputation":
|
| 119 |
+
dec_out = self.imputation(x_enc)
|
| 120 |
+
return dec_out
|
| 121 |
+
if self.task_name == "anomaly_detection":
|
| 122 |
+
dec_out = self.anomaly_detection(x_enc)
|
| 123 |
+
return dec_out
|
| 124 |
+
if self.task_name == "classification":
|
| 125 |
+
dec_out = self.classification(x_enc)
|
| 126 |
+
return dec_out
|
| 127 |
+
return None
|
models/MTSTimesNet.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MTSTimesNet: Multi-scale Temporal TimesNet for Rural Voltage Anomaly Detection
|
| 3 |
+
|
| 4 |
+
Core Innovation: Parallel multi-scale temporal branches that simultaneously capture
|
| 5 |
+
patterns at different time scales (short-term fluctuations, medium-term trends, long-term patterns).
|
| 6 |
+
|
| 7 |
+
Rural power grids exhibit multi-scale temporal patterns:
|
| 8 |
+
- Short-term (seconds to minutes): Transient events, voltage sags
|
| 9 |
+
- Medium-term (minutes to hours): Load variations, daily patterns
|
| 10 |
+
- Long-term (hours to days): Seasonal patterns, systematic issues
|
| 11 |
+
|
| 12 |
+
Key Components:
|
| 13 |
+
1. Multi-scale TimesBlocks: Parallel branches with different period focus
|
| 14 |
+
2. Adaptive Fusion Gate: Learns optimal combination of scales
|
| 15 |
+
3. Cross-scale Residual Connections: Information flow across scales
|
| 16 |
+
|
| 17 |
+
Author: Voltage Anomaly Detection Research
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.fft
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from layers.Conv_Blocks import Inception_Block_V1
|
| 29 |
+
from layers.Embed import DataEmbedding
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def FFT_for_Period_Range(x, k=2, min_period=2, max_period=None):
|
| 33 |
+
"""FFT-based period discovery with range constraints."""
|
| 34 |
+
B, T, C = x.size()
|
| 35 |
+
if max_period is None:
|
| 36 |
+
max_period = T // 2
|
| 37 |
+
|
| 38 |
+
xf = torch.fft.rfft(x, dim=1)
|
| 39 |
+
frequency_list = abs(xf).mean(0).mean(-1)
|
| 40 |
+
frequency_list[0] = 0
|
| 41 |
+
|
| 42 |
+
periods = T / (torch.arange(len(frequency_list), device=x.device) + 1e-8)
|
| 43 |
+
mask = (periods >= min_period) & (periods <= max_period)
|
| 44 |
+
frequency_list = frequency_list * mask.float()
|
| 45 |
+
|
| 46 |
+
_, top_list = torch.topk(frequency_list, min(k, mask.sum().item()))
|
| 47 |
+
top_list = top_list.detach().cpu().numpy()
|
| 48 |
+
|
| 49 |
+
period_list = T // (top_list + 1)
|
| 50 |
+
period_list = np.clip(period_list, min_period, max_period)
|
| 51 |
+
|
| 52 |
+
return period_list, abs(xf).mean(-1)[:, top_list]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ScaleSpecificTimesBlock(nn.Module):
|
| 56 |
+
"""TimesBlock focused on a specific temporal scale."""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
configs,
|
| 61 |
+
scale_name: str = "medium",
|
| 62 |
+
min_period: int = 10,
|
| 63 |
+
max_period: int = 50,
|
| 64 |
+
):
|
| 65 |
+
super(ScaleSpecificTimesBlock, self).__init__()
|
| 66 |
+
|
| 67 |
+
self.seq_len = configs.seq_len
|
| 68 |
+
self.pred_len = configs.pred_len
|
| 69 |
+
self.k = configs.top_k
|
| 70 |
+
self.scale_name = scale_name
|
| 71 |
+
self.min_period = min_period
|
| 72 |
+
self.max_period = min(max_period, configs.seq_len // 2)
|
| 73 |
+
|
| 74 |
+
self.conv = nn.Sequential(
|
| 75 |
+
Inception_Block_V1(
|
| 76 |
+
configs.d_model, configs.d_ff, num_kernels=configs.num_kernels
|
| 77 |
+
),
|
| 78 |
+
nn.GELU(),
|
| 79 |
+
Inception_Block_V1(
|
| 80 |
+
configs.d_ff, configs.d_model, num_kernels=configs.num_kernels
|
| 81 |
+
),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.layer_norm = nn.LayerNorm(configs.d_model)
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 87 |
+
B, T, N = x.size()
|
| 88 |
+
|
| 89 |
+
period_list, period_weight = FFT_for_Period_Range(
|
| 90 |
+
x, self.k, self.min_period, self.max_period
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
res = []
|
| 94 |
+
for i in range(len(period_list)):
|
| 95 |
+
period = int(period_list[i])
|
| 96 |
+
if period < 2:
|
| 97 |
+
period = 2
|
| 98 |
+
|
| 99 |
+
total_len = self.seq_len + self.pred_len
|
| 100 |
+
if total_len % period != 0:
|
| 101 |
+
length = ((total_len // period) + 1) * period
|
| 102 |
+
padding = torch.zeros([B, length - total_len, N], device=x.device)
|
| 103 |
+
out = torch.cat([x, padding], dim=1)
|
| 104 |
+
else:
|
| 105 |
+
length = total_len
|
| 106 |
+
out = x
|
| 107 |
+
|
| 108 |
+
out = (
|
| 109 |
+
out.reshape(B, length // period, period, N)
|
| 110 |
+
.permute(0, 3, 1, 2)
|
| 111 |
+
.contiguous()
|
| 112 |
+
)
|
| 113 |
+
out = self.conv(out)
|
| 114 |
+
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
| 115 |
+
res.append(out[:, :total_len, :])
|
| 116 |
+
|
| 117 |
+
if len(res) == 0:
|
| 118 |
+
return x, torch.ones(B, 1, device=x.device)
|
| 119 |
+
|
| 120 |
+
res = torch.stack(res, dim=-1)
|
| 121 |
+
|
| 122 |
+
period_weight = F.softmax(period_weight, dim=1)
|
| 123 |
+
period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
|
| 124 |
+
output = torch.sum(res * period_weight, dim=-1)
|
| 125 |
+
output = self.layer_norm(output + x)
|
| 126 |
+
|
| 127 |
+
return output, period_weight.mean(dim=(1, 2))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class AdaptiveFusionGate(nn.Module):
|
| 131 |
+
"""Adaptive gate for fusing multi-scale features."""
|
| 132 |
+
|
| 133 |
+
def __init__(self, d_model: int, n_scales: int = 3):
|
| 134 |
+
super(AdaptiveFusionGate, self).__init__()
|
| 135 |
+
|
| 136 |
+
self.n_scales = n_scales
|
| 137 |
+
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
| 138 |
+
|
| 139 |
+
self.gate_network = nn.Sequential(
|
| 140 |
+
nn.Linear(d_model * n_scales, d_model),
|
| 141 |
+
nn.ReLU(),
|
| 142 |
+
nn.Linear(d_model, n_scales),
|
| 143 |
+
nn.Softmax(dim=-1),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def forward(self, scale_features: List[torch.Tensor]) -> torch.Tensor:
|
| 147 |
+
B, T, D = scale_features[0].size()
|
| 148 |
+
|
| 149 |
+
contexts = []
|
| 150 |
+
for feat in scale_features:
|
| 151 |
+
ctx = self.global_pool(feat.transpose(1, 2)).squeeze(-1)
|
| 152 |
+
contexts.append(ctx)
|
| 153 |
+
|
| 154 |
+
combined_ctx = torch.cat(contexts, dim=-1)
|
| 155 |
+
weights = self.gate_network(combined_ctx)
|
| 156 |
+
|
| 157 |
+
weights = weights.unsqueeze(1).unsqueeze(-1)
|
| 158 |
+
stacked = torch.stack(scale_features, dim=2)
|
| 159 |
+
fused = (stacked * weights).sum(dim=2)
|
| 160 |
+
|
| 161 |
+
return fused
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class CrossScaleConnection(nn.Module):
|
| 165 |
+
"""Cross-scale residual connections for information exchange between scales."""
|
| 166 |
+
|
| 167 |
+
def __init__(self, d_model: int, n_scales: int = 3):
|
| 168 |
+
super(CrossScaleConnection, self).__init__()
|
| 169 |
+
|
| 170 |
+
self.n_scales = n_scales
|
| 171 |
+
self.cross_attention = nn.MultiheadAttention(
|
| 172 |
+
d_model, num_heads=4, dropout=0.1, batch_first=True
|
| 173 |
+
)
|
| 174 |
+
self.projections = nn.ModuleList(
|
| 175 |
+
[nn.Linear(d_model, d_model) for _ in range(n_scales)]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(self, scale_features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 179 |
+
B, T, D = scale_features[0].size()
|
| 180 |
+
all_scales = torch.cat(scale_features, dim=1)
|
| 181 |
+
|
| 182 |
+
enhanced = []
|
| 183 |
+
for i, feat in enumerate(scale_features):
|
| 184 |
+
attended, _ = self.cross_attention(feat, all_scales, all_scales)
|
| 185 |
+
enhanced_feat = self.projections[i](feat + attended)
|
| 186 |
+
enhanced.append(enhanced_feat)
|
| 187 |
+
|
| 188 |
+
return enhanced
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Model(nn.Module):
|
| 192 |
+
"""
|
| 193 |
+
MTSTimesNet: Multi-scale Temporal TimesNet
|
| 194 |
+
|
| 195 |
+
Architecture:
|
| 196 |
+
- Shared Embedding Layer
|
| 197 |
+
- Parallel Multi-scale TimesBlocks
|
| 198 |
+
- Cross-scale Connections
|
| 199 |
+
- Adaptive Fusion Gate
|
| 200 |
+
- Output Projection
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
SCALE_CONFIGS = {
|
| 204 |
+
"short": {"min_period": 2, "max_period": 20},
|
| 205 |
+
"medium": {"min_period": 20, "max_period": 60},
|
| 206 |
+
"long": {"min_period": 60, "max_period": 200},
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
def __init__(self, configs):
|
| 210 |
+
super(Model, self).__init__()
|
| 211 |
+
|
| 212 |
+
self.configs = configs
|
| 213 |
+
self.task_name = configs.task_name
|
| 214 |
+
self.seq_len = configs.seq_len
|
| 215 |
+
self.label_len = getattr(configs, "label_len", 0)
|
| 216 |
+
self.pred_len = getattr(configs, "pred_len", 0)
|
| 217 |
+
|
| 218 |
+
# 创建实例级别的 scale_configs 副本,避免修改类属性
|
| 219 |
+
import copy
|
| 220 |
+
self.scale_configs = copy.deepcopy(self.SCALE_CONFIGS)
|
| 221 |
+
self._adjust_scale_configs()
|
| 222 |
+
|
| 223 |
+
self.enc_embedding = DataEmbedding(
|
| 224 |
+
configs.enc_in,
|
| 225 |
+
configs.d_model,
|
| 226 |
+
configs.embed,
|
| 227 |
+
configs.freq,
|
| 228 |
+
configs.dropout,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
self.n_layers = configs.e_layers
|
| 232 |
+
self.n_scales = len(self.scale_configs)
|
| 233 |
+
|
| 234 |
+
self.scale_blocks = nn.ModuleDict()
|
| 235 |
+
for scale_name, scale_cfg in self.scale_configs.items():
|
| 236 |
+
self.scale_blocks[scale_name] = nn.ModuleList(
|
| 237 |
+
[
|
| 238 |
+
ScaleSpecificTimesBlock(
|
| 239 |
+
configs,
|
| 240 |
+
scale_name=scale_name,
|
| 241 |
+
min_period=scale_cfg["min_period"],
|
| 242 |
+
max_period=scale_cfg["max_period"],
|
| 243 |
+
)
|
| 244 |
+
for _ in range(self.n_layers)
|
| 245 |
+
]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.cross_scale = nn.ModuleList(
|
| 249 |
+
[
|
| 250 |
+
CrossScaleConnection(configs.d_model, self.n_scales)
|
| 251 |
+
for _ in range(self.n_layers - 1)
|
| 252 |
+
]
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.fusion_gate = AdaptiveFusionGate(configs.d_model, self.n_scales)
|
| 256 |
+
self.layer_norm = nn.LayerNorm(configs.d_model)
|
| 257 |
+
|
| 258 |
+
if self.task_name == "anomaly_detection" or self.task_name == "imputation":
|
| 259 |
+
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 260 |
+
|
| 261 |
+
if (
|
| 262 |
+
self.task_name == "long_term_forecast"
|
| 263 |
+
or self.task_name == "short_term_forecast"
|
| 264 |
+
):
|
| 265 |
+
self.predict_linear = nn.Linear(self.seq_len, self.pred_len + self.seq_len)
|
| 266 |
+
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 267 |
+
|
| 268 |
+
if self.task_name == "classification":
|
| 269 |
+
self.act = F.gelu
|
| 270 |
+
self.dropout = nn.Dropout(configs.dropout)
|
| 271 |
+
self.projection = nn.Linear(
|
| 272 |
+
configs.d_model * configs.seq_len, configs.num_class
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def _adjust_scale_configs(self):
|
| 276 |
+
"""调整尺度配置,使用实例属性避免类属性污染"""
|
| 277 |
+
seq_len = self.seq_len
|
| 278 |
+
for scale_name in self.scale_configs:
|
| 279 |
+
max_period = self.scale_configs[scale_name]["max_period"]
|
| 280 |
+
if max_period > seq_len // 2:
|
| 281 |
+
self.scale_configs[scale_name]["max_period"] = seq_len // 2
|
| 282 |
+
|
| 283 |
+
valid_scales = {
|
| 284 |
+
k: v
|
| 285 |
+
for k, v in self.scale_configs.items()
|
| 286 |
+
if v["min_period"] < v["max_period"]
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if len(valid_scales) < 3:
|
| 290 |
+
self.scale_configs = {
|
| 291 |
+
"short": {"min_period": 2, "max_period": max(4, seq_len // 10)},
|
| 292 |
+
"medium": {
|
| 293 |
+
"min_period": max(4, seq_len // 10),
|
| 294 |
+
"max_period": max(8, seq_len // 5),
|
| 295 |
+
},
|
| 296 |
+
"long": {
|
| 297 |
+
"min_period": max(8, seq_len // 5),
|
| 298 |
+
"max_period": seq_len // 2,
|
| 299 |
+
},
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
def _process_multi_scale(self, x: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
scale_names = list(self.scale_blocks.keys())
|
| 304 |
+
scale_features = {name: x for name in scale_names}
|
| 305 |
+
|
| 306 |
+
for layer_idx in range(self.n_layers):
|
| 307 |
+
new_features = {}
|
| 308 |
+
for scale_name in scale_names:
|
| 309 |
+
# self.scale_blocks[scale_name] 是 ModuleList,直接索引
|
| 310 |
+
block_list = self.scale_blocks[scale_name]
|
| 311 |
+
feat, _ = block_list[layer_idx](scale_features[scale_name])
|
| 312 |
+
new_features[scale_name] = feat
|
| 313 |
+
|
| 314 |
+
if layer_idx < self.n_layers - 1:
|
| 315 |
+
feature_list = [new_features[name] for name in scale_names]
|
| 316 |
+
enhanced_list = self.cross_scale[layer_idx](feature_list)
|
| 317 |
+
for i, name in enumerate(scale_names):
|
| 318 |
+
new_features[name] = enhanced_list[i]
|
| 319 |
+
|
| 320 |
+
scale_features = new_features
|
| 321 |
+
|
| 322 |
+
feature_list = [scale_features[name] for name in scale_names]
|
| 323 |
+
fused = self.fusion_gate(feature_list)
|
| 324 |
+
|
| 325 |
+
return self.layer_norm(fused)
|
| 326 |
+
|
| 327 |
+
def anomaly_detection(self, x_enc: torch.Tensor) -> torch.Tensor:
|
| 328 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 329 |
+
x_enc = x_enc - means
|
| 330 |
+
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 331 |
+
x_enc = x_enc / stdev
|
| 332 |
+
|
| 333 |
+
enc_out = self.enc_embedding(x_enc, None)
|
| 334 |
+
enc_out = self._process_multi_scale(enc_out)
|
| 335 |
+
dec_out = self.projection(enc_out)
|
| 336 |
+
|
| 337 |
+
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1).repeat(
|
| 338 |
+
1, self.seq_len + self.pred_len, 1
|
| 339 |
+
)
|
| 340 |
+
dec_out = dec_out + means[:, 0, :].unsqueeze(1).repeat(
|
| 341 |
+
1, self.seq_len + self.pred_len, 1
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
return dec_out
|
| 345 |
+
|
| 346 |
+
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 347 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 348 |
+
x_enc = x_enc - means
|
| 349 |
+
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 350 |
+
x_enc = x_enc / stdev
|
| 351 |
+
|
| 352 |
+
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 353 |
+
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1)
|
| 354 |
+
enc_out = self._process_multi_scale(enc_out)
|
| 355 |
+
dec_out = self.projection(enc_out)
|
| 356 |
+
|
| 357 |
+
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1).repeat(
|
| 358 |
+
1, self.pred_len + self.seq_len, 1
|
| 359 |
+
)
|
| 360 |
+
dec_out = dec_out + means[:, 0, :].unsqueeze(1).repeat(
|
| 361 |
+
1, self.pred_len + self.seq_len, 1
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
return dec_out
|
| 365 |
+
|
| 366 |
+
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 367 |
+
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
| 368 |
+
means = means.unsqueeze(1).detach()
|
| 369 |
+
x_enc = x_enc - means
|
| 370 |
+
x_enc = x_enc.masked_fill(mask == 0, 0)
|
| 371 |
+
stdev = torch.sqrt(
|
| 372 |
+
torch.sum(x_enc * x_enc, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5
|
| 373 |
+
)
|
| 374 |
+
stdev = stdev.unsqueeze(1).detach()
|
| 375 |
+
x_enc = x_enc / stdev
|
| 376 |
+
|
| 377 |
+
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 378 |
+
enc_out = self._process_multi_scale(enc_out)
|
| 379 |
+
dec_out = self.projection(enc_out)
|
| 380 |
+
|
| 381 |
+
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1).repeat(
|
| 382 |
+
1, self.pred_len + self.seq_len, 1
|
| 383 |
+
)
|
| 384 |
+
dec_out = dec_out + means[:, 0, :].unsqueeze(1).repeat(
|
| 385 |
+
1, self.pred_len + self.seq_len, 1
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
return dec_out
|
| 389 |
+
|
| 390 |
+
def classification(self, x_enc, x_mark_enc):
|
| 391 |
+
enc_out = self.enc_embedding(x_enc, None)
|
| 392 |
+
enc_out = self._process_multi_scale(enc_out)
|
| 393 |
+
|
| 394 |
+
output = self.act(enc_out)
|
| 395 |
+
output = self.dropout(output)
|
| 396 |
+
output = output * x_mark_enc.unsqueeze(-1)
|
| 397 |
+
output = output.reshape(output.shape[0], -1)
|
| 398 |
+
output = self.projection(output)
|
| 399 |
+
|
| 400 |
+
return output
|
| 401 |
+
|
| 402 |
+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 403 |
+
if (
|
| 404 |
+
self.task_name == "long_term_forecast"
|
| 405 |
+
or self.task_name == "short_term_forecast"
|
| 406 |
+
):
|
| 407 |
+
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 408 |
+
return dec_out[:, -self.pred_len :, :]
|
| 409 |
+
if self.task_name == "imputation":
|
| 410 |
+
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 411 |
+
return dec_out
|
| 412 |
+
if self.task_name == "anomaly_detection":
|
| 413 |
+
dec_out = self.anomaly_detection(x_enc)
|
| 414 |
+
return dec_out
|
| 415 |
+
if self.task_name == "classification":
|
| 416 |
+
dec_out = self.classification(x_enc, x_mark_enc)
|
| 417 |
+
return dec_out
|
| 418 |
+
return None
|