File size: 3,104 Bytes
f4dd973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
---
license: apache-2.0
library_name: onnxruntime
tags:
  - sleep-staging
  - knowledge-distillation
  - onnx
  - torchscript
  - edge-deployment
  - eeg
  - polysomnography
datasets:
  - sleep-edf
metrics:
  - accuracy
pipeline_tag: tabular-classification
model-index:
  - name: Conv1dStack_T2_a0.3
    results:
      - task:
          type: tabular-classification
          name: Sleep Stage Classification
        dataset:
          type: sleep-edf
          name: Sleep-EDF
        metrics:
          - type: accuracy
            value: 0.752
            name: Validation Accuracy
---

# Conv1dStack_T2_a0.3 — Distilled Sleep Stage Classifier

A tiny (103KB, 25,957 params) sleep stage classifier distilled from
[SleepFM](https://arxiv.org/abs/2311.07919) for real-time edge deployment on
NVIDIA Jetson TK1 and similar constrained devices.

## Model Details

| Property | Value |
|----------|-------|
| Architecture | Conv1dStack |
| Parameters | 25,957 |
| Model size | 103.3 KB |
| Distillation temperature | 2 |
| Alpha (hard label weight) | 0.3 |
| Validation accuracy | 75.2% |
| Input shape | `(B, S, 128)` — pre-pooled embeddings |
| Output | 5-class logits (Wake, REM, N1, N2, N3) |
| ONNX opset | 11 |

### Student Architecture Config

```yaml
Conv1dStack:
  hidden_channels: 32
  kernel_size: 5
```

### Distillation Setup

- **Teacher**: SleepFM (SleepEventLSTMClassifier) — biLSTM, 128-dim embeddings
- **Sweep**: 3 temperatures × 3 alphas × 3 architectures = 27 experiments
- **Training**: 50 epochs, AdamW lr=0.001, early stopping (patience=10)
- **Data**: Sleep-EDF (5 train / 1 val / 1 test subjects)

### Target Hardware

| Spec | Value |
|------|-------|
| Device | NVIDIA Jetson TK1 |
| CUDA cores | 192 |
| RAM | 2 GB |
| Compute capability | 3.2 |

## Usage

### ONNX Runtime

```python
import numpy as np
import onnxruntime as ort

session = ort.InferenceSession("Conv1dStack_T2_a0.3.onnx")

# Input: pre-pooled embeddings (batch, seq_len, 128)
embeddings = np.random.randn(1, 120, 128).astype(np.float32)
logits = session.run(None, {"input": embeddings})[0]
predicted_stages = np.argmax(logits, axis=-1)

# Stage mapping: 0=Wake, 1=REM, 2=N1, 3=N2, 4=N3
print(predicted_stages)
```

### TorchScript

```python
import torch

model = torch.jit.load("Conv1dStack_T2_a0.3.pt")
embeddings = torch.randn(1, 120, 128)
logits = model(embeddings)
predicted_stages = logits.argmax(dim=-1)
```

## Files

| File | Format | Description |
|------|--------|-------------|
| `Conv1dStack_T2_a0.3.onnx` | ONNX (opset 11) | For ONNX Runtime / TensorRT |
| `Conv1dStack_T2_a0.3.pt` | TorchScript | For PyTorch / LibTorch on-device |

## Limitations

- Trained on Sleep-EDF only (7 subjects) — may not generalize to other PSG datasets
- Expects pre-pooled 128-dim embeddings from SleepFM's encoder, not raw EEG
- No per-class metrics reported (overall accuracy only)
- Distilled from a single teacher checkpoint

## Citation

```bibtex
@misc{circadia-distill-2026,
  title={Distilled Sleep Stage Classifier for Edge Deployment},
  year={2026},
  url={https://github.com/circadia}
}
```