File size: 3,104 Bytes
f4dd973 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | ---
license: apache-2.0
library_name: onnxruntime
tags:
- sleep-staging
- knowledge-distillation
- onnx
- torchscript
- edge-deployment
- eeg
- polysomnography
datasets:
- sleep-edf
metrics:
- accuracy
pipeline_tag: tabular-classification
model-index:
- name: Conv1dStack_T2_a0.3
results:
- task:
type: tabular-classification
name: Sleep Stage Classification
dataset:
type: sleep-edf
name: Sleep-EDF
metrics:
- type: accuracy
value: 0.752
name: Validation Accuracy
---
# Conv1dStack_T2_a0.3 — Distilled Sleep Stage Classifier
A tiny (103KB, 25,957 params) sleep stage classifier distilled from
[SleepFM](https://arxiv.org/abs/2311.07919) for real-time edge deployment on
NVIDIA Jetson TK1 and similar constrained devices.
## Model Details
| Property | Value |
|----------|-------|
| Architecture | Conv1dStack |
| Parameters | 25,957 |
| Model size | 103.3 KB |
| Distillation temperature | 2 |
| Alpha (hard label weight) | 0.3 |
| Validation accuracy | 75.2% |
| Input shape | `(B, S, 128)` — pre-pooled embeddings |
| Output | 5-class logits (Wake, REM, N1, N2, N3) |
| ONNX opset | 11 |
### Student Architecture Config
```yaml
Conv1dStack:
hidden_channels: 32
kernel_size: 5
```
### Distillation Setup
- **Teacher**: SleepFM (SleepEventLSTMClassifier) — biLSTM, 128-dim embeddings
- **Sweep**: 3 temperatures × 3 alphas × 3 architectures = 27 experiments
- **Training**: 50 epochs, AdamW lr=0.001, early stopping (patience=10)
- **Data**: Sleep-EDF (5 train / 1 val / 1 test subjects)
### Target Hardware
| Spec | Value |
|------|-------|
| Device | NVIDIA Jetson TK1 |
| CUDA cores | 192 |
| RAM | 2 GB |
| Compute capability | 3.2 |
## Usage
### ONNX Runtime
```python
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession("Conv1dStack_T2_a0.3.onnx")
# Input: pre-pooled embeddings (batch, seq_len, 128)
embeddings = np.random.randn(1, 120, 128).astype(np.float32)
logits = session.run(None, {"input": embeddings})[0]
predicted_stages = np.argmax(logits, axis=-1)
# Stage mapping: 0=Wake, 1=REM, 2=N1, 3=N2, 4=N3
print(predicted_stages)
```
### TorchScript
```python
import torch
model = torch.jit.load("Conv1dStack_T2_a0.3.pt")
embeddings = torch.randn(1, 120, 128)
logits = model(embeddings)
predicted_stages = logits.argmax(dim=-1)
```
## Files
| File | Format | Description |
|------|--------|-------------|
| `Conv1dStack_T2_a0.3.onnx` | ONNX (opset 11) | For ONNX Runtime / TensorRT |
| `Conv1dStack_T2_a0.3.pt` | TorchScript | For PyTorch / LibTorch on-device |
## Limitations
- Trained on Sleep-EDF only (7 subjects) — may not generalize to other PSG datasets
- Expects pre-pooled 128-dim embeddings from SleepFM's encoder, not raw EEG
- No per-class metrics reported (overall accuracy only)
- Distilled from a single teacher checkpoint
## Citation
```bibtex
@misc{circadia-distill-2026,
title={Distilled Sleep Stage Classifier for Edge Deployment},
year={2026},
url={https://github.com/circadia}
}
```
|