bakhshaliyev commited on
Commit
d88ab6d
·
verified ·
1 Parent(s): 8e728a9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/architecture.pdf filter=lfs diff=lfs merge=lfs -text
37
+ assets/spikf-go-architecture.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/supplementary.pdf filter=lfs diff=lfs merge=lfs -text
CITATION.cff ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: "SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting"
3
+ message: "If you use this code, please cite our ECML PKDD 2026 paper."
4
+ authors:
5
+ - family-names: Bakhshaliyev
6
+ given-names: Jafar
7
+ - family-names: Landwehr
8
+ given-names: Niels
9
+ year: 2026
10
+ conference: "ECML PKDD 2026"
11
+ repository-code: "https://github.com/jafarbakhshaliyev/SpikF-GO"
12
+ license: MIT
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Jafar Bakhshaliyev
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,194 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ tags:
4
+ - time-series
5
+ - forecasting
6
+ - spiking-neural-networks
7
+ - graph-neural-networks
8
+ - multivariate-time-series
9
  ---
10
+
11
+ # SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting
12
+
13
+ [![arXiv](https://img.shields.io/badge/arXiv-2606.13901-b31b1b.svg)](https://arxiv.org/abs/2606.13901)
14
+ [![ECML PKDD 2026](https://img.shields.io/badge/ECML%20PKDD-2026-blue.svg)](https://arxiv.org/abs/2606.13901)
15
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
16
+
17
+ 📄 **Paper (arXiv):** https://arxiv.org/abs/2606.13901
18
+ 💻 **GitHub:** https://github.com/jafarbakhshaliyev/SpikF-GO
19
+
20
+ Official implementation of **SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting**, accepted to the **ECML PKDD 2026 Research Track**.
21
+
22
+ ![SpikF-GO architecture](assets/spikf-go-architecture.png)
23
+
24
+ ---
25
+
26
+ ## Abstract
27
+
28
+ SpikF-GO is a spiking neural architecture for multivariate time series forecasting. It combines the hypervariate graph formulation of FourierGNN with spike-driven Fourier-domain graph processing, enabling joint modeling of intra-series temporal dependencies, inter-series dependencies, and time-varying cross-variable interactions. The model introduces sparse frequency selection and Complex LIF-based spectral gating to preserve event-driven computation in the Fourier domain. We also provide **SpikF-GO w/ CPG**, which incorporates Central Pattern Generator-based positional signals for improved long-range temporal modeling.
29
+
30
+ ---
31
+
32
+ ## Key Contributions
33
+
34
+ - **Graph-based SNN forecasting:** SpikF-GO brings hypervariate graph modeling into SNN-based multivariate time series forecasting.
35
+ - **Spike-driven Fourier graph operators:** The model combines sparse frequency gating with Complex LIF-based spectral processing to preserve event-driven computation in the Fourier domain.
36
+ - **Unified SNN benchmark:** We evaluate SpikF-GO against major SNN forecasting families under a common experimental protocol across eight benchmark datasets.
37
+ - **Energy-aware forecasting:** SpikF-GO achieves competitive-to-superior forecasting performance while reducing theoretical energy consumption relative to FourierGNN.
38
+
39
+ ---
40
+
41
+ ## Related Library: SpikingTSF
42
+
43
+ We also maintain **[SpikingTSF](https://github.com/spikora/SpikingTSF)**, a broader open-source benchmark library for spiking neural network-based time series forecasting. SpikingTSF unifies SNN forecasting architectures and ANN baselines under a common training and evaluation protocol across datasets, horizons, metrics, and random seeds.
44
+
45
+ > **Note:** SpikingTSF is a benchmarking library and may not reproduce all experiments from this repository directly.
46
+
47
+ ---
48
+
49
+ ## Repository Structure
50
+
51
+ ```
52
+ SpikF-GO/
53
+ ├── README.md
54
+ ├── LICENSE
55
+ ├── CITATION.cff
56
+ ├── requirements.txt
57
+ ├── train.py # main training & evaluation entry point
58
+ ├── model/ # SpikF-GO + all baseline implementations
59
+ ├── utils/ # shared utilities (metrics, helpers)
60
+ ├── data/
61
+ │ └── data_loader.py # dataset loading (raw files placed here at runtime)
62
+ ├── scripts/
63
+ │ ├── ecg.sh
64
+ │ ├── covid.sh
65
+ │ ├── solar.sh
66
+ │ ├── ecl.sh
67
+ │ ├── traffic.sh
68
+ │ ├── metr_la.sh
69
+ │ ├── pems_bay.sh
70
+ │ └── wiki.sh
71
+ └── assets/
72
+ ├── spikf-go-architecture.png
73
+ └── supplementary.pdf
74
+ ```
75
+
76
+ ---
77
+
78
+ ## Environment Setup
79
+
80
+ Create and activate a virtual environment:
81
+
82
+ **Linux / macOS**
83
+ ```bash
84
+ python3 -m venv venv
85
+ source venv/bin/activate
86
+ ```
87
+
88
+ **Windows**
89
+ ```bash
90
+ python -m venv venv
91
+ venv\Scripts\activate
92
+ ```
93
+
94
+ Install dependencies:
95
+ ```bash
96
+ pip install -r requirements.txt
97
+ ```
98
+
99
+ Experiments were run with **PyTorch 2.5.1** on a single **NVIDIA RTX 4090**.
100
+
101
+ ---
102
+
103
+ ## Dataset
104
+
105
+ Download the processed datasets from Figshare:
106
+
107
+ https://figshare.com/s/7617530bce306584fe95?file=62576929
108
+
109
+ Place all dataset files **directly** inside the `data/` folder (do **not** create subfolders):
110
+
111
+ ```
112
+ SpikF-GO/
113
+ ├── data/
114
+ │ ├── dataset_file_1
115
+ │ ├── dataset_file_2
116
+ │ └── ...
117
+ ├── model/
118
+ ├── scripts/
119
+ └── train.py
120
+ ```
121
+
122
+ ---
123
+
124
+ ## Run Experiments
125
+
126
+ Scripts are in `scripts/`, one per dataset:
127
+
128
+ ```bash
129
+ bash scripts/ecg.sh
130
+ bash scripts/covid.sh
131
+ bash scripts/solar.sh
132
+ bash scripts/ecl.sh
133
+ bash scripts/traffic.sh
134
+ bash scripts/metr_la.sh
135
+ bash scripts/pems_bay.sh
136
+ bash scripts/wiki.sh
137
+ ```
138
+
139
+ Each script sets the exact hyperparameters used to produce the results reported in the paper.
140
+
141
+ ---
142
+
143
+ ## Supplementary Material
144
+
145
+ Available at [`assets/supplementary.pdf`](assets/supplementary.pdf).
146
+
147
+ ---
148
+
149
+ ## Citation
150
+
151
+ If you use this code or build on SpikF-GO, please cite our paper:
152
+
153
+ **arXiv preprint:**
154
+ ```bibtex
155
+ @misc{bakhshaliyev2026spikfgo,
156
+ title = {SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting},
157
+ author = {Bakhshaliyev, Jafar and Landwehr, Niels},
158
+ year = {2026},
159
+ eprint = {2606.13901},
160
+ archivePrefix= {arXiv},
161
+ primaryClass = {cs.LG},
162
+ url = {https://arxiv.org/abs/2606.13901}
163
+ }
164
+ ```
165
+
166
+ **ECML PKDD 2026 proceedings:**
167
+ ```bibtex
168
+ @inproceedings{bakhshaliyev2026spikfgo,
169
+ title = {SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting},
170
+ author = {Bakhshaliyev, Jafar and Landwehr, Niels},
171
+ booktitle = {ECML PKDD},
172
+ year = {2026}
173
+ }
174
+ ```
175
+
176
+ See [`CITATION.cff`](CITATION.cff) for full citation metadata.
177
+
178
+ ---
179
+
180
+ ## Acknowledgements
181
+
182
+ The baselines in `model/` build on prior work. We thank the authors for releasing their code; original licenses are respected.
183
+
184
+ - **`SpikF.py`** — adapted from **SpikF** (Wu, Huo & Chen, *"SpikF: Spiking Fourier Network for Efficient Long-term Prediction"*, [ICML 2025 / PMLR v267](https://proceedings.mlr.press/v267/wu25m.html)).
185
+ - **`TS_Former.py`, `TS_GRU.py`, `TS_TCN.py`** — adapted from **TS-LIF** (Feng et al., *"TS-LIF: A Temporal Segment Spiking Neuron Network for Time Series Forecasting"*, [arXiv:2503.05108](https://arxiv.org/abs/2503.05108)).
186
+ - **`iSpikformer.py`, `SpikeGRU.py`** — adapted from **SeqSNN** (Lv et al., *"Efficient and Effective Time-Series Forecasting with Spiking Neural Networks"*, [arXiv:2402.01533](https://arxiv.org/abs/2402.01533)), [microsoft/SeqSNN](https://github.com/microsoft/SeqSNN).
187
+ - **`SpikeRNN_CPG.py`, `SpikeTCN_CPG.py`, `Spikformer_CPG.py`** — CPG variants build on [arXiv:2405.14362](https://arxiv.org/abs/2405.14362) / [microsoft/SeqSNN](https://github.com/microsoft/SeqSNN).
188
+ - **`FourierGNN.py`** — adapted from **FourierGNN**, [arXiv:2311.06190](https://arxiv.org/abs/2311.06190) / [aikunyi/FourierGNN](https://github.com/aikunyi/FourierGNN).
189
+
190
+ ---
191
+
192
+ ## License
193
+
194
+ This project is released under the [MIT License](LICENSE).
assets/architecture.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b4851b633d30ef57fa79e8c1096ad660b31e8a9dad9376e7840b39ff3adc0a4
3
+ size 194717
assets/spikf-go-architecture.png ADDED

Git LFS Details

  • SHA256: b040e6768b43142c925d80b2d60394e29ac526aa3f592b81a020e438feb7a0c6
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
assets/supplementary.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:353d376afbe7e500c8af0a7cf21f927a95eeac9feb781a324991c5b2b063dc4e
3
+ size 219947
data/.gitkeep ADDED
File without changes
data/data_loader.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from torch.utils.data import Dataset
7
+ from sklearn.preprocessing import StandardScaler
8
+
9
+
10
+
11
+ def _split_with_overlap(data: np.ndarray, train_ratio: float, val_ratio: float, seq_len: int):
12
+ """
13
+ Time split with overlap for val/test to allow past context:
14
+ train: [0 : train_end)
15
+ val : [train_end - seq_len : val_end)
16
+ test : [val_end - seq_len : T)
17
+ """
18
+ T = len(data)
19
+ train_end = int(T * train_ratio)
20
+ val_end = int(T * (train_ratio + val_ratio))
21
+
22
+ train_end = max(0, min(train_end, T))
23
+ val_end = max(train_end, min(val_end, T))
24
+
25
+ val_start = max(0, train_end - seq_len)
26
+ test_start = max(0, val_end - seq_len)
27
+
28
+ train_data = data[:train_end]
29
+ val_data = data[val_start:val_end]
30
+ test_data = data[test_start:]
31
+
32
+ return train_data, val_data, test_data
33
+
34
+
35
+ def _fit_transform_splits(train_data, val_data, test_data, type_flag: str, scaler=None):
36
+ if type_flag == "1":
37
+ if scaler is None:
38
+ scaler = StandardScaler()
39
+ scaler.fit(train_data)
40
+ train_data = scaler.transform(train_data)
41
+ val_data = scaler.transform(val_data)
42
+ test_data = scaler.transform(test_data)
43
+ return train_data, val_data, test_data, scaler
44
+ else:
45
+ return train_data, val_data, test_data, None
46
+
47
+
48
+ def _to_float32(x: np.ndarray) -> np.ndarray:
49
+ return np.asarray(x, dtype=np.float32)
50
+
51
+
52
+ def _clean_numeric_csv(df: pd.DataFrame) -> np.ndarray:
53
+ """
54
+ Keep only numeric columns, and drop common junk index columns.
55
+ """
56
+ drop_cols = [c for c in df.columns if str(c).lower().startswith("unnamed")]
57
+ if drop_cols:
58
+ df = df.drop(columns=drop_cols, errors="ignore")
59
+
60
+ num_df = df.select_dtypes(include=[np.number])
61
+
62
+ if num_df.shape[1] == 0:
63
+ raise ValueError("No numeric columns found in CSV after cleaning. Check your file format.")
64
+
65
+ num_df = num_df.dropna(axis=0, how="any")
66
+
67
+ return num_df.values.astype(np.float32)
68
+
69
+
70
+
71
+ class _BaseTimeSeriesDataset(Dataset):
72
+
73
+ def __init__(self, flag, seq_len, pre_len):
74
+ assert flag in ["train", "val", "test"]
75
+ self.flag = flag
76
+ self.seq_len = int(seq_len)
77
+ self.pre_len = int(pre_len)
78
+ self.scaler = None
79
+ self.split = None
80
+
81
+ def __getitem__(self, index):
82
+ s_begin = index
83
+ s_end = s_begin + self.seq_len
84
+ r_end = s_end + self.pre_len
85
+
86
+ x = self.split[s_begin:s_end]
87
+ y = self.split[s_end:r_end]
88
+ return x, y
89
+
90
+ def __len__(self):
91
+ if self.split is None:
92
+ return 0
93
+ return max(0, len(self.split) - self.seq_len - self.pre_len)
94
+
95
+
96
+ class Dataset_Dhfm(_BaseTimeSeriesDataset):
97
+ def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
98
+ super().__init__(flag, seq_len, pre_len)
99
+ self.path = root_path
100
+
101
+ load_data = np.load(root_path)
102
+ data = np.array(load_data).transpose()
103
+ data = _to_float32(data)
104
+
105
+ train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
106
+ train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
107
+
108
+ if self.flag == "train":
109
+ self.split = train_data
110
+ elif self.flag == "val":
111
+ self.split = val_data
112
+ else:
113
+ self.split = test_data
114
+
115
+
116
+ class Dataset_ECG(_BaseTimeSeriesDataset):
117
+ def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
118
+ super().__init__(flag, seq_len, pre_len)
119
+ self.path = root_path
120
+
121
+ df = pd.read_csv(root_path)
122
+ data = _clean_numeric_csv(df)
123
+
124
+ train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
125
+ train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
126
+
127
+ if self.flag == "train":
128
+ self.split = train_data
129
+ elif self.flag == "val":
130
+ self.split = val_data
131
+ else:
132
+ self.split = test_data
133
+
134
+ class Dataset_Solar(_BaseTimeSeriesDataset):
135
+ def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
136
+ super().__init__(flag, seq_len, pre_len)
137
+ self.path = root_path
138
+
139
+ files = os.listdir(root_path)
140
+ solar_data = []
141
+ time_data = None
142
+
143
+ for file in files:
144
+ full = os.path.join(root_path, file)
145
+ if os.path.isdir(full):
146
+ continue
147
+ if file.startswith("DA_"):
148
+ arr = pd.read_csv(full).values
149
+ raw_time = arr[:, 0:1]
150
+ if time_data is None:
151
+ time_data = raw_time
152
+ raw_data = arr[:, 1:arr.shape[1]]
153
+ raw_data = raw_data.transpose()
154
+ solar_data.append(raw_data)
155
+
156
+ if len(solar_data) == 0 or time_data is None:
157
+ raise ValueError(f"No solar files found in {root_path} with prefix 'DA_'.")
158
+
159
+ solar_data = np.array(solar_data).squeeze(1).transpose() # (T, N)
160
+ time_data = np.array(time_data) # (T, 1)
161
+ out = np.concatenate((time_data, solar_data), axis=1) # (T, 1+N)
162
+
163
+ filtered = []
164
+ for item in out:
165
+ dt = datetime.datetime.strptime(item[0], "%m/%d/%y %H:%M")
166
+ if 8 <= dt.hour <= 17:
167
+ filtered.append(item[1:out.shape[1]-1])
168
+
169
+ data = _to_float32(np.array(filtered))
170
+
171
+ train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
172
+ train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
173
+
174
+ if self.flag == "train":
175
+ self.split = train_data
176
+ elif self.flag == "val":
177
+ self.split = val_data
178
+ else:
179
+ self.split = test_data
180
+
181
+
182
+ class Dataset_Wiki(_BaseTimeSeriesDataset):
183
+ def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None):
184
+ super().__init__(flag, seq_len, pre_len)
185
+ self.path = root_path
186
+
187
+ df = pd.read_csv(root_path)
188
+
189
+ if df.shape[1] < 2:
190
+ raise ValueError("Wiki CSV must have at least 2 columns (time + features).")
191
+
192
+ df_feat = df.iloc[:, 1:]
193
+ data = _clean_numeric_csv(df_feat)
194
+
195
+ train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
196
+ train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
197
+
198
+ if self.flag == "train":
199
+ self.split = train_data
200
+ elif self.flag == "val":
201
+ self.split = val_data
202
+ else:
203
+ self.split = test_data
204
+
205
+
206
+
207
+ class Dataset_PEMS_BAY(_BaseTimeSeriesDataset):
208
+ def __init__(self, root_path, flag, seq_len, pre_len, type, train_ratio, val_ratio, scaler=None, fillna="ffill"):
209
+ super().__init__(flag, seq_len, pre_len)
210
+ self.path = root_path
211
+
212
+ obj = pd.read_hdf(root_path)
213
+
214
+ if isinstance(obj, pd.Series):
215
+ df = obj.to_frame()
216
+ elif isinstance(obj, pd.DataFrame):
217
+ df = obj
218
+ else:
219
+ df = pd.DataFrame(obj)
220
+
221
+ if fillna == "ffill":
222
+ df = df.ffill()
223
+ df = df.fillna(0.0)
224
+ elif fillna == "zero":
225
+ df = df.fillna(0.0)
226
+ elif fillna == "drop":
227
+ df = df.dropna(axis=0, how="any")
228
+ elif fillna is None:
229
+ pass
230
+ else:
231
+ raise ValueError("fillna must be one of: 'ffill', 'zero', 'drop', or None")
232
+
233
+ data = df.values.astype(np.float32)
234
+
235
+ train_data, val_data, test_data = _split_with_overlap(data, train_ratio, val_ratio, self.seq_len)
236
+ train_data, val_data, test_data, self.scaler = _fit_transform_splits(train_data, val_data, test_data, type, scaler)
237
+
238
+ if self.flag == "train":
239
+ self.split = train_data
240
+ elif self.flag == "val":
241
+ self.split = val_data
242
+ else:
243
+ self.split = test_data
model/FourierGNN.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class FGN(nn.Module):
6
+ def __init__(self, args, pre_length, embed_size,
7
+ feature_size, seq_length, hidden_size, hard_thresholding_fraction=1, hidden_size_factor=1, sparsity_threshold=0.01):
8
+ super().__init__()
9
+ self.embed_size = embed_size
10
+ self.hidden_size = hidden_size
11
+ self.number_frequency = 1
12
+ self.pre_length = pre_length
13
+ self.feature_size = feature_size
14
+ self.seq_length = seq_length
15
+ self.frequency_size = self.embed_size // self.number_frequency
16
+ self.hidden_size_factor = hidden_size_factor
17
+ self.sparsity_threshold = sparsity_threshold
18
+ self.hard_thresholding_fraction = hard_thresholding_fraction
19
+ self.scale = 0.02
20
+ self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
21
+ self.args = args
22
+
23
+ self.w1 = nn.Parameter(
24
+ self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor))
25
+ self.b1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
26
+ self.w2 = nn.Parameter(
27
+ self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor, self.frequency_size))
28
+ self.b2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size))
29
+ self.w3 = nn.Parameter(
30
+ self.scale * torch.randn(2, self.frequency_size,
31
+ self.frequency_size * self.hidden_size_factor))
32
+ self.b3 = nn.Parameter(
33
+ self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor))
34
+ self.embeddings_10 = nn.Parameter(torch.randn(self.seq_length, 8))
35
+ self.fc = nn.Sequential(
36
+ nn.Linear(self.embed_size * 8, 64),
37
+ nn.LeakyReLU(),
38
+ nn.Linear(64, self.hidden_size),
39
+ nn.LeakyReLU(),
40
+ nn.Linear(self.hidden_size, self.pre_length)
41
+ )
42
+ self.to('cuda:0')
43
+
44
+ def tokenEmb(self, x):
45
+ x = x.unsqueeze(2)
46
+ y = self.embeddings
47
+ return x * y
48
+
49
+ # FourierGNN
50
+ def fourierGC(self, x, B, N, L):
51
+ o1_real = torch.zeros([B, (N*L)//2 + 1, self.frequency_size * self.hidden_size_factor],
52
+ device=x.device)
53
+ o1_imag = torch.zeros([B, (N*L)//2 + 1, self.frequency_size * self.hidden_size_factor],
54
+ device=x.device)
55
+ o2_real = torch.zeros(x.shape, device=x.device)
56
+ o2_imag = torch.zeros(x.shape, device=x.device)
57
+
58
+ o3_real = torch.zeros(x.shape, device=x.device)
59
+ o3_imag = torch.zeros(x.shape, device=x.device)
60
+
61
+ o1_real = F.relu(
62
+ torch.einsum('bli,ii->bli', x.real, self.w1[0]) - \
63
+ torch.einsum('bli,ii->bli', x.imag, self.w1[1]) + \
64
+ self.b1[0]
65
+ )
66
+
67
+ o1_imag = F.relu(
68
+ torch.einsum('bli,ii->bli', x.imag, self.w1[0]) + \
69
+ torch.einsum('bli,ii->bli', x.real, self.w1[1]) + \
70
+ self.b1[1]
71
+ )
72
+
73
+ # 1 layer
74
+ y = torch.stack([o1_real, o1_imag], dim=-1)
75
+ y = F.softshrink(y, lambd=self.sparsity_threshold)
76
+
77
+ o2_real = F.relu(
78
+ torch.einsum('bli,ii->bli', o1_real, self.w2[0]) - \
79
+ torch.einsum('bli,ii->bli', o1_imag, self.w2[1]) + \
80
+ self.b2[0]
81
+ )
82
+
83
+ o2_imag = F.relu(
84
+ torch.einsum('bli,ii->bli', o1_imag, self.w2[0]) + \
85
+ torch.einsum('bli,ii->bli', o1_real, self.w2[1]) + \
86
+ self.b2[1]
87
+ )
88
+
89
+ # 2 layer
90
+ x = torch.stack([o2_real, o2_imag], dim=-1)
91
+ x = F.softshrink(x, lambd=self.sparsity_threshold)
92
+ x = x + y
93
+
94
+ o3_real = F.relu(
95
+ torch.einsum('bli,ii->bli', o2_real, self.w3[0]) - \
96
+ torch.einsum('bli,ii->bli', o2_imag, self.w3[1]) + \
97
+ self.b3[0]
98
+ )
99
+
100
+ o3_imag = F.relu(
101
+ torch.einsum('bli,ii->bli', o2_imag, self.w3[0]) + \
102
+ torch.einsum('bli,ii->bli', o2_real, self.w3[1]) + \
103
+ self.b3[1]
104
+ )
105
+
106
+ # 3 layer
107
+ z = torch.stack([o3_real, o3_imag], dim=-1)
108
+ z = F.softshrink(z, lambd=self.sparsity_threshold)
109
+ z = z + x
110
+ z = torch.view_as_complex(z)
111
+ return z
112
+
113
+ def forward(self, x):
114
+
115
+ if self.args.normalize:
116
+
117
+ mean = x.mean(dim=1, keepdim=True).detach()
118
+ x = x - mean
119
+
120
+ std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
121
+ x = x / std
122
+
123
+
124
+ x = x.permute(0, 2, 1).contiguous()
125
+ B, N, L = x.shape
126
+ # B*N*L ==> B*NL
127
+ x = x.reshape(B, -1)
128
+ # embedding B*NL ==> B*NL*D
129
+ x = self.tokenEmb(x)
130
+
131
+
132
+ # FFT B*NL*D ==> B*NT/2*D
133
+ x = torch.fft.rfft(x, dim=1, norm='ortho')
134
+
135
+ x = x.reshape(B, (N*L)//2+1, self.frequency_size)
136
+
137
+ bias = x
138
+
139
+ # FourierGNN
140
+ x = self.fourierGC(x, B, N, L)
141
+
142
+ x = x + bias
143
+
144
+ x = x.reshape(B, (N*L)//2+1, self.embed_size)
145
+
146
+ # ifft
147
+ x = torch.fft.irfft(x, n=N*L, dim=1, norm="ortho")
148
+
149
+ x = x.reshape(B, N, L, self.embed_size)
150
+ x = x.permute(0, 1, 3, 2) # B, N, D, L
151
+
152
+ # projection
153
+ x = torch.matmul(x, self.embeddings_10)
154
+ x = x.reshape(B, N, -1)
155
+ x = self.fc(x)
156
+ x = x.permute(0, 2, 1)
157
+
158
+ if self.args.normalize:
159
+ x = x * std
160
+ x = x + mean
161
+
162
+
163
+ aux = {
164
+ 'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
165
+ }
166
+
167
+ return x, aux
168
+
model/SpikF.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from spikingjelly.clock_driven.neuron import MultiStepLIFNode
4
+
5
+ class SPE(nn.Module):
6
+ def __init__(self, input_len, patch_num, patch_dim, T, tau, D):
7
+ super().__init__()
8
+ self.patch_projector = nn.Linear(input_len // patch_num, patch_dim)
9
+ self.bn = nn.BatchNorm2d(patch_dim)
10
+ self.encoder_lif = MultiStepLIFNode(tau=tau, detach_reset=False, backend='torch')
11
+
12
+ self.D = D
13
+ self.T = T
14
+ self.patch_dim = patch_dim
15
+ self.patch_num = patch_num
16
+
17
+ def forward(self, x):
18
+ B, L, D = x.shape
19
+
20
+ x = x.view(B, self.patch_num, L // self.patch_num, D).contiguous()
21
+ x = x.transpose(-1, -2).contiguous()
22
+ x = self.patch_projector(x)
23
+ x = x.repeat(self.T, 1, 1, 1, 1)
24
+ x = x.permute(0, 1, 4, 2, 3).contiguous()
25
+ x = x.flatten(0, 1)
26
+ x = self.bn(x)
27
+ x = x.view(self.T, B, self.patch_dim, self.patch_num, D)
28
+ x = self.encoder_lif(x)
29
+
30
+ return x
31
+
32
+ class SFS(nn.Module):
33
+ def __init__(self, patch_num, D, patch_dim, tau, alpha):
34
+ super().__init__()
35
+ self.time2freq = nn.Linear(patch_num, patch_num // 2 + 1)
36
+
37
+ self.intra_conv = nn.Conv2d(in_channels=patch_dim, out_channels=patch_dim, kernel_size=[5, 1], stride=[1, 1], padding=[2, 0])
38
+ self.inter_conv = nn.Conv2d(in_channels=patch_dim, out_channels=patch_dim, kernel_size=[3, 1], stride=[1, 1], padding=[1, 0])
39
+
40
+ self.generator_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch', v_threshold=0.1)
41
+ self.mp_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
42
+ self.sfs_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
43
+ self.intra_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
44
+ self.inter_lif = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
45
+
46
+ self.bn1 = nn.BatchNorm2d(patch_dim)
47
+ self.bn2 = nn.BatchNorm2d(patch_dim)
48
+ self.bn3 = nn.BatchNorm2d(patch_dim)
49
+ self.bn4 = nn.BatchNorm2d(patch_dim)
50
+
51
+ def forward(self, x):
52
+ res_x = x
53
+ T, B, pd, pn, D = x.shape
54
+
55
+ x = x.transpose(-1, -2).contiguous()
56
+ freq_spec = torch.fft.rfft(x)
57
+
58
+ selector = self.time2freq(x)
59
+ selector = selector.flatten(0, 1)
60
+ selector = self.bn1(selector)
61
+ selector = selector.view(T, B, pd, D, -1)
62
+ selector = self.generator_lif(selector)
63
+ selector = selector.sum(dim=0, keepdim=True)
64
+ selector = self.mp_lif(selector)
65
+ selector = selector.repeat(T, 1, 1, 1, 1).float()
66
+ selector_imag = torch.zeros(selector.size()).to(x.device)
67
+ selector = torch.complex(selector, selector_imag).to(x.device)
68
+
69
+ remain_freq = selector * freq_spec
70
+
71
+ current = torch.fft.irfft(remain_freq)
72
+ current = current.transpose(-1, -2).contiguous()
73
+ current = current.flatten(0, 1)
74
+ current = self.bn2(current)
75
+ current = current.view(T, B, pd, pn, D)
76
+
77
+ spike = self.sfs_lif(current)
78
+ x = spike + res_x
79
+ res_x = x
80
+
81
+ x = x.flatten(0, 1)
82
+ x = self.intra_conv(x)
83
+ x = self.bn3(x)
84
+ x = x.view(T, B, pd, pn, D)
85
+ x = self.intra_lif(x) + res_x
86
+ res_x = x
87
+
88
+ x = x.transpose(0, 3).contiguous()
89
+ x = x.flatten(0, 1)
90
+ x = self.inter_conv(x)
91
+ x = self.bn4(x)
92
+ x = x.view(pn, B, pd, T, D)
93
+ x = x.transpose(0, 3)
94
+ x = self.inter_lif(x)
95
+ x = x + res_x
96
+
97
+ return x
98
+
99
+ class SpikF(nn.Module):
100
+ def __init__(self, args, input_len, patch_num, patch_dim, T, blocks, D, pred_len, tau, alpha, hidden_dim):
101
+ super().__init__()
102
+ self.SPE = SPE(input_len, patch_num, patch_dim, T, tau, D)
103
+ self.args = args
104
+
105
+ self.SFSs = nn.ModuleList()
106
+ for i in range(blocks):
107
+ self.SFSs.append(SFS(patch_num, D, patch_dim, tau, alpha))
108
+
109
+ self.dense1 = nn.Linear(patch_num * patch_dim, hidden_dim)
110
+ self.dense2 = nn.Linear(hidden_dim, pred_len)
111
+
112
+ self.bn = nn.BatchNorm1d(D)
113
+
114
+ self.activ = nn.GELU()
115
+
116
+ def forward(self, x):
117
+
118
+ if self.args.normalize:
119
+
120
+ mean = x.mean(dim=1, keepdim=True).detach()
121
+ x = x - mean
122
+
123
+ std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
124
+ x = x / std
125
+
126
+ x = self.SPE(x)
127
+ T, B, pd, pn, D = x.shape
128
+
129
+ for i in range(len(self.SFSs)):
130
+ x = self.SFSs[i](x)
131
+
132
+ x = x.permute(0, 1, 4, 2, 3).contiguous()
133
+ x = x.flatten(-2, -1)
134
+ x = self.dense1(x)
135
+ x = x.flatten(0, 1)
136
+ x = self.bn(x)
137
+ x = self.activ(x)
138
+ x = self.dense2(x)
139
+ x = x.transpose(-1, -2).contiguous()
140
+ x = x.view(T, B, -1, D)
141
+
142
+ if self.args.normalize:
143
+ x = x * std
144
+ x = x + mean.repeat(T, 1, 1, 1)
145
+
146
+ aux = {
147
+ 'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
148
+ }
149
+
150
+
151
+ return x, aux
model/SpikF_GO.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Tuple
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils import weight_norm
8
+
9
+ from spikingjelly.clock_driven.neuron import MultiStepLIFNode
10
+ from spikingjelly.activation_based import surrogate
11
+
12
+
13
+
14
+
15
+ class Affine(nn.Module):
16
+ def __init__(self, D: int):
17
+ super().__init__()
18
+ self.gamma = nn.Parameter(torch.ones(D))
19
+ self.beta = nn.Parameter(torch.zeros(D))
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ return x * self.gamma + self.beta
23
+
24
+
25
+
26
+ class RMSNorm(nn.Module):
27
+ """
28
+ tok: [B, M, E]
29
+ Normalize over M per sample, per channel plus affine.
30
+ """
31
+ def __init__(self, E: int, eps: float = 1e-6):
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.affine = Affine(E)
35
+
36
+ def forward(self, tok: torch.Tensor) -> torch.Tensor:
37
+ rms = torch.rsqrt(tok.pow(2).mean(dim=1, keepdim=True) + self.eps) # [B,1,E]
38
+ y = tok * rms
39
+ y = self.affine(y)
40
+ return y
41
+
42
+
43
+
44
+ class SFFT(nn.Module):
45
+ """
46
+ S-FFT: implementing FFT on GPU; for theoretical information (spiking FFT),
47
+ refer to the our paper and paper SpikF.
48
+ """
49
+ def __init__(self, M: int):
50
+ super().__init__()
51
+ self.M = M
52
+ self.F = M // 2 + 1
53
+
54
+ def rfft(self, s_t: torch.Tensor) -> torch.Tensor:
55
+ T, B, M, E = s_t.shape
56
+ x = s_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, M) # [T*B*E, M]
57
+ Z = torch.fft.rfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, F] complex
58
+ Z = Z.view(T, B, E, self.F).permute(0, 1, 3, 2).contiguous() # [T,B,F,E]
59
+ return Z
60
+
61
+ def irfft(self, Z_t: torch.Tensor) -> torch.Tensor:
62
+ T, B, Freq, E = Z_t.shape
63
+ x = Z_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, Freq) # [T*B*E, F]
64
+ y = torch.fft.irfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, M]
65
+ y = y.view(T, B, E, self.M).permute(0, 1, 3, 2).contiguous() # [T,B,M,E]
66
+ return y
67
+
68
+
69
+
70
+ class HardConcreteGate(nn.Module):
71
+ """
72
+ Gate over frequency bins.
73
+ Z: [T,B,F,E]
74
+ mask m: [1,1,F,1] in [0,1]
75
+ """
76
+ def __init__(self, F_bins: int, init_logit: float = 2.0, eps: float = 1e-6):
77
+ super().__init__()
78
+ self.log_alpha = nn.Parameter(torch.full((F_bins,), float(init_logit)))
79
+ self.eps = eps
80
+
81
+ def _sample_u(self, shape, device):
82
+ return torch.empty(shape, device=device).uniform_(self.eps, 1.0 - self.eps)
83
+
84
+ def _hard_concrete(self, training: bool, device, tau: float):
85
+ if training:
86
+ u = self._sample_u(self.log_alpha.shape, device)
87
+ s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.log_alpha) / tau)
88
+ else:
89
+ s = torch.sigmoid(self.log_alpha)
90
+ s_bar = s * 1.2 - 0.1
91
+ return s_bar.clamp(0.0, 1.0)
92
+
93
+ def forward(self, Z: torch.Tensor, tau: float) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ m = self._hard_concrete(self.training, Z.device, tau=tau) # [F]
95
+ m = m.view(1, 1, -1, 1).to(Z.real.dtype) # [1,1,F,1]
96
+ return Z * m, m
97
+
98
+ def l0(self) -> torch.Tensor:
99
+ return torch.sigmoid(self.log_alpha).mean()
100
+
101
+
102
+
103
+
104
+ class ComplexAffine(nn.Module):
105
+ def __init__(self, E: int):
106
+ super().__init__()
107
+ self.gamma_r = nn.Parameter(torch.ones(E))
108
+ self.beta_r = nn.Parameter(torch.zeros(E))
109
+ self.gamma_i = nn.Parameter(torch.ones(E))
110
+ self.beta_i = nn.Parameter(torch.zeros(E))
111
+
112
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
113
+ zr = z.real * self.gamma_r + self.beta_r
114
+ zi = z.imag * self.gamma_i + self.beta_i
115
+ return torch.complex(zr, zi)
116
+
117
+
118
+
119
+ class ComplexLinear(nn.Module):
120
+ def __init__(self, E_in: int, E_out: int, init_scale: float = 0.02):
121
+ super().__init__()
122
+ self.Wr = nn.Parameter(init_scale * torch.randn(E_in, E_out))
123
+ self.Wi = nn.Parameter(init_scale * torch.randn(E_in, E_out))
124
+ self.br = nn.Parameter(torch.zeros(E_out))
125
+ self.bi = nn.Parameter(torch.zeros(E_out))
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ xr, xi = x.real, x.imag
129
+ yr = xr @ self.Wr - xi @ self.Wi + self.br
130
+ yi = xi @ self.Wr + xr @ self.Wi + self.bi
131
+ return torch.complex(yr, yi)
132
+
133
+
134
+ class ComplexLIFGate(nn.Module):
135
+ def __init__(self, tau: float, v_th: float):
136
+ super().__init__()
137
+ self.lif_r = MultiStepLIFNode(
138
+ tau=tau, v_threshold=v_th, detach_reset=True,
139
+ surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
140
+ )
141
+ self.lif_i = MultiStepLIFNode(
142
+ tau=tau, v_threshold=v_th, detach_reset=True,
143
+ surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
144
+ )
145
+
146
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
147
+ s_r = self.lif_r(z.real) # [T,B,F,D] in [0,1]
148
+ s_i = self.lif_i(z.imag)
149
+ g = ((s_r > 0) | (s_i > 0)).to(z.real.dtype)
150
+ return g
151
+
152
+
153
+
154
+ class SFGO(nn.Module):
155
+ def __init__(
156
+ self,
157
+ args,
158
+ E: int,
159
+ hidden_size_factor: int,
160
+ tau: float = 2.0,
161
+ v_th: float = 1.0,
162
+ apply_gate_to_complex: bool = True,
163
+ ):
164
+ super().__init__()
165
+ H = int(E * hidden_size_factor)
166
+
167
+ self.args = args
168
+
169
+ self.lin1 = ComplexLinear(E, H)
170
+ self.lin2 = ComplexLinear(H, E)
171
+ self.lin3 = ComplexLinear(E, E)
172
+
173
+ self.g1 = ComplexLIFGate(tau=tau, v_th=v_th)
174
+ self.g2 = ComplexLIFGate(tau=tau, v_th=v_th)
175
+ self.g3 = ComplexLIFGate(tau=tau, v_th=v_th)
176
+
177
+ self.apply_gate_to_complex = apply_gate_to_complex
178
+
179
+ self.r2 = nn.Parameter(torch.tensor(0.1))
180
+ self.r3 = nn.Parameter(torch.tensor(0.1))
181
+
182
+ if self.args.affine:
183
+
184
+ self.a1 = ComplexAffine(E)
185
+ self.a2 = ComplexAffine(H)
186
+ self.a3 = ComplexAffine(E)
187
+
188
+ self.ga1 = ComplexLIFGate(tau=tau, v_th=v_th)
189
+ self.ga2 = ComplexLIFGate(tau=tau, v_th=v_th)
190
+ self.ga3 = ComplexLIFGate(tau=tau, v_th=v_th)
191
+
192
+
193
+ def _apply_gate(self, z: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
194
+ if not self.apply_gate_to_complex:
195
+ return z
196
+ return z * g.to(z.real.dtype)
197
+
198
+ def forward(self, Z: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
199
+ stats: Dict[str, torch.Tensor] = {}
200
+
201
+
202
+ if self.args.affine:
203
+ A1 = self.a1(Z)
204
+ GA1 = self.ga1(A1)
205
+ A1 = self._apply_gate(A1, GA1)
206
+ else:
207
+ A1 = Z
208
+
209
+ Y = self.lin1(A1)
210
+ G1 = self.g1(Y)
211
+ Y = self._apply_gate(Y, G1)
212
+
213
+ if self.args.affine:
214
+ A2 = self.a2(Y)
215
+ GA2 = self.ga2(A2)
216
+ A2 = self._apply_gate(A2, GA2)
217
+ else:
218
+ A2 = Y
219
+
220
+ X = self.lin2(A2)
221
+ G2 = self.g2(X)
222
+ X = self._apply_gate(X, G2)
223
+
224
+ Z2 = Z + self.r2 * X
225
+
226
+ if self.args.affine:
227
+ A3 = self.a3(Z2)
228
+ GA3 = self.ga3(A3)
229
+ A3 = self._apply_gate(A3, GA3)
230
+ else:
231
+ A3 = Z2
232
+
233
+
234
+ W = self.lin3(A3)
235
+ G3 = self.g3(W)
236
+ W = self._apply_gate(W, G3)
237
+
238
+ out = Z2 + self.r3 * W
239
+
240
+ with torch.no_grad():
241
+ mag2 = out.real * out.real + out.imag * out.imag
242
+ stats["freq_active_frac"] = (mag2 > 0).float().mean()
243
+
244
+ stats["rezero_r2"] = self.r2.detach()
245
+ stats["rezero_r3"] = self.r3.detach()
246
+
247
+ stats["gate_lin_frac_1"] = G1.mean().detach()
248
+ stats["gate_lin_frac_2"] = G2.mean().detach()
249
+ stats["gate_lin_frac_3"] = G3.mean().detach()
250
+
251
+ return out, stats
252
+
253
+
254
+
255
+ class Decoder(nn.Module):
256
+ def __init__(
257
+ self,
258
+ E: int,
259
+ L: int,
260
+ pred_len: int,
261
+ T: int,
262
+ tau: float,
263
+ v_th: float,
264
+ proj_dim: int = 4,
265
+ reduced_dim: int = 64,
266
+ ):
267
+ super().__init__()
268
+ self.E, self.L, self.P, self.T = E, L, pred_len, T
269
+ self.proj_dim = int(proj_dim)
270
+
271
+ self.time_proj = nn.Linear(L, self.proj_dim, bias=False)
272
+ D_in = E * self.proj_dim
273
+ self.reduced_dim = int(reduced_dim)
274
+
275
+ self.lif = MultiStepLIFNode(
276
+ tau=tau,
277
+ v_threshold=v_th,
278
+ detach_reset=True,
279
+ surrogate_function=surrogate.ATan(alpha=4.0),
280
+ backend="torch",
281
+ )
282
+
283
+ self.fc_reduce = weight_norm(nn.Linear(D_in, int(reduced_dim), bias=True))
284
+ self.fc_out = weight_norm(nn.Linear(int(reduced_dim), pred_len, bias=True))
285
+
286
+ nn.init.xavier_uniform_(self.time_proj.weight, gain=0.5)
287
+ nn.init.xavier_uniform_(self.fc_reduce.weight, gain=0.6)
288
+ nn.init.xavier_uniform_(self.fc_out.weight, gain=0.2)
289
+ nn.init.zeros_(self.fc_reduce.bias)
290
+ nn.init.zeros_(self.fc_out.bias)
291
+
292
+ def forward(self, y_t: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
293
+ T, B, N, E, L = y_t.shape
294
+
295
+ y_p = self.time_proj(y_t) # [T,B,N,E,p]
296
+ x = y_p.reshape(T, B * N, E * self.proj_dim) # [T,B*N,D]
297
+ s = self.lif(x) # [T,B*N,D] spikes
298
+ h_t = self.fc_reduce(s.reshape(T * B * N, -1)).view(T, B * N, self.reduced_dim)
299
+
300
+ h = h_t.mean(dim=0) # [B*N,reduced_dim]
301
+ h = F.gelu(h)
302
+ out = self.fc_out(h) # [B*N,O]
303
+
304
+ preds = out.view(B, N, self.P).permute(0, 2, 1).contiguous()
305
+ stats = {"dec_spike_rate": s.mean().detach()}
306
+ return preds, stats
307
+
308
+
309
+
310
+ class SpikF_GO(nn.Module):
311
+ def __init__(
312
+ self,
313
+ args,
314
+ pre_length: int,
315
+ embed_size: int,
316
+ feature_size: int,
317
+ seq_length: int,
318
+ hidden_size: int,
319
+ hard_thresholding_fraction=1,
320
+ hidden_size_factor: int = 1,
321
+ sparsity_threshold: float = 0.01,
322
+ ):
323
+ super().__init__()
324
+ self.args = args
325
+
326
+ self.N = feature_size
327
+ self.L = seq_length
328
+ self.E = embed_size
329
+ self.T = args.T
330
+ self.M = self.N * self.L
331
+
332
+
333
+ self.embeddings = nn.Parameter(torch.randn(1, self.E) * 0.02)
334
+ self.node_aff = Affine(self.E)
335
+ self.node_rms = RMSNorm(E=self.E, eps=1e-6)
336
+
337
+ # step modulation
338
+ self.step_gamma = nn.Parameter(torch.ones(self.T))
339
+ self.step_beta = nn.Parameter(torch.zeros(self.T))
340
+ self.register_buffer("step_scale", torch.linspace(0, 1, steps=self.T).view(self.T, 1, 1, 1))
341
+
342
+ # Encoder LIF
343
+ self.encoder_lif = MultiStepLIFNode(
344
+ tau=args.tau,
345
+ v_threshold=args.alpha,
346
+ detach_reset=True,
347
+ surrogate_function=surrogate.ATan(alpha=4.0),
348
+ backend="torch",
349
+ )
350
+
351
+ self.sfft = SFFT(self.M)
352
+ self.F_bins = self.sfft.F
353
+
354
+ # frequency gate
355
+ self.freq_gate = HardConcreteGate(self.F_bins, init_logit=2.0)
356
+ self.register_buffer("gate_tau", torch.tensor(0.10))
357
+
358
+ self.sfgo = SFGO(
359
+ self.args,
360
+ E=self.E,
361
+ hidden_size_factor=hidden_size_factor,
362
+ tau=args.tau,
363
+ v_th=args.alpha,
364
+ apply_gate_to_complex=True,
365
+ )
366
+
367
+ # decoder
368
+ proj_dim = self.args.proj_dim
369
+ reduced_dim = max(16, min(128, hidden_size // 4))
370
+ self.decoder = Decoder(
371
+ E=self.E,
372
+ L=self.L,
373
+ pred_len=pre_length,
374
+ T=self.T,
375
+ tau=args.tau,
376
+ v_th=args.alpha,
377
+ proj_dim=proj_dim,
378
+ reduced_dim=reduced_dim,
379
+ )
380
+
381
+ def node_embed(self, x: torch.Tensor) -> torch.Tensor:
382
+ # x: [B,L,N] -> [B,M,E]
383
+ B, L, N = x.shape
384
+ x_flat = x.permute(0, 2, 1).contiguous().reshape(B, self.M) # [B,M]
385
+ tok = x_flat.unsqueeze(-1) * self.embeddings # [B,M,E]
386
+ tok = self.node_aff(tok)
387
+ return tok
388
+
389
+
390
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
391
+ B, L, N = x.shape
392
+
393
+ # normalize
394
+ if self.args.normalize:
395
+ mean = x.mean(dim=1, keepdim=True).detach()
396
+ x0 = x - mean
397
+ std = torch.sqrt(torch.var(x0, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
398
+ x0 = x0 / std
399
+ else:
400
+ mean, std = None, None
401
+ x0 = x
402
+
403
+
404
+ tok = self.node_embed(x0) # [B,M,E]
405
+ tok = self.node_rms(tok) # RMSNorm
406
+
407
+
408
+ # step modulation
409
+ cur_t = tok.unsqueeze(0).repeat(self.T, 1, 1, 1)
410
+ cur_t = cur_t * self.step_gamma.view(self.T, 1, 1, 1) + self.step_beta.view(self.T, 1, 1, 1)
411
+ cur_t = cur_t * (1.0 + 0.02 * self.step_scale.to(cur_t.dtype))
412
+
413
+
414
+ # spikes
415
+ s_t = self.encoder_lif(cur_t)
416
+ enc_rate = s_t.mean()
417
+
418
+ # FFT
419
+ Z_t = self.sfft.rfft(s_t)
420
+
421
+ # prune
422
+ Z_t, m = self.freq_gate(Z_t, tau=float(self.gate_tau))
423
+
424
+ # S-FGO blocks
425
+ Z_t, fb_stats = self.sfgo(Z_t)
426
+
427
+ # iFFT
428
+ y_time_t = self.sfft.irfft(Z_t).to(tok.dtype)
429
+
430
+ y_t = y_time_t.view(self.T, B, N, self.L, self.E).permute(0, 1, 2, 4, 3).contiguous()
431
+
432
+ preds, dec_stats = self.decoder(y_t)
433
+
434
+ if self.args.normalize:
435
+ preds = preds * std + mean # denormalize
436
+
437
+ aux = {
438
+ "enc_rate": enc_rate.detach(),
439
+ "rho_hat": self.freq_gate.l0().detach(),
440
+ "freq_mask_mean": m.mean().detach(),
441
+ "freq_mask_active": (m > 0.5).float().mean().detach(),
442
+ **fb_stats,
443
+ **dec_stats,
444
+ }
445
+ return preds, aux
model/SpikF_GO_CPG.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Dict, Optional
4
+ import torch
5
+ import math
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from spikingjelly.clock_driven.neuron import MultiStepLIFNode
11
+ from spikingjelly.activation_based import surrogate
12
+
13
+
14
+
15
+
16
+ class Affine(nn.Module):
17
+ def __init__(self, D: int):
18
+ super().__init__()
19
+ self.gamma = nn.Parameter(torch.ones(D))
20
+ self.beta = nn.Parameter(torch.zeros(D))
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return x * self.gamma + self.beta
24
+
25
+
26
+
27
+ class RMSNorm(nn.Module):
28
+ """
29
+ tok: [B, M, E]
30
+ Normalize over M per sample, per channel plus affine.
31
+ """
32
+ def __init__(self, E: int, eps: float = 1e-6):
33
+ super().__init__()
34
+ self.eps = eps
35
+ self.affine = Affine(E)
36
+
37
+ def forward(self, tok: torch.Tensor) -> torch.Tensor:
38
+ rms = torch.rsqrt(tok.pow(2).mean(dim=1, keepdim=True) + self.eps) # [B,1,E]
39
+ y = tok * rms
40
+ y = self.affine(y)
41
+ return y
42
+
43
+
44
+
45
+
46
+ class CPGSpikePE(nn.Module):
47
+ """
48
+ Spike-form positional encoding (CPG-PE).
49
+ Generates 2*N_pe binary channels with log-spaced rhythms over the flattened index t in [0, T*M).
50
+ Shapes:
51
+ returns pe: [T, B, M, 2*N_pe] with 0/1 spikes (no learnable params).
52
+ """
53
+ def __init__(self,
54
+ num_pairs: int = 20,
55
+ tau: float = 10000.0,
56
+ eta: float = 1.0,
57
+ vthres: float = 0.8,
58
+ w_max: float = 10000.0):
59
+ super().__init__()
60
+ self.num_pairs = num_pairs
61
+ self.tau = tau
62
+ self.eta = eta
63
+ self.vthres = vthres
64
+ self.w_max = w_max
65
+
66
+ def forward(self, T: int, B: int, M: int, device) -> torch.Tensor:
67
+ t = torch.arange(T * M, device=device, dtype=torch.float32) # [T*M]
68
+ i = torch.arange(self.num_pairs, device=device, dtype=torch.float32)
69
+ freq = torch.exp(-torch.log(torch.tensor(self.w_max, device=device)) * (i / max(1, self.num_pairs))) # [N_pe]
70
+
71
+ arg = self.eta * (t[:, None] * freq[None, :] / self.tau) # [T*M, N_pe]
72
+ cos_spk = (torch.cos(arg) - self.vthres > 0).float()
73
+ sin_spk = (torch.sin(arg) - self.vthres > 0).float()
74
+
75
+ pe = torch.cat([cos_spk, sin_spk], dim=1) # [T*M, 2*N_pe]
76
+ pe = pe.view(T, M, 2 * self.num_pairs).unsqueeze(1) # [T, 1, M, 2*N_pe]
77
+ pe = pe.expand(-1, B, -1, -1).contiguous() # [T, B, M, 2*N_pe]
78
+ return pe
79
+
80
+
81
+
82
+
83
+ class SFFT(nn.Module):
84
+ """
85
+ S-FFT: implementing FFT on GPU; for theoretical information (spiking FFT),
86
+ refer to the our paper and paper SpikF.
87
+ """
88
+ def __init__(self, M: int):
89
+ super().__init__()
90
+ self.M = M
91
+ self.F = M // 2 + 1
92
+
93
+ def rfft(self, s_t: torch.Tensor) -> torch.Tensor:
94
+ T, B, M, E = s_t.shape
95
+ x = s_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, M) # [T*B*E, M]
96
+ Z = torch.fft.rfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, F] complex
97
+ Z = Z.view(T, B, E, self.F).permute(0, 1, 3, 2).contiguous() # [T,B,F,E]
98
+ return Z
99
+
100
+ def irfft(self, Z_t: torch.Tensor) -> torch.Tensor:
101
+ T, B, Freq, E = Z_t.shape
102
+ x = Z_t.permute(0, 1, 3, 2).contiguous().view(T * B * E, Freq) # [T*B*E, F]
103
+ y = torch.fft.irfft(x, n=self.M, dim=-1, norm="ortho") # [T*B*E, M]
104
+ y = y.view(T, B, E, self.M).permute(0, 1, 3, 2).contiguous() # [T,B,M,E]
105
+ return y
106
+
107
+
108
+
109
+ class HardConcreteGate(nn.Module):
110
+ """
111
+ Gate over frequency bins.
112
+ Z: [T,B,F,E]
113
+ mask m: [1,1,F,1] in [0,1]
114
+ """
115
+ def __init__(self, F_bins: int, init_logit: float = 2.0, eps: float = 1e-6):
116
+ super().__init__()
117
+ self.log_alpha = nn.Parameter(torch.full((F_bins,), float(init_logit)))
118
+ self.eps = eps
119
+
120
+ def _sample_u(self, shape, device):
121
+ return torch.empty(shape, device=device).uniform_(self.eps, 1.0 - self.eps)
122
+
123
+ def _hard_concrete(self, training: bool, device, tau: float):
124
+ if training:
125
+ u = self._sample_u(self.log_alpha.shape, device)
126
+ s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.log_alpha) / tau)
127
+ else:
128
+ s = torch.sigmoid(self.log_alpha)
129
+ s_bar = s * 1.2 - 0.1
130
+ return s_bar.clamp(0.0, 1.0)
131
+
132
+ def forward(self, Z: torch.Tensor, tau: float) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ m = self._hard_concrete(self.training, Z.device, tau=tau) # [F]
134
+ m = m.view(1, 1, -1, 1).to(Z.real.dtype) # [1,1,F,1]
135
+ return Z * m, m
136
+
137
+ def l0(self) -> torch.Tensor:
138
+ return torch.sigmoid(self.log_alpha).mean()
139
+
140
+
141
+
142
+
143
+ class ComplexAffine(nn.Module):
144
+ def __init__(self, E: int):
145
+ super().__init__()
146
+ self.gamma_r = nn.Parameter(torch.ones(E))
147
+ self.beta_r = nn.Parameter(torch.zeros(E))
148
+ self.gamma_i = nn.Parameter(torch.ones(E))
149
+ self.beta_i = nn.Parameter(torch.zeros(E))
150
+
151
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
152
+ zr = z.real * self.gamma_r + self.beta_r
153
+ zi = z.imag * self.gamma_i + self.beta_i
154
+ return torch.complex(zr, zi)
155
+
156
+
157
+
158
+ class ComplexLinear(nn.Module):
159
+ def __init__(self, E_in: int, E_out: int, init_scale: float = 0.02):
160
+ super().__init__()
161
+ self.Wr = nn.Parameter(init_scale * torch.randn(E_in, E_out))
162
+ self.Wi = nn.Parameter(init_scale * torch.randn(E_in, E_out))
163
+ self.br = nn.Parameter(torch.zeros(E_out))
164
+ self.bi = nn.Parameter(torch.zeros(E_out))
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ xr, xi = x.real, x.imag
168
+ yr = xr @ self.Wr - xi @ self.Wi + self.br
169
+ yi = xi @ self.Wr + xr @ self.Wi + self.bi
170
+ return torch.complex(yr, yi)
171
+
172
+
173
+ class ComplexLIFGate(nn.Module):
174
+ def __init__(self, tau: float, v_th: float):
175
+ super().__init__()
176
+ self.lif_r = MultiStepLIFNode(
177
+ tau=tau, v_threshold=v_th, detach_reset=True,
178
+ surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
179
+ )
180
+ self.lif_i = MultiStepLIFNode(
181
+ tau=tau, v_threshold=v_th, detach_reset=True,
182
+ surrogate_function=surrogate.ATan(alpha=4.0), backend="torch"
183
+ )
184
+
185
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
186
+ s_r = self.lif_r(z.real) # [T,B,F,D] in [0,1]
187
+ s_i = self.lif_i(z.imag)
188
+ g = ((s_r > 0) | (s_i > 0)).to(z.real.dtype)
189
+ return g
190
+
191
+
192
+
193
+ class SFGO(nn.Module):
194
+ def __init__(
195
+ self,
196
+ args,
197
+ E: int,
198
+ hidden_size_factor: int,
199
+ tau: float = 2.0,
200
+ v_th: float = 1.0,
201
+ apply_gate_to_complex: bool = True,
202
+ ):
203
+ super().__init__()
204
+ H = int(E * hidden_size_factor)
205
+
206
+ self.args = args
207
+
208
+ self.lin1 = ComplexLinear(E, H)
209
+ self.lin2 = ComplexLinear(H, E)
210
+ self.lin3 = ComplexLinear(E, E)
211
+
212
+ self.g1 = ComplexLIFGate(tau=tau, v_th=v_th)
213
+ self.g2 = ComplexLIFGate(tau=tau, v_th=v_th)
214
+ self.g3 = ComplexLIFGate(tau=tau, v_th=v_th)
215
+
216
+ self.apply_gate_to_complex = apply_gate_to_complex
217
+
218
+ self.r2 = nn.Parameter(torch.tensor(0.1))
219
+ self.r3 = nn.Parameter(torch.tensor(0.1))
220
+
221
+ if self.args.affine:
222
+
223
+ self.a1 = ComplexAffine(E)
224
+ self.a2 = ComplexAffine(H)
225
+ self.a3 = ComplexAffine(E)
226
+
227
+ self.ga1 = ComplexLIFGate(tau=tau, v_th=v_th)
228
+ self.ga2 = ComplexLIFGate(tau=tau, v_th=v_th)
229
+ self.ga3 = ComplexLIFGate(tau=tau, v_th=v_th)
230
+
231
+
232
+ def _apply_gate(self, z: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
233
+ if not self.apply_gate_to_complex:
234
+ return z
235
+ return z * g.to(z.real.dtype)
236
+
237
+ def forward(self, Z: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
238
+ stats: Dict[str, torch.Tensor] = {}
239
+
240
+
241
+ if self.args.affine:
242
+ A1 = self.a1(Z)
243
+ GA1 = self.ga1(A1)
244
+ A1 = self._apply_gate(A1, GA1)
245
+ else:
246
+ A1 = Z
247
+
248
+ Y = self.lin1(A1)
249
+ G1 = self.g1(Y)
250
+ Y = self._apply_gate(Y, G1)
251
+
252
+ if self.args.affine:
253
+ A2 = self.a2(Y)
254
+ GA2 = self.ga2(A2)
255
+ A2 = self._apply_gate(A2, GA2)
256
+ else:
257
+ A2 = Y
258
+
259
+ X = self.lin2(A2)
260
+ G2 = self.g2(X)
261
+ X = self._apply_gate(X, G2)
262
+
263
+ Z2 = Z + self.r2 * X
264
+
265
+ if self.args.affine:
266
+ A3 = self.a3(Z2)
267
+ GA3 = self.ga3(A3)
268
+ A3 = self._apply_gate(A3, GA3)
269
+ else:
270
+ A3 = Z2
271
+
272
+
273
+ W = self.lin3(A3)
274
+ G3 = self.g3(W)
275
+ W = self._apply_gate(W, G3)
276
+
277
+ out = Z2 + self.r3 * W
278
+
279
+ with torch.no_grad():
280
+ mag2 = out.real * out.real + out.imag * out.imag
281
+ stats["freq_active_frac"] = (mag2 > 0).float().mean()
282
+
283
+ stats["rezero_r2"] = self.r2.detach()
284
+ stats["rezero_r3"] = self.r3.detach()
285
+
286
+ stats["gate_lin_frac_1"] = G1.mean().detach()
287
+ stats["gate_lin_frac_2"] = G2.mean().detach()
288
+ stats["gate_lin_frac_3"] = G3.mean().detach()
289
+
290
+ return out, stats
291
+
292
+
293
+
294
+ class Decoder(nn.Module):
295
+ def __init__(
296
+ self,
297
+ E: int,
298
+ L: int,
299
+ pred_len: int,
300
+ T: int,
301
+ tau: float,
302
+ v_th: float,
303
+ proj_dim: int = 4,
304
+ reduced_dim: int = 64,
305
+ ):
306
+ super().__init__()
307
+ self.E, self.L, self.P, self.T = E, L, pred_len, T
308
+ self.proj_dim = int(proj_dim)
309
+
310
+ self.time_proj = nn.Linear(L, self.proj_dim, bias=False)
311
+ D_in = E * self.proj_dim
312
+ self.reduced_dim = int(reduced_dim)
313
+
314
+ self.lif = MultiStepLIFNode(
315
+ tau=tau,
316
+ v_threshold=v_th,
317
+ detach_reset=True,
318
+ surrogate_function=surrogate.ATan(alpha=4.0),
319
+ backend="torch",
320
+ )
321
+
322
+ self.fc_reduce = weight_norm(nn.Linear(D_in, int(reduced_dim), bias=True))
323
+ self.fc_out = weight_norm(nn.Linear(int(reduced_dim), pred_len, bias=True))
324
+
325
+ nn.init.xavier_uniform_(self.time_proj.weight, gain=0.5)
326
+ nn.init.xavier_uniform_(self.fc_reduce.weight, gain=0.6)
327
+ nn.init.xavier_uniform_(self.fc_out.weight, gain=0.2)
328
+ nn.init.zeros_(self.fc_reduce.bias)
329
+ nn.init.zeros_(self.fc_out.bias)
330
+
331
+ def forward(self, y_t: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
332
+ T, B, N, E, L = y_t.shape
333
+
334
+ y_p = self.time_proj(y_t) # [T,B,N,E,p]
335
+ x = y_p.reshape(T, B * N, E * self.proj_dim) # [T,B*N,D]
336
+ s = self.lif(x) # [T,B*N,D] spikes
337
+ h_t = self.fc_reduce(s.reshape(T * B * N, -1)).view(T, B * N, self.reduced_dim)
338
+
339
+ h = h_t.mean(dim=0) # [B*N,reduced_dim]
340
+ h = F.gelu(h)
341
+ out = self.fc_out(h) # [B*N,O]
342
+
343
+ preds = out.view(B, N, self.P).permute(0, 2, 1).contiguous()
344
+ stats = {"dec_spike_rate": s.mean().detach()}
345
+ return preds, stats
346
+
347
+
348
+
349
+ class SpikF_GO_CPG(nn.Module):
350
+ def __init__(
351
+ self,
352
+ args,
353
+ pre_length: int,
354
+ embed_size: int,
355
+ feature_size: int,
356
+ seq_length: int,
357
+ hidden_size: int,
358
+ hard_thresholding_fraction=1,
359
+ hidden_size_factor: int = 1,
360
+ sparsity_threshold: float = 0.01,
361
+ ):
362
+ super().__init__()
363
+ self.args = args
364
+
365
+ self.N = feature_size
366
+ self.L = seq_length
367
+ self.E = embed_size
368
+ self.T = args.T
369
+ self.M = self.N * self.L
370
+
371
+ self.use_cpg_pe = True
372
+ self.num_pe_pairs = 20
373
+ self.pe_tau = 10000.0
374
+ self.pe_eta = 1.0
375
+ self.pe_vthres = 0.8
376
+ self.pe_wmax = 10000.0
377
+
378
+
379
+ if self.use_cpg_pe:
380
+ self.cpg_pe = CPGSpikePE(
381
+ num_pairs=self.num_pe_pairs,
382
+ tau=self.pe_tau, eta=self.pe_eta,
383
+ vthres=self.pe_vthres, w_max=self.pe_wmax
384
+ )
385
+ self.pe_linear = nn.Linear(self.E + 2 * self.num_pe_pairs, self.E, bias=False)
386
+ self.pe_bn = nn.BatchNorm1d(self.E)
387
+ self.pe_lif = MultiStepLIFNode(
388
+ tau=self.args.tau, v_threshold=self.args.alpha, detach_reset=True,
389
+ surrogate_function=surrogate.ATan(alpha=4.0), backend='torch'
390
+ )
391
+
392
+
393
+
394
+ self.embeddings = nn.Parameter(torch.randn(1, self.E) * 0.02)
395
+ self.node_aff = Affine(self.E)
396
+ self.node_rms = RMSNorm(E=self.E, eps=1e-6)
397
+
398
+ # step modulation
399
+ self.step_gamma = nn.Parameter(torch.ones(self.T))
400
+ self.step_beta = nn.Parameter(torch.zeros(self.T))
401
+ self.register_buffer("step_scale", torch.linspace(0, 1, steps=self.T).view(self.T, 1, 1, 1))
402
+
403
+ # Encoder LIF
404
+ self.encoder_lif = MultiStepLIFNode(
405
+ tau=args.tau,
406
+ v_threshold=args.alpha,
407
+ detach_reset=True,
408
+ surrogate_function=surrogate.ATan(alpha=4.0),
409
+ backend="torch",
410
+ )
411
+
412
+ self.sfft = SFFT(self.M)
413
+ self.F_bins = self.sfft.F
414
+
415
+ # frequency gate
416
+ self.freq_gate = HardConcreteGate(self.F_bins, init_logit=2.0)
417
+ self.register_buffer("gate_tau", torch.tensor(0.10))
418
+
419
+ self.sfgo = SFGO(
420
+ self.args,
421
+ E=self.E,
422
+ hidden_size_factor=hidden_size_factor,
423
+ tau=args.tau,
424
+ v_th=args.alpha,
425
+ apply_gate_to_complex=True,
426
+ )
427
+
428
+ # decoder
429
+ proj_dim = self.args.proj_dim
430
+ reduced_dim = max(16, min(128, hidden_size // 4))
431
+ self.decoder = Decoder(
432
+ E=self.E,
433
+ L=self.L,
434
+ pred_len=pre_length,
435
+ T=self.T,
436
+ tau=args.tau,
437
+ v_th=args.alpha,
438
+ proj_dim=proj_dim,
439
+ reduced_dim=reduced_dim,
440
+ )
441
+
442
+ def node_embed(self, x: torch.Tensor) -> torch.Tensor:
443
+ # x: [B,L,N] -> [B,M,E]
444
+ B, L, N = x.shape
445
+ x_flat = x.permute(0, 2, 1).contiguous().reshape(B, self.M) # [B,M]
446
+ tok = x_flat.unsqueeze(-1) * self.embeddings # [B,M,E]
447
+ tok = self.node_aff(tok)
448
+ return tok
449
+
450
+
451
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
452
+ B, L, N = x.shape
453
+
454
+ # normalize
455
+ if self.args.normalize:
456
+ mean = x.mean(dim=1, keepdim=True).detach()
457
+ x0 = x - mean
458
+ std = torch.sqrt(torch.var(x0, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
459
+ x0 = x0 / std
460
+ else:
461
+ mean, std = None, None
462
+ x0 = x
463
+
464
+
465
+ tok = self.node_embed(x0) # [B,M,E]
466
+ tok = self.node_rms(tok) # RMSNorm
467
+
468
+
469
+ # step modulation
470
+ cur_t = tok.unsqueeze(0).repeat(self.T, 1, 1, 1)
471
+ cur_t = cur_t * self.step_gamma.view(self.T, 1, 1, 1) + self.step_beta.view(self.T, 1, 1, 1)
472
+ cur_t = cur_t * (1.0 + 0.02 * self.step_scale.to(cur_t.dtype))
473
+
474
+
475
+ # spikes
476
+ s_t = self.encoder_lif(cur_t)
477
+ if self.use_cpg_pe:
478
+ pe_spk = self.cpg_pe(T=self.T, B=B, M=self.M, device=x.device) # [T,B,M,2*N_pe]
479
+ s_cat = torch.cat([s_t, pe_spk], dim=-1) # [T,B,M,E+2*N_pe]
480
+ h = self.pe_linear(s_cat) # [T,B,M,E]
481
+ h = h.reshape(self.T * B * self.M, self.E)
482
+ h = self.pe_bn(h).view(self.T, B, self.M, self.E)
483
+ s_t = self.pe_lif(h)
484
+
485
+ enc_rate = s_t.mean()
486
+
487
+ # FFT
488
+ Z_t = self.sfft.rfft(s_t)
489
+
490
+ # prune
491
+ Z_t, m = self.freq_gate(Z_t, tau=float(self.gate_tau))
492
+
493
+ # S-FGO blocks
494
+ Z_t, fb_stats = self.sfgo(Z_t)
495
+
496
+ # iFFT
497
+ y_time_t = self.sfft.irfft(Z_t).to(tok.dtype)
498
+
499
+ y_t = y_time_t.view(self.T, B, N, self.L, self.E).permute(0, 1, 2, 4, 3).contiguous()
500
+
501
+ preds, dec_stats = self.decoder(y_t)
502
+
503
+ if self.args.normalize:
504
+ preds = preds * std + mean # denormalize
505
+
506
+ aux = {
507
+ "enc_rate": enc_rate.detach(),
508
+ "rho_hat": self.freq_gate.l0().detach(),
509
+ "freq_mask_mean": m.mean().detach(),
510
+ "freq_mask_active": (m > 0.5).float().mean().detach(),
511
+ **fb_stats,
512
+ **dec_stats,
513
+ }
514
+ return preds, aux
model/SpikeGRU.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from pathlib import Path
3
+
4
+ from spikingjelly.activation_based import surrogate as sj_surrogate
5
+ from snntorch import utils
6
+ import snntorch as snn
7
+ from snntorch import surrogate
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+
13
+ class GRUCell(nn.Module):
14
+ def __init__(
15
+ self,
16
+ input_size: int,
17
+ hidden_size: int,
18
+ num_steps: int = 4,
19
+ grad_slope: float = 25.0,
20
+ beta: float = 0.99,
21
+ output_mems: bool = False,
22
+ ):
23
+ super().__init__()
24
+ self.spike_grad = surrogate.atan(alpha=2.0)
25
+ self.input_size = input_size
26
+ self.num_steps = num_steps
27
+ self.hidden_size = hidden_size
28
+ self.beta = beta
29
+ self.full_rec = output_mems
30
+ self.lif = snn.Leaky(
31
+ beta=self.beta,
32
+ spike_grad=self.spike_grad,
33
+ init_hidden=True,
34
+ output=output_mems,
35
+ )
36
+ self.linear_ih = nn.Linear(input_size, 3 * hidden_size)
37
+ self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size)
38
+ self.surrogate_function1 = sj_surrogate.ATan()
39
+
40
+ def forward(self, inputs):
41
+ if inputs.size(-1) == self.input_size:
42
+ # assume static spikes:
43
+ h = torch.zeros(
44
+ size=[inputs.shape[0], self.hidden_size],
45
+ dtype=torch.float,
46
+ device=inputs.device,
47
+ )
48
+ y_ih = torch.split(self.linear_ih(inputs), self.hidden_size, dim=1)
49
+ y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1)
50
+ r = self.surrogate_function1(y_ih[0] + y_hh[0])
51
+ z = self.surrogate_function1(y_ih[1] + y_hh[1])
52
+ n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
53
+ h = (1.0 - z) * n + z * h
54
+ cur = h
55
+ static = True
56
+ elif inputs.size(-1) == self.num_steps and inputs.size(-2) == self.input_size:
57
+ inputs = inputs.transpose(-1, -2) # BC, T, H
58
+ h = torch.zeros(
59
+ size=[inputs.shape[0], self.hidden_size, self.num_steps],
60
+ dtype=torch.float,
61
+ device=inputs.device,
62
+ )
63
+ y_ih = torch.split(
64
+ self.linear_ih(inputs).transpose(-1, -2), self.hidden_size, dim=1
65
+ )
66
+ y_hh = torch.split(
67
+ self.linear_hh(h.transpose(-1, -2)).transpose(-1, -2),
68
+ self.hidden_size,
69
+ dim=1,
70
+ )
71
+ r = self.surrogate_function1(y_ih[0] + y_hh[0])
72
+ z = self.surrogate_function1(y_ih[1] + y_hh[1])
73
+ n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
74
+ h = (1.0 - z) * n + z * h
75
+ cur = h
76
+ static = False
77
+ else:
78
+ raise ValueError(
79
+ f"Input size mismatch!"
80
+ f"Got {inputs.size()} but expected (..., {self.input_size}, {self.num_steps}) or (..., {self.input_size})"
81
+ )
82
+
83
+ spk_rec = []
84
+ mem_rec = []
85
+ if self.full_rec:
86
+ for i_step in range(self.num_steps):
87
+ if static:
88
+ spk, mem = self.lif(cur)
89
+ else:
90
+ spk, mem = self.lif(cur[:, :, i_step])
91
+ spk_rec.append(spk)
92
+ mem_rec.append(mem)
93
+ spks = torch.stack(spk_rec, dim=-1)
94
+ mems = torch.stack(mem_rec, dim=-1)
95
+ return spks, mems
96
+ else:
97
+ for i_step in range(self.num_steps):
98
+ if static:
99
+ spk = self.lif(cur)
100
+ else:
101
+ spk = self.lif(cur[:, :, i_step])
102
+ spk_rec.append(spk)
103
+ spks = torch.stack(spk_rec, dim=-1)
104
+ return spks
105
+
106
+
107
+ class DeltaEncoder(nn.Module):
108
+ def __init__(self, output_size: int):
109
+ super().__init__()
110
+ self.norm = nn.BatchNorm2d(1)
111
+ self.enc = nn.Linear(1, output_size)
112
+ self.lif = snn.Leaky(
113
+ beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
114
+ )
115
+
116
+ def forward(self, inputs: torch.Tensor):
117
+ # inputs: batch, L, C
118
+ delta = torch.zeros_like(inputs)
119
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
120
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
121
+ delta = self.norm(delta)
122
+ delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
123
+ enc = self.enc(delta) # batch, C, L, output_size
124
+ enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
125
+ spks = self.lif(enc)
126
+ return spks
127
+
128
+
129
+ class ConvEncoder(nn.Module):
130
+ def __init__(self, output_size: int, kernel_size: int = 3):
131
+ super().__init__()
132
+ self.encoder = nn.Sequential(
133
+ nn.Conv2d(
134
+ in_channels=1,
135
+ out_channels=output_size,
136
+ kernel_size=(1, kernel_size),
137
+ stride=1,
138
+ padding=(0, kernel_size // 2),
139
+ ),
140
+ nn.BatchNorm2d(output_size),
141
+ )
142
+ self.lif = snn.Leaky(
143
+ beta=0.99,
144
+ spike_grad=surrogate.atan(alpha=2.0),
145
+ init_hidden=True,
146
+ output=False,
147
+ )
148
+
149
+ def forward(self, inputs: torch.Tensor):
150
+ # inputs: batch, L, C
151
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
152
+ enc = self.encoder(inputs) # batch, output_size, C, L
153
+ spks = self.lif(enc)
154
+ return spks
155
+
156
+
157
+
158
+ class SpikeGRU(nn.Module):
159
+ def __init__(
160
+ self,
161
+ args,
162
+ hidden_size: int,
163
+ layers: int = 1,
164
+ num_steps: int = 50,
165
+ grad_slope: float = 25.0,
166
+ input_size: Optional[int] = None,
167
+ max_length: Optional[int] = None,
168
+ weight_file: Optional[Path] = None,
169
+ encoder_type: Optional[str] = "conv",
170
+ ):
171
+ super().__init__()
172
+ self.args = args
173
+ self.hidden_size = args.hidden_size
174
+ self.num_steps = args.T
175
+ self.input_size = args.feature_size
176
+ self.pre_length = args.pre_length
177
+ self.layers = args.blocks
178
+
179
+
180
+ if encoder_type == "conv":
181
+ self.encoder = ConvEncoder(self.hidden_size)
182
+ elif encoder_type == "delta":
183
+ self.encoder = DeltaEncoder(self.hidden_size)
184
+ else:
185
+ raise ValueError(f"Unknown encoder type {encoder_type}")
186
+
187
+ self.net = nn.Sequential(
188
+ *[
189
+ GRUCell(
190
+ self.hidden_size,
191
+ self.hidden_size,
192
+ num_steps=self.num_steps,
193
+ grad_slope=grad_slope,
194
+ output_mems=(i == self.layers - 1),
195
+ )
196
+ for i in range(self.layers)
197
+ ]
198
+ )
199
+
200
+ self.__output_size = self.hidden_size
201
+ self.fc = nn.Linear(self.__output_size, self.pre_length)
202
+
203
+ self.to('cuda:0')
204
+
205
+ def forward(
206
+ self,
207
+ inputs: torch.Tensor,
208
+ ):
209
+ utils.reset(self.encoder)
210
+ for layer in self.net:
211
+ utils.reset(layer)
212
+
213
+
214
+ bs, length, c_num = inputs.size()
215
+
216
+ if self.args.normalize:
217
+ mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
218
+ inputs = inputs - mean
219
+ std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
220
+ inputs = inputs / std
221
+
222
+ h = self.encoder(inputs) # B, H, C, L
223
+ hidden_size = h.size(1)
224
+ h = h.permute(0, 2, 3, 1).reshape(bs * c_num, length, hidden_size) # BC, L, H
225
+ for i in range(length):
226
+ spks, mems = self.net(h[:, i, :])
227
+ spks = spks.reshape(bs, c_num * hidden_size, -1) # B, CH, Time Step
228
+ spks = spks[:, :, -1] # aggregate over time dimension shape, (B, CH)
229
+ preds = self.fc(spks.view(bs, c_num, -1)).squeeze(-1) # B, O, C
230
+ preds = preds.permute(0, 2, 1).contiguous()
231
+
232
+ if self.args.normalize:
233
+ preds = preds * std + mean # denormalize
234
+
235
+ aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # palceholder
236
+
237
+ return preds, aux
238
+
239
+ @property
240
+ def output_size(self):
241
+ return self.__output_size
model/SpikeRNN_CPG.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from pathlib import Path
3
+ import torch
4
+ from torch import nn
5
+ from spikingjelly.activation_based import surrogate, neuron, functional
6
+ import math
7
+ import copy
8
+
9
+
10
+ tau = 2.0
11
+ backend = "torch"
12
+ detach_reset = True
13
+
14
+
15
+
16
+ def generate_ones_and_minus_ones_matrix(rows, cols):
17
+ random_matrix = torch.randint(0, 2, (rows, cols))
18
+ binary_matrix = torch.where(
19
+ random_matrix == 0,
20
+ -1 * torch.ones_like(random_matrix),
21
+ torch.ones_like(random_matrix),
22
+ )
23
+ return binary_matrix.float()
24
+
25
+
26
+ class RandomPE(nn.Module):
27
+ def __init__(
28
+ self,
29
+ d_model,
30
+ pe_mode="concat",
31
+ num_pe_neuron=10,
32
+ neuron_pe_scale=1000.0,
33
+ dropout=0.1,
34
+ num_steps=4,
35
+ ):
36
+ super().__init__()
37
+ self.max_len = 5000 # different from windows
38
+ self.pe_mode = pe_mode
39
+ self.neuron_pe_scale = neuron_pe_scale
40
+ self.dropout = nn.Dropout(p=dropout)
41
+ if self.pe_mode == "concat":
42
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
43
+ elif self.pe_mode == "add":
44
+ self.num_pe_neuron = copy.deepcopy(d_model)
45
+ pe = generate_ones_and_minus_ones_matrix(
46
+ self.max_len, self.num_pe_neuron
47
+ ) # MaxL, Neur
48
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
49
+ print("pe.shape: ", pe.shape)
50
+ self.register_buffer("pe", pe)
51
+
52
+ def forward(self, x):
53
+ # T, B, L, D
54
+ T, B, L, _ = x.shape
55
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
56
+ x = x.flatten(1, 2) # B, TL, D
57
+ if self.pe_mode == "concat":
58
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
59
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
60
+ x = torch.concat([x, tmp], dim=-1)
61
+ # print(x.shape) # B, TL, D'
62
+ elif self.pe_mode == "add":
63
+ # [B, TL, D] + [1, TL, Neur]
64
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
65
+ # print(x.shape) # B, TL, D
66
+ x = x.transpose(0, 1) # TL, B D
67
+ x = x.reshape(T, L, B, -1) # T, L, B, D
68
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
69
+ return self.dropout(x)
70
+
71
+
72
+ class NeuronPE(nn.Module):
73
+ def __init__(
74
+ self,
75
+ d_model,
76
+ pe_mode="concat",
77
+ num_pe_neuron=10,
78
+ neuron_pe_scale=10000.0,
79
+ dropout=0.1,
80
+ num_steps=4,
81
+ ):
82
+ super().__init__()
83
+ self.max_len = 50000 # different from windows
84
+ self.pe_mode = pe_mode
85
+ self.neuron_pe_scale = neuron_pe_scale
86
+ self.dropout = nn.Dropout(p=dropout)
87
+ if self.pe_mode == "concat":
88
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
89
+ elif self.pe_mode == "add":
90
+ self.num_pe_neuron = copy.deepcopy(d_model)
91
+ pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
92
+ position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
93
+ 1
94
+ ) # MaxL, 1
95
+ div_term = torch.exp(
96
+ torch.arange(0, self.num_pe_neuron, 2).float()
97
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
98
+ )
99
+ div_term_single = torch.exp(
100
+ torch.arange(0, self.num_pe_neuron - 1, 2).float()
101
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
102
+ )
103
+ pe[:, 0::2] = torch.heaviside(
104
+ torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
105
+ )
106
+ pe[:, 1::2] = torch.heaviside(
107
+ torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
108
+ )
109
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
110
+ print("pe.shape: ", pe.shape)
111
+ self.register_buffer("pe", pe)
112
+
113
+ def forward(self, x):
114
+ # T, B, L, D
115
+ T, B, L, _ = x.shape
116
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
117
+ x = x.flatten(1, 2) # B, TL, D
118
+ if self.pe_mode == "concat":
119
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
120
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
121
+ x = torch.concat([x, tmp], dim=-1)
122
+ # print(x.shape) # B, TL, D'
123
+ elif self.pe_mode == "add":
124
+ # [B, TL, D] + [1, TL, Neur]
125
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
126
+ # print(x.shape) # B, TL, D
127
+ x = x.transpose(0, 1) # TL, B D
128
+ x = x.reshape(T, L, B, -1) # T, L, B, D
129
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
130
+ return self.dropout(x)
131
+
132
+
133
+ class StaticPE(nn.Module):
134
+ r"""Inject some information about the relative or absolute position of the tokens
135
+ in the sequence. The positional encodings have the same dimension as
136
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
137
+ functions of different frequencies.
138
+ .. math::
139
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
140
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
141
+ \text{where pos is the word position and i is the embed idx)"""
142
+
143
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
144
+ super().__init__()
145
+ self.dropout = nn.Dropout(p=dropout)
146
+ pe = torch.zeros(max_len, d_model) # MaxL, D
147
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
148
+ div_term = torch.exp(
149
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
150
+ )
151
+ div_term_single = torch.exp(
152
+ torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
153
+ )
154
+ pe[:, 0::2] = torch.sin(position * div_term)
155
+ pe[:, 1::2] = torch.cos(position * div_term_single)
156
+ pe = pe.unsqueeze(0).transpose(0, 1)
157
+ self.register_buffer("pe", pe)
158
+
159
+ def forward(self, x):
160
+ # x: L, TB, D
161
+ x = x + self.pe[: x.size(0), :]
162
+ x = self.dropout(x)
163
+ return x
164
+
165
+
166
+ class ConvPE(nn.Module):
167
+ def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
168
+
169
+ super().__init__()
170
+ self.T = num_steps
171
+ self.rpe_conv = nn.Conv1d(
172
+ d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
173
+ )
174
+ self.rpe_bn = nn.BatchNorm1d(d_model)
175
+ self.rpe_lif = neuron.LIFNode(
176
+ step_mode="m",
177
+ detach_reset=True,
178
+ surrogate_function=surrogate.ATan(),
179
+ v_threshold=1.0,
180
+ )
181
+ self.dropout = nn.Dropout(p=dropout)
182
+
183
+ def forward(self, x):
184
+ # x: L, TB, D
185
+ L, TB, D = x.shape
186
+ x_feat = x.permute(1, 2, 0) # TB, D, L
187
+ x_feat = self.rpe_conv(x_feat) # TB, D, L
188
+ x_feat = (
189
+ self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
190
+ ) # T, B, D, L
191
+ x_feat = self.rpe_lif(x_feat)
192
+ x_feat = x_feat.flatten(0, 1) # TB, D, L
193
+ x_feat = self.dropout(x_feat) # TB, D, L
194
+ x_feat = x_feat.permute(2, 0, 1) # L, TB, D
195
+ x = x + x_feat
196
+ return x
197
+
198
+
199
+ class PositionEmbedding(nn.Module):
200
+ def __init__(
201
+ self,
202
+ input_size: int,
203
+ pe_type: str,
204
+ max_len: int = 5000,
205
+ pe_mode: str = "add",
206
+ num_pe_neuron: int = 10,
207
+ neuron_pe_scale: float = 1000.0,
208
+ dropout=0.1,
209
+ num_steps=4,
210
+ ):
211
+ super().__init__()
212
+ self.emb_type = pe_type
213
+ if pe_type in ["learn", "none"]:
214
+ self.emb = nn.Embedding(max_len, input_size)
215
+ elif pe_type == "conv":
216
+ self.emb = ConvPE(
217
+ d_model=input_size,
218
+ max_len=max_len,
219
+ dropout=dropout,
220
+ num_steps=num_steps,
221
+ )
222
+ elif pe_type == "static":
223
+ self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
224
+ elif pe_type == "neuron":
225
+ self.emb = NeuronPE(
226
+ d_model=input_size,
227
+ pe_mode=pe_mode,
228
+ num_pe_neuron=num_pe_neuron,
229
+ neuron_pe_scale=neuron_pe_scale,
230
+ dropout=dropout,
231
+ num_steps=num_steps,
232
+ )
233
+ elif pe_type == "random":
234
+ self.emb = RandomPE(
235
+ d_model=input_size,
236
+ pe_mode=pe_mode,
237
+ num_pe_neuron=num_pe_neuron,
238
+ neuron_pe_scale=neuron_pe_scale,
239
+ dropout=dropout,
240
+ num_steps=num_steps,
241
+ )
242
+ else:
243
+ raise ValueError("Unknown embedding type: {}".format(pe_type))
244
+
245
+ def forward(self, x):
246
+ if self.emb_type == "learn":
247
+ # T, B, L, D = x.shape # x: T, B, L, D
248
+ # x = x.flatten(0, 1) # TB, L, D
249
+ tmp = torch.arange(
250
+ end=x.size()[1], device=x.device
251
+ ) # [0,1,2,...,L-1], shape: L
252
+ embedding = self.emb(tmp) # shape: L, D
253
+ embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
254
+ x = x + embedding
255
+ # x = x.reshape(T, B, L, -1)
256
+ elif self.emb_type in ["static", "conv"]:
257
+ T, B, L, _ = x.shape # x: T, B, L, D
258
+ x = x.flatten(0, 1) # TB, L, D
259
+ x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
260
+ x = x.reshape(T, B, L, -1)
261
+ elif self.emb_type in ["neuron", "random"]:
262
+ T, B, L, _ = x.shape # x: T, B, L, D
263
+ # T, B, L, D
264
+ x = self.emb(x)
265
+ x = x.reshape(T, B, L, -1)
266
+ return x # T, B, L, D'
267
+
268
+
269
+ class RepeatEncoder(nn.Module):
270
+ def __init__(self, output_size: int):
271
+ super().__init__()
272
+ self.out_size = output_size
273
+ self.lif = neuron.LIFNode(
274
+ tau=tau,
275
+ step_mode="m",
276
+ detach_reset=detach_reset,
277
+ surrogate_function=surrogate.ATan(),
278
+ )
279
+
280
+ def forward(self, inputs: torch.Tensor):
281
+ # inputs: B, L, C
282
+ inputs = inputs.repeat(
283
+ tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
284
+ ) # T B L C
285
+ inputs = inputs.permute(0, 1, 3, 2) # T B C L
286
+ spks = self.lif(inputs) # T B C L
287
+ return spks
288
+
289
+
290
+ class DeltaEncoder(nn.Module):
291
+ def __init__(self, output_size: int):
292
+ super().__init__()
293
+ self.norm = nn.BatchNorm2d(1)
294
+ self.enc = nn.Linear(1, output_size)
295
+ self.lif = neuron.LIFNode(
296
+ tau=tau,
297
+ step_mode="m",
298
+ detach_reset=detach_reset,
299
+ surrogate_function=surrogate.ATan(),
300
+ )
301
+
302
+ def forward(self, inputs: torch.Tensor):
303
+ # inputs: B, L, C
304
+ delta = torch.zeros_like(inputs)
305
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
306
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
307
+ delta = self.norm(delta)
308
+ delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
309
+ enc = self.enc(delta) # B, C, L, T
310
+ enc = enc.permute(3, 0, 1, 2) # T, B, C, L
311
+ spks = self.lif(enc)
312
+ return spks
313
+
314
+
315
+ class ConvEncoder(nn.Module):
316
+ def __init__(self, output_size: int, kernel_size: int = 3):
317
+ super().__init__()
318
+ self.encoder = nn.Sequential(
319
+ nn.Conv2d(
320
+ in_channels=1,
321
+ out_channels=output_size,
322
+ kernel_size=(1, kernel_size),
323
+ stride=1,
324
+ padding=(0, kernel_size // 2),
325
+ ),
326
+ nn.BatchNorm2d(output_size),
327
+ )
328
+ self.lif = neuron.LIFNode(
329
+ tau=tau,
330
+ step_mode="m",
331
+ detach_reset=detach_reset,
332
+ surrogate_function=surrogate.ATan(),
333
+ )
334
+
335
+ def forward(self, inputs: torch.Tensor):
336
+ # inputs: B, L, C
337
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
338
+ enc = self.encoder(inputs) # B, T, C, L
339
+ enc = enc.permute(1, 0, 2, 3) # T, B, C, L
340
+ spks = self.lif(enc) # T, B, C, L
341
+ return spks
342
+
343
+
344
+
345
+ SpikeEncoder = {
346
+ "snntorch": {
347
+ "repeat": RepeatEncoder,
348
+ "conv": ConvEncoder,
349
+ "delta": DeltaEncoder,
350
+ },
351
+ "spikingjelly": {
352
+ "repeat": RepeatEncoder,
353
+ "conv": ConvEncoder,
354
+ "delta": DeltaEncoder,
355
+ },
356
+ }
357
+
358
+
359
+
360
+ class SpikeRNNCell(nn.Module):
361
+ def __init__(self, input_size: int, output_size: int):
362
+ super().__init__()
363
+ self.input_size = input_size
364
+ self.linear = nn.Linear(input_size, output_size)
365
+ self.lif = neuron.LIFNode(
366
+ tau=tau,
367
+ step_mode="m",
368
+ detach_reset=detach_reset,
369
+ surrogate_function=surrogate.ATan(),
370
+ )
371
+
372
+ def forward(self, x):
373
+ # T, B, L, C'
374
+ T, B, L, _ = x.shape
375
+ x = x.flatten(0, 1) # TB, L, C'
376
+ x = self.linear(x)
377
+ x = x.reshape(T, B, L, -1)
378
+ x = self.lif(x) # T, B, L, C'
379
+ return x
380
+
381
+
382
+ class SpikeRNN_CPG(nn.Module):
383
+
384
+ def __init__(
385
+ self,
386
+ args,
387
+ hidden_size: int,
388
+ layers: int = 1,
389
+ num_steps: int = 4,
390
+ input_size: Optional[int] = None,
391
+ max_length: Optional[int] = 5000,
392
+ weight_file: Optional[Path] = None,
393
+ encoder_type: Optional[str] = "conv",
394
+ num_pe_neuron: int = 40,
395
+ pe_type: str = "neuron",
396
+ pe_mode: str = "concat", # "add" or concat
397
+ neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
398
+ ):
399
+ super().__init__()
400
+ self._snn_backend = "spikingjelly"
401
+ self.hidden_size = args.hidden_size
402
+ self.num_steps = args.T
403
+ self.input_size = args.feature_size
404
+ self.pre_length = args.pre_length
405
+ self.layers = args.blocks
406
+ self.pe_type = pe_type
407
+ self.pe_mode = pe_mode
408
+ self.num_pe_neuron = num_pe_neuron
409
+ self.neuron_pe_scale = neuron_pe_scale
410
+ self.temporal_encoder = SpikeEncoder[self._snn_backend][encoder_type](self.num_steps)
411
+ self.args = args
412
+
413
+ self.pe = PositionEmbedding(
414
+ pe_type=pe_type,
415
+ pe_mode=pe_mode,
416
+ neuron_pe_scale=neuron_pe_scale,
417
+ input_size=self.input_size,
418
+ max_len=max_length,
419
+ num_pe_neuron=self.num_pe_neuron,
420
+ dropout=0.1,
421
+ num_steps=self.num_steps,
422
+ )
423
+
424
+ if self.pe_type == "neuron" and self.pe_mode == "concat":
425
+ self.dim = hidden_size + num_pe_neuron
426
+ else:
427
+ self.dim = hidden_size
428
+
429
+ if self.pe_type == "neuron" and self.pe_mode == "concat":
430
+ self.encoder = nn.Linear(input_size + num_pe_neuron, self.dim)
431
+ else:
432
+ self.encoder = nn.Linear(input_size, self.dim)
433
+
434
+ self.init_lif = neuron.LIFNode(
435
+ tau=tau,
436
+ step_mode="m",
437
+ detach_reset=detach_reset,
438
+ surrogate_function=surrogate.ATan(),
439
+ v_threshold=1.0,
440
+ backend=backend,
441
+ )
442
+
443
+ self.net = nn.Sequential(
444
+ *[
445
+ SpikeRNNCell(input_size=self.dim, output_size=self.dim)
446
+ for i in range(layers)
447
+ ]
448
+ )
449
+
450
+ self.__output_size = self.dim
451
+ self.fc1 = nn.Linear(self.__output_size, args.feature_size)
452
+ self.fc2 = nn.Linear(args.seq_length, self.pre_length)
453
+ self.to('cuda:0')
454
+
455
+
456
+ def forward(
457
+ self,
458
+ inputs: torch.Tensor,
459
+ ):
460
+ functional.reset_net(self)
461
+ if self.args.normalize:
462
+ mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
463
+ inputs = inputs - mean
464
+
465
+ std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
466
+ inputs = inputs / std
467
+
468
+
469
+ hiddens = self.temporal_encoder(inputs) # T, B, C, L
470
+ hiddens = hiddens.transpose(-2, -1) # T, B, L, C
471
+ T, B, L, _ = hiddens.size() # T, B, L, D
472
+ if self.pe_type != "none":
473
+ hiddens = self.pe(hiddens) # T B L C'
474
+ hiddens = self.encoder(hiddens.flatten(0, 1)).reshape(T, B, L, -1) # T B L D
475
+ hiddens = self.init_lif(hiddens)
476
+ hiddens = self.net(hiddens) # T, B, L, D
477
+ out = hiddens.mean(0) # B, L, D
478
+ preds = self.fc1(out) # B, L, C
479
+ preds = self.fc2(preds.permute(0, 2, 1)) # B, C, L
480
+ preds = preds.permute(0, 2, 1).contiguous()
481
+
482
+ if self.args.normalize:
483
+ preds = preds * std + mean # denormalize
484
+
485
+ aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
486
+
487
+ return preds, aux
488
+
489
+
model/SpikeTCN_CPG.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+ import snntorch as snn
7
+ from snntorch import surrogate
8
+ from snntorch import utils
9
+ import copy
10
+ import math
11
+
12
+ def generate_ones_and_minus_ones_matrix(rows, cols):
13
+ random_matrix = torch.randint(0, 2, (rows, cols))
14
+ binary_matrix = torch.where(
15
+ random_matrix == 0,
16
+ -1 * torch.ones_like(random_matrix),
17
+ torch.ones_like(random_matrix),
18
+ )
19
+ return binary_matrix.float()
20
+
21
+
22
+ class RandomPE(nn.Module):
23
+ def __init__(
24
+ self,
25
+ d_model,
26
+ pe_mode="concat",
27
+ num_pe_neuron=10,
28
+ neuron_pe_scale=1000.0,
29
+ dropout=0.1,
30
+ num_steps=4,
31
+ ):
32
+ super().__init__()
33
+ self.max_len = 5000 # different from windows
34
+ self.pe_mode = pe_mode
35
+ self.neuron_pe_scale = neuron_pe_scale
36
+ self.dropout = nn.Dropout(p=dropout)
37
+ if self.pe_mode == "concat":
38
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
39
+ elif self.pe_mode == "add":
40
+ self.num_pe_neuron = copy.deepcopy(d_model)
41
+ pe = generate_ones_and_minus_ones_matrix(
42
+ self.max_len, self.num_pe_neuron
43
+ ) # MaxL, Neur
44
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
45
+ print("pe.shape: ", pe.shape)
46
+ self.register_buffer("pe", pe)
47
+
48
+ def forward(self, x):
49
+ # T, B, L, D
50
+ T, B, L, _ = x.shape
51
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
52
+ x = x.flatten(1, 2) # B, TL, D
53
+ if self.pe_mode == "concat":
54
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
55
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
56
+ x = torch.concat([x, tmp], dim=-1)
57
+ # print(x.shape) # B, TL, D'
58
+ elif self.pe_mode == "add":
59
+ # [B, TL, D] + [1, TL, Neur]
60
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
61
+ # print(x.shape) # B, TL, D
62
+ x = x.transpose(0, 1) # TL, B D
63
+ x = x.reshape(T, L, B, -1) # T, L, B, D
64
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
65
+ return self.dropout(x)
66
+
67
+
68
+ class NeuronPE(nn.Module):
69
+ def __init__(
70
+ self,
71
+ d_model,
72
+ pe_mode="concat",
73
+ num_pe_neuron=10,
74
+ neuron_pe_scale=10000.0,
75
+ dropout=0.1,
76
+ num_steps=4,
77
+ ):
78
+ super().__init__()
79
+ self.max_len = 50000 # different from windows
80
+ self.pe_mode = pe_mode
81
+ self.neuron_pe_scale = neuron_pe_scale
82
+ self.dropout = nn.Dropout(p=dropout)
83
+ if self.pe_mode == "concat":
84
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
85
+ elif self.pe_mode == "add":
86
+ self.num_pe_neuron = copy.deepcopy(d_model)
87
+ pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
88
+ position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
89
+ 1
90
+ ) # MaxL, 1
91
+ div_term = torch.exp(
92
+ torch.arange(0, self.num_pe_neuron, 2).float()
93
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
94
+ )
95
+ div_term_single = torch.exp(
96
+ torch.arange(0, self.num_pe_neuron - 1, 2).float()
97
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
98
+ )
99
+ pe[:, 0::2] = torch.heaviside(
100
+ torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
101
+ )
102
+ pe[:, 1::2] = torch.heaviside(
103
+ torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
104
+ )
105
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
106
+ print("pe.shape: ", pe.shape)
107
+ self.register_buffer("pe", pe)
108
+
109
+ def forward(self, x):
110
+ # T, B, L, D
111
+ T, B, L, _ = x.shape
112
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
113
+ x = x.flatten(1, 2) # B, TL, D
114
+ if self.pe_mode == "concat":
115
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
116
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
117
+ x = torch.concat([x, tmp], dim=-1)
118
+ # print(x.shape) # B, TL, D'
119
+ elif self.pe_mode == "add":
120
+ # [B, TL, D] + [1, TL, Neur]
121
+ # print(self.pe[:x.size(-2), :].shape)
122
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
123
+ # print(x.shape) # B, TL, D
124
+ x = x.transpose(0, 1) # TL, B D
125
+ x = x.reshape(T, L, B, -1) # T, L, B, D
126
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
127
+ return self.dropout(x)
128
+
129
+
130
+ class StaticPE(nn.Module):
131
+ r"""Inject some information about the relative or absolute position of the tokens
132
+ in the sequence. The positional encodings have the same dimension as
133
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
134
+ functions of different frequencies.
135
+ .. math::
136
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
137
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
138
+ \text{where pos is the word position and i is the embed idx)"""
139
+
140
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
141
+ super().__init__()
142
+ self.dropout = nn.Dropout(p=dropout)
143
+ pe = torch.zeros(max_len, d_model) # MaxL, D
144
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
145
+ div_term = torch.exp(
146
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
147
+ )
148
+ div_term_single = torch.exp(
149
+ torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
150
+ )
151
+ pe[:, 0::2] = torch.sin(position * div_term)
152
+ pe[:, 1::2] = torch.cos(position * div_term_single)
153
+ pe = pe.unsqueeze(0).transpose(0, 1)
154
+ self.register_buffer("pe", pe)
155
+
156
+ def forward(self, x):
157
+ # x: L, TB, D
158
+ x = x + self.pe[: x.size(0), :]
159
+ x = self.dropout(x)
160
+ return x
161
+
162
+
163
+ class ConvPE(nn.Module):
164
+ def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
165
+
166
+ super().__init__()
167
+ self.T = num_steps
168
+ self.rpe_conv = nn.Conv1d(
169
+ d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
170
+ )
171
+ self.rpe_bn = nn.BatchNorm1d(d_model)
172
+ self.rpe_lif = neuron.LIFNode(
173
+ step_mode="m",
174
+ detach_reset=True,
175
+ surrogate_function=surrogate.ATan(),
176
+ v_threshold=1.0,
177
+ )
178
+ self.dropout = nn.Dropout(p=dropout)
179
+
180
+ def forward(self, x):
181
+ # x: L, TB, D
182
+ L, TB, D = x.shape
183
+ x_feat = x.permute(1, 2, 0) # TB, D, L
184
+ x_feat = self.rpe_conv(x_feat) # TB, D, L
185
+ x_feat = (
186
+ self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
187
+ ) # T, B, D, L
188
+ x_feat = self.rpe_lif(x_feat)
189
+ x_feat = x_feat.flatten(0, 1) # TB, D, L
190
+ x_feat = self.dropout(x_feat) # TB, D, L
191
+ x_feat = x_feat.permute(2, 0, 1) # L, TB, D
192
+ x = x + x_feat
193
+ return x
194
+
195
+
196
+ class PositionEmbedding(nn.Module):
197
+ def __init__(
198
+ self,
199
+ input_size: int,
200
+ pe_type: str,
201
+ max_len: int = 5000,
202
+ pe_mode: str = "add",
203
+ num_pe_neuron: int = 10,
204
+ neuron_pe_scale: float = 1000.0,
205
+ dropout=0.1,
206
+ num_steps=4,
207
+ ):
208
+ super().__init__()
209
+ self.emb_type = pe_type
210
+ if pe_type in ["learn", "none"]:
211
+ self.emb = nn.Embedding(max_len, input_size)
212
+ elif pe_type == "conv":
213
+ self.emb = ConvPE(
214
+ d_model=input_size,
215
+ max_len=max_len,
216
+ dropout=dropout,
217
+ num_steps=num_steps,
218
+ )
219
+ elif pe_type == "static":
220
+ self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
221
+ elif pe_type == "neuron":
222
+ self.emb = NeuronPE(
223
+ d_model=input_size,
224
+ pe_mode=pe_mode,
225
+ num_pe_neuron=num_pe_neuron,
226
+ neuron_pe_scale=neuron_pe_scale,
227
+ dropout=dropout,
228
+ num_steps=num_steps,
229
+ )
230
+ elif pe_type == "random":
231
+ self.emb = RandomPE(
232
+ d_model=input_size,
233
+ pe_mode=pe_mode,
234
+ num_pe_neuron=num_pe_neuron,
235
+ neuron_pe_scale=neuron_pe_scale,
236
+ dropout=dropout,
237
+ num_steps=num_steps,
238
+ )
239
+ else:
240
+ raise ValueError("Unknown embedding type: {}".format(pe_type))
241
+
242
+ def forward(self, x):
243
+ if self.emb_type == "learn":
244
+ # T, B, L, D = x.shape # x: T, B, L, D
245
+ # x = x.flatten(0, 1) # TB, L, D
246
+ tmp = torch.arange(
247
+ end=x.size()[1], device=x.device
248
+ ) # [0,1,2,...,L-1], shape: L
249
+ embedding = self.emb(tmp) # shape: L, D
250
+ embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
251
+ x = x + embedding
252
+ # x = x.reshape(T, B, L, -1)
253
+ elif self.emb_type in ["static", "conv"]:
254
+ T, B, L, _ = x.shape # x: T, B, L, D
255
+ x = x.flatten(0, 1) # TB, L, D
256
+ x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
257
+ x = x.reshape(T, B, L, -1)
258
+ elif self.emb_type in ["neuron", "random"]:
259
+ T, B, L, _ = x.shape # x: T, B, L, D
260
+ # T, B, L, D
261
+ x = self.emb(x)
262
+ x = x.reshape(T, B, L, -1)
263
+ return x # T, B, L, D'
264
+
265
+
266
+
267
+
268
+
269
+
270
+ class RepeatEncoder(nn.Module):
271
+ def __init__(self, output_size: int):
272
+ super().__init__()
273
+ self.out_size = output_size
274
+ self.lif = snn.Leaky(
275
+ beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
276
+ )
277
+
278
+ def forward(self, inputs: torch.Tensor):
279
+ # inputs: batch, L, C
280
+ inputs = inputs.repeat(
281
+ tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
282
+ ) # out_size batch L C
283
+ inputs = inputs.permute(1, 0, 3, 2) # batch out_size L C
284
+ spks = self.lif(inputs)
285
+ return spks
286
+
287
+
288
+ class ConvEncoder(nn.Module):
289
+ def __init__(self, output_size: int, kernel_size: int = 3):
290
+ super().__init__()
291
+ self.encoder = nn.Sequential(
292
+ nn.Conv2d(
293
+ in_channels=1,
294
+ out_channels=output_size,
295
+ kernel_size=(1, kernel_size),
296
+ stride=1,
297
+ padding=(0, kernel_size // 2),
298
+ ),
299
+ nn.BatchNorm2d(output_size),
300
+ )
301
+ self.lif = snn.Leaky(
302
+ beta=0.99,
303
+ spike_grad=surrogate.atan(alpha=2.0),
304
+ init_hidden=True,
305
+ output=False,
306
+ )
307
+
308
+ def forward(self, inputs: torch.Tensor):
309
+ # inputs: batch, L, C
310
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
311
+ enc = self.encoder(inputs) # batch, output_size, C, L
312
+ spks = self.lif(enc)
313
+ return spks
314
+
315
+
316
+ class DeltaEncoder(nn.Module):
317
+ def __init__(self, output_size: int):
318
+ super().__init__()
319
+ self.norm = nn.BatchNorm2d(1)
320
+ self.enc = nn.Linear(1, output_size)
321
+ self.lif = snn.Leaky(
322
+ beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
323
+ )
324
+
325
+ def forward(self, inputs: torch.Tensor):
326
+ # inputs: batch, L, C
327
+ delta = torch.zeros_like(inputs)
328
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
329
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
330
+ delta = self.norm(delta)
331
+ delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
332
+ enc = self.enc(delta) # batch, C, L, output_size
333
+ enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
334
+ spks = self.lif(enc)
335
+ return spks
336
+
337
+
338
+
339
+ class Chomp1d(nn.Module):
340
+ def __init__(self, chomp_size):
341
+ super().__init__()
342
+ self.chomp_size = chomp_size
343
+
344
+ def forward(self, x):
345
+ return x[:, :, : -self.chomp_size].contiguous()
346
+
347
+
348
+ class Chomp2d(nn.Module):
349
+ def __init__(self, chomp_size):
350
+ super().__init__()
351
+ self.chomp_size = chomp_size
352
+
353
+ def forward(self, x):
354
+ return x[:, :, :, : -self.chomp_size].contiguous()
355
+
356
+
357
+
358
+ SpikeEncoder = {
359
+ "snntorch": {
360
+ "repeat": RepeatEncoder,
361
+ "conv": ConvEncoder,
362
+ "delta": DeltaEncoder,
363
+ },
364
+ "spikingjelly": {
365
+ "repeat": RepeatEncoder,
366
+ "conv": ConvEncoder,
367
+ "delta": DeltaEncoder,
368
+ },
369
+ }
370
+
371
+
372
+
373
+ class SpikeTemporalBlock2D(nn.Module):
374
+ def __init__(
375
+ self,
376
+ n_inputs,
377
+ n_outputs,
378
+ kernel_size,
379
+ stride,
380
+ dilation,
381
+ padding,
382
+ num_steps=4,
383
+ ):
384
+ super().__init__()
385
+ self.num_steps = num_steps
386
+ self.conv1 = weight_norm(
387
+ nn.Conv2d(
388
+ n_inputs,
389
+ n_outputs,
390
+ (1, kernel_size),
391
+ stride=stride,
392
+ padding=(0, padding),
393
+ dilation=(1, dilation),
394
+ )
395
+ )
396
+ self.bn1 = nn.BatchNorm2d(n_outputs)
397
+ self.chomp1 = Chomp2d(padding)
398
+ self.lif1 = snn.Leaky(
399
+ beta=0.99,
400
+ spike_grad=surrogate.atan(alpha=2.0),
401
+ init_hidden=True,
402
+ threshold=1.0,
403
+ )
404
+
405
+ self.conv2 = weight_norm(
406
+ nn.Conv2d(
407
+ n_outputs,
408
+ n_outputs,
409
+ (1, kernel_size),
410
+ stride=stride,
411
+ padding=(0, padding),
412
+ dilation=(1, dilation),
413
+ )
414
+ )
415
+ self.bn2 = nn.BatchNorm2d(n_outputs)
416
+ self.chomp2 = Chomp2d(padding)
417
+ self.lif2 = snn.Leaky(
418
+ beta=0.99,
419
+ spike_grad=surrogate.atan(alpha=2.0),
420
+ init_hidden=True,
421
+ threshold=1.0,
422
+ )
423
+
424
+ self.downsample = (
425
+ nn.Conv2d(n_inputs, n_outputs, (1, 1)) if n_inputs != n_outputs else None
426
+ )
427
+ self.lif = snn.Leaky(
428
+ beta=0.99,
429
+ spike_grad=surrogate.atan(alpha=2.0),
430
+ init_hidden=True,
431
+ threshold=1.0,
432
+ )
433
+
434
+ def init_weights(self):
435
+ self.conv1.weight.data.normal_(0, 0.01)
436
+ self.conv2.weight.data.normal_(0, 0.01)
437
+ if self.downsample is not None:
438
+ self.downsample.weight.data.normal_(0, 0.01)
439
+
440
+ def forward(self, x):
441
+ out1 = self.chomp1(self.bn1(self.conv1(x)))
442
+ spk_rec1 = []
443
+ for _ in range(self.num_steps):
444
+ spk = self.lif1(out1)
445
+ spk_rec1.append(spk)
446
+ spks1 = torch.stack(spk_rec1, dim=-1) # spks1: B, H, C, L, T
447
+ spks1 = spks1.mean(-1) # spks1: B, H, C, L
448
+
449
+ out2 = self.chomp2(self.bn2(self.conv2(spks1)))
450
+ spk_rec2 = []
451
+ for _ in range(self.num_steps):
452
+ spk = self.lif2(out2)
453
+ spk_rec2.append(spk)
454
+ spks2 = torch.stack(spk_rec2, dim=-1) # spks2: B, H, C, L, T
455
+ spks2 = spks2.mean(-1) # spks2: B, H, C, L
456
+
457
+ if torch.isnan(spks2).any() or torch.isinf(spks2).any():
458
+ print("illegal value in TemporalBlock2D")
459
+
460
+ if self.downsample is None:
461
+ res = x
462
+ else:
463
+ res = self.downsample(x)
464
+ spk_rec3 = []
465
+ for _ in range(self.num_steps):
466
+ spk = self.lif(spks2 + res)
467
+ spk_rec3.append(spk)
468
+
469
+ res = torch.stack(spk_rec3, dim=-1) # res: B, H, C, L, T
470
+ res = res.mean(-1)
471
+
472
+ return res
473
+
474
+
475
+ class SpikeTCN_CPG(nn.Module):
476
+
477
+
478
+ def __init__(
479
+ self,
480
+ args,
481
+ num_levels: int=3,
482
+ channel: int=16,
483
+ dilation: int=2,
484
+ stride: int = 1,
485
+ num_steps: int = 16,
486
+ kernel_size: int = 2,
487
+ dropout: float = 0.2,
488
+ max_length: int = 100,
489
+ input_size: Optional[int] = None,
490
+ hidden_size: int = 128,
491
+ encoder_type: Optional[str] = "conv",
492
+ num_pe_neuron: int = 40,
493
+ pe_type: str = "neuron",
494
+ pe_mode: str = "concat", # "add" or "concat"
495
+ neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
496
+ ):
497
+ """
498
+ Args:
499
+ num_channels: The number of convolutional channels in each layer.
500
+ kernel_size: The kernel size of convolutional layers.
501
+ dropout: Dropout rate.
502
+ """
503
+ super().__init__()
504
+ self.pe_type = pe_type
505
+ self._snn_backend = "snntorch"
506
+ self.pe_mode = pe_mode
507
+ self.num_pe_neuron = num_pe_neuron
508
+ self.hidden_size = args.hidden_size
509
+ self.num_steps = args.T
510
+ self.input_size = args.feature_size
511
+ self.pre_length = args.pre_length
512
+ self.num_levels = args.blocks
513
+ self.pe_type = pe_type
514
+ self.pe_mode = pe_mode
515
+ self.num_pe_neuron = num_pe_neuron
516
+ self.kernel_size = args.kernel_size
517
+
518
+ self.encoder = SpikeEncoder[self._snn_backend][encoder_type](self.hidden_size)
519
+ self.args = args
520
+
521
+
522
+ self.pe = PositionEmbedding(
523
+ pe_type=pe_type,
524
+ pe_mode=pe_mode,
525
+ neuron_pe_scale=neuron_pe_scale,
526
+ input_size=self.input_size,
527
+ max_len=max_length,
528
+ num_pe_neuron=self.num_pe_neuron,
529
+ dropout=0.1,
530
+ num_steps=self.num_steps,
531
+ )
532
+ layers = []
533
+ num_channels = [channel] * self.num_levels
534
+ num_channels.append(1)
535
+ for i in range(self.num_levels + 1):
536
+ dilation_size = dilation**i
537
+ in_channels = self.hidden_size if i == 0 else num_channels[i - 1]
538
+ out_channels = num_channels[i]
539
+ layers += [
540
+ SpikeTemporalBlock2D(
541
+ in_channels,
542
+ out_channels,
543
+ self.kernel_size,
544
+ stride=stride,
545
+ dilation=dilation_size,
546
+ padding=(self.kernel_size - 1) * dilation_size,
547
+ num_steps=self.num_steps,
548
+ )
549
+ ]
550
+
551
+ self.network = nn.Sequential(*layers)
552
+ if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
553
+ self.pe_type == "random" and self.pe_mode == "concat"
554
+ ):
555
+ self.__output_size = args.feature_size + num_pe_neuron
556
+ else:
557
+ self.__output_size = args.seq_length
558
+
559
+ self.fc1 = nn.Linear(self.__output_size, args.feature_size)
560
+ self.fc2 = nn.Linear(args.seq_length, self.pre_length)
561
+ self.to('cuda:0')
562
+
563
+ def forward(self, inputs: torch.Tensor):
564
+ utils.reset(self.encoder)
565
+ for layer in self.network:
566
+ utils.reset(layer)
567
+
568
+ if self.args.normalize:
569
+ mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
570
+ inputs = inputs - mean
571
+
572
+ std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
573
+ inputs = inputs / std
574
+
575
+ inputs = self.encoder(inputs) # B, H, C, L
576
+ if self.pe_type != "none":
577
+ # B, H, C, L -> H B L C' -> B H C' L
578
+ inputs = self.pe(inputs.permute(1, 0, 3, 2)).permute(1, 0, 3, 2)
579
+ spks = self.network(inputs)
580
+ spks = spks.squeeze(1) # B, C', L
581
+
582
+ preds = self.fc1(spks.permute(0, 2, 1)) # B, L, C
583
+ preds = self.fc2(preds.permute(0, 2, 1)) # B, C', L
584
+ #.squeeze(-1) # B, O, C'
585
+ preds = preds.permute(0, 2, 1).contiguous()
586
+ if self.args.normalize:
587
+ preds = preds * std + mean # denormalize
588
+ aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
589
+
590
+ return preds, aux
591
+
592
+ @property
593
+ def output_size(self):
594
+ return self.__output_size
595
+
596
+
model/Spikformer_CPG.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pathlib import Path
4
+ import torch
5
+ from torch import nn
6
+ from spikingjelly.activation_based import surrogate, neuron, functional
7
+
8
+ import math
9
+ from dataclasses import dataclass
10
+ import warnings
11
+
12
+
13
+
14
+ tau = 2.0 # beta = 1 - 1/tau
15
+ backend = "torch"
16
+ detach_reset = True
17
+
18
+
19
+
20
+ @dataclass
21
+ class CPG(nn.Module):
22
+ num_neurons: int = 40
23
+ w_max: float = 10000.0
24
+ l_max: int = 5000
25
+
26
+ def __post_init__(self):
27
+ self._cpg = torch.zeros(self.l_max, self.num_neurons)
28
+ position = torch.arange(0, self.l_max, dtype=torch.float).unsqueeze(
29
+ 1
30
+ ) # MaxL, 1
31
+ div_term = torch.exp(
32
+ torch.arange(0, self.num_neurons, 2).float()
33
+ * (-math.log(self.w_max) / self.num_neurons)
34
+ )
35
+ div_term_single = torch.exp(
36
+ torch.arange(0, self.num_neurons - 1, 2).float()
37
+ * (-math.log(self.w_max) / self.num_neurons)
38
+ )
39
+ self._cpg[:, 0::2] = torch.heaviside(
40
+ torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
41
+ )
42
+ self._cpg[:, 1::2] = torch.heaviside(
43
+ torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
44
+ )
45
+
46
+ @property
47
+ def cpg(self):
48
+ return self._cpg
49
+
50
+
51
+ class CPGLinear(nn.Module):
52
+ def __init__(
53
+ self, input_size: int, output_size: int, cpg: CPG = CPG(), dropout: float = 0.1
54
+ ):
55
+ super().__init__()
56
+ self.cpg = nn.Parameter(cpg.cpg, requires_grad=False)
57
+ self.inp_linear = nn.Linear(input_size, output_size)
58
+ self.cpg_linear = nn.Linear(cpg.num_neurons, output_size)
59
+ self.dropout = nn.Dropout(dropout)
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ # B TL D
63
+ cpg = self.cpg[: x.size(-2)]
64
+ x = self.dropout(x)
65
+ return self.inp_linear(x) + self.cpg_linear(cpg)
66
+
67
+
68
+
69
+
70
+ class RepeatEncoder(nn.Module):
71
+ def __init__(self, output_size: int):
72
+ super().__init__()
73
+ self.out_size = output_size
74
+ self.lif = neuron.LIFNode(
75
+ tau=tau,
76
+ step_mode="m",
77
+ detach_reset=detach_reset,
78
+ surrogate_function=surrogate.ATan(),
79
+ )
80
+
81
+ def forward(self, inputs: torch.Tensor):
82
+ # inputs: B, L, C
83
+ inputs = inputs.repeat(
84
+ tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
85
+ ) # T B L C
86
+ inputs = inputs.permute(0, 1, 3, 2) # T B C L
87
+ spks = self.lif(inputs) # T B C L
88
+ return spks
89
+
90
+
91
+ class DeltaEncoder(nn.Module):
92
+ def __init__(self, output_size: int):
93
+ super().__init__()
94
+ self.norm = nn.BatchNorm2d(1)
95
+ self.enc = nn.Linear(1, output_size)
96
+ self.lif = neuron.LIFNode(
97
+ tau=tau,
98
+ step_mode="m",
99
+ detach_reset=detach_reset,
100
+ surrogate_function=surrogate.ATan(),
101
+ )
102
+
103
+ def forward(self, inputs: torch.Tensor):
104
+ # inputs: B, L, C
105
+ delta = torch.zeros_like(inputs)
106
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
107
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
108
+ delta = self.norm(delta)
109
+ delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
110
+ enc = self.enc(delta) # B, C, L, T
111
+ enc = enc.permute(3, 0, 1, 2) # T, B, C, L
112
+ spks = self.lif(enc)
113
+ return spks
114
+
115
+
116
+ class ConvEncoder(nn.Module):
117
+ def __init__(self, output_size: int, kernel_size: int = 3):
118
+ super().__init__()
119
+ self.encoder = nn.Sequential(
120
+ nn.Conv2d(
121
+ in_channels=1,
122
+ out_channels=output_size,
123
+ kernel_size=(1, kernel_size),
124
+ stride=1,
125
+ padding=(0, kernel_size // 2),
126
+ ),
127
+ nn.BatchNorm2d(output_size),
128
+ )
129
+ self.lif = neuron.LIFNode(
130
+ tau=tau,
131
+ step_mode="m",
132
+ detach_reset=detach_reset,
133
+ surrogate_function=surrogate.ATan(),
134
+ )
135
+
136
+ def forward(self, inputs: torch.Tensor):
137
+ # inputs: B, L, C
138
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
139
+ enc = self.encoder(inputs) # B, T, C, L
140
+ enc = enc.permute(1, 0, 2, 3) # T, B, C, L
141
+ spks = self.lif(enc) # T, B, C, L
142
+ return spks
143
+
144
+
145
+
146
+ SpikeEncoder = {
147
+ "snntorch": {
148
+ "repeat": RepeatEncoder,
149
+ "conv": ConvEncoder,
150
+ "delta": DeltaEncoder,
151
+ },
152
+ "spikingjelly": {
153
+ "repeat": RepeatEncoder,
154
+ "conv": ConvEncoder,
155
+ "delta": DeltaEncoder,
156
+ },
157
+ }
158
+
159
+
160
+
161
+
162
+
163
+ class SSA(nn.Module):
164
+ def __init__(
165
+ self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25
166
+ ):
167
+ super().__init__()
168
+ assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}."
169
+
170
+ self.dim = dim
171
+ self.heads = heads
172
+ self.qk_scale = qk_scale
173
+
174
+ self.q_m = nn.Linear(dim, dim)
175
+ self.q_bn = nn.BatchNorm1d(dim)
176
+ self.q_lif = neuron.LIFNode(
177
+ tau=tau,
178
+ step_mode="m",
179
+ detach_reset=detach_reset,
180
+ surrogate_function=surrogate.ATan(),
181
+ v_threshold=common_thr,
182
+ backend=backend,
183
+ )
184
+
185
+ self.k_m = nn.Linear(dim, dim)
186
+ self.k_bn = nn.BatchNorm1d(dim)
187
+ self.k_lif = neuron.LIFNode(
188
+ tau=tau,
189
+ step_mode="m",
190
+ detach_reset=detach_reset,
191
+ surrogate_function=surrogate.ATan(),
192
+ v_threshold=common_thr,
193
+ backend=backend,
194
+ )
195
+
196
+ self.v_m = nn.Linear(dim, dim)
197
+ self.v_bn = nn.BatchNorm1d(dim)
198
+ self.v_lif = neuron.LIFNode(
199
+ tau=tau,
200
+ step_mode="m",
201
+ detach_reset=detach_reset,
202
+ surrogate_function=surrogate.ATan(),
203
+ v_threshold=common_thr,
204
+ backend=backend,
205
+ )
206
+
207
+ self.attn_lif = neuron.LIFNode(
208
+ tau=tau,
209
+ step_mode="m",
210
+ detach_reset=detach_reset,
211
+ surrogate_function=surrogate.ATan(),
212
+ v_threshold=common_thr / 2,
213
+ backend=backend,
214
+ )
215
+
216
+ self.last_m = nn.Linear(dim, dim)
217
+ self.last_bn = nn.BatchNorm1d(dim)
218
+ self.last_lif = neuron.LIFNode(
219
+ tau=tau,
220
+ step_mode="m",
221
+ detach_reset=detach_reset,
222
+ surrogate_function=surrogate.ATan(),
223
+ v_threshold=common_thr,
224
+ backend=backend,
225
+ )
226
+
227
+ def forward(self, x):
228
+ T, B, L, D = x.shape
229
+ x_for_qkv = x.flatten(0, 1) # TB L D
230
+ q_m_out = self.q_m(x_for_qkv) # TB L D
231
+ q_m_out = (
232
+ self.q_bn(q_m_out.transpose(-1, -2))
233
+ .transpose(-1, -2)
234
+ .reshape(T, B, L, D)
235
+ .contiguous()
236
+ )
237
+ q_m_out = self.q_lif(q_m_out)
238
+ q = (
239
+ q_m_out.reshape(T, B, L, self.heads, D // self.heads)
240
+ .permute(0, 1, 3, 2, 4)
241
+ .contiguous()
242
+ )
243
+
244
+ k_m_out = self.k_m(x_for_qkv)
245
+ k_m_out = (
246
+ self.k_bn(k_m_out.transpose(-1, -2))
247
+ .transpose(-1, -2)
248
+ .reshape(T, B, L, D)
249
+ .contiguous()
250
+ )
251
+ k_m_out = self.k_lif(k_m_out)
252
+ k = (
253
+ k_m_out.reshape(T, B, L, self.heads, D // self.heads)
254
+ .permute(0, 1, 3, 2, 4)
255
+ .contiguous()
256
+ )
257
+
258
+ v_m_out = self.v_m(x_for_qkv)
259
+ v_m_out = (
260
+ self.v_bn(v_m_out.transpose(-1, -2))
261
+ .transpose(-1, -2)
262
+ .reshape(T, B, L, D)
263
+ .contiguous()
264
+ )
265
+ v_m_out = self.v_lif(v_m_out)
266
+ v = (
267
+ v_m_out.reshape(T, B, L, self.heads, D // self.heads)
268
+ .permute(0, 1, 3, 2, 4)
269
+ .contiguous()
270
+ )
271
+
272
+ attn = (q @ k.transpose(-2, -1)) * self.qk_scale
273
+ x = attn @ v # x_shape: T * B * heads * L * D//heads
274
+
275
+ x = x.transpose(2, 3).reshape(T, B, L, D).contiguous()
276
+ x = self.attn_lif(x)
277
+
278
+ x = x.flatten(0, 1)
279
+ x = self.last_m(x)
280
+ x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2)
281
+ x = self.last_lif(x.reshape(T, B, L, D).contiguous())
282
+ return x
283
+
284
+
285
+ class MLP(nn.Module):
286
+ def __init__(
287
+ self,
288
+ length,
289
+ tau,
290
+ common_thr,
291
+ in_features,
292
+ hidden_features=None,
293
+ out_features=None,
294
+ ):
295
+ super().__init__()
296
+ out_features = out_features or in_features
297
+ self.in_features = in_features
298
+ self.hidden_features = hidden_features
299
+ self.out_features = out_features
300
+
301
+ self.fc1 = CPGLinear(in_features, hidden_features)
302
+ self.bn1 = nn.BatchNorm1d(hidden_features)
303
+ self.lif1 = neuron.LIFNode(
304
+ tau=tau,
305
+ step_mode="m",
306
+ detach_reset=detach_reset,
307
+ surrogate_function=surrogate.ATan(),
308
+ v_threshold=common_thr,
309
+ backend=backend,
310
+ )
311
+
312
+ self.fc2 = CPGLinear(hidden_features, out_features)
313
+ self.bn2 = nn.BatchNorm1d(out_features)
314
+ self.lif2 = neuron.LIFNode(
315
+ tau=tau,
316
+ step_mode="m",
317
+ detach_reset=detach_reset,
318
+ surrogate_function=surrogate.ATan(),
319
+ v_threshold=common_thr,
320
+ backend=backend,
321
+ )
322
+
323
+ def forward(self, x):
324
+ T, B, L, D = x.shape
325
+ x = x.transpose(0, 1).flatten(1, 2) # B TL D
326
+ x = self.fc1(x) # B TL H
327
+ x = (
328
+ self.bn1(x.transpose(-1, -2))
329
+ .transpose(-1, -2)
330
+ .reshape(B, T, L, self.hidden_features)
331
+ .contiguous()
332
+ ) # B T L H
333
+ x = self.lif1(x.transpose(0, 1)).transpose(0, 1) # B T L H
334
+ x = x.flatten(1, 2) # B TL H
335
+ x = self.fc2(x) # B TL D
336
+ x = (
337
+ self.bn2(x.transpose(-1, -2))
338
+ .transpose(-1, -2)
339
+ .reshape(B, T, L, D)
340
+ .contiguous()
341
+ ) # B T L D
342
+ x = self.lif2(x.transpose(0, 1)) # T B L D
343
+ return x
344
+
345
+
346
+ class Block(nn.Module):
347
+ def __init__(
348
+ self,
349
+ length,
350
+ tau,
351
+ common_thr,
352
+ dim,
353
+ d_ff,
354
+ heads=8,
355
+ qkv_bias=False,
356
+ qk_scale=0.125,
357
+ ):
358
+ super().__init__()
359
+ self.attn = SSA(
360
+ length=length,
361
+ tau=tau,
362
+ common_thr=common_thr,
363
+ dim=dim,
364
+ heads=heads,
365
+ qkv_bias=qkv_bias,
366
+ qk_scale=qk_scale,
367
+ )
368
+ self.mlp = MLP(
369
+ length=length,
370
+ tau=tau,
371
+ common_thr=common_thr,
372
+ in_features=dim,
373
+ hidden_features=d_ff,
374
+ )
375
+
376
+ def forward(self, x):
377
+ # T B L D
378
+ x = x + self.attn(x)
379
+ x = x + self.mlp(x)
380
+ return x
381
+
382
+
383
+ class Spikformer_CPG(nn.Module):
384
+ def __init__(
385
+ self,
386
+ args,
387
+ dim: int=256,
388
+ d_ff: Optional[int] = None,
389
+ num_pe_neuron: int = 40,
390
+ pe_type: str = "neuron",
391
+ pe_mode: str = "concat", # "add" or concat
392
+ neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
393
+ depths: int = 2,
394
+ common_thr: float = 1.0,
395
+ max_length: int = 5000,
396
+ num_steps: int = 4,
397
+ heads: int = 8,
398
+ qkv_bias: bool = False,
399
+ qk_scale: float = 0.125,
400
+ input_size: Optional[int] = None,
401
+ weight_file: Optional[Path] = None,
402
+ ):
403
+ super().__init__()
404
+ self.dim = 256
405
+ self.d_ff = 1024
406
+ self.T = args.T
407
+ self.depths = args.blocks
408
+ self.pe_type = pe_type
409
+ self.pe_mode = pe_mode
410
+ self.num_pe_neuron = num_pe_neuron
411
+ self.input_size = args.feature_size
412
+ self.pre_length = args.pre_length
413
+ self.args = args
414
+
415
+
416
+ self._snn_backend = "spikingjelly"
417
+
418
+ self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps)
419
+ self.encoder = CPGLinear(self.input_size, dim, CPG(num_neurons=num_pe_neuron))
420
+
421
+ self.init_lif = neuron.LIFNode(
422
+ tau=tau,
423
+ step_mode="m",
424
+ detach_reset=detach_reset,
425
+ surrogate_function=surrogate.ATan(),
426
+ v_threshold=common_thr,
427
+ backend=backend,
428
+ )
429
+
430
+ self.blocks = nn.ModuleList(
431
+ [
432
+ Block(
433
+ length=max_length,
434
+ tau=tau,
435
+ common_thr=common_thr,
436
+ dim=dim,
437
+ d_ff=self.d_ff,
438
+ heads=heads,
439
+ qkv_bias=qkv_bias,
440
+ qk_scale=qk_scale,
441
+ )
442
+ for _ in range(depths)
443
+ ]
444
+ )
445
+
446
+ self.apply(self._init_weights)
447
+
448
+ self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size)
449
+
450
+ def _init_weights(self, m):
451
+ if isinstance(m, nn.Linear):
452
+ nn.init.normal_(m.weight, std=0.02)
453
+ if isinstance(m, nn.Linear) and m.bias is not None:
454
+ nn.init.constant_(m.bias, 0.0)
455
+ elif isinstance(m, nn.LayerNorm):
456
+ nn.init.constant_(m.weight, 1.0)
457
+ nn.init.constant_(m.bias, 0.0)
458
+
459
+ def forward(self, x: torch.Tensor):
460
+ functional.reset_net(self)
461
+
462
+ if self.args.normalize:
463
+
464
+ mean = x.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
465
+ x = x - mean
466
+
467
+ std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
468
+ x = x / std
469
+
470
+
471
+ x = self.temporal_encoder(x) # B L C -> T B C L
472
+ T, B, _, L = x.shape
473
+ x = x.permute(1, 0, 3, 2) # B T L C
474
+ x = x.flatten(1, 2) # B TL C
475
+ x = self.encoder(x) # B TL D
476
+ x = x.reshape(B, T, L, -1).permute(1, 0, 2, 3) # T B L D
477
+ x = self.init_lif(x)
478
+
479
+ for blk in self.blocks:
480
+ x = blk(x) # T B L D
481
+ out = x.mean(0)
482
+ out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.input_size) # B D L -> B L D
483
+ if self.args.normalize:
484
+ out = out * std + mean # denormalization
485
+ aux = {'gate_l0': torch.tensor(0.0, device=out.device)} # placeholder
486
+ return out, aux # B D L -> B L D
487
+
model/TS_Former.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+
3
+ from pathlib import Path
4
+ import torch
5
+ from torch import nn
6
+ from spikingjelly.activation_based import surrogate, neuron, functional
7
+
8
+ import math
9
+ import copy
10
+ from spikingjelly.activation_based import surrogate, neuron
11
+ from abc import abstractmethod
12
+ import snntorch as snn
13
+ from snntorch import utils
14
+ import warnings
15
+
16
+ surrogate.ATan = lambda alpha=2.0: SG.apply
17
+
18
+
19
+ def generate_ones_and_minus_ones_matrix(rows, cols):
20
+ random_matrix = torch.randint(0, 2, (rows, cols))
21
+ binary_matrix = torch.where(
22
+ random_matrix == 0,
23
+ -1 * torch.ones_like(random_matrix),
24
+ torch.ones_like(random_matrix),
25
+ )
26
+ return binary_matrix.float()
27
+
28
+
29
+ class RandomPE(nn.Module):
30
+ def __init__(
31
+ self,
32
+ d_model,
33
+ pe_mode="concat",
34
+ num_pe_neuron=10,
35
+ neuron_pe_scale=1000.0,
36
+ dropout=0.1,
37
+ num_steps=4,
38
+ ):
39
+ super().__init__()
40
+ self.max_len = 5000 # different from windows
41
+ self.pe_mode = pe_mode
42
+ self.neuron_pe_scale = neuron_pe_scale
43
+ self.dropout = nn.Dropout(p=dropout)
44
+ if self.pe_mode == "concat":
45
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
46
+ elif self.pe_mode == "add":
47
+ self.num_pe_neuron = copy.deepcopy(d_model)
48
+ pe = generate_ones_and_minus_ones_matrix(
49
+ self.max_len, self.num_pe_neuron
50
+ ) # MaxL, Neur
51
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
52
+ print("pe.shape: ", pe.shape)
53
+ self.register_buffer("pe", pe)
54
+
55
+ def forward(self, x):
56
+ # T, B, L, D
57
+ T, B, L, _ = x.shape
58
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
59
+ x = x.flatten(1, 2) # B, TL, D
60
+ if self.pe_mode == "concat":
61
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
62
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
63
+ x = torch.concat([x, tmp], dim=-1)
64
+ # print(x.shape) # B, TL, D'
65
+ elif self.pe_mode == "add":
66
+ # [B, TL, D] + [1, TL, Neur]
67
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
68
+ # print(x.shape) # B, TL, D
69
+ x = x.transpose(0, 1) # TL, B D
70
+ x = x.reshape(T, L, B, -1) # T, L, B, D
71
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
72
+ return self.dropout(x)
73
+
74
+
75
+ class NeuronPE(nn.Module):
76
+ def __init__(
77
+ self,
78
+ d_model,
79
+ pe_mode="concat",
80
+ num_pe_neuron=10,
81
+ neuron_pe_scale=10000.0,
82
+ dropout=0.1,
83
+ num_steps=4,
84
+ ):
85
+ super().__init__()
86
+ self.max_len = 50000 # different from windows
87
+ self.pe_mode = pe_mode
88
+ self.neuron_pe_scale = neuron_pe_scale
89
+ self.dropout = nn.Dropout(p=dropout)
90
+ if self.pe_mode == "concat":
91
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
92
+ elif self.pe_mode == "add":
93
+ self.num_pe_neuron = copy.deepcopy(d_model)
94
+ pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
95
+ position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
96
+ 1
97
+ ) # MaxL, 1
98
+ div_term = torch.exp(
99
+ torch.arange(0, self.num_pe_neuron, 2).float()
100
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
101
+ )
102
+ div_term_single = torch.exp(
103
+ torch.arange(0, self.num_pe_neuron - 1, 2).float()
104
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
105
+ )
106
+ pe[:, 0::2] = torch.heaviside(
107
+ torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
108
+ )
109
+ pe[:, 1::2] = torch.heaviside(
110
+ torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
111
+ )
112
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
113
+ print("pe.shape: ", pe.shape)
114
+ self.register_buffer("pe", pe)
115
+
116
+ def forward(self, x):
117
+ # T, B, L, D
118
+ T, B, L, _ = x.shape
119
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
120
+ x = x.flatten(1, 2) # B, TL, D
121
+ if self.pe_mode == "concat":
122
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
123
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
124
+ x = torch.concat([x, tmp], dim=-1)
125
+ # print(x.shape) # B, TL, D'
126
+ elif self.pe_mode == "add":
127
+ # [B, TL, D] + [1, TL, Neur]
128
+ # print(self.pe[:x.size(-2), :].shape)
129
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
130
+ # print(x.shape) # B, TL, D
131
+ x = x.transpose(0, 1) # TL, B D
132
+ x = x.reshape(T, L, B, -1) # T, L, B, D
133
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
134
+ return self.dropout(x)
135
+
136
+
137
+ class StaticPE(nn.Module):
138
+ r"""Inject some information about the relative or absolute position of the tokens
139
+ in the sequence. The positional encodings have the same dimension as
140
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
141
+ functions of different frequencies.
142
+ .. math::
143
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
144
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
145
+ \text{where pos is the word position and i is the embed idx)"""
146
+
147
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
148
+ super().__init__()
149
+ self.dropout = nn.Dropout(p=dropout)
150
+ pe = torch.zeros(max_len, d_model) # MaxL, D
151
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
152
+ div_term = torch.exp(
153
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
154
+ )
155
+ div_term_single = torch.exp(
156
+ torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
157
+ )
158
+ pe[:, 0::2] = torch.sin(position * div_term)
159
+ pe[:, 1::2] = torch.cos(position * div_term_single)
160
+ pe = pe.unsqueeze(0).transpose(0, 1)
161
+ self.register_buffer("pe", pe)
162
+
163
+ def forward(self, x):
164
+ # x: L, TB, D
165
+ x = x + self.pe[: x.size(0), :]
166
+ x = self.dropout(x)
167
+ return x
168
+
169
+
170
+ class ConvPE(nn.Module):
171
+ def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
172
+
173
+ super().__init__()
174
+ self.T = num_steps
175
+ self.rpe_conv = nn.Conv1d(
176
+ d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
177
+ )
178
+ self.rpe_bn = nn.BatchNorm1d(d_model)
179
+ self.rpe_lif = neuron.LIFNode(
180
+ step_mode="m",
181
+ detach_reset=True,
182
+ surrogate_function=surrogate.ATan(),
183
+ v_threshold=1.0,
184
+ )
185
+ self.dropout = nn.Dropout(p=dropout)
186
+
187
+ def forward(self, x):
188
+ # x: L, TB, D
189
+ L, TB, D = x.shape
190
+ x_feat = x.permute(1, 2, 0) # TB, D, L
191
+ x_feat = self.rpe_conv(x_feat) # TB, D, L
192
+ x_feat = (
193
+ self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
194
+ ) # T, B, D, L
195
+ x_feat = self.rpe_lif(x_feat)
196
+ x_feat = x_feat.flatten(0, 1) # TB, D, L
197
+ x_feat = self.dropout(x_feat) # TB, D, L
198
+ x_feat = x_feat.permute(2, 0, 1) # L, TB, D
199
+ x = x + x_feat
200
+ return x
201
+
202
+
203
+ class PositionEmbedding(nn.Module):
204
+ def __init__(
205
+ self,
206
+ input_size: int,
207
+ pe_type: str,
208
+ max_len: int = 5000,
209
+ pe_mode: str = "add",
210
+ num_pe_neuron: int = 10,
211
+ neuron_pe_scale: float = 1000.0,
212
+ dropout=0.1,
213
+ num_steps=4,
214
+ ):
215
+ super().__init__()
216
+ self.emb_type = pe_type
217
+ if pe_type in ["learn", "none"]:
218
+ self.emb = nn.Embedding(max_len, input_size)
219
+ elif pe_type == "conv":
220
+ self.emb = ConvPE(
221
+ d_model=input_size,
222
+ max_len=max_len,
223
+ dropout=dropout,
224
+ num_steps=num_steps,
225
+ )
226
+ elif pe_type == "static":
227
+ self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
228
+ elif pe_type == "neuron":
229
+ self.emb = NeuronPE(
230
+ d_model=input_size,
231
+ pe_mode=pe_mode,
232
+ num_pe_neuron=num_pe_neuron,
233
+ neuron_pe_scale=neuron_pe_scale,
234
+ dropout=dropout,
235
+ num_steps=num_steps,
236
+ )
237
+ elif pe_type == "random":
238
+ self.emb = RandomPE(
239
+ d_model=input_size,
240
+ pe_mode=pe_mode,
241
+ num_pe_neuron=num_pe_neuron,
242
+ neuron_pe_scale=neuron_pe_scale,
243
+ dropout=dropout,
244
+ num_steps=num_steps,
245
+ )
246
+ else:
247
+ raise ValueError("Unknown embedding type: {}".format(pe_type))
248
+
249
+ def forward(self, x):
250
+ if self.emb_type == "learn":
251
+ # T, B, L, D = x.shape # x: T, B, L, D
252
+ # x = x.flatten(0, 1) # TB, L, D
253
+ tmp = torch.arange(
254
+ end=x.size()[1], device=x.device
255
+ ) # [0,1,2,...,L-1], shape: L
256
+ embedding = self.emb(tmp) # shape: L, D
257
+ embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
258
+ x = x + embedding
259
+ # x = x.reshape(T, B, L, -1)
260
+ elif self.emb_type in ["static", "conv"]:
261
+ T, B, L, _ = x.shape # x: T, B, L, D
262
+ x = x.flatten(0, 1) # TB, L, D
263
+ x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
264
+ x = x.reshape(T, B, L, -1)
265
+ elif self.emb_type in ["neuron", "random"]:
266
+ T, B, L, _ = x.shape # x: T, B, L, D
267
+ # T, B, L, D
268
+ x = self.emb(x)
269
+ x = x.reshape(T, B, L, -1)
270
+ return x # T, B, L, D'
271
+
272
+
273
+ tau = 2.0 # beta = 1 - 1/tau
274
+ backend = "torch"
275
+ detach_reset = True
276
+
277
+
278
+ class RepeatEncoder(nn.Module):
279
+ def __init__(self, output_size: int):
280
+ super().__init__()
281
+ self.out_size = output_size
282
+ self.lif = neuron.LIFNode(
283
+ tau=tau,
284
+ step_mode="m",
285
+ detach_reset=detach_reset,
286
+ surrogate_function=surrogate.ATan(),
287
+ )
288
+
289
+ def forward(self, inputs: torch.Tensor):
290
+ # inputs: B, L, C
291
+ inputs = inputs.repeat(
292
+ tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
293
+ ) # T B L C
294
+ inputs = inputs.permute(0, 1, 3, 2) # T B C L
295
+ spks = self.lif(inputs) # T B C L
296
+ return spks
297
+
298
+
299
+ class DeltaEncoder(nn.Module):
300
+ def __init__(self, output_size: int):
301
+ super().__init__()
302
+ self.norm = nn.BatchNorm2d(1)
303
+ self.enc = nn.Linear(1, output_size)
304
+ self.lif = neuron.LIFNode(
305
+ tau=tau,
306
+ step_mode="m",
307
+ detach_reset=detach_reset,
308
+ surrogate_function=surrogate.ATan(),
309
+ )
310
+
311
+ def forward(self, inputs: torch.Tensor):
312
+ # inputs: B, L, C
313
+ delta = torch.zeros_like(inputs)
314
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
315
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # B, 1, C, L
316
+ delta = self.norm(delta)
317
+ delta = delta.permute(0, 2, 3, 1) # B, C, L, 1
318
+ enc = self.enc(delta) # B, C, L, T
319
+ enc = enc.permute(3, 0, 1, 2) # T, B, C, L
320
+ spks = self.lif(enc)
321
+ return spks
322
+
323
+
324
+ class ConvEncoder(nn.Module):
325
+ def __init__(self, output_size: int, kernel_size: int = 3):
326
+ super().__init__()
327
+ self.encoder = nn.Sequential(
328
+ nn.Conv2d(
329
+ in_channels=1,
330
+ out_channels=output_size,
331
+ kernel_size=(1, kernel_size),
332
+ stride=1,
333
+ padding=(0, kernel_size // 2),
334
+ ),
335
+ nn.BatchNorm2d(output_size),
336
+ )
337
+ self.lif = neuron.LIFNode(
338
+ tau=tau,
339
+ step_mode="m",
340
+ detach_reset=detach_reset,
341
+ surrogate_function=surrogate.ATan(),
342
+ )
343
+
344
+ def forward(self, inputs: torch.Tensor):
345
+ # inputs: B, L, C
346
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # B, 1, C, L
347
+ enc = self.encoder(inputs) # B, T, C, L
348
+ enc = enc.permute(1, 0, 2, 3) # T, B, C, L
349
+ spks = self.lif(enc) # T, B, C, L
350
+ return spks
351
+
352
+
353
+
354
+
355
+ SpikeEncoder = {
356
+ "snntorch": {
357
+ "repeat": RepeatEncoder,
358
+ "conv": ConvEncoder,
359
+ "delta": DeltaEncoder,
360
+ },
361
+ "spikingjelly": {
362
+ "repeat": RepeatEncoder,
363
+ "conv": ConvEncoder,
364
+ "delta": DeltaEncoder,
365
+ },
366
+ }
367
+
368
+
369
+
370
+ class SSA(nn.Module):
371
+ def __init__(
372
+ self, length, tau, common_thr, dim, heads=8, qkv_bias=False, qk_scale=0.25
373
+ ):
374
+ super().__init__()
375
+ assert dim % heads == 0, f"dim {dim} should be divided by num_heads {heads}."
376
+
377
+ self.dim = dim
378
+ self.heads = heads
379
+ self.qk_scale = qk_scale
380
+
381
+ self.q_m = nn.Linear(dim, dim)
382
+ self.q_bn = nn.BatchNorm1d(dim)
383
+
384
+
385
+ self.q_tslif = TSLIFNode(
386
+ surrogate_function=SG.apply,
387
+ )
388
+
389
+ self.k_m = nn.Linear(dim, dim)
390
+ self.k_bn = nn.BatchNorm1d(dim)
391
+
392
+
393
+ self.k_tslif = TSLIFNode(
394
+ surrogate_function =SG.apply,
395
+ )
396
+
397
+ self.v_m = nn.Linear(dim, dim)
398
+ self.v_bn = nn.BatchNorm1d(dim)
399
+
400
+ self.v_tslif = TSLIFNode(
401
+ surrogate_function =SG.apply,
402
+ )
403
+
404
+
405
+ self.attn_tslif = TSLIFNode(
406
+ v_threshold=0.7,
407
+ surrogate_function=SG.apply
408
+ )
409
+
410
+ self.last_m = nn.Linear(dim, dim)
411
+ self.last_bn = nn.BatchNorm1d(dim)
412
+
413
+ self.last_tslif = TSLIFNode(
414
+ surrogate_function=SG.apply
415
+ )
416
+
417
+ def forward(self, x):
418
+ utils.reset(self.q_tslif)
419
+ utils.reset(self.k_tslif)
420
+ utils.reset(self.v_tslif)
421
+ utils.reset(self.attn_tslif)
422
+ utils.reset(self.last_tslif)
423
+ # x = x.transpose(0, 1)
424
+
425
+ # T, B, L, D = x.shape
426
+ B, T, L, D = x.shape
427
+ x_for_qkv = x.flatten(0, 1) # BT L D
428
+ q_m_out = self.q_m(x_for_qkv) # BT L D
429
+
430
+ q_m_out = (
431
+ self.q_bn(q_m_out.transpose(-1, -2))
432
+ .transpose(-1, -2)
433
+ .reshape(B, T, L, D)
434
+ .contiguous()
435
+ )
436
+ q_m_out = self.q_tslif(q_m_out)
437
+
438
+ q = (
439
+ q_m_out.reshape(B, T, L, self.heads, D // self.heads)
440
+ .permute(0, 1, 3, 2, 4)
441
+ .contiguous()
442
+ )
443
+
444
+ k_m_out = self.k_m(x_for_qkv)
445
+
446
+ k_m_out = (
447
+ self.k_bn(k_m_out.transpose(-1, -2))
448
+ .transpose(-1, -2)
449
+ .reshape(B, T, L, D)
450
+ .contiguous()
451
+ )
452
+
453
+ k_m_out = self.k_tslif(k_m_out)
454
+ k = (
455
+ k_m_out.reshape(B, T, L, self.heads, D // self.heads)
456
+ .permute(0, 1, 3, 2, 4)
457
+ .contiguous()
458
+ )
459
+
460
+ v_m_out = self.v_m(x_for_qkv)
461
+ v_m_out = (
462
+ self.v_bn(v_m_out.transpose(-1, -2))
463
+ .transpose(-1, -2)
464
+ .reshape(B, T, L, D)
465
+ .contiguous()
466
+ )
467
+
468
+ v_m_out = self.v_tslif(v_m_out)
469
+
470
+
471
+ v = (
472
+ v_m_out.reshape(B, T, L, self.heads, D // self.heads)
473
+ .permute(0, 1, 3, 2, 4)
474
+ .contiguous()
475
+ )
476
+
477
+ attn = (q @ k.transpose(-2, -1)) * self.qk_scale
478
+ x = attn @ v # x_shape: T * B * heads * L * D//heads
479
+
480
+ x = x.transpose(2, 3).reshape(B, T, L, D).contiguous()
481
+ x = self.attn_tslif(x)
482
+ x = x.flatten(0, 1)
483
+ x = self.last_m(x)
484
+ x = self.last_bn(x.transpose(-1, -2)).transpose(-1, -2)
485
+ x = self.last_tslif(x.reshape(B, T, L, D).contiguous())
486
+ return x
487
+
488
+
489
+ class MLP(nn.Module):
490
+ def __init__(
491
+ self,
492
+ length,
493
+ tau,
494
+ common_thr,
495
+ in_features,
496
+ hidden_features=None,
497
+ out_features=None,
498
+ ):
499
+ super().__init__()
500
+ out_features = out_features or in_features
501
+ self.in_features = in_features
502
+ self.hidden_features = hidden_features
503
+ self.out_features = out_features
504
+
505
+ self.fc1 = nn.Linear(in_features, hidden_features)
506
+ self.bn1 = nn.BatchNorm1d(hidden_features)
507
+
508
+ self.mlp_tclif1 = TCLIFNode2(
509
+ surrogate_function =SG.apply,
510
+ )
511
+
512
+ self.fc2 = nn.Linear(hidden_features, out_features)
513
+ self.bn2 = nn.BatchNorm1d(out_features)
514
+
515
+
516
+
517
+ self.mlp_tclif2 = TCLIFNode(
518
+ surrogate_function =SG.apply,
519
+ )
520
+
521
+ def forward(self, x):
522
+ utils.reset(self.mlp_tclif1)
523
+ utils.reset(self.mlp_tclif2)
524
+ # T, B, L, D = x.shape
525
+ B, T, L, D = x.shape
526
+ x = x.flatten(0, 1) # BT L D
527
+ x = self.fc1(x) # TB L H
528
+ x = (
529
+ self.bn1(x.transpose(-1, -2))
530
+ .transpose(-1, -2)
531
+ .reshape(B, T, L, self.hidden_features)
532
+ .contiguous()
533
+ )
534
+ x = self.mlp_tclif1(x)
535
+ x = x.flatten(0, 1) # TB L H
536
+ x = self.fc2(x) # TB L D
537
+ x = (
538
+ self.bn2(x.transpose(-1, -2))
539
+ .transpose(-1, -2)
540
+ .reshape(B, T, L, D)
541
+ .contiguous()
542
+ )
543
+ x = self.mlp_tclif2(x)
544
+ return x
545
+
546
+
547
+ class Block(nn.Module):
548
+ def __init__(
549
+ self,
550
+ length,
551
+ tau,
552
+ common_thr,
553
+ dim,
554
+ d_ff,
555
+ heads=8,
556
+ qkv_bias=False,
557
+ qk_scale=0.125,
558
+ ):
559
+ super().__init__()
560
+ self.attn = SSA(
561
+ length=length,
562
+ tau=tau,
563
+ common_thr=common_thr,
564
+ dim=dim,
565
+ heads=heads,
566
+ qkv_bias=qkv_bias,
567
+ qk_scale=qk_scale,
568
+ )
569
+ self.mlp = MLP(
570
+ length=length,
571
+ tau=tau,
572
+ common_thr=common_thr,
573
+ in_features=dim,
574
+ hidden_features=d_ff,
575
+ )
576
+
577
+ def forward(self, x):
578
+ x = x + self.attn(x)
579
+ x = x + self.mlp(x)
580
+ return x
581
+
582
+
583
+
584
+
585
+
586
+ @torch.jit.script
587
+ def heaviside(x: torch.Tensor):
588
+ return (x >= 0).to(x)
589
+
590
+ @torch.jit.script
591
+ def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
592
+
593
+ return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
594
+ #
595
+
596
+ class SG(torch.autograd.Function):
597
+ @staticmethod
598
+ def forward(ctx, x, alpha=2.0):
599
+ if x.requires_grad:
600
+ #ctx.save_for_backward(x.detach().clone()) # additional instead
601
+ ctx.save_for_backward(x)
602
+ ctx.alpha = alpha
603
+ return heaviside(x)
604
+
605
+ @staticmethod
606
+ def backward(ctx, grad_output):
607
+ return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
608
+
609
+
610
+ class MemoryModule(nn.Module):
611
+ def __init__(self):
612
+ """
613
+ * :ref:`API in English <MemoryModule.__init__-en>`
614
+
615
+ .. _MemoryModule.__init__-cn:
616
+
617
+ ``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
618
+
619
+ * :ref:`中文API <MemoryModule.__init__-cn>`
620
+
621
+ .. _MemoryModule.__init__-en:
622
+
623
+ ``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
624
+
625
+ """
626
+ super().__init__()
627
+ self._memories = {}
628
+ self._memories_rv = {}
629
+
630
+ def register_memory(self, name: str, value):
631
+ """
632
+ * :ref:`API in English <MemoryModule.register_memory-en>`
633
+
634
+ .. _MemoryModule.register_memory-cn:
635
+
636
+ :param name: 变量的名字
637
+ :type name: str
638
+ :param value: 变量的值
639
+ :type value: any
640
+
641
+ 将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
642
+ 函数后, ``self.name`` 都会被重置为 ``value``。
643
+
644
+ * :ref:`中文API <MemoryModule.register_memory-cn>`
645
+
646
+ .. _MemoryModule.register_memory-en:
647
+
648
+ :param name: variable's name
649
+ :type name: str
650
+ :param value: variable's value
651
+ :type value: any
652
+
653
+ Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
654
+ spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
655
+ each calling of ``self.reset()``.
656
+
657
+ """
658
+ assert not hasattr(self, name), f'{name} has been set as a member variable!'
659
+ self._memories[name] = value
660
+ self.set_reset_value(name, value)
661
+
662
+ def reset(self):
663
+ """
664
+ * :ref:`API in English <MemoryModule.reset-en>`
665
+
666
+ .. _MemoryModule.reset-cn:
667
+
668
+ 重置所有有状态变量为默认值。
669
+
670
+ * :ref:`中文API <MemoryModule.reset-cn>`
671
+
672
+ .. _MemoryModule.reset-en:
673
+
674
+ Reset all stateful variables to their default values.
675
+ """
676
+ for key in self._memories.keys():
677
+ self._memories[key] = copy.deepcopy(self._memories_rv[key])
678
+
679
+ def set_reset_value(self, name: str, value):
680
+ self._memories_rv[name] = copy.deepcopy(value)
681
+
682
+ def __getattr__(self, name: str):
683
+ if '_memories' in self.__dict__:
684
+ memories = self.__dict__['_memories']
685
+ if name in memories:
686
+ return memories[name]
687
+
688
+ return super().__getattr__(name)
689
+
690
+ def __setattr__(self, name: str, value) -> None:
691
+ _memories = self.__dict__.get('_memories')
692
+ if _memories is not None and name in _memories:
693
+ _memories[name] = value
694
+ else:
695
+ super().__setattr__(name, value)
696
+
697
+ def __delattr__(self, name):
698
+ if name in self._memories:
699
+ del self._memories[name]
700
+ del self._memories_rv[name]
701
+ else:
702
+ return super().__delattr__(name)
703
+
704
+ def __dir__(self):
705
+ module_attrs = dir(self.__class__)
706
+ attrs = list(self.__dict__.keys())
707
+ parameters = list(self._parameters.keys())
708
+ modules = list(self._modules.keys())
709
+ buffers = list(self._buffers.keys())
710
+ memories = list(self._memories.keys())
711
+ keys = module_attrs + attrs + parameters + modules + buffers + memories
712
+
713
+ # Eliminate attrs that are not legal Python variable names
714
+ keys = [key for key in keys if not key[0].isdigit()]
715
+
716
+ return sorted(keys)
717
+
718
+ def memories(self):
719
+ """
720
+ * :ref:`API in English <MemoryModule.memories-en>`
721
+
722
+ .. _MemoryModule.memories-cn:
723
+
724
+ :return: 返回一个所有状态变量的迭代器
725
+ :rtype: Iterator
726
+
727
+ * :ref:`中文API <MemoryModule.memories-cn>`
728
+
729
+ .. _MemoryModule.memories-en:
730
+
731
+ :return: an iterator over all stateful variables
732
+ :rtype: Iterator
733
+ """
734
+ for name, value in self._memories.items():
735
+ yield value
736
+
737
+ def named_memories(self):
738
+ """
739
+ * :ref:`API in English <MemoryModule.named_memories-en>`
740
+
741
+ .. _MemoryModule.named_memories-cn:
742
+
743
+ :return: 返回一个所有状态变量及其名称的迭代器
744
+ :rtype: Iterator
745
+
746
+ * :ref:`中文API <MemoryModule.named_memories-cn>`
747
+
748
+ .. _MemoryModule.named_memories-en:
749
+
750
+ :return: an iterator over all stateful variables and their names
751
+ :rtype: Iterator
752
+ """
753
+
754
+ for name, value in self._memories.items():
755
+ yield name, value
756
+
757
+ def detach(self):
758
+ """
759
+ * :ref:`API in English <MemoryModule.detach-en>`
760
+
761
+ .. _MemoryModule.detach-cn:
762
+
763
+ 从计算图中分离所有有状态变量。
764
+
765
+ .. tip::
766
+
767
+ 可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
768
+
769
+
770
+ * :ref:`中文API <MemoryModule.detach-cn>`
771
+
772
+ .. _MemoryModule.detach-en:
773
+
774
+ Detach all stateful variables.
775
+
776
+ .. admonition:: Tip
777
+ :class: tip
778
+
779
+ We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
780
+
781
+ """
782
+
783
+ for key in self._memories.keys():
784
+ if isinstance(self._memories[key], torch.Tensor):
785
+ self._memories[key].detach_()
786
+
787
+ def _apply(self, fn):
788
+ for key, value in self._memories.items():
789
+ if isinstance(value, torch.Tensor):
790
+ self._memories[key] = fn(value)
791
+ # do not apply on default values
792
+ # for key, value in self._memories_rv.items():
793
+ # if isinstance(value, torch.Tensor):
794
+ # self._memories_rv[key] = fn(value)
795
+ return super()._apply(fn)
796
+
797
+ def _replicate_for_data_parallel(self):
798
+ replica = super()._replicate_for_data_parallel()
799
+ replica._memories = self._memories.copy()
800
+ return replica
801
+
802
+
803
+ class StepModule:
804
+ def supported_step_mode(self):
805
+ """
806
+ * :ref:`API in English <StepModule.supported_step_mode-en>`
807
+ .. _StepModule.supported_step_mode-cn:
808
+ :return: 包含支持的后端的tuple
809
+ :rtype: tuple[str]
810
+ 返回此模块支持的步进模式。
811
+ * :ref:`中文 API <StepModule.supported_step_mode-cn>`
812
+ .. _StepModule.supported_step_mode-en:
813
+ :return: a tuple that contains the supported backends
814
+ :rtype: tuple[str]
815
+ """
816
+ return ('s', 'm')
817
+
818
+ @property
819
+ def step_mode(self):
820
+ """
821
+ * :ref:`API in English <StepModule.step_mode-en>`
822
+ .. _StepModule.step_mode-cn:
823
+ :return: 模块当前使用的步进模式
824
+ :rtype: str
825
+ * :ref:`中文 API <StepModule.step_mode-cn>`
826
+ .. _StepModule.step_mode-en:
827
+ :return: the current step mode of this module
828
+ :rtype: str
829
+ """
830
+ return self._step_mode
831
+
832
+ @step_mode.setter
833
+ def step_mode(self, value: str):
834
+ """
835
+ * :ref:`API in English <StepModule.step_mode-setter-en>`
836
+ .. _StepModule.step_mode-setter-cn:
837
+ :param value: 步进模式
838
+ :type value: str
839
+ 将本模块的步进模式设置为 ``value``
840
+ * :ref:`中文 API <StepModule.step_mode-setter-cn>`
841
+ .. _StepModule.step_mode-setter-en:
842
+ :param value: the step mode
843
+ :type value: str
844
+ Set the step mode of this module to be ``value``
845
+ """
846
+ if value not in self.supported_step_mode():
847
+ raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
848
+ self._step_mode = value
849
+
850
+
851
+
852
+ class BaseNode(MemoryModule):
853
+ def __init__(self,
854
+ v_threshold: float = 1.,
855
+ v_reset: float = 0.,
856
+ surrogate_function: Callable = None,
857
+ detach_reset: bool = False,
858
+ step_mode='s', backend='torch',
859
+ store_v_seq: bool = True):
860
+
861
+ assert isinstance(v_reset, float) or v_reset is None
862
+ assert isinstance(v_threshold, float)
863
+ assert isinstance(detach_reset, bool)
864
+ super().__init__()
865
+
866
+ if v_reset is None:
867
+ self.register_memory('v', 0.)
868
+ self.register_memory('v_s', 0.)
869
+ else:
870
+ self.register_memory('v', v_reset)
871
+
872
+ self.v_threshold = v_threshold
873
+
874
+ self.v_reset = v_reset
875
+ self.detach_reset = detach_reset
876
+ self.surrogate_function = surrogate_function
877
+
878
+ self.step_mode = step_mode
879
+ self.backend = backend
880
+
881
+ self.store_v_seq = store_v_seq
882
+
883
+
884
+ self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
885
+ self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
886
+ #self.alpha_s = torch.nn.Parameter(torch.randn([1, 128], dtype=torch.float))
887
+ #self.alpha_l = torch.nn.Parameter(torch.randn([1, 128], dtype=torch.float))
888
+
889
+ @property
890
+ def store_v_seq(self):
891
+ return self._store_v_seq
892
+
893
+ @store_v_seq.setter
894
+ def store_v_seq(self, value: bool):
895
+ self._store_v_seq = value
896
+ if value:
897
+ if not hasattr(self, 'v_seq'):
898
+ self.register_memory('v_seq', None)
899
+
900
+ @staticmethod
901
+ @torch.jit.script
902
+ def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
903
+ v = (1. - spike) * v + spike * v_reset
904
+
905
+ return v
906
+
907
+ @staticmethod
908
+ @torch.jit.script
909
+ def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
910
+ v = v - spike * v_threshold
911
+ return v
912
+
913
+
914
+ @abstractmethod
915
+ def neuronal_charge(self, x: torch.Tensor):
916
+ raise NotImplementedError
917
+
918
+ def neuronal_fire(self):
919
+ return self.surrogate_function(self.v - self.v_threshold, 2.0)
920
+
921
+ def sl_neuronal_fire(self):
922
+ s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
923
+ s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
924
+ return s_s, s_l
925
+
926
+ def extra_repr(self):
927
+ return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
928
+
929
+ def single_step_forward(self, x: torch.Tensor):
930
+ self.v_float_to_tensor(x)
931
+ self.neuronal_charge(x)
932
+ s_s, s_l = self.sl_neuronal_fire()
933
+ spike = self.alpha_s * s_s + self.alpha_l * s_l
934
+ self.neuronal_reset(s_s, s_l)
935
+
936
+ return spike
937
+
938
+ def multi_step_forward(self, x_seq: torch.Tensor):
939
+
940
+ #### time series ###
941
+ T = x_seq.shape[-1]
942
+ y_seq = []
943
+ if self.store_v_seq:
944
+ v_seq = []
945
+ for t in range(T):
946
+ y = self.single_step_forward(x_seq[:, t])
947
+ y_seq.append(y)
948
+ if self.store_v_seq:
949
+ v_seq.append(self.v)
950
+ if self.store_v_seq:
951
+ self.v_seq = torch.stack(v_seq)
952
+
953
+ # if self.store_v_seq:
954
+ # self.v_seq = torch.stack(v_seq)
955
+ outputs = torch.stack(y_seq, dim=0).permute(1, 0)
956
+
957
+ return outputs
958
+
959
+ def v_float_to_tensor(self, x: torch.Tensor):
960
+ if isinstance(self.v, float):
961
+ v_init = self.v
962
+ self.v = torch.full_like(x.data, v_init)
963
+
964
+
965
+ class TSLIFNode(BaseNode):
966
+ def __init__(self,
967
+ v_threshold=1.0,
968
+ v_reset=0.,
969
+ surrogate_function: Callable = None,
970
+ detach_reset=False,
971
+ hard_reset=False,
972
+ step_mode='s',
973
+ k=2,
974
+ decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
975
+ gamma: float = 0.5):
976
+ super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
977
+ self.k = k
978
+ for i in range(1, self.k + 1):
979
+ self.register_memory('v' + str(i), 0.)
980
+
981
+
982
+ self.names = self._memories
983
+ self.hard_reset = hard_reset
984
+ self.gamma = gamma
985
+ self.decay_factor = torch.nn.Parameter(decay_factor)
986
+ self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
987
+ self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
988
+
989
+ @property
990
+ def supported_backends(self):
991
+ if self.step_mode == 's':
992
+ return ('torch',)
993
+ elif self.step_mode == 'm':
994
+ return ('torch', 'cupy')
995
+ else:
996
+ raise ValueError(self.step_mode)
997
+
998
+ def neuronal_charge(self, x: torch.Tensor):
999
+ self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
1000
+ self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
1001
+ self.v = self.names['v2']
1002
+ self.v_s = self.names['v1']
1003
+
1004
+ def neuronal_reset(self, spike_s, spike_l):
1005
+ if not self.hard_reset:
1006
+ self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
1007
+ self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
1008
+ else:
1009
+ for i in range(2, self.k + 1):
1010
+ self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
1011
+
1012
+ def forward(self, x: torch.Tensor):
1013
+ return super().single_step_forward(x)
1014
+ def extra_repr(self):
1015
+ return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
1016
+ f"hard_reset={self.hard_reset}, " \
1017
+ f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
1018
+
1019
+
1020
+
1021
+
1022
+
1023
+ class BaseNode1(MemoryModule):
1024
+ def __init__(self,
1025
+ v_threshold: float = 1.,
1026
+ v_reset: float = 0.,
1027
+ surrogate_function: Callable = None,
1028
+ detach_reset: bool = False,
1029
+ step_mode='s', backend='torch',
1030
+ store_v_seq: bool = True):
1031
+
1032
+ assert isinstance(v_reset, float) or v_reset is None
1033
+ assert isinstance(v_threshold, float)
1034
+ assert isinstance(detach_reset, bool)
1035
+ super().__init__()
1036
+
1037
+ if v_reset is None:
1038
+ self.register_memory('v', 0.)
1039
+ self.register_memory('v_s', 0.)
1040
+ else:
1041
+ self.register_memory('v', v_reset)
1042
+
1043
+ self.v_threshold = v_threshold
1044
+
1045
+ self.v_reset = v_reset
1046
+ self.detach_reset = detach_reset
1047
+ self.surrogate_function = surrogate_function
1048
+
1049
+ self.step_mode = step_mode
1050
+ self.backend = backend
1051
+
1052
+ self.store_v_seq = store_v_seq
1053
+ self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
1054
+ self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
1055
+
1056
+ @property
1057
+ def store_v_seq(self):
1058
+ return self._store_v_seq
1059
+
1060
+ @store_v_seq.setter
1061
+ def store_v_seq(self, value: bool):
1062
+ self._store_v_seq = value
1063
+ if value:
1064
+ if not hasattr(self, 'v_seq'):
1065
+ self.register_memory('v_seq', None)
1066
+
1067
+ @staticmethod
1068
+ @torch.jit.script
1069
+ def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
1070
+ v = (1. - spike) * v + spike * v_reset
1071
+
1072
+ return v
1073
+
1074
+ @staticmethod
1075
+ @torch.jit.script
1076
+ def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
1077
+ v = v - spike * v_threshold
1078
+ return v
1079
+
1080
+
1081
+ @abstractmethod
1082
+ def neuronal_charge(self, x: torch.Tensor):
1083
+ raise NotImplementedError
1084
+
1085
+ def neuronal_fire(self):
1086
+ return self.surrogate_function(self.v - self.v_threshold, 2.0)
1087
+
1088
+ def sl_neuronal_fire(self):
1089
+ s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
1090
+ s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
1091
+ return s_s, s_l
1092
+
1093
+ def extra_repr(self):
1094
+ return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
1095
+
1096
+ def single_step_forward(self, x: torch.Tensor):
1097
+ self.v_float_to_tensor(x)
1098
+ self.neuronal_charge(x)
1099
+ s_s, s_l = self.sl_neuronal_fire()
1100
+ spike = self.alpha_s * s_s + self.alpha_l * s_l
1101
+ self.neuronal_reset(s_s, s_l)
1102
+ return spike
1103
+
1104
+ def multi_step_forward(self, x_seq: torch.Tensor):
1105
+
1106
+ #### time series ###
1107
+ T = x_seq.shape[-1]
1108
+ y_seq = []
1109
+ if self.store_v_seq:
1110
+ v_seq = []
1111
+ for t in range(2):
1112
+ y = self.single_step_forward(x_seq[:, t, :, :])
1113
+ y_seq.append(y)
1114
+ if self.store_v_seq:
1115
+ v_seq.append(self.v)
1116
+ if self.store_v_seq:
1117
+ self.v_seq = torch.stack(v_seq)
1118
+ outputs = torch.stack(y_seq, dim=0)
1119
+ outputs = outputs.permute(1, 0, 2, 3)
1120
+
1121
+ return outputs
1122
+
1123
+
1124
+ def v_float_to_tensor(self, x: torch.Tensor):
1125
+ if isinstance(self.v, float):
1126
+ v_init = self.v
1127
+ self.v = torch.full_like(x.data, v_init)
1128
+
1129
+
1130
+
1131
+ class TCLIFNode2(BaseNode1):
1132
+ def __init__(self,
1133
+ v_threshold=0.8,
1134
+ v_reset=0.,
1135
+ surrogate_function: Callable = None,
1136
+ detach_reset=False,
1137
+ hard_reset=False,
1138
+ step_mode='s',
1139
+ k=2,
1140
+ decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
1141
+ gamma: float = 0.5):
1142
+ super(TCLIFNode2, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
1143
+ self.k = k
1144
+ for i in range(1, self.k + 1):
1145
+ self.register_memory('v' + str(i), 0.)
1146
+
1147
+ self.names = self._memories
1148
+ self.hard_reset = hard_reset
1149
+ self.gamma = gamma
1150
+ self.decay_factor = torch.nn.Parameter(decay_factor)
1151
+ self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
1152
+ self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
1153
+
1154
+ @property
1155
+ def supported_backends(self):
1156
+ if self.step_mode == 's':
1157
+ return ('torch',)
1158
+ elif self.step_mode == 'm':
1159
+ return ('torch', 'cupy')
1160
+ else:
1161
+ raise ValueError(self.step_mode)
1162
+
1163
+ def neuronal_charge(self, x: torch.Tensor):
1164
+ self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
1165
+ self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
1166
+ self.v = self.names['v2']
1167
+ self.v_s = self.names['v1']
1168
+
1169
+ def neuronal_reset(self, spike_s, spike_l):
1170
+ if not self.hard_reset:
1171
+ self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l , self.gamma)
1172
+ self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
1173
+ else:
1174
+ # hard reset
1175
+ for i in range(2, self.k + 1):
1176
+ self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_d, self.v_reset)
1177
+
1178
+ def forward(self, x: torch.Tensor):
1179
+ return super().single_step_forward(x)
1180
+
1181
+ def extra_repr(self):
1182
+ return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
1183
+ f"hard_reset={self.hard_reset}, " \
1184
+ f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
1185
+
1186
+
1187
+
1188
+
1189
+
1190
+ class TCLIFNode(BaseNode):
1191
+ def __init__(self,
1192
+ v_threshold=1.0,
1193
+ v_reset=0.,
1194
+ surrogate_function: Callable = None,
1195
+ detach_reset=False,
1196
+ hard_reset=False,
1197
+ step_mode='s',
1198
+ k=2,
1199
+ decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
1200
+ gamma: float = 0.5):
1201
+ super(TCLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
1202
+ self.k = k
1203
+ for i in range(1, self.k + 1):
1204
+ self.register_memory('v' + str(i), 0.)
1205
+
1206
+ self.names = self._memories
1207
+ self.hard_reset = hard_reset
1208
+ self.gamma = gamma
1209
+ self.decay_factor = torch.nn.Parameter(decay_factor)
1210
+ self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
1211
+ self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
1212
+
1213
+ @property
1214
+ def supported_backends(self):
1215
+ if self.step_mode == 's':
1216
+ return ('torch',)
1217
+ elif self.step_mode == 'm':
1218
+ return ('torch', 'cupy')
1219
+ else:
1220
+ raise ValueError(self.step_mode)
1221
+
1222
+ def neuronal_charge(self, x: torch.Tensor):
1223
+ self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
1224
+ self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
1225
+ self.v = self.names['v2']
1226
+ self.v_s = self.names['v1']
1227
+
1228
+ def neuronal_reset(self, spike_s, spike_l):
1229
+ if not self.hard_reset:
1230
+ self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l , self.gamma)
1231
+ self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
1232
+ else:
1233
+ # hard reset
1234
+ for i in range(2, self.k + 1):
1235
+ self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_d, self.v_reset)
1236
+
1237
+ def forward(self, x: torch.Tensor):
1238
+ return super().single_step_forward(x)
1239
+ def extra_repr(self):
1240
+ return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
1241
+ f"hard_reset={self.hard_reset}, " \
1242
+ f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
1243
+
1244
+
1245
+
1246
+
1247
+
1248
+ class TSFormer(nn.Module):
1249
+
1250
+ def __init__(
1251
+ self,
1252
+ args,
1253
+ dim: int = 256,
1254
+ d_ff: Optional[int] = None,
1255
+ num_pe_neuron: int = 40,
1256
+ pe_type: str = "neuron",
1257
+ pe_mode: str = "concat", # "add" or concat
1258
+ neuron_pe_scale: float = 10000.0, # "100" or "1000" or "10000"
1259
+ depths: int = 2,
1260
+ common_thr: float = 1.0,
1261
+ max_length: int = 5000,
1262
+ num_steps: int = 4,
1263
+ heads: int = 8,
1264
+ qkv_bias: bool = False,
1265
+ qk_scale: float = 0.125,
1266
+ input_size: Optional[int] = None,
1267
+ weight_file: Optional[Path] = None,
1268
+ ):
1269
+ super().__init__()
1270
+ self.dim = 256
1271
+ self.d_ff = 1024
1272
+ self.T = args.T
1273
+ self.depths = args.blocks
1274
+ self.pe_type = pe_type
1275
+ self.pe_mode = pe_mode
1276
+ self.num_pe_neuron = num_pe_neuron
1277
+ self.input_size = args.feature_size
1278
+ self._snn_backend = "spikingjelly"
1279
+ self.temporal_encoder = SpikeEncoder[self._snn_backend]["conv"](num_steps)
1280
+ self.pre_length = args.pre_length
1281
+ self.feature_size = args.feature_size
1282
+ self.args = args
1283
+ self.pe = PositionEmbedding(
1284
+ pe_type=pe_type,
1285
+ pe_mode=pe_mode,
1286
+ neuron_pe_scale=neuron_pe_scale,
1287
+ input_size=self.input_size,
1288
+ max_len=max_length,
1289
+ num_pe_neuron=self.num_pe_neuron,
1290
+ dropout=0.1,
1291
+ num_steps=num_steps,
1292
+ )
1293
+ if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
1294
+ self.pe_type == "random" and self.pe_mode == "concat"
1295
+ ):
1296
+ self.encoder = nn.Linear(self.input_size + num_pe_neuron, dim)
1297
+ else:
1298
+ self.encoder = nn.Linear(self.input_size, dim)
1299
+
1300
+ self.init_lif = neuron.LIFNode(
1301
+ tau=tau,
1302
+ step_mode="m",
1303
+ detach_reset=detach_reset,
1304
+ surrogate_function=surrogate.ATan(),
1305
+ v_threshold=common_thr,
1306
+ backend=backend,
1307
+ )
1308
+
1309
+ self.blocks = nn.ModuleList(
1310
+ [
1311
+ Block(
1312
+ length=max_length,
1313
+ tau=tau,
1314
+ common_thr=common_thr,
1315
+ dim=dim,
1316
+ d_ff=self.d_ff,
1317
+ heads=heads,
1318
+ qkv_bias=qkv_bias,
1319
+ qk_scale=qk_scale,
1320
+ )
1321
+ for _ in range(depths)
1322
+ ]
1323
+ )
1324
+
1325
+ self.apply(self._init_weights)
1326
+
1327
+ self.fc = nn.Linear(args.seq_length*dim, args.pre_length*args.feature_size)
1328
+
1329
+ def _init_weights(self, m):
1330
+ if isinstance(m, nn.Linear):
1331
+ nn.init.normal_(m.weight, std=0.02)
1332
+ if isinstance(m, nn.Linear) and m.bias is not None:
1333
+ nn.init.constant_(m.bias, 0.0)
1334
+ elif isinstance(m, nn.LayerNorm):
1335
+ nn.init.constant_(m.weight, 1.0)
1336
+ nn.init.constant_(m.bias, 0.0)
1337
+
1338
+ def forward(self, x):
1339
+ functional.reset_net(self)
1340
+
1341
+ if self.args.normalize:
1342
+
1343
+ mean = x.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
1344
+ x = x - mean
1345
+
1346
+ std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
1347
+ x = x / std
1348
+
1349
+ x = self.temporal_encoder(x) # B L C -> T B C L
1350
+ x = x.transpose(-2, -1) # T B L C
1351
+ if self.pe_type != "none":
1352
+ x = self.pe(x) # T B L C'
1353
+ T, B, L, _ = x.shape
1354
+ x = self.encoder(x.flatten(0, 1)).reshape(T, B, L, -1) # T B L D
1355
+ x = self.init_lif(x)
1356
+
1357
+ for blk in self.blocks:
1358
+ x = blk(x) # T B L D
1359
+ out = x.mean(0) # B L D
1360
+ out = self.fc(out.flatten(-2, -1)).reshape(-1, self.pre_length, self.feature_size) # B D L -> B L D
1361
+ if self.args.normalize:
1362
+ out = out * std + mean # denormalization
1363
+ aux = {'gate_l0': torch.tensor(0.0, device=out.device)} # placeholder
1364
+ return out, aux # B D L -> B L D
1365
+
model/TS_GRU.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+ from pathlib import Path
3
+ from spikingjelly.activation_based import surrogate as sj_surrogate
4
+ from snntorch import utils
5
+ import snntorch as snn
6
+ from snntorch import surrogate
7
+ import torch
8
+ from torch import nn
9
+ import numpy as np
10
+ import copy
11
+ import torch.nn.functional as F
12
+ import math
13
+ from abc import abstractmethod
14
+
15
+
16
+ @torch.jit.script
17
+ def heaviside(x: torch.Tensor):
18
+ return (x >= 0).to(x)
19
+
20
+ @torch.jit.script
21
+ def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
22
+
23
+ return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
24
+
25
+
26
+ class SG(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(ctx, x, alpha=2.0):
29
+ if x.requires_grad:
30
+ ctx.save_for_backward(x)
31
+ ctx.alpha = alpha
32
+ return heaviside(x)
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_output):
36
+ return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
37
+
38
+
39
+ class MemoryModule(nn.Module):
40
+ def __init__(self):
41
+ """
42
+ * :ref:`API in English <MemoryModule.__init__-en>`
43
+
44
+ .. _MemoryModule.__init__-cn:
45
+
46
+ ``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
47
+
48
+ * :ref:`中文API <MemoryModule.__init__-cn>`
49
+
50
+ .. _MemoryModule.__init__-en:
51
+
52
+ ``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
53
+
54
+ """
55
+ super().__init__()
56
+ self._memories = {}
57
+ self._memories_rv = {}
58
+
59
+ def register_memory(self, name: str, value):
60
+ """
61
+ * :ref:`API in English <MemoryModule.register_memory-en>`
62
+
63
+ .. _MemoryModule.register_memory-cn:
64
+
65
+ :param name: 变量的名字
66
+ :type name: str
67
+ :param value: 变量的值
68
+ :type value: any
69
+
70
+ 将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
71
+ 函数后, ``self.name`` 都会被重置为 ``value``。
72
+
73
+ * :ref:`中文API <MemoryModule.register_memory-cn>`
74
+
75
+ .. _MemoryModule.register_memory-en:
76
+
77
+ :param name: variable's name
78
+ :type name: str
79
+ :param value: variable's value
80
+ :type value: any
81
+
82
+ Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
83
+ spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
84
+ each calling of ``self.reset()``.
85
+
86
+ """
87
+ assert not hasattr(self, name), f'{name} has been set as a member variable!'
88
+ self._memories[name] = value
89
+ self.set_reset_value(name, value)
90
+
91
+ def reset(self):
92
+ """
93
+ * :ref:`API in English <MemoryModule.reset-en>`
94
+
95
+ .. _MemoryModule.reset-cn:
96
+
97
+ 重置所有有状态变量为默认值。
98
+
99
+ * :ref:`中文API <MemoryModule.reset-cn>`
100
+
101
+ .. _MemoryModule.reset-en:
102
+
103
+ Reset all stateful variables to their default values.
104
+ """
105
+ for key in self._memories.keys():
106
+ self._memories[key] = copy.deepcopy(self._memories_rv[key])
107
+
108
+ def set_reset_value(self, name: str, value):
109
+ self._memories_rv[name] = copy.deepcopy(value)
110
+
111
+ def __getattr__(self, name: str):
112
+ if '_memories' in self.__dict__:
113
+ memories = self.__dict__['_memories']
114
+ if name in memories:
115
+ return memories[name]
116
+
117
+ return super().__getattr__(name)
118
+
119
+ def __setattr__(self, name: str, value) -> None:
120
+ _memories = self.__dict__.get('_memories')
121
+ if _memories is not None and name in _memories:
122
+ _memories[name] = value
123
+ else:
124
+ super().__setattr__(name, value)
125
+
126
+ def __delattr__(self, name):
127
+ if name in self._memories:
128
+ del self._memories[name]
129
+ del self._memories_rv[name]
130
+ else:
131
+ return super().__delattr__(name)
132
+
133
+ def __dir__(self):
134
+ module_attrs = dir(self.__class__)
135
+ attrs = list(self.__dict__.keys())
136
+ parameters = list(self._parameters.keys())
137
+ modules = list(self._modules.keys())
138
+ buffers = list(self._buffers.keys())
139
+ memories = list(self._memories.keys())
140
+ keys = module_attrs + attrs + parameters + modules + buffers + memories
141
+ keys = [key for key in keys if not key[0].isdigit()]
142
+
143
+ return sorted(keys)
144
+
145
+ def memories(self):
146
+ """
147
+ * :ref:`API in English <MemoryModule.memories-en>`
148
+
149
+ .. _MemoryModule.memories-cn:
150
+
151
+ :return: 返回一个所有状态变量的迭代器
152
+ :rtype: Iterator
153
+
154
+ * :ref:`中文API <MemoryModule.memories-cn>`
155
+
156
+ .. _MemoryModule.memories-en:
157
+
158
+ :return: an iterator over all stateful variables
159
+ :rtype: Iterator
160
+ """
161
+ for name, value in self._memories.items():
162
+ yield value
163
+
164
+ def named_memories(self):
165
+ """
166
+ * :ref:`API in English <MemoryModule.named_memories-en>`
167
+
168
+ .. _MemoryModule.named_memories-cn:
169
+
170
+ :return: 返回一个所有状态变量及其名称的迭代器
171
+ :rtype: Iterator
172
+
173
+ * :ref:`中文API <MemoryModule.named_memories-cn>`
174
+
175
+ .. _MemoryModule.named_memories-en:
176
+
177
+ :return: an iterator over all stateful variables and their names
178
+ :rtype: Iterator
179
+ """
180
+
181
+ for name, value in self._memories.items():
182
+ yield name, value
183
+
184
+ def detach(self):
185
+ """
186
+ * :ref:`API in English <MemoryModule.detach-en>`
187
+
188
+ .. _MemoryModule.detach-cn:
189
+
190
+ 从计算图中分离所有有状态变量。
191
+
192
+ .. tip::
193
+
194
+ 可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
195
+
196
+
197
+ * :ref:`中文API <MemoryModule.detach-cn>`
198
+
199
+ .. _MemoryModule.detach-en:
200
+
201
+ Detach all stateful variables.
202
+
203
+ .. admonition:: Tip
204
+ :class: tip
205
+
206
+ We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
207
+
208
+ """
209
+
210
+ for key in self._memories.keys():
211
+ if isinstance(self._memories[key], torch.Tensor):
212
+ self._memories[key].detach_()
213
+
214
+ def _apply(self, fn):
215
+ for key, value in self._memories.items():
216
+ if isinstance(value, torch.Tensor):
217
+ self._memories[key] = fn(value)
218
+ return super()._apply(fn)
219
+
220
+ def _replicate_for_data_parallel(self):
221
+ replica = super()._replicate_for_data_parallel()
222
+ replica._memories = self._memories.copy()
223
+ return replica
224
+
225
+
226
+ class StepModule:
227
+ def supported_step_mode(self):
228
+ """
229
+ * :ref:`API in English <StepModule.supported_step_mode-en>`
230
+ .. _StepModule.supported_step_mode-cn:
231
+ :return: 包含支持的后端的tuple
232
+ :rtype: tuple[str]
233
+ 返回此模块支持的步进模式。
234
+ * :ref:`中文 API <StepModule.supported_step_mode-cn>`
235
+ .. _StepModule.supported_step_mode-en:
236
+ :return: a tuple that contains the supported backends
237
+ :rtype: tuple[str]
238
+ """
239
+ return ('s', 'm')
240
+
241
+ @property
242
+ def step_mode(self):
243
+ """
244
+ * :ref:`API in English <StepModule.step_mode-en>`
245
+ .. _StepModule.step_mode-cn:
246
+ :return: 模块当前使用的步进模式
247
+ :rtype: str
248
+ * :ref:`中文 API <StepModule.step_mode-cn>`
249
+ .. _StepModule.step_mode-en:
250
+ :return: the current step mode of this module
251
+ :rtype: str
252
+ """
253
+ return self._step_mode
254
+
255
+ @step_mode.setter
256
+ def step_mode(self, value: str):
257
+ """
258
+ * :ref:`API in English <StepModule.step_mode-setter-en>`
259
+ .. _StepModule.step_mode-setter-cn:
260
+ :param value: 步进模式
261
+ :type value: str
262
+ 将本模块的步进模式设置为 ``value``
263
+ * :ref:`中文 API <StepModule.step_mode-setter-cn>`
264
+ .. _StepModule.step_mode-setter-en:
265
+ :param value: the step mode
266
+ :type value: str
267
+ Set the step mode of this module to be ``value``
268
+ """
269
+ if value not in self.supported_step_mode():
270
+ raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
271
+ self._step_mode = value
272
+
273
+
274
+
275
+ class BaseNode(MemoryModule):
276
+ def __init__(self,
277
+ v_threshold: float = 1.,
278
+ v_reset: float = 0.,
279
+ surrogate_function: Callable = None,
280
+ detach_reset: bool = False,
281
+ step_mode='s', backend='torch',
282
+ store_v_seq: bool = True):
283
+
284
+ assert isinstance(v_reset, float) or v_reset is None
285
+ assert isinstance(v_threshold, float)
286
+ assert isinstance(detach_reset, bool)
287
+ super().__init__()
288
+
289
+ if v_reset is None:
290
+ self.register_memory('v', 0.)
291
+ self.register_memory('v_s', 0.)
292
+ else:
293
+ self.register_memory('v', v_reset)
294
+
295
+ self.v_threshold = v_threshold
296
+
297
+ self.v_reset = v_reset
298
+ self.detach_reset = detach_reset
299
+ self.surrogate_function = surrogate_function
300
+
301
+ self.step_mode = step_mode
302
+ self.backend = backend
303
+
304
+ self.store_v_seq = store_v_seq
305
+ self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
306
+ self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
307
+
308
+ @property
309
+ def store_v_seq(self):
310
+ return self._store_v_seq
311
+
312
+ @store_v_seq.setter
313
+ def store_v_seq(self, value: bool):
314
+ self._store_v_seq = value
315
+ if value:
316
+ if not hasattr(self, 'v_seq'):
317
+ self.register_memory('v_seq', None)
318
+
319
+ @staticmethod
320
+ @torch.jit.script
321
+ def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
322
+ v = (1. - spike) * v + spike * v_reset
323
+
324
+ return v
325
+
326
+ @staticmethod
327
+ @torch.jit.script
328
+ def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
329
+ v = v - spike * v_threshold
330
+ return v
331
+
332
+
333
+ @abstractmethod
334
+ def neuronal_charge(self, x: torch.Tensor):
335
+ raise NotImplementedError
336
+
337
+ def neuronal_fire(self):
338
+ return self.surrogate_function(self.v - self.v_threshold, 2.0)
339
+
340
+ def sl_neuronal_fire(self):
341
+ s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
342
+ s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
343
+ return s_s, s_l
344
+
345
+ def extra_repr(self):
346
+ return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
347
+
348
+ def single_step_forward(self, x: torch.Tensor):
349
+ self.v_float_to_tensor(x)
350
+ self.neuronal_charge(x)
351
+ s_s, s_l = self.sl_neuronal_fire()
352
+ spike = self.alpha_s * s_s + self.alpha_l * s_l
353
+ self.neuronal_reset(s_s, s_l)
354
+
355
+ return spike
356
+
357
+ def multi_step_forward(self, x_seq: torch.Tensor):
358
+
359
+ T = x_seq.shape[-1]
360
+ y_seq = []
361
+ if self.store_v_seq:
362
+ v_seq = []
363
+ for t in range(T):
364
+ y = self.single_step_forward(x_seq[:, t])
365
+ y_seq.append(y)
366
+ if self.store_v_seq:
367
+ v_seq.append(self.v)
368
+ if self.store_v_seq:
369
+ self.v_seq = torch.stack(v_seq)
370
+ outputs = torch.stack(y_seq, dim=0).permute(1, 0)
371
+
372
+ return outputs
373
+
374
+ def v_float_to_tensor(self, x: torch.Tensor):
375
+ if isinstance(self.v, float):
376
+ v_init = self.v
377
+ self.v = torch.full_like(x.data, v_init)
378
+
379
+
380
+ class TSLIFNode(BaseNode):
381
+ def __init__(self,
382
+ v_threshold=1.0,
383
+ v_reset=0.,
384
+ surrogate_function: Callable = None,
385
+ detach_reset=False,
386
+ hard_reset=False,
387
+ step_mode='s',
388
+ k=2,
389
+ decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
390
+ gamma: float = 0.5):
391
+ super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
392
+ self.k = k
393
+ for i in range(1, self.k + 1):
394
+ self.register_memory('v' + str(i), 0.)
395
+ self.names = self._memories
396
+ self.hard_reset = hard_reset
397
+ self.gamma = gamma
398
+ self.decay_factor = torch.nn.Parameter(decay_factor)
399
+ self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
400
+ self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
401
+
402
+ @property
403
+ def supported_backends(self):
404
+ if self.step_mode == 's':
405
+ return ('torch',)
406
+ elif self.step_mode == 'm':
407
+ return ('torch', 'cupy')
408
+ else:
409
+ raise ValueError(self.step_mode)
410
+
411
+ def neuronal_charge(self, x: torch.Tensor):
412
+ self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
413
+ self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
414
+ self.v = self.names['v2']
415
+ self.v_s = self.names['v1']
416
+
417
+ def neuronal_reset(self, spike_s, spike_l):
418
+ if not self.hard_reset:
419
+ self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
420
+ self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
421
+ else:
422
+ for i in range(2, self.k + 1):
423
+ self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
424
+
425
+ def forward(self, x: torch.Tensor):
426
+ return super().single_step_forward(x)
427
+ def extra_repr(self):
428
+ return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
429
+ f"hard_reset={self.hard_reset}, " \
430
+ f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
431
+
432
+
433
+
434
+
435
+ class GRUCell(nn.Module):
436
+ def __init__(
437
+ self,
438
+ input_size: int,
439
+ hidden_size: int,
440
+ num_steps: int = 4,
441
+ grad_slope: float = 25.0,
442
+ beta: float = 0.99,
443
+ output_mems: bool = False,
444
+ ):
445
+ super().__init__()
446
+ self.spike_grad = surrogate.atan(alpha=2.0)
447
+ self.input_size = input_size
448
+ self.num_steps = num_steps
449
+ self.hidden_size = hidden_size
450
+ self.beta = beta
451
+ self.full_rec = output_mems
452
+
453
+
454
+ self.linear_ih = nn.Linear(input_size, 3 * hidden_size)
455
+ self.linear_hh = nn.Linear(hidden_size, 3 * hidden_size)
456
+ self.surrogate_function1 = sj_surrogate.ATan()
457
+
458
+ self.tslif = TSLIFNode(
459
+ surrogate_function=SG.apply
460
+ )
461
+
462
+ def forward(self, inputs):
463
+ if inputs.size(-1) == self.input_size:
464
+ h = torch.zeros(
465
+ size=[inputs.shape[0], self.hidden_size],
466
+ dtype=torch.float,
467
+ device=inputs.device,
468
+ )
469
+ y_ih = torch.split(self.linear_ih(inputs), self.hidden_size, dim=1)
470
+ y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1)
471
+ r = self.surrogate_function1(y_ih[0] + y_hh[0])
472
+ z = self.surrogate_function1(y_ih[1] + y_hh[1])
473
+ n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
474
+ h = (1.0 - z) * n + z * h
475
+ cur = h
476
+ elif inputs.size(-1) == self.num_steps and inputs.size(-2) == self.input_size:
477
+ inputs = inputs.transpose(-1, -2) # BC, T, H
478
+ h = torch.zeros(
479
+ size=[inputs.shape[0], self.hidden_size, self.num_steps],
480
+ dtype=torch.float,
481
+ device=inputs.device,
482
+ )
483
+ y_ih = torch.split(
484
+ self.linear_ih(inputs).transpose(-1, -2), self.hidden_size, dim=1
485
+ )
486
+ y_hh = torch.split(
487
+ self.linear_hh(h.transpose(-1, -2)).transpose(-1, -2),
488
+ self.hidden_size,
489
+ dim=1,
490
+ )
491
+ r = self.surrogate_function1(y_ih[0] + y_hh[0])
492
+ z = self.surrogate_function1(y_ih[1] + y_hh[1])
493
+ n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
494
+ h = (1.0 - z) * n + z * h
495
+ cur = h
496
+ static = False
497
+ else:
498
+ raise ValueError(
499
+ f"Input size mismatch! Got {inputs.size()} but expected "
500
+ f"(..., {self.input_size}, {self.num_steps}) or (..., {self.input_size})"
501
+ )
502
+
503
+ spks = self.tslif(cur)
504
+ return spks
505
+
506
+
507
+ class DeltaEncoder(nn.Module):
508
+ def __init__(self, output_size: int):
509
+ super().__init__()
510
+ self.norm = nn.BatchNorm2d(1)
511
+ self.enc = nn.Linear(1, output_size)
512
+ self.lif = snn.Leaky(
513
+ beta=0.99, spike_grad=SG.apply, init_hidden=True, output=False
514
+ )
515
+
516
+ def forward(self, inputs: torch.Tensor):
517
+ # inputs: batch, L, C
518
+ delta = torch.zeros_like(inputs)
519
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
520
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
521
+ delta = self.norm(delta)
522
+ delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
523
+ enc = self.enc(delta) # batch, C, L, output_size
524
+ enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
525
+ spks = self.lif(enc)
526
+ return spks
527
+
528
+
529
+ class ConvEncoder(nn.Module):
530
+ def __init__(self, output_size: int, kernel_size: int = 3):
531
+ super().__init__()
532
+ self.encoder = nn.Sequential(
533
+ nn.Conv2d(
534
+ in_channels=1,
535
+ out_channels=output_size,
536
+ kernel_size=(1, kernel_size),
537
+ stride=1,
538
+ padding=(0, kernel_size // 2),
539
+ ),
540
+ nn.BatchNorm2d(output_size),
541
+ )
542
+ self.lif = snn.Leaky(
543
+ beta=0.99,
544
+ spike_grad=surrogate.atan(alpha=2.0),
545
+ init_hidden=True,
546
+ output=False,
547
+ )
548
+
549
+ def forward(self, inputs: torch.Tensor):
550
+ # inputs: batch, L, C
551
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
552
+ enc = self.encoder(inputs) # batch, output_size, C, L
553
+ spks = self.lif(enc)
554
+ return spks
555
+
556
+
557
+ class TSGRU(nn.Module):
558
+ def __init__(
559
+ self,
560
+ args,
561
+ hidden_size: int,
562
+ layers: int = 1,
563
+ num_steps: int = 50,
564
+ grad_slope: float = 25.0,
565
+ input_size: Optional[int] = None,
566
+ max_length: Optional[int] = None,
567
+ weight_file: Optional[Path] = None,
568
+ encoder_type: Optional[str] = "conv",
569
+ ):
570
+ super().__init__()
571
+
572
+ self.hidden_size = args.hidden_size
573
+ self.num_steps = args.T
574
+ self.input_size = args.feature_size
575
+ self.pre_length = args.pre_length
576
+ self.layers = args.blocks
577
+ self.args = args
578
+
579
+ if encoder_type == "conv":
580
+ self.encoder = ConvEncoder(self.hidden_size)
581
+ elif encoder_type == "delta":
582
+ self.encoder = DeltaEncoder(self.hidden_size)
583
+ else:
584
+ raise ValueError(f"Unknown encoder type {encoder_type}")
585
+
586
+ self.net = nn.Sequential(
587
+ *[
588
+ GRUCell(
589
+ self.hidden_size,
590
+ self.hidden_size,
591
+ num_steps=self.num_steps,
592
+ grad_slope=grad_slope,
593
+ output_mems=(i == self.layers - 1),
594
+ )
595
+ for i in range(self.layers)
596
+ ]
597
+ )
598
+
599
+ self.__output_size = self.hidden_size
600
+ self.fc = nn.Linear(self.__output_size, self.pre_length)
601
+
602
+ self.to('cuda:0')
603
+
604
+ def forward(self, inputs: torch.Tensor):
605
+
606
+ utils.reset(self.encoder)
607
+ for layer in self.net:
608
+ utils.reset(layer)
609
+
610
+
611
+ bs, length, c_num = inputs.size()
612
+
613
+ if self.args.normalize:
614
+
615
+ mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
616
+ inputs = inputs - mean
617
+
618
+ std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
619
+ inputs = inputs / std
620
+
621
+ h = self.encoder(inputs)
622
+ hidden_size = h.size(1)
623
+ h = h.permute(0, 2, 3, 1).reshape(bs * c_num, length, hidden_size) # (BC, L, H)
624
+
625
+ for i in range(length):
626
+ spks = self.net(h[:, i, :])
627
+
628
+ spks = spks.reshape(bs, c_num * hidden_size, -1) # B, CH, Time Step
629
+
630
+ spks = spks[:, :, -1] # aggregate over time dimension shape, (B, CH)
631
+ preds = self.fc(spks.view(bs, c_num, -1)).squeeze(-1) # B, O, C
632
+ preds = preds.permute(0, 2, 1).contiguous()
633
+ if self.args.normalize:
634
+ preds = preds * std + mean # denormalize
635
+ aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
636
+ return preds, aux
637
+
638
+ @property
639
+ def output_size(self):
640
+ return self.__output_size
model/TS_TCN.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+ import snntorch as snn
7
+ from snntorch import surrogate
8
+ from snntorch import utils
9
+ import numpy as np
10
+ import math
11
+ import copy
12
+ from spikingjelly.activation_based import surrogate, neuron
13
+ from abc import abstractmethod
14
+ import warnings
15
+
16
+
17
+
18
+ surrogate.atan = lambda alpha=2.0: SG.apply
19
+
20
+
21
+ class Chomp1d(nn.Module):
22
+ def __init__(self, chomp_size):
23
+ super().__init__()
24
+ self.chomp_size = chomp_size
25
+
26
+ def forward(self, x):
27
+ return x[:, :, : -self.chomp_size].contiguous()
28
+
29
+
30
+ class Chomp2d(nn.Module):
31
+ def __init__(self, chomp_size):
32
+ super().__init__()
33
+ self.chomp_size = chomp_size
34
+
35
+ def forward(self, x):
36
+ return x[:, :, :, : -self.chomp_size].contiguous()
37
+
38
+
39
+
40
+ class RepeatEncoder(nn.Module):
41
+ def __init__(self, output_size: int):
42
+ super().__init__()
43
+ self.out_size = output_size
44
+ self.lif = snn.Leaky(
45
+ beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
46
+ )
47
+
48
+ def forward(self, inputs: torch.Tensor):
49
+ # inputs: batch, L, C
50
+ inputs = inputs.repeat(
51
+ tuple([self.out_size] + torch.ones(len(inputs.size()), dtype=int).tolist())
52
+ ) # out_size batch L C
53
+ inputs = inputs.permute(1, 0, 3, 2) # batch out_size L C
54
+ spks = self.lif(inputs)
55
+ return spks
56
+
57
+
58
+ class ConvEncoder(nn.Module):
59
+ def __init__(self, output_size: int, kernel_size: int = 3):
60
+ super().__init__()
61
+ self.encoder = nn.Sequential(
62
+ nn.Conv2d(
63
+ in_channels=1,
64
+ out_channels=output_size,
65
+ kernel_size=(1, kernel_size),
66
+ stride=1,
67
+ padding=(0, kernel_size // 2),
68
+ ),
69
+ nn.BatchNorm2d(output_size),
70
+ )
71
+ self.lif = snn.Leaky(
72
+ beta=0.99,
73
+ spike_grad=surrogate.atan(alpha=2.0),
74
+ init_hidden=True,
75
+ output=False,
76
+ )
77
+
78
+ def forward(self, inputs: torch.Tensor):
79
+ # inputs: batch, L, C
80
+ inputs = inputs.permute(0, 2, 1).unsqueeze(1) # batch, 1, C, L
81
+ enc = self.encoder(inputs) # batch, output_size, C, L
82
+ spks = self.lif(enc)
83
+ return spks
84
+
85
+
86
+
87
+
88
+
89
+ class DeltaEncoder(nn.Module):
90
+ def __init__(self, output_size: int):
91
+ super().__init__()
92
+ self.norm = nn.BatchNorm2d(1)
93
+ self.enc = nn.Linear(1, output_size)
94
+ self.lif = snn.Leaky(
95
+ beta=0.99, spike_grad=surrogate.atan(), init_hidden=True, output=False
96
+ )
97
+
98
+ def forward(self, inputs: torch.Tensor):
99
+ # inputs: batch, L, C
100
+ delta = torch.zeros_like(inputs)
101
+ delta[:, 1:] = inputs[:, 1:, :] - inputs[:, :-1, :]
102
+ delta = delta.unsqueeze(1).permute(0, 1, 3, 2) # batch, 1, C, L
103
+ delta = self.norm(delta)
104
+ delta = delta.permute(0, 2, 3, 1) # batch, C, L, 1
105
+ enc = self.enc(delta) # batch, C, L, output_size
106
+ enc = enc.permute(0, 3, 1, 2) # batch, output_size, C, L
107
+ spks = self.lif(enc)
108
+ return spks
109
+
110
+
111
+ SpikeEncoder = {
112
+ "snntorch": {
113
+ "repeat": RepeatEncoder,
114
+ "conv": ConvEncoder,
115
+ "delta": DeltaEncoder,
116
+ },
117
+ "spikingjelly": {
118
+ "repeat": RepeatEncoder,
119
+ "conv": ConvEncoder,
120
+ "delta": DeltaEncoder,
121
+ },
122
+ }
123
+
124
+
125
+ def generate_ones_and_minus_ones_matrix(rows, cols):
126
+ random_matrix = torch.randint(0, 2, (rows, cols))
127
+ binary_matrix = torch.where(
128
+ random_matrix == 0,
129
+ -1 * torch.ones_like(random_matrix),
130
+ torch.ones_like(random_matrix),
131
+ )
132
+ return binary_matrix.float()
133
+
134
+
135
+ class RandomPE(nn.Module):
136
+ def __init__(
137
+ self,
138
+ d_model,
139
+ pe_mode="concat",
140
+ num_pe_neuron=10,
141
+ neuron_pe_scale=1000.0,
142
+ dropout=0.1,
143
+ num_steps=4,
144
+ ):
145
+ super().__init__()
146
+ self.max_len = 5000 # different from windows
147
+ self.pe_mode = pe_mode
148
+ self.neuron_pe_scale = neuron_pe_scale
149
+ self.dropout = nn.Dropout(p=dropout)
150
+ if self.pe_mode == "concat":
151
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
152
+ elif self.pe_mode == "add":
153
+ self.num_pe_neuron = copy.deepcopy(d_model)
154
+ pe = generate_ones_and_minus_ones_matrix(
155
+ self.max_len, self.num_pe_neuron
156
+ ) # MaxL, Neur
157
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
158
+ print("pe.shape: ", pe.shape)
159
+ self.register_buffer("pe", pe)
160
+
161
+ def forward(self, x):
162
+ # T, B, L, D
163
+ T, B, L, _ = x.shape
164
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
165
+ x = x.flatten(1, 2) # B, TL, D
166
+ if self.pe_mode == "concat":
167
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
168
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
169
+ x = torch.concat([x, tmp], dim=-1)
170
+ # print(x.shape) # B, TL, D'
171
+ elif self.pe_mode == "add":
172
+ # [B, TL, D] + [1, TL, Neur]
173
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
174
+ # print(x.shape) # B, TL, D
175
+ x = x.transpose(0, 1) # TL, B D
176
+ x = x.reshape(T, L, B, -1) # T, L, B, D
177
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
178
+ return self.dropout(x)
179
+
180
+
181
+ class NeuronPE(nn.Module):
182
+ def __init__(
183
+ self,
184
+ d_model,
185
+ pe_mode="concat",
186
+ num_pe_neuron=10,
187
+ neuron_pe_scale=10000.0,
188
+ dropout=0.1,
189
+ num_steps=4,
190
+ ):
191
+ super().__init__()
192
+ self.max_len = 50000 # different from windows
193
+ self.pe_mode = pe_mode
194
+ self.neuron_pe_scale = neuron_pe_scale
195
+ self.dropout = nn.Dropout(p=dropout)
196
+ if self.pe_mode == "concat":
197
+ self.num_pe_neuron = copy.deepcopy(num_pe_neuron)
198
+ elif self.pe_mode == "add":
199
+ self.num_pe_neuron = copy.deepcopy(d_model)
200
+ pe = torch.zeros(self.max_len, self.num_pe_neuron) # MaxL, Neur
201
+ position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(
202
+ 1
203
+ ) # MaxL, 1
204
+ div_term = torch.exp(
205
+ torch.arange(0, self.num_pe_neuron, 2).float()
206
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
207
+ )
208
+ div_term_single = torch.exp(
209
+ torch.arange(0, self.num_pe_neuron - 1, 2).float()
210
+ * (-math.log(neuron_pe_scale) / self.num_pe_neuron)
211
+ )
212
+ pe[:, 0::2] = torch.heaviside(
213
+ torch.sin(position * div_term) - 0.8, torch.tensor([1.0])
214
+ )
215
+ pe[:, 1::2] = torch.heaviside(
216
+ torch.cos(position * div_term_single) - 0.8, torch.tensor([1.0])
217
+ )
218
+ pe = pe.unsqueeze(0).transpose(0, 1) # MaxL, 1, Neur
219
+ print("pe.shape: ", pe.shape)
220
+ self.register_buffer("pe", pe)
221
+
222
+ def forward(self, x):
223
+ # T, B, L, D
224
+ T, B, L, _ = x.shape
225
+ x = x.permute(1, 0, 2, 3) # B, T, L, D
226
+ x = x.flatten(1, 2) # B, TL, D
227
+ if self.pe_mode == "concat":
228
+ # tmp: TL, 1, Neur -> TL, B, Neur -> B, TL, Neur
229
+ tmp = self.pe[: x.size(-2), :].repeat(1, B, 1).transpose(0, 1)
230
+ x = torch.concat([x, tmp], dim=-1)
231
+ # print(x.shape) # B, TL, D'
232
+ elif self.pe_mode == "add":
233
+ # [B, TL, D] + [1, TL, Neur]
234
+ # print(self.pe[:x.size(-2), :].shape)
235
+ x = x + self.pe[: x.size(-2), :].transpose(0, 1)
236
+ # print(x.shape) # B, TL, D
237
+ x = x.transpose(0, 1) # TL, B D
238
+ x = x.reshape(T, L, B, -1) # T, L, B, D
239
+ x = x.permute(0, 2, 1, 3) # T, B, L, D
240
+ return self.dropout(x)
241
+
242
+
243
+ class StaticPE(nn.Module):
244
+ r"""Inject some information about the relative or absolute position of the tokens
245
+ in the sequence. The positional encodings have the same dimension as
246
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
247
+ functions of different frequencies.
248
+ .. math::
249
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
250
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
251
+ \text{where pos is the word position and i is the embed idx)"""
252
+
253
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
254
+ super().__init__()
255
+ self.dropout = nn.Dropout(p=dropout)
256
+ pe = torch.zeros(max_len, d_model) # MaxL, D
257
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # MaxL, 1
258
+ div_term = torch.exp(
259
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
260
+ )
261
+ div_term_single = torch.exp(
262
+ torch.arange(0, d_model - 1, 2).float() * (-math.log(10000.0) / d_model)
263
+ )
264
+ pe[:, 0::2] = torch.sin(position * div_term)
265
+ pe[:, 1::2] = torch.cos(position * div_term_single)
266
+ pe = pe.unsqueeze(0).transpose(0, 1)
267
+ self.register_buffer("pe", pe)
268
+
269
+ def forward(self, x):
270
+ # x: L, TB, D
271
+ x = x + self.pe[: x.size(0), :]
272
+ x = self.dropout(x)
273
+ return x
274
+
275
+
276
+ class ConvPE(nn.Module):
277
+ def __init__(self, d_model, dropout=0.1, max_len=5000, num_steps=4):
278
+
279
+ super().__init__()
280
+ self.T = num_steps
281
+ self.rpe_conv = nn.Conv1d(
282
+ d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False
283
+ )
284
+ self.rpe_bn = nn.BatchNorm1d(d_model)
285
+ self.rpe_lif = neuron.LIFNode(
286
+ step_mode="m",
287
+ detach_reset=True,
288
+ surrogate_function=surrogate.ATan(),
289
+ v_threshold=1.0,
290
+ )
291
+ self.dropout = nn.Dropout(p=dropout)
292
+
293
+ def forward(self, x):
294
+ # x: L, TB, D
295
+ L, TB, D = x.shape
296
+ x_feat = x.permute(1, 2, 0) # TB, D, L
297
+ x_feat = self.rpe_conv(x_feat) # TB, D, L
298
+ x_feat = (
299
+ self.rpe_bn(x_feat).reshape(self.T, int(TB / self.T), D, L).contiguous()
300
+ ) # T, B, D, L
301
+ x_feat = self.rpe_lif(x_feat)
302
+ x_feat = x_feat.flatten(0, 1) # TB, D, L
303
+ x_feat = self.dropout(x_feat) # TB, D, L
304
+ x_feat = x_feat.permute(2, 0, 1) # L, TB, D
305
+ x = x + x_feat
306
+ return x
307
+
308
+
309
+ class PositionEmbedding(nn.Module):
310
+ def __init__(
311
+ self,
312
+ input_size: int,
313
+ pe_type: str,
314
+ max_len: int = 5000,
315
+ pe_mode: str = "add",
316
+ num_pe_neuron: int = 10,
317
+ neuron_pe_scale: float = 1000.0,
318
+ dropout=0.1,
319
+ num_steps=4,
320
+ ):
321
+ super().__init__()
322
+ self.emb_type = pe_type
323
+ if pe_type in ["learn", "none"]:
324
+ self.emb = nn.Embedding(max_len, input_size)
325
+ elif pe_type == "conv":
326
+ self.emb = ConvPE(
327
+ d_model=input_size,
328
+ max_len=max_len,
329
+ dropout=dropout,
330
+ num_steps=num_steps,
331
+ )
332
+ elif pe_type == "static":
333
+ self.emb = StaticPE(d_model=input_size, max_len=max_len, dropout=dropout)
334
+ elif pe_type == "neuron":
335
+ self.emb = NeuronPE(
336
+ d_model=input_size,
337
+ pe_mode=pe_mode,
338
+ num_pe_neuron=num_pe_neuron,
339
+ neuron_pe_scale=neuron_pe_scale,
340
+ dropout=dropout,
341
+ num_steps=num_steps,
342
+ )
343
+ elif pe_type == "random":
344
+ self.emb = RandomPE(
345
+ d_model=input_size,
346
+ pe_mode=pe_mode,
347
+ num_pe_neuron=num_pe_neuron,
348
+ neuron_pe_scale=neuron_pe_scale,
349
+ dropout=dropout,
350
+ num_steps=num_steps,
351
+ )
352
+ else:
353
+ raise ValueError("Unknown embedding type: {}".format(pe_type))
354
+
355
+ def forward(self, x):
356
+ if self.emb_type == "learn":
357
+ # T, B, L, D = x.shape # x: T, B, L, D
358
+ # x = x.flatten(0, 1) # TB, L, D
359
+ tmp = torch.arange(
360
+ end=x.size()[1], device=x.device
361
+ ) # [0,1,2,...,L-1], shape: L
362
+ embedding = self.emb(tmp) # shape: L, D
363
+ embedding = embedding.repeat([x.size()[0], 1, 1]) # TB, L, D'
364
+ x = x + embedding
365
+ # x = x.reshape(T, B, L, -1)
366
+ elif self.emb_type in ["static", "conv"]:
367
+ T, B, L, _ = x.shape # x: T, B, L, D
368
+ x = x.flatten(0, 1) # TB, L, D
369
+ x = self.emb(x.transpose(0, 1)).transpose(0, 1) # x: TB, L, D'
370
+ x = x.reshape(T, B, L, -1)
371
+ elif self.emb_type in ["neuron", "random"]:
372
+ T, B, L, _ = x.shape # x: T, B, L, D
373
+ # T, B, L, D
374
+ x = self.emb(x)
375
+ x = x.reshape(T, B, L, -1)
376
+ return x # T, B, L, D'
377
+
378
+
379
+ @torch.jit.script
380
+ def heaviside(x: torch.Tensor):
381
+ return (x >= 0).to(x)
382
+
383
+ @torch.jit.script
384
+ def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float):
385
+
386
+ return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None
387
+
388
+
389
+ class SG(torch.autograd.Function):
390
+ @staticmethod
391
+ def forward(ctx, x, alpha=2.0):
392
+ if x.requires_grad:
393
+ ctx.save_for_backward(x)
394
+ ctx.alpha = alpha
395
+ return heaviside(x)
396
+
397
+ @staticmethod
398
+ def backward(ctx, grad_output):
399
+ return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha)
400
+
401
+
402
+ class MemoryModule(nn.Module):
403
+ def __init__(self):
404
+ """
405
+ * :ref:`API in English <MemoryModule.__init__-en>`
406
+
407
+ .. _MemoryModule.__init__-cn:
408
+
409
+ ``MemoryModule`` 是SpikingJelly中所有有状态(记忆)模块的基类。
410
+
411
+ * :ref:`中文API <MemoryModule.__init__-cn>`
412
+
413
+ .. _MemoryModule.__init__-en:
414
+
415
+ ``MemoryModule`` is the base class of all stateful modules in SpikingJelly.
416
+
417
+ """
418
+ super().__init__()
419
+ self._memories = {}
420
+ self._memories_rv = {}
421
+
422
+ def register_memory(self, name: str, value):
423
+ """
424
+ * :ref:`API in English <MemoryModule.register_memory-en>`
425
+
426
+ .. _MemoryModule.register_memory-cn:
427
+
428
+ :param name: 变量的名字
429
+ :type name: str
430
+ :param value: 变量的值
431
+ :type value: any
432
+
433
+ 将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为 ``value``。每次调用 ``self.reset()``
434
+ 函数后, ``self.name`` 都会被重置为 ``value``。
435
+
436
+ * :ref:`中文API <MemoryModule.register_memory-cn>`
437
+
438
+ .. _MemoryModule.register_memory-en:
439
+
440
+ :param name: variable's name
441
+ :type name: str
442
+ :param value: variable's value
443
+ :type value: any
444
+
445
+ Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a
446
+ spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after
447
+ each calling of ``self.reset()``.
448
+
449
+ """
450
+ assert not hasattr(self, name), f'{name} has been set as a member variable!'
451
+ self._memories[name] = value
452
+ self.set_reset_value(name, value)
453
+
454
+ def reset(self):
455
+ """
456
+ * :ref:`API in English <MemoryModule.reset-en>`
457
+
458
+ .. _MemoryModule.reset-cn:
459
+
460
+ 重置所有有状态变量为默认值。
461
+
462
+ * :ref:`中文API <MemoryModule.reset-cn>`
463
+
464
+ .. _MemoryModule.reset-en:
465
+
466
+ Reset all stateful variables to their default values.
467
+ """
468
+ for key in self._memories.keys():
469
+ self._memories[key] = copy.deepcopy(self._memories_rv[key])
470
+
471
+ def set_reset_value(self, name: str, value):
472
+ self._memories_rv[name] = copy.deepcopy(value)
473
+
474
+ def __getattr__(self, name: str):
475
+ if '_memories' in self.__dict__:
476
+ memories = self.__dict__['_memories']
477
+ if name in memories:
478
+ return memories[name]
479
+
480
+ return super().__getattr__(name)
481
+
482
+ def __setattr__(self, name: str, value) -> None:
483
+ _memories = self.__dict__.get('_memories')
484
+ if _memories is not None and name in _memories:
485
+ _memories[name] = value
486
+ else:
487
+ super().__setattr__(name, value)
488
+
489
+ def __delattr__(self, name):
490
+ if name in self._memories:
491
+ del self._memories[name]
492
+ del self._memories_rv[name]
493
+ else:
494
+ return super().__delattr__(name)
495
+
496
+ def __dir__(self):
497
+ module_attrs = dir(self.__class__)
498
+ attrs = list(self.__dict__.keys())
499
+ parameters = list(self._parameters.keys())
500
+ modules = list(self._modules.keys())
501
+ buffers = list(self._buffers.keys())
502
+ memories = list(self._memories.keys())
503
+ keys = module_attrs + attrs + parameters + modules + buffers + memories
504
+
505
+ # Eliminate attrs that are not legal Python variable names
506
+ keys = [key for key in keys if not key[0].isdigit()]
507
+
508
+ return sorted(keys)
509
+
510
+ def memories(self):
511
+ """
512
+ * :ref:`API in English <MemoryModule.memories-en>`
513
+
514
+ .. _MemoryModule.memories-cn:
515
+
516
+ :return: 返回一个所有状态变量的迭代器
517
+ :rtype: Iterator
518
+
519
+ * :ref:`中文API <MemoryModule.memories-cn>`
520
+
521
+ .. _MemoryModule.memories-en:
522
+
523
+ :return: an iterator over all stateful variables
524
+ :rtype: Iterator
525
+ """
526
+ for name, value in self._memories.items():
527
+ yield value
528
+
529
+ def named_memories(self):
530
+ """
531
+ * :ref:`API in English <MemoryModule.named_memories-en>`
532
+
533
+ .. _MemoryModule.named_memories-cn:
534
+
535
+ :return: 返回一个所有状态变量及其名称的迭代器
536
+ :rtype: Iterator
537
+
538
+ * :ref:`中文API <MemoryModule.named_memories-cn>`
539
+
540
+ .. _MemoryModule.named_memories-en:
541
+
542
+ :return: an iterator over all stateful variables and their names
543
+ :rtype: Iterator
544
+ """
545
+
546
+ for name, value in self._memories.items():
547
+ yield name, value
548
+
549
+ def detach(self):
550
+ """
551
+ * :ref:`API in English <MemoryModule.detach-en>`
552
+
553
+ .. _MemoryModule.detach-cn:
554
+
555
+ 从计算图中分离所有有状态变量。
556
+
557
+ .. tip::
558
+
559
+ 可以使用这个函数实现TBPTT(Truncated Back Propagation Through Time)。
560
+
561
+
562
+ * :ref:`中文API <MemoryModule.detach-cn>`
563
+
564
+ .. _MemoryModule.detach-en:
565
+
566
+ Detach all stateful variables.
567
+
568
+ .. admonition:: Tip
569
+ :class: tip
570
+
571
+ We can use this function to implement TBPTT(Truncated Back Propagation Through Time).
572
+
573
+ """
574
+
575
+ for key in self._memories.keys():
576
+ if isinstance(self._memories[key], torch.Tensor):
577
+ self._memories[key].detach_()
578
+
579
+ def _apply(self, fn):
580
+ for key, value in self._memories.items():
581
+ if isinstance(value, torch.Tensor):
582
+ self._memories[key] = fn(value)
583
+ return super()._apply(fn)
584
+
585
+ def _replicate_for_data_parallel(self):
586
+ replica = super()._replicate_for_data_parallel()
587
+ replica._memories = self._memories.copy()
588
+ return replica
589
+
590
+
591
+ class StepModule:
592
+ def supported_step_mode(self):
593
+ """
594
+ * :ref:`API in English <StepModule.supported_step_mode-en>`
595
+ .. _StepModule.supported_step_mode-cn:
596
+ :return: 包含支持的后端的tuple
597
+ :rtype: tuple[str]
598
+ 返回此模块支持的步进模式。
599
+ * :ref:`中文 API <StepModule.supported_step_mode-cn>`
600
+ .. _StepModule.supported_step_mode-en:
601
+ :return: a tuple that contains the supported backends
602
+ :rtype: tuple[str]
603
+ """
604
+ return ('s', 'm')
605
+
606
+ @property
607
+ def step_mode(self):
608
+ """
609
+ * :ref:`API in English <StepModule.step_mode-en>`
610
+ .. _StepModule.step_mode-cn:
611
+ :return: 模块当前使用的步进模式
612
+ :rtype: str
613
+ * :ref:`中文 API <StepModule.step_mode-cn>`
614
+ .. _StepModule.step_mode-en:
615
+ :return: the current step mode of this module
616
+ :rtype: str
617
+ """
618
+ return self._step_mode
619
+
620
+ @step_mode.setter
621
+ def step_mode(self, value: str):
622
+ """
623
+ * :ref:`API in English <StepModule.step_mode-setter-en>`
624
+ .. _StepModule.step_mode-setter-cn:
625
+ :param value: 步进模式
626
+ :type value: str
627
+ 将本模块的步进模式设置为 ``value``
628
+ * :ref:`中文 API <StepModule.step_mode-setter-cn>`
629
+ .. _StepModule.step_mode-setter-en:
630
+ :param value: the step mode
631
+ :type value: str
632
+ Set the step mode of this module to be ``value``
633
+ """
634
+ if value not in self.supported_step_mode():
635
+ raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
636
+ self._step_mode = value
637
+
638
+
639
+
640
+ class BaseNode(MemoryModule):
641
+ def __init__(self,
642
+ v_threshold: float = 1.,
643
+ v_reset: float = 0.,
644
+ surrogate_function: Callable = None,
645
+ detach_reset: bool = False,
646
+ step_mode='s', backend='torch',
647
+ store_v_seq: bool = True):
648
+
649
+ assert isinstance(v_reset, float) or v_reset is None
650
+ assert isinstance(v_threshold, float)
651
+ assert isinstance(detach_reset, bool)
652
+ super().__init__()
653
+
654
+ if v_reset is None:
655
+ self.register_memory('v', 0.)
656
+ self.register_memory('v_s', 0.)
657
+ else:
658
+ self.register_memory('v', v_reset)
659
+
660
+ self.v_threshold = v_threshold
661
+
662
+ self.v_reset = v_reset
663
+ self.detach_reset = detach_reset
664
+ self.surrogate_function = surrogate_function
665
+
666
+ self.step_mode = step_mode
667
+ self.backend = backend
668
+
669
+ self.store_v_seq = store_v_seq
670
+
671
+
672
+ self.alpha_s = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
673
+ self.alpha_l = torch.nn.Parameter(torch.tensor(0.5, dtype=torch.float))
674
+
675
+ @property
676
+ def store_v_seq(self):
677
+ return self._store_v_seq
678
+
679
+ @store_v_seq.setter
680
+ def store_v_seq(self, value: bool):
681
+ self._store_v_seq = value
682
+ if value:
683
+ if not hasattr(self, 'v_seq'):
684
+ self.register_memory('v_seq', None)
685
+
686
+ @staticmethod
687
+ @torch.jit.script
688
+ def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
689
+ v = (1. - spike) * v + spike * v_reset
690
+
691
+ return v
692
+
693
+ @staticmethod
694
+ @torch.jit.script
695
+ def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
696
+ v = v - spike * v_threshold
697
+ return v
698
+
699
+
700
+ @abstractmethod
701
+ def neuronal_charge(self, x: torch.Tensor):
702
+ raise NotImplementedError
703
+
704
+ def neuronal_fire(self):
705
+ return self.surrogate_function(self.v - self.v_threshold, 2.0)
706
+
707
+ def sl_neuronal_fire(self):
708
+ s_s = self.surrogate_function(self.v - self.v_threshold, 2.0)
709
+ s_l = self.surrogate_function(self.v_s - self.v_threshold, 2.0)
710
+ return s_s, s_l
711
+
712
+ def extra_repr(self):
713
+ return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
714
+
715
+ def single_step_forward(self, x: torch.Tensor):
716
+ self.v_float_to_tensor(x)
717
+ self.neuronal_charge(x)
718
+ s_s, s_l = self.sl_neuronal_fire()
719
+ spike = self.alpha_s * s_s + self.alpha_l * s_l
720
+ self.neuronal_reset(s_s, s_l)
721
+
722
+ return spike
723
+
724
+ def multi_step_forward(self, x_seq: torch.Tensor):
725
+
726
+ T = x_seq.shape[-1]
727
+ y_seq = []
728
+ if self.store_v_seq:
729
+ v_seq = []
730
+ for t in range(T):
731
+ y = self.single_step_forward(x_seq[:, t])
732
+ y_seq.append(y)
733
+ if self.store_v_seq:
734
+ v_seq.append(self.v)
735
+ if self.store_v_seq:
736
+ self.v_seq = torch.stack(v_seq)
737
+
738
+ outputs = torch.stack(y_seq, dim=0).permute(1, 0)
739
+
740
+ return outputs
741
+
742
+ def v_float_to_tensor(self, x: torch.Tensor):
743
+ if isinstance(self.v, float):
744
+ v_init = self.v
745
+ self.v = torch.full_like(x.data, v_init)
746
+
747
+
748
+ class TSLIFNode(BaseNode):
749
+ def __init__(self,
750
+ v_threshold=1.0,
751
+ v_reset=0.,
752
+ surrogate_function: Callable = None,
753
+ detach_reset=False,
754
+ hard_reset=False,
755
+ step_mode='s',
756
+ k=2,
757
+ decay_factor: torch.Tensor = torch.tensor([0.8, 0.2, 0.3, 0.7], dtype=torch.float),
758
+ gamma: float = 0.5):
759
+ super(TSLIFNode, self).__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
760
+ self.k = k
761
+ for i in range(1, self.k + 1):
762
+ self.register_memory('v' + str(i), 0.)
763
+ self.names = self._memories
764
+ self.hard_reset = hard_reset
765
+ self.gamma = gamma
766
+ self.decay_factor = torch.nn.Parameter(decay_factor)
767
+ self.kk = torch.nn.Parameter(torch.tensor([0.8], dtype=torch.float))
768
+ self.yy = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float))
769
+
770
+ @property
771
+ def supported_backends(self):
772
+ if self.step_mode == 's':
773
+ return ('torch',)
774
+ elif self.step_mode == 'm':
775
+ return ('torch', 'cupy')
776
+ else:
777
+ raise ValueError(self.step_mode)
778
+
779
+ def neuronal_charge(self, x: torch.Tensor):
780
+ self.names['v1'] = self.decay_factor[0] * self.names['v1'] + self.decay_factor[1] * x - self.yy * self.names['v2']
781
+ self.names['v2'] = self.decay_factor[2] * self.names['v2'] + self.decay_factor[3] * x - self.kk * self.names['v1']
782
+ self.v = self.names['v2']
783
+ self.v_s = self.names['v1']
784
+
785
+ def neuronal_reset(self, spike_s, spike_l):
786
+ if not self.hard_reset:
787
+ self.names['v1'] = self.jit_soft_reset(self.names['v1'], spike_l, self.gamma)
788
+ self.names['v2'] = self.jit_soft_reset(self.names['v2'], spike_s, self.v_threshold)
789
+ else:
790
+ for i in range(2, self.k + 1):
791
+ self.names['v' + str(i)] = self.jit_hard_reset(self.names['v' + str(i)], spike_s, self.v_reset)
792
+
793
+ def forward(self, x: torch.Tensor):
794
+ return super().single_step_forward(x)
795
+ def extra_repr(self):
796
+ return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, " \
797
+ f"hard_reset={self.hard_reset}, " \
798
+ f"gamma={self.gamma}, k={self.k}, step_mode={self.step_mode}, backend={self.backend}"
799
+
800
+
801
+
802
+
803
+
804
+ class SpikeTemporalBlock(nn.Module):
805
+ def __init__(
806
+ self,
807
+ n_inputs,
808
+ n_outputs,
809
+ kernel_size,
810
+ stride,
811
+ dilation,
812
+ padding,
813
+ num_steps=4,
814
+ ):
815
+ super().__init__()
816
+ self.num_steps = num_steps
817
+ self.conv1 = weight_norm(
818
+ nn.Conv2d(
819
+ n_inputs,
820
+ n_outputs,
821
+ (1, kernel_size),
822
+ stride=stride,
823
+ padding=(0, padding),
824
+ dilation=(1, dilation),
825
+ )
826
+ )
827
+ self.bn1 = nn.BatchNorm2d(n_outputs)
828
+ self.chomp1 = Chomp2d(padding)
829
+
830
+ self.tslif1 = TSLIFNode(
831
+ surrogate_function =SG.apply,
832
+ )
833
+
834
+ self.conv2 = weight_norm(
835
+ nn.Conv2d(
836
+ n_outputs,
837
+ n_outputs,
838
+ (1, kernel_size),
839
+ stride=stride,
840
+ padding=(0, padding),
841
+ dilation=(1, dilation),
842
+ )
843
+ )
844
+ self.bn2 = nn.BatchNorm2d(n_outputs)
845
+ self.chomp2 = Chomp2d(padding)
846
+
847
+ self.tslif2 = TSLIFNode(
848
+ surrogate_function =SG.apply,
849
+ )
850
+
851
+ self.downsample = (
852
+ nn.Conv2d(n_inputs, n_outputs, (1, 1)) if n_inputs != n_outputs else None
853
+ )
854
+
855
+ self.tslif = TSLIFNode(
856
+ surrogate_function =SG.apply,
857
+ )
858
+
859
+ def init_weights(self):
860
+ self.conv1.weight.data.normal_(0, 0.01)
861
+ self.conv2.weight.data.normal_(0, 0.01)
862
+ if self.downsample is not None:
863
+ self.downsample.weight.data.normal_(0, 0.01)
864
+
865
+ def forward(self, x):
866
+ # out1: 24, 16, 361, 168
867
+
868
+ out1 = self.chomp1(self.bn1(self.conv1(x)))
869
+ spk_rec1 = []
870
+ for _ in range(self.num_steps):
871
+ spk = self.tslif1(out1)
872
+ spk_rec1.append(spk)
873
+
874
+ spks1 = torch.stack(spk_rec1, dim=-1) # spks1: B, H, C, L, T
875
+ spks1 = spks1.mean(-1) # spks1: B, H, C, L
876
+
877
+ out2 = self.chomp2(self.bn2(self.conv2(spks1)))
878
+ spk_rec2 = []
879
+ for _ in range(self.num_steps):
880
+ # spk: 24, 16, 361, 168
881
+ spk = self.tslif2(out2)
882
+ spk_rec2.append(spk)
883
+
884
+ spks2 = torch.stack(spk_rec2, dim=-1) # spks2: B, H, C, L, T
885
+ spks2 = spks2.mean(-1) # spks2: B, H, C, L
886
+
887
+ if torch.isnan(spks2).any() or torch.isinf(spks2).any():
888
+ print("illegal value in TemporalBlock2D")
889
+
890
+ if self.downsample is None:
891
+ res = x
892
+ else:
893
+ res = self.downsample(x)
894
+
895
+ spk_rec3 = []
896
+ for _ in range(self.num_steps):
897
+
898
+ spk = self.tslif(spks2 + res)
899
+ spk_rec3.append(spk)
900
+
901
+
902
+ res = torch.stack(spk_rec3, dim=-1) # res: B, H, C, L, T
903
+
904
+ res = res.mean(-1)
905
+
906
+ return res
907
+
908
+
909
+
910
+
911
+
912
+ class TSTCN(nn.Module):
913
+ def __init__(
914
+ self,
915
+ args,
916
+ num_levels: int = 3,
917
+ channel: int = 16,
918
+ dilation: int = 2,
919
+ stride: int = 1,
920
+ kernel_size: int = 2,
921
+ dropout: float = 0.2,
922
+ max_length: int = 100,
923
+ encoder_type: str = "conv",
924
+ pe_type: str = "neuron",
925
+ pe_mode: str = "concat",
926
+ num_pe_neuron: int = 40,
927
+ neuron_pe_scale: float = 1000.0,
928
+ ):
929
+ super().__init__()
930
+
931
+ self.hidden_size = args.hidden_size
932
+ self.num_steps = args.T
933
+ self.input_size = args.feature_size
934
+ self.feature_size = args.feature_size
935
+ self.pre_length = args.pre_length
936
+ self.num_levels = args.blocks
937
+ self.pe_type = pe_type
938
+ self.pe_mode = pe_mode
939
+ self.num_pe_neuron = num_pe_neuron
940
+ self.kernel_size = args.kernel_size
941
+ self.args = args
942
+
943
+
944
+
945
+ self._snn_backend = "snntorch"
946
+ self.encoder = SpikeEncoder[self._snn_backend][encoder_type](self.hidden_size)
947
+
948
+
949
+ self.pe = PositionEmbedding(
950
+ pe_type=pe_type,
951
+ pe_mode=pe_mode,
952
+ neuron_pe_scale=neuron_pe_scale,
953
+ input_size=self.input_size,
954
+ max_len=max_length,
955
+ num_pe_neuron=self.num_pe_neuron,
956
+ dropout=0.1,
957
+ num_steps=self.num_steps,
958
+ )
959
+
960
+
961
+
962
+ layers = []
963
+ num_channels = [channel] * self.num_levels
964
+ num_channels.append(1)
965
+ for i in range(self.num_levels + 1):
966
+ dilation_size = dilation**i
967
+ in_channels = self.hidden_size if i == 0 else num_channels[i - 1]
968
+ out_channels = num_channels[i]
969
+ layers += [
970
+ SpikeTemporalBlock(
971
+ in_channels,
972
+ out_channels,
973
+ self.kernel_size,
974
+ stride=stride,
975
+ dilation=dilation_size,
976
+ padding=(self.kernel_size - 1) * dilation_size,
977
+ num_steps=self.num_steps,
978
+ )
979
+ ]
980
+
981
+
982
+
983
+ self.network = nn.Sequential(*layers)
984
+
985
+ if (self.pe_type == "neuron" and self.pe_mode == "concat") or (
986
+ self.pe_type == "random" and self.pe_mode == "concat"
987
+ ):
988
+ self.__output_size = self.feature_size + num_pe_neuron
989
+ else:
990
+ self.__output_size = args.seq_length
991
+
992
+ self.fc1 = nn.Linear(self.__output_size, args.feature_size)
993
+ self.fc2 = nn.Linear(args.seq_length, self.pre_length)
994
+ self.to('cuda:0')
995
+
996
+ def forward(self, inputs: torch.Tensor):
997
+ utils.reset(self.encoder)
998
+
999
+ if self.args.normalize:
1000
+
1001
+ mean = inputs.mean(dim=1, keepdim=True).detach() # shape [B, 1, D]
1002
+ inputs = inputs - mean
1003
+
1004
+ std = torch.sqrt(torch.var(inputs, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
1005
+ inputs = inputs / std
1006
+
1007
+ inputs = self.encoder(inputs) # B, H, C, L
1008
+ # inputs: 24, 64, 321, 168
1009
+
1010
+
1011
+ if self.pe_type != "none":
1012
+ inputs = self.pe(inputs.permute(1, 0, 3, 2)).permute(1, 0, 3, 2)
1013
+
1014
+ spks = self.network(inputs)
1015
+ spks = spks.squeeze(1) # B, C', L
1016
+ preds = self.fc1(spks.permute(0, 2, 1)) # B, L, C
1017
+ preds = self.fc2(preds.permute(0, 2, 1)) # B, C', L
1018
+ preds = preds.permute(0, 2, 1).contiguous()
1019
+ if self.args.normalize:
1020
+ preds = preds * std + mean # denormalize
1021
+
1022
+
1023
+ # Create auxiliary output
1024
+ aux = {'gate_l0': torch.tensor(0.0, device=preds.device)} # placeholder
1025
+
1026
+ return preds, aux
1027
+
1028
+ @property
1029
+ def output_size(self):
1030
+ return self.__output_size
model/iSpikformer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from spikingjelly.clock_driven.neuron import MultiStepLIFNode
4
+
5
+ class SPE(nn.Module):
6
+ def __init__(self, input_len, patch_num, patch_dim, T, tau, D):
7
+ super().__init__()
8
+ self.patch_projector = nn.Linear(input_len // patch_num, patch_dim)
9
+ self.bn = nn.BatchNorm2d(patch_dim)
10
+ self.encoder_lif = MultiStepLIFNode(tau=tau, detach_reset=False, backend='torch')
11
+
12
+ self.D = D
13
+ self.T = T
14
+ self.patch_dim = patch_dim
15
+ self.patch_num = patch_num
16
+
17
+ def forward(self, x):
18
+ B, L, D = x.shape
19
+
20
+ x = x.view(B, self.patch_num, L // self.patch_num, D).contiguous()
21
+ x = x.transpose(-1, -2).contiguous()
22
+ x = self.patch_projector(x)
23
+ x = x.repeat(self.T, 1, 1, 1, 1)
24
+ x = x.permute(0, 1, 4, 2, 3).contiguous()
25
+ x = x.flatten(0, 1)
26
+ x = self.bn(x)
27
+ x = x.view(self.T, B, self.patch_dim, self.patch_num, D)
28
+ x = self.encoder_lif(x)
29
+
30
+ return x
31
+
32
+ class iSSA(nn.Module):
33
+ def __init__(self, patch_num, D, patch_dim, tau, alpha):
34
+ super().__init__()
35
+ self.lin1 = nn.Linear(patch_num, patch_num)
36
+ self.lin2 = nn.Linear(patch_num, patch_num)
37
+ self.lin3 = nn.Linear(patch_num, patch_num)
38
+
39
+ self.lif1 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
40
+ self.lif2 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
41
+ self.lif3 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
42
+ self.lif4 = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
43
+
44
+ self.b1 = nn.BatchNorm2d(patch_dim)
45
+ self.b2 = nn.BatchNorm2d(patch_dim)
46
+ self.b3 = nn.BatchNorm2d(patch_dim)
47
+ self.b4 = nn.BatchNorm2d(patch_dim)
48
+
49
+ def forward(self, x):
50
+ res_x = x
51
+ T, B, pd, pn, D = x.shape
52
+
53
+
54
+ x = x.transpose(-1, -2).contiguous()
55
+ q = self.lin1(x).flatten(0, 1)
56
+ k = self.lin2(x).flatten(0, 1)
57
+ v = self.lin3(x).flatten(0, 1)
58
+
59
+ q = self.b1(q)
60
+ k = self.b2(k)
61
+ v = self.b3(v)
62
+
63
+ q = q.view(T, B, pd, D, -1)
64
+ k = k.view(T, B, pd, D, -1)
65
+ v = v.view(T, B, pd, D, -1)
66
+
67
+ q = self.lif1(q)
68
+ k = self.lif2(k).transpose(-1, -2).contiguous()
69
+ v = self.lif3(v)
70
+
71
+ attn = q @ k
72
+ attn = attn @ v
73
+ attn = attn.flatten(0, 1)
74
+ attn = self.b4(attn)
75
+ attn = attn.view(T, B, pd, D, pn)
76
+ attn = self.lif4(attn)
77
+ attn = attn.transpose(-1, -2).contiguous()
78
+
79
+ return attn
80
+
81
+ class iSpikformer(nn.Module):
82
+ def __init__(self, args, input_len, patch_num, patch_dim, T, blocks, D, pred_len, tau, alpha, hidden_dim):
83
+ super().__init__()
84
+ self.emb = SPE(input_len, patch_num, patch_dim, T, tau, D)
85
+ self.args = args
86
+ self.attn = nn.ModuleList()
87
+ for i in range(blocks):
88
+ self.attn.append(iSSA(patch_num, D, patch_dim, tau, alpha))
89
+
90
+ self.dense1 = nn.Linear(patch_num*patch_dim, hidden_dim)
91
+ self.dense2 = nn.Linear(hidden_dim, pred_len)
92
+ self.bn = nn.BatchNorm1d(D)
93
+ self.activ = MultiStepLIFNode(tau=tau, detach_reset=True, backend='torch')
94
+ self.to('cuda:0')
95
+
96
+
97
+ def forward(self, x):
98
+ if self.args.normalize:
99
+ mean = x.mean(dim=1, keepdim=True).detach()
100
+ x = x - mean
101
+
102
+ std = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
103
+ x = x / std
104
+
105
+ x = self.emb(x)
106
+ T, B, pd, pn, D = x.shape
107
+
108
+ for i in range(len(self.attn)):
109
+ x = self.attn[i](x)
110
+ x = x.permute(0, 1, 4, 2, 3).contiguous()
111
+ x = x.flatten(-2, -1)
112
+ x = self.dense1(x)
113
+ x = x.flatten(0, 1)
114
+ x = self.bn(x)
115
+ x = self.activ(x)
116
+ x = self.dense2(x)
117
+ x = x.transpose(-1, -2).contiguous()
118
+ x = x.view(T, B, -1, D)
119
+
120
+ if self.args.normalize:
121
+ x = x * std
122
+ x = x + mean.repeat(T, 1, 1, 1)
123
+
124
+ aux = {
125
+ 'gate_l0': torch.tensor(0.0, device=x.device) # placeholder
126
+ }
127
+
128
+ return x.mean(dim=0), aux
129
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ torch
4
+ scikit-learn
5
+ snntorch
6
+ spikingjelly
scripts/ecl.sh ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+
5
+
6
+
7
+
8
+ if [ ! -d "./logs" ]; then
9
+ mkdir ./logs
10
+ fi
11
+
12
+ if [ ! -d "./logs/LongForecasting" ]; then
13
+ mkdir ./logs/LongForecasting
14
+ fi
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ python train.py \
24
+ --model FGN \
25
+ --data electricity \
26
+ --feature_size 370\
27
+ --embed_size 128 \
28
+ --hidden_size 256 \
29
+ --batch_size 16 \
30
+ --train_ratio 0.7 \
31
+ --val_ratio 0.2 \
32
+ --seq_length 12 \
33
+ --pre_length 12 \
34
+ --train_epochs 100 \
35
+ --learning_rate 0.00001 \
36
+ --device cuda:0 >logs/LongForecasting/ECL_FGN.log
37
+
38
+
39
+ python train.py \
40
+ --model SpikF \
41
+ --data electricity \
42
+ --feature_size 370\
43
+ --embed_size 128 \
44
+ --hidden_size 256 \
45
+ --batch_size 16 \
46
+ --train_ratio 0.7 \
47
+ --val_ratio 0.2 \
48
+ --seq_length 12 \
49
+ --pre_length 12 \
50
+ --train_epochs 100 \
51
+ --learning_rate 0.00001 \
52
+ --T 16 \
53
+ --blocks 2\
54
+ --device cuda:0 >logs/LongForecasting/ECL_SpikF.log
55
+
56
+
57
+ python train.py \
58
+ --model iSpikformer \
59
+ --data electricity \
60
+ --feature_size 370\
61
+ --embed_size 128 \
62
+ --hidden_size 256 \
63
+ --batch_size 16 \
64
+ --train_ratio 0.7 \
65
+ --val_ratio 0.2 \
66
+ --seq_length 12 \
67
+ --pre_length 12 \
68
+ --train_epochs 100 \
69
+ --learning_rate 0.00001 \
70
+ --blocks 2 \
71
+ --device cuda:0 >logs/LongForecasting/ECL_iSpikformer.log
72
+
73
+
74
+
75
+ python train.py \
76
+ --model SpikF_GO \
77
+ --data electricity \
78
+ --feature_size 370\
79
+ --embed_size 128 \
80
+ --hidden_size 256 \
81
+ --batch_size 16 \
82
+ --train_ratio 0.7 \
83
+ --val_ratio 0.2 \
84
+ --seq_length 12 \
85
+ --pre_length 12 \
86
+ --train_epochs 100 \
87
+ --learning_rate 0.00001 \
88
+ --energy_loss True \
89
+ --device cuda:0 >logs/LongForecasting/ECL_SpikFGO.log
90
+
91
+
92
+ python train.py \
93
+ --model SpikF_GO_CPG \
94
+ --data electricity \
95
+ --feature_size 370\
96
+ --embed_size 128 \
97
+ --hidden_size 256 \
98
+ --batch_size 16 \
99
+ --train_ratio 0.7 \
100
+ --val_ratio 0.2 \
101
+ --seq_length 12 \
102
+ --pre_length 12 \
103
+ --train_epochs 100 \
104
+ --learning_rate 0.00001 \
105
+ --energy_loss True \
106
+ --device cuda:0 >logs/LongForecasting/ECL_SpikFGOCPG.log
107
+
108
+
109
+ python train.py \
110
+ --model SpikeRNN_CPG \
111
+ --data electricity \
112
+ --feature_size 370\
113
+ --embed_size 128 \
114
+ --hidden_size 128\
115
+ --batch_size 16 \
116
+ --train_ratio 0.7 \
117
+ --val_ratio 0.2 \
118
+ --seq_length 12 \
119
+ --pre_length 12 \
120
+ --train_epochs 100 \
121
+ --learning_rate 0.00001 \
122
+ --blocks 2 \
123
+ --device cuda:0 >logs/LongForecasting/ECL_SpikeRNNCPG.log
124
+
125
+
126
+
127
+
128
+ python train.py \
129
+ --model SpikeGRU \
130
+ --data electricity \
131
+ --feature_size 370\
132
+ --embed_size 128 \
133
+ --hidden_size 64 \
134
+ --batch_size 16 \
135
+ --train_ratio 0.7 \
136
+ --val_ratio 0.2 \
137
+ --seq_length 12 \
138
+ --pre_length 12 \
139
+ --train_epochs 100 \
140
+ --learning_rate 0.00001 \
141
+ --device cuda:0 >logs/LongForecasting/ECL_SpikeGRU.log
142
+
143
+
144
+
145
+
146
+ python train.py \
147
+ --model SpikeTCN_CPG \
148
+ --data electricity \
149
+ --feature_size 370\
150
+ --embed_size 128 \
151
+ --hidden_size 64\
152
+ --batch_size 16 \
153
+ --train_ratio 0.7 \
154
+ --val_ratio 0.2 \
155
+ --seq_length 12 \
156
+ --pre_length 12 \
157
+ --train_epochs 100 \
158
+ --learning_rate 0.00001 \
159
+ --blocks 3\
160
+ --device cuda:0 >logs/LongForecasting/ECL_SpikeTCNCPG.log
161
+
162
+
163
+
164
+
165
+ python train.py \
166
+ --model Spikformer_CPG \
167
+ --data electricity \
168
+ --feature_size 370\
169
+ --embed_size 128 \
170
+ --hidden_size 128\
171
+ --batch_size 16 \
172
+ --train_ratio 0.7 \
173
+ --val_ratio 0.2 \
174
+ --seq_length 12 \
175
+ --pre_length 12 \
176
+ --train_epochs 100 \
177
+ --learning_rate 0.00001 \
178
+ --blocks 2 \
179
+ --device cuda:0 >logs/LongForecasting/ECL_SpikformerCPG.log
180
+
181
+
182
+
183
+ python train.py \
184
+ --model TSTCN \
185
+ --data electricity \
186
+ --feature_size 370\
187
+ --embed_size 128 \
188
+ --hidden_size 64 \
189
+ --batch_size 16 \
190
+ --train_ratio 0.7 \
191
+ --val_ratio 0.2 \
192
+ --seq_length 12 \
193
+ --pre_length 12 \
194
+ --train_epochs 100 \
195
+ --learning_rate 0.00001 \
196
+ --kernel_size 3\
197
+ --blocks 3 \
198
+ --device cuda:0 >logs/LongForecasting/ECL_TSTCN.log
199
+
200
+
201
+ python train.py \
202
+ --model TSGRU \
203
+ --data electricity \
204
+ --feature_size 370\
205
+ --embed_size 128 \
206
+ --hidden_size 64 \
207
+ --batch_size 16 \
208
+ --train_ratio 0.7 \
209
+ --val_ratio 0.2 \
210
+ --seq_length 12 \
211
+ --pre_length 12 \
212
+ --train_epochs 100 \
213
+ --learning_rate 0.00001 \
214
+ --device cuda:0 >logs/LongForecasting/ECL_TSGRU.log
215
+
216
+
217
+
218
+
219
+ python train.py \
220
+ --model TSFormer \
221
+ --data electricity \
222
+ --feature_size 370\
223
+ --embed_size 128 \
224
+ --hidden_size 64 \
225
+ --batch_size 16 \
226
+ --train_ratio 0.7 \
227
+ --val_ratio 0.2 \
228
+ --seq_length 12 \
229
+ --pre_length 12 \
230
+ --train_epochs 100 \
231
+ --learning_rate 0.00001 \
232
+ --device cuda:0 >logs/LongForecasting/ECL_TSFormer.log
train.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ import snntorch as snn
6
+ import time
7
+ import os
8
+ import numpy as np
9
+ import warnings
10
+ from spikingjelly.clock_driven import functional
11
+
12
+ from data.data_loader import (
13
+ Dataset_ECG, Dataset_Dhfm, Dataset_Solar, Dataset_Wiki, Dataset_PEMS_BAY
14
+ )
15
+ from utils.utils import save_model_ts, load_model_ts, evaluate
16
+
17
+ from model.FourierGNN import FGN
18
+ from model.SpikF import SpikF
19
+ from model.iSpikformer import iSpikformer
20
+ from model.SpikF_GO import SpikF_GO
21
+ from model.SpikF_GO_CPG import SpikF_GO_CPG
22
+ from model.TS_GRU import TSGRU
23
+ from model.TS_TCN import TSTCN
24
+ from model.TS_Former import TSFormer
25
+ from model.SpikeGRU import SpikeGRU
26
+ from model.Spikformer_CPG import Spikformer_CPG
27
+ from model.SpikeRNN_CPG import SpikeRNN_CPG
28
+ from model.SpikeTCN_CPG import SpikeTCN_CPG
29
+ from model.TS_TCN import TSLIFNode
30
+
31
+
32
+ def remove(model):
33
+ """Reset states of spiking neurons with warning suppression"""
34
+ if model is None:
35
+ return
36
+ with warnings.catch_warnings():
37
+ warnings.filterwarnings("ignore", message=".*not base.MemoryModule.*")
38
+ if hasattr(model, '__iter__'):
39
+ for m in model:
40
+ if hasattr(m, 'reset'):
41
+ m.reset()
42
+ elif hasattr(m, 'v'):
43
+ m.v = 0.0
44
+ elif hasattr(model, 'reset'):
45
+ model.reset()
46
+ elif hasattr(model, 'v'):
47
+ model.v = 0.0
48
+
49
+
50
+ def reset_states(model):
51
+ """Reset states of all spiking neurons (TSLIFNode, Leaky, etc.) with warning suppression."""
52
+ if model is None:
53
+ return
54
+ with warnings.catch_warnings():
55
+ warnings.filterwarnings("ignore", message=".*not base.MemoryModule.*")
56
+ if hasattr(model, '__iter__'):
57
+ for m in model:
58
+ reset_states(m)
59
+ elif hasattr(model, 'modules'):
60
+ for module in model.modules():
61
+ if isinstance(module, (snn.Leaky, TSLIFNode)):
62
+ try:
63
+ module.reset()
64
+ except Exception:
65
+ if hasattr(module, 'v'):
66
+ module.v = 0.0
67
+ elif hasattr(model, 'reset'):
68
+ model.reset()
69
+ elif hasattr(model, 'v'):
70
+ model.v = 0.0
71
+
72
+
73
+ def _inverse_if_possible(arr: np.ndarray, scaler):
74
+ """
75
+ Inverse-transform arr of shape (..., D) using scaler fitted on train.
76
+ If scaler is None, returns arr unchanged.
77
+ """
78
+ if scaler is None:
79
+ return arr
80
+ if not hasattr(scaler, "inverse_transform"):
81
+ return arr
82
+
83
+ if arr.ndim < 2:
84
+ return arr
85
+
86
+ D = arr.shape[-1]
87
+ flat = arr.reshape(-1, D)
88
+ inv = scaler.inverse_transform(flat)
89
+ return inv.reshape(arr.shape)
90
+
91
+
92
+ def compute_scores_scaled_and_orig(trues: np.ndarray, preds: np.ndarray, scaler):
93
+ score_scaled = evaluate(trues, preds)
94
+
95
+ trues_inv = _inverse_if_possible(trues, scaler)
96
+ preds_inv = _inverse_if_possible(preds, scaler)
97
+ score_orig = evaluate(trues_inv, preds_inv)
98
+
99
+ return score_scaled, score_orig
100
+
101
+
102
+ def _fmt_score(tag, score):
103
+ mape, mae, rmse, r2, rse = score
104
+ mape_pct = mape * 100.0
105
+ return f"{tag}: MAPE {mape_pct:10.6f}; MAE {mae:10.6f}; RMSE {rmse:10.6f}; R2 {r2:10.6f}; RSE {rse:10.6f}."
106
+
107
+
108
+ # args
109
+ parser = argparse.ArgumentParser(description='SpikF-GO: Spiking Fourier Graph Operators for Multivariate Time Series Forecasting')
110
+ parser.add_argument('--data', type=str, default='ECG', help='data set')
111
+ parser.add_argument('--feature_size', type=int, default=140, help='feature size')
112
+ parser.add_argument('--seq_length', type=int, default=12, help='input length')
113
+ parser.add_argument('--pre_length', type=int, default=12, help='predict length')
114
+ parser.add_argument('--embed_size', type=int, default=128, help='embedding dimensions')
115
+ parser.add_argument('--hidden_size', type=int, default=256, help='hidden dimensions')
116
+ parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
117
+ parser.add_argument('--batch_size', type=int, default=4, help='input data batch size')
118
+ parser.add_argument('--learning_rate', type=float, default=0.00001, help='optimizer learning rate')
119
+ parser.add_argument('--exponential_decay_step', type=int, default=5)
120
+ parser.add_argument('--validate_freq', type=int, default=1)
121
+ parser.add_argument('--early_stop', type=bool, default=False)
122
+ parser.add_argument('--decay_rate', type=float, default=0.5)
123
+ parser.add_argument('--train_ratio', type=float, default=0.6)
124
+ parser.add_argument('--val_ratio', type=float, default=0.2)
125
+ parser.add_argument('--device', type=str, default='cuda:0', help='device')
126
+ parser.add_argument('--tau', type=float, default=2.0, help='tau')
127
+ parser.add_argument('--alpha', type=float, default=1.0)
128
+ parser.add_argument('--T', type=int, default=4)
129
+ parser.add_argument('--proj_dim', type=int, default=32, help='proj dim')
130
+ parser.add_argument('--model', type=str, default='FGN', help='model name')
131
+
132
+ parser.add_argument('--patch_num', type=int, default=4)
133
+ parser.add_argument('--patch_dim', type=int, default=16)
134
+ parser.add_argument('--blocks', type=int, default=1)
135
+ parser.add_argument('--energy_loss', type=bool, default=False)
136
+ parser.add_argument('--normalize', action='store_false', help='Disable normalization')
137
+ parser.add_argument('--affine', action='store_false', help='Disable affine layer')
138
+ parser.add_argument('--kernel_size', type=int, default=16)
139
+
140
+ args = parser.parse_args()
141
+ print(f'Training configs: {args}')
142
+
143
+
144
+ data_parser = {
145
+ 'traffic': {'root_path': 'data/traffic.npy', 'type': '0'},
146
+ 'ECG': {'root_path': 'data/ECG_data.csv', 'type': '0'},
147
+ 'COVID': {'root_path': 'data/covid.csv', 'type': '0'},
148
+ 'electricity': {'root_path': 'data/electricity.csv','type': '0'},
149
+ 'solar': {'root_path': './data/solar', 'type': '0'},
150
+ 'metr': {'root_path': 'data/metr.csv', 'type': '0'},
151
+ 'wiki': {'root_path': 'data/wiki.csv', 'type': '0'},
152
+ 'pems_bay': {'root_path': 'data/pems-bay.h5', 'type': '0'},
153
+ }
154
+
155
+ data_dict = {
156
+ 'ECG': Dataset_ECG,
157
+ 'COVID': Dataset_ECG,
158
+ 'traffic': Dataset_Dhfm,
159
+ 'solar': Dataset_Solar,
160
+ 'wiki': Dataset_Wiki,
161
+ 'electricity': Dataset_ECG,
162
+ 'metr': Dataset_ECG,
163
+ 'pems_bay': Dataset_PEMS_BAY,
164
+ }
165
+
166
+ if args.data not in data_parser:
167
+ raise ValueError(f"Unknown dataset {args.data}. Available: {list(data_parser.keys())}")
168
+
169
+ data_info = data_parser[args.data]
170
+ Data = data_dict[args.data]
171
+
172
+
173
+ train_set = Data(
174
+ root_path=data_info['root_path'], flag='train',
175
+ seq_len=args.seq_length, pre_len=args.pre_length,
176
+ type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
177
+ scaler=None
178
+ )
179
+ train_scaler = getattr(train_set, "scaler", None)
180
+
181
+ val_set = Data(
182
+ root_path=data_info['root_path'], flag='val',
183
+ seq_len=args.seq_length, pre_len=args.pre_length,
184
+ type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
185
+ scaler=train_scaler
186
+ )
187
+
188
+ test_set = Data(
189
+ root_path=data_info['root_path'], flag='test',
190
+ seq_len=args.seq_length, pre_len=args.pre_length,
191
+ type=data_info['type'], train_ratio=args.train_ratio, val_ratio=args.val_ratio,
192
+ scaler=train_scaler
193
+ )
194
+
195
+ train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
196
+ val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
197
+ test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
198
+
199
+ print("Train samples:", len(train_set))
200
+ print("Val samples:", len(val_set))
201
+ print("Test samples:", len(test_set))
202
+
203
+
204
+ MODELS_SET2 = ["TSGRU", "TSTCN", "TSFormer", "Spikformer_CPG", "SpikeGRU", "SpikeRNN_CPG", "SpikeTCN_CPG"]
205
+
206
+
207
+
208
+ def validate(model, vali_loader, scaler):
209
+ model.eval()
210
+ cnt = 0
211
+ loss_total = 0.0
212
+ preds_list = []
213
+ trues_list = []
214
+
215
+ for x, y in vali_loader:
216
+ if args.model in MODELS_SET2 and args.model != 'TSGRU':
217
+ reset_states(model=model)
218
+ elif args.model == 'TSGRU':
219
+ remove(model=model.net[0].tslif)
220
+
221
+ x = x.float().to(args.device)
222
+ y = y.float().to(args.device)
223
+
224
+ forecast, _ = model(x)
225
+ if len(forecast.shape) == 4:
226
+ forecast = forecast.mean(dim=0)
227
+
228
+ loss = forecast_loss(forecast, y)
229
+ loss_total += float(loss)
230
+ cnt += 1
231
+
232
+ if args.model not in MODELS_SET2:
233
+ functional.reset_net(model)
234
+
235
+ preds_list.append(forecast.detach().cpu().numpy())
236
+ trues_list.append(y.detach().cpu().numpy())
237
+
238
+ preds = np.concatenate(preds_list, axis=0)
239
+ trues = np.concatenate(trues_list, axis=0)
240
+
241
+ score_scaled, score_orig = compute_scores_scaled_and_orig(trues, preds, scaler)
242
+
243
+ print(_fmt_score("SCALED", score_scaled))
244
+ print(_fmt_score("ORIG ", score_orig))
245
+
246
+ model.train()
247
+ return loss_total / max(1, cnt)
248
+
249
+
250
+ def test(model, result_test_file, scaler, load_epoch=97):
251
+ model = load_model_ts(model, result_test_file, load_epoch)
252
+ model.eval()
253
+
254
+ preds_list = []
255
+ trues_list = []
256
+
257
+ for x, y in test_dataloader:
258
+ if args.model in MODELS_SET2 and args.model != 'TSGRU':
259
+ reset_states(model=model)
260
+ elif args.model == 'TSGRU':
261
+ remove(model=model.net[0].tslif)
262
+
263
+ x = x.float().to(args.device)
264
+ y = y.float().to(args.device)
265
+
266
+ forecast, _ = model(x)
267
+ if len(forecast.shape) == 4:
268
+ forecast = forecast.mean(dim=0)
269
+
270
+ if args.model not in MODELS_SET2:
271
+ functional.reset_net(model)
272
+
273
+ preds_list.append(forecast.detach().cpu().numpy())
274
+ trues_list.append(y.detach().cpu().numpy())
275
+
276
+ preds = np.concatenate(preds_list, axis=0)
277
+ trues = np.concatenate(trues_list, axis=0)
278
+
279
+ score_scaled, score_orig = compute_scores_scaled_and_orig(trues, preds, scaler)
280
+
281
+ print(_fmt_score("SCALED", score_scaled))
282
+ print(_fmt_score("ORIG ", score_orig))
283
+
284
+ return score_scaled, score_orig
285
+
286
+
287
+ def build_opt_sched(model, lr=3e-4, wd=0.01, gate_lr_ratio=0.3,
288
+ warmup_epochs=8, total_epochs=100):
289
+ decay, no_decay, gate = [], [], []
290
+ for name, p in model.named_parameters():
291
+ if not p.requires_grad:
292
+ continue
293
+ name_l = name.lower()
294
+ is_bias = name.endswith('bias')
295
+ is_norm = ('norm' in name_l) or ('bn' in name_l)
296
+ is_embed = ('embeddings' in name_l) or ('time_basis' in name_l)
297
+ if 'freq_gate' in name_l and 'log_alpha' in name_l:
298
+ no_decay.append(p)
299
+ elif is_bias or is_norm or is_embed or p.ndim == 1:
300
+ no_decay.append(p)
301
+ else:
302
+ decay.append(p)
303
+
304
+ optim = torch.optim.AdamW([
305
+ {'params': decay, 'lr': lr, 'weight_decay': wd},
306
+ {'params': no_decay, 'lr': lr, 'weight_decay': 0.0},
307
+ ], betas=(0.9, 0.99), eps=1e-8)
308
+
309
+ warmup = torch.optim.lr_scheduler.LinearLR(
310
+ optim, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
311
+ )
312
+ cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
313
+ optim, T_max=max(1, total_epochs - warmup_epochs), eta_min=lr * 0.1
314
+ )
315
+ sched = torch.optim.lr_scheduler.SequentialLR(
316
+ optim, schedulers=[warmup, cosine], milestones=[warmup_epochs]
317
+ )
318
+ return optim, sched
319
+
320
+
321
+
322
+ if __name__ == '__main__':
323
+
324
+ seeds = [2021, 2022, 2023, 2024, 2025]
325
+
326
+ scaled_results = {'mape': [], 'mae': [], 'rmse': [], 'r2': [], 'rse': []}
327
+ orig_results = {'mape': [], 'mae': [], 'rmse': [], 'r2': [], 'rse': []}
328
+
329
+ for run_idx, seed in enumerate(seeds):
330
+ print(f"\n{'='*60}")
331
+ print(f"Starting Run {run_idx + 1}/5 | seed={seed}")
332
+ print(f"{'='*60}")
333
+
334
+ torch.manual_seed(seed)
335
+ np.random.seed(seed)
336
+ if torch.cuda.is_available():
337
+ torch.cuda.manual_seed(seed)
338
+ torch.cuda.manual_seed_all(seed)
339
+
340
+ result_train_file = os.path.join('output', args.data, args.model, f'train_run_{run_idx+1}_seed_{seed}')
341
+ result_test_file = os.path.join('output', args.data, args.model, f'train_run_{run_idx+1}_seed_{seed}')
342
+ os.makedirs(result_train_file, exist_ok=True)
343
+ os.makedirs(result_test_file, exist_ok=True)
344
+
345
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
346
+
347
+ if args.model == 'SpikF_GO':
348
+ model = SpikF_GO(args, pre_length=args.pre_length, embed_size=args.embed_size,
349
+ feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
350
+ my_optim, my_lr_scheduler = build_opt_sched(
351
+ model, lr=args.learning_rate, wd=0.01,
352
+ warmup_epochs=max(4, args.train_epochs//8), total_epochs=args.train_epochs
353
+ )
354
+ elif args.model == 'SpikF_GO_CPG':
355
+ model = SpikF_GO_CPG(args, pre_length=args.pre_length, embed_size=args.embed_size,
356
+ feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
357
+ my_optim, my_lr_scheduler = build_opt_sched(
358
+ model, lr=args.learning_rate, wd=0.01,
359
+ warmup_epochs=max(4, args.train_epochs//8), total_epochs=args.train_epochs
360
+ )
361
+ elif args.model == 'FGN':
362
+ model = FGN(args, pre_length=args.pre_length, embed_size=args.embed_size,
363
+ feature_size=args.feature_size, seq_length=args.seq_length, hidden_size=args.hidden_size)
364
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
365
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
366
+ elif args.model == 'SpikF':
367
+ model = SpikF(args, input_len=args.seq_length, patch_num=args.patch_num, patch_dim=args.patch_dim,
368
+ T=args.T, blocks=args.blocks, D=args.feature_size, pred_len=args.pre_length,
369
+ tau=args.tau, alpha=args.alpha, hidden_dim=args.hidden_size)
370
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
371
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
372
+ elif args.model == 'iSpikformer':
373
+ model = iSpikformer(args, input_len=args.seq_length, patch_num=args.patch_num, patch_dim=args.patch_dim,
374
+ T=args.T, blocks=args.blocks, D=args.feature_size, pred_len=args.pre_length,
375
+ tau=args.tau, alpha=args.alpha, hidden_dim=args.hidden_size)
376
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
377
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
378
+ elif args.model == 'TSGRU':
379
+ model = TSGRU(args, hidden_size=args.hidden_size, layers=args.blocks,
380
+ num_steps=args.T, input_size=args.feature_size)
381
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
382
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
383
+ elif args.model == 'TSTCN':
384
+ model = TSTCN(args=args, num_levels=args.blocks)
385
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
386
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
387
+ elif args.model == 'TSFormer':
388
+ model = TSFormer(args=args)
389
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
390
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
391
+ elif args.model == 'Spikformer_CPG':
392
+ model = Spikformer_CPG(args=args)
393
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
394
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
395
+ elif args.model == 'SpikeGRU':
396
+ model = SpikeGRU(args, hidden_size=args.hidden_size, layers=args.blocks,
397
+ num_steps=args.T, input_size=args.feature_size)
398
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
399
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
400
+ elif args.model == 'SpikeRNN_CPG':
401
+ model = SpikeRNN_CPG(args, hidden_size=args.hidden_size, layers=args.blocks,
402
+ num_steps=args.T, input_size=args.feature_size)
403
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
404
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
405
+ elif args.model == 'SpikeTCN_CPG':
406
+ model = SpikeTCN_CPG(args=args, num_levels=args.blocks)
407
+ my_optim = torch.optim.RMSprop(params=model.parameters(), lr=args.learning_rate, eps=1e-08)
408
+ my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=my_optim, gamma=args.decay_rate)
409
+ else:
410
+ raise ValueError(f"Unknown model: {args.model}")
411
+
412
+ model = model.to(device)
413
+ forecast_loss = nn.MSELoss(reduction='mean').to(device)
414
+
415
+ # train
416
+ for epoch in range(args.train_epochs):
417
+ warm = int(0.3 * args.train_epochs)
418
+ cool = epoch >= warm
419
+
420
+ epoch_start_time = time.time()
421
+ model.train()
422
+ loss_total = 0.0
423
+ cnt = 0
424
+
425
+ for x, y in train_dataloader:
426
+ if args.model in MODELS_SET2 and args.model != 'TSGRU':
427
+ reset_states(model=model)
428
+ elif args.model == 'TSGRU':
429
+ remove(model=model.net[0].tslif)
430
+
431
+ x = x.float().to(device)
432
+ y = y.float().to(device)
433
+
434
+ forecast, aux = model(x)
435
+
436
+ if len(forecast.shape) == 4:
437
+ y_rep = y.repeat(args.T, 1, 1, 1)
438
+ else:
439
+ y_rep = y
440
+
441
+ if (args.model in ['SpikF_GO', 'SpikF_GO_CPG']) and args.energy_loss:
442
+ energy_lambda = 20.0
443
+ mse = forecast_loss(forecast, y_rep)
444
+ adaptive_lambda = (mse.detach() / 100.0) * energy_lambda
445
+ loss = mse + adaptive_lambda * aux["rho_hat"]
446
+ else:
447
+ loss = forecast_loss(forecast, y_rep)
448
+
449
+ my_optim.zero_grad(set_to_none=True)
450
+ loss.backward()
451
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
452
+ my_optim.step()
453
+
454
+ loss_total += float(loss)
455
+ cnt += 1
456
+
457
+ if args.model not in MODELS_SET2:
458
+ functional.reset_net(model)
459
+
460
+ if (epoch + 1) % args.exponential_decay_step == 0:
461
+ my_lr_scheduler.step()
462
+
463
+ if (epoch + 1) % args.validate_freq == 0:
464
+ val_loss = validate(model, val_dataloader, train_scaler)
465
+ enc_rate_v = float(aux.get('enc_rate', torch.tensor(0.0)))
466
+ gate_l0_v = float(aux.get('rho_hat', torch.tensor(0.0)))
467
+ freq_act_v = float(aux.get('freq_mask_active', torch.tensor(0.0)))
468
+
469
+ print('Run {} | epoch {:03d} | {:5.2f}s | train_loss {:5.4f} | val_loss {:5.4f} | enc_rate {:.3f} | gate_L0 {:.3f} | f_active {:.3f}'.format(
470
+ run_idx + 1, epoch, (time.time() - epoch_start_time), loss_total / max(1, cnt), val_loss,
471
+ enc_rate_v, gate_l0_v, freq_act_v))
472
+
473
+ save_model_ts(model, result_train_file, epoch)
474
+
475
+ save_model_ts(model, result_train_file, f'final_run_{run_idx+1}')
476
+
477
+ print("--- TEST ---")
478
+ score_scaled, score_orig = test(model, result_test_file, train_scaler, load_epoch=97)
479
+
480
+ scaled_results['mape'].append(score_scaled[0])
481
+ scaled_results['mae'].append(score_scaled[1])
482
+ scaled_results['rmse'].append(score_scaled[2])
483
+ scaled_results['r2'].append(score_scaled[3])
484
+ scaled_results['rse'].append(score_scaled[4])
485
+
486
+ orig_results['mape'].append(score_orig[0])
487
+ orig_results['mae'].append(score_orig[1])
488
+ orig_results['rmse'].append(score_orig[2])
489
+ orig_results['r2'].append(score_orig[3])
490
+ orig_results['rse'].append(score_orig[4])
491
+
492
+ print(f"Run {run_idx + 1} completed.")
493
+ print(_fmt_score("Results", score_scaled))
494
+
495
+ def _mean_std(arr):
496
+ arr = np.asarray(arr, dtype=np.float64)
497
+ return float(np.mean(arr)), float(np.std(arr))
498
+
499
+ print(f"\n{'='*60}")
500
+ print("FINAL RESULTS ACROSS RUNS ")
501
+ print(f"{'='*60}")
502
+
503
+ for tag, store in [("SCALED", scaled_results)]:
504
+ mape_pct = np.asarray(store['mape'], dtype=np.float64) * 100.0
505
+ m_mean, m_std = _mean_std(mape_pct)
506
+ a_mean, a_std = _mean_std(store['mae'])
507
+ r_mean, r_std = _mean_std(store['rmse'])
508
+ r2_mean, r2_std = _mean_std(store['r2'])
509
+ rse_mean, rse_std = _mean_std(store['rse'])
510
+
511
+ print(f"\n[{tag}]")
512
+ print(f"MAPE: {mape_pct} | mean={m_mean:.6f} std={m_std:.6f}")
513
+ print(f"MAE : {np.array(store['mae'])} | mean={a_mean:.6f} std={a_std:.6f}")
514
+ print(f"RMSE: {np.array(store['rmse'])} | mean={r_mean:.6f} std={r_std:.6f}")
515
+ print(f"R2 : {np.array(store['r2'])} | mean={r2_mean:.6f} std={r2_std:.6f}")
516
+ print(f"RSE : {np.array(store['rse'])} | mean={rse_mean:.6f} std={rse_std:.6f}")
517
+
518
+ summary_file = os.path.join('output', args.data, args.model, 'summary_results.txt')
519
+ os.makedirs(os.path.dirname(summary_file), exist_ok=True)
520
+
521
+ with open(summary_file, 'w') as f:
522
+ f.write("Results across 5 runs:\n")
523
+ f.write(f"Seeds used: {seeds}\n\n")
524
+
525
+ for tag, store in [("SCALED", scaled_results)]:
526
+ mape_pct = np.asarray(store['mape'], dtype=np.float64) * 100.0
527
+ m_mean, m_std = _mean_std(mape_pct)
528
+ a_mean, a_std = _mean_std(store['mae'])
529
+ r_mean, r_std = _mean_std(store['rmse'])
530
+ r2_mean, r2_std = _mean_std(store['r2'])
531
+ rse_mean, rse_std = _mean_std(store['rse'])
532
+
533
+ f.write(f"[{tag}]\n")
534
+ f.write(f"MAPE - Individual: {mape_pct}\n")
535
+ f.write(f"MAPE - Mean: {m_mean:.6f}, Std: {m_std:.6f}\n")
536
+ f.write(f"MAE - Individual: {np.array(store['mae'])}\n")
537
+ f.write(f"MAE - Mean: {a_mean:.6f}, Std: {a_std:.6f}\n")
538
+ f.write(f"RMSE - Individual: {np.array(store['rmse'])}\n")
539
+ f.write(f"RMSE - Mean: {r_mean:.6f}, Std: {r_std:.6f}\n")
540
+ f.write(f"R2 - Individual: {np.array(store['r2'])}\n")
541
+ f.write(f"R2 - Mean: {r2_mean:.6f}, Std: {r2_std:.6f}\n\n")
542
+ f.write(f"RSE - Individual: {np.array(store['rse'])}\n")
543
+ f.write(f"RSE - Mean: {rse_mean:.6f}, Std: {rse_std:.6f}\n\n")
544
+
545
+ print(f"\nSaved summary to: {summary_file}")
utils/utils.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+
7
+
8
+ def concat_fun(inputs, axis=-1):
9
+ if len(inputs) == 1:
10
+ return inputs[0]
11
+ else:
12
+ return torch.cat(inputs, dim=axis)
13
+
14
+
15
+ def slice_arrays(arrays, start=None, stop=None):
16
+ """Slice an array or list of arrays.
17
+
18
+ This takes an array-like, or a list of
19
+ array-likes, and outputs:
20
+ - arrays[start:stop] if `arrays` is an array-like
21
+ - [x[start:stop] for x in arrays] if `arrays` is a list
22
+
23
+ Can also work on list/array of indices: `slice_arrays(x, indices)`
24
+
25
+ Arguments:
26
+ arrays: Single array or list of arrays.
27
+ start: can be an integer index (start index)
28
+ or a list/array of indices
29
+ stop: integer (stop index); should be None if
30
+ `start` was a list.
31
+
32
+ Returns:
33
+ A slice of the array(s).
34
+
35
+ Raises:
36
+ ValueError: If the value of start is a list and stop is not None.
37
+ """
38
+
39
+ if arrays is None:
40
+ return [None]
41
+
42
+ if isinstance(arrays, np.ndarray):
43
+ arrays = [arrays]
44
+
45
+ if isinstance(start, list) and stop is not None:
46
+ raise ValueError('The stop argument has to be None if the value of start '
47
+ 'is a list.')
48
+ elif isinstance(arrays, list):
49
+ if hasattr(start, '__len__'):
50
+ # hdf5 datasets only support list objects as indices
51
+ if hasattr(start, 'shape'):
52
+ start = start.tolist()
53
+ return [None if x is None else x[start] for x in arrays]
54
+ else:
55
+ if len(arrays) == 1:
56
+ return arrays[0][start:stop]
57
+ return [None if x is None else x[start:stop] for x in arrays]
58
+ else:
59
+ if hasattr(start, '__len__'):
60
+ if hasattr(start, 'shape'):
61
+ start = start.tolist()
62
+ return arrays[start]
63
+ elif hasattr(start, '__getitem__'):
64
+ return arrays[start:stop]
65
+ else:
66
+ return [None]
67
+
68
+
69
+ def save_model(model, model_dir, epoch=None):
70
+ if model_dir is None:
71
+ return
72
+ if not os.path.exists(model_dir):
73
+ os.makedirs(model_dir)
74
+ epoch = str(epoch) if epoch else ''
75
+ file_name = os.path.join(model_dir, epoch + '_dhfm.pt')
76
+ with open(file_name, 'wb') as f:
77
+ torch.save(model, f)
78
+
79
+
80
+ def load_model(model_dir, epoch=None):
81
+ if not model_dir:
82
+ return
83
+ epoch = str(epoch) if epoch else ''
84
+ file_name = os.path.join(model_dir, epoch + '_dhfm.pt')
85
+ if not os.path.exists(model_dir):
86
+ os.makedirs(model_dir)
87
+ if not os.path.exists(file_name):
88
+ return
89
+ with open(file_name, 'rb') as f:
90
+ model = torch.load(f)
91
+ return model
92
+
93
+ def masked_MAPE(v, v_, axis=None):
94
+ '''
95
+ Mean absolute percentage error.
96
+ :param v: np.ndarray or int, ground truth.
97
+ :param v_: np.ndarray or int, prediction.
98
+ :param axis: axis to do calculation.
99
+ :return: int, MAPE averages on all elements of input.
100
+ '''
101
+ mask = (v == 0)
102
+ percentage = np.abs(v_ - v) / np.abs(v)
103
+ if np.any(mask):
104
+ masked_array = np.ma.masked_array(percentage, mask=mask) # mask the dividing-zero as invalid
105
+ result = masked_array.mean(axis=axis)
106
+ if isinstance(result, np.ma.MaskedArray):
107
+ return result.filled(np.nan)
108
+ else:
109
+ return result
110
+ return np.mean(percentage, axis).astype(np.float64)
111
+
112
+ """
113
+ original
114
+ def MAPE(v, v_, axis=None):
115
+ '''
116
+ Mean absolute percentage error.
117
+ :param v: np.ndarray or int, ground truth.
118
+ :param v_: np.ndarray or int, prediction.
119
+ :param axis: axis to do calculation.
120
+ :return: int, MAPE averages on all elements of input.
121
+ '''
122
+ mape = (np.abs(v_ - v) / np.abs(v)+1e-5).astype(np.float64)
123
+ mape = np.where(mape > 5, 5, mape)
124
+ return np.mean(mape, axis)
125
+
126
+ """
127
+
128
+ def MAPE(v, v_, axis=None):
129
+ '''
130
+ Mean absolute percentage error.
131
+ :param v: np.ndarray or int, ground truth.
132
+ :param v_: np.ndarray or int, prediction.
133
+ :param axis: axis to do calculation.
134
+ :return: float, MAPE averages on all elements of input.
135
+ '''
136
+ mape = (np.abs(v_ - v) / (np.abs(v) + 1e-5)).astype(np.float64)
137
+ mape = np.where(mape > 5, 5, mape) # clip extreme values
138
+ return np.mean(mape, axis)
139
+
140
+
141
+ #def MAPE(true, pred):
142
+ # return np.mean(np.abs((pred - true) / (true+1e-5)))
143
+
144
+ def smape(P, A):
145
+ nz = np.where(A > 0)
146
+ Pz = P[nz]
147
+ Az = A[nz]
148
+
149
+ return np.mean(2 * np.abs(Az - Pz) / (np.abs(Az) + np.abs(Pz)))
150
+
151
+
152
+ def R2(y, y_hat, axis=None, eps=1e-12):
153
+ """
154
+ R^2 score for arrays shaped like [count, time_step, node] (or compatible).
155
+ axis=None -> global scalar R2 over all elements.
156
+ axis can be int or tuple of ints: reduce over those axes, keeping the others.
157
+ """
158
+ y = np.asarray(y, dtype=np.float64)
159
+ y_hat = np.asarray(y_hat, dtype=np.float64)
160
+
161
+ # residual sum of squares
162
+ ss_res = np.sum((y - y_hat) ** 2, axis=axis)
163
+
164
+ # total sum of squares around mean of y along the same reduction axis
165
+ y_mean = np.mean(y, axis=axis, keepdims=True)
166
+ ss_tot = np.sum((y - y_mean) ** 2, axis=axis)
167
+
168
+ # Avoid division by zero (constant targets)
169
+ denom = ss_tot + eps
170
+ r2 = 1.0 - (ss_res / denom)
171
+
172
+ # If ss_tot is truly ~0, R2 is not well-defined; mark as nan
173
+ # (Optional) If you want 0.0 instead, replace np.nan with 0.0
174
+ if np.isscalar(ss_tot):
175
+ if ss_tot < eps:
176
+ return np.nan
177
+ return float(r2)
178
+
179
+ r2 = np.where(ss_tot < eps, np.nan, r2)
180
+ return r2.astype(np.float64)
181
+
182
+ def RSE(v, v_, axis=None, eps=1e-12):
183
+ '''
184
+ Relative squared error (rooted):
185
+ sqrt( sum((v_ - v)^2) / sum((v - mean(v))^2) )
186
+ :param v: np.ndarray or int, ground truth.
187
+ :param v_: np.ndarray or int, prediction.
188
+ :param axis: axis to do calculation.
189
+ :return: float, RSE on all elements of input (or reduced by axis).
190
+ '''
191
+ v = np.asarray(v, dtype=np.float64)
192
+ v_ = np.asarray(v_, dtype=np.float64)
193
+
194
+ v_mean = np.mean(v, axis=axis, keepdims=True)
195
+ num = np.sum((v_ - v) ** 2, axis=axis)
196
+ denom = np.sum((v - v_mean) ** 2, axis=axis)
197
+ return np.sqrt(num / (denom + eps)).astype(np.float64)
198
+
199
+ def RMSE(v, v_, axis=None):
200
+ '''
201
+ Mean squared error.
202
+ :param v: np.ndarray or int, ground truth.
203
+ :param v_: np.ndarray or int, prediction.
204
+ :param axis: axis to do calculation.
205
+ :return: int, RMSE averages on all elements of input.
206
+ '''
207
+ return np.sqrt(np.mean((v_ - v) ** 2, axis)).astype(np.float64)
208
+
209
+
210
+ def MAE(v, v_, axis=None):
211
+ '''
212
+ Mean absolute error.
213
+ :param v: np.ndarray or int, ground truth.
214
+ :param v_: np.ndarray or int, prediction.
215
+ :param axis: axis to do calculation.
216
+ :return: int, MAE averages on all elements of input.
217
+ '''
218
+ return np.mean(np.abs(v_ - v), axis).astype(np.float64)
219
+
220
+
221
+ def evaluate(y, y_hat, by_step=False, by_node=False):
222
+ '''
223
+ :param y: array in shape of [count, time_step, node].
224
+ :param y_hat: in same shape with y.
225
+ :param by_step: evaluate by time_step dim.
226
+ :param by_node: evaluate by node dim.
227
+ :return: array of mape, mae and rmse.
228
+ '''
229
+ if not by_step and not by_node:
230
+ return MAPE(y, y_hat), MAE(y, y_hat), RMSE(y, y_hat), R2(y, y_hat), RSE(y, y_hat)
231
+ if by_step and by_node:
232
+ return MAPE(y, y_hat, axis=0), MAE(y, y_hat, axis=0), RMSE(y, y_hat, axis=0), R2(y, y_hat, axis=0)
233
+ if by_step:
234
+ return MAPE(y, y_hat, axis=(0, 2)), MAE(y, y_hat, axis=(0, 2)), RMSE(y, y_hat, axis=(0, 2)), R2(y, y_hat, axis=(0, 2))
235
+ if by_node:
236
+ return MAPE(y, y_hat, axis=(0, 1)), MAE(y, y_hat, axis=(0, 1)), RMSE(y, y_hat, axis=(0, 1)), R2(y, y_hat, axis=(0, 1))
237
+
238
+
239
+ def save_model_ts(model, path, epoch):
240
+ if not os.path.exists(path):
241
+ os.makedirs(path)
242
+ filename = 'epoch_{}.pth'.format(epoch)
243
+ f = os.path.join(path, filename)
244
+ # Save state_dict instead of the entire model
245
+ torch.save(model.state_dict(), f)
246
+
247
+ def load_model_ts(model, path, epoch):
248
+ """Load state dict into an existing model instance"""
249
+ filename = 'epoch_{}.pth'.format(epoch)
250
+ f = os.path.join(path, filename)
251
+ model.load_state_dict(torch.load(f))
252
+ return model