Spaces:
Sleeping
Sleeping
Commit Β·
e22f65c
1
Parent(s): cc6a769
Initial deploy: real-time weather forecast demo
Browse files- .gitattributes +0 -34
- README.md +39 -7
- __pycache__/app.cpython-313.pyc +0 -0
- __pycache__/hrrr_fetch.cpython-313.pyc +0 -0
- __pycache__/model_utils.cpython-313.pyc +0 -0
- __pycache__/var_mapping.cpython-313.pyc +0 -0
- __pycache__/visualization.cpython-313.pyc +0 -0
- app.py +356 -0
- checkpoints/cnn_baseline.pt +3 -0
- checkpoints/resnet18.pt +3 -0
- hrrr_fetch.py +142 -0
- model_utils.py +180 -0
- models/__init__.py +45 -0
- models/__pycache__/__init__.cpython-313.pyc +0 -0
- models/__pycache__/cnn_3d.cpython-313.pyc +0 -0
- models/__pycache__/cnn_baseline.cpython-313.pyc +0 -0
- models/__pycache__/cnn_multi_frame.cpython-313.pyc +0 -0
- models/__pycache__/convnext_baseline.cpython-313.pyc +0 -0
- models/__pycache__/resnet_baseline.cpython-313.pyc +0 -0
- models/__pycache__/vit.cpython-313.pyc +0 -0
- models/cnn_3d.py +99 -0
- models/cnn_baseline.py +81 -0
- models/cnn_multi_frame.py +64 -0
- models/convnext_baseline.py +43 -0
- models/resnet_baseline.py +41 -0
- models/vit.py +116 -0
- packages.txt +2 -0
- requirements.txt +10 -0
- var_mapping.py +82 -0
- visualization.py +234 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,14 +1,46 @@
|
|
| 1 |
---
|
| 2 |
-
title: Weather
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
short_description: A real time weather prediction system with CNN based model.
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Tufts Jumbo Weather Forecast
|
| 3 |
+
emoji: "\U0001F324"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.1"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Tufts Jumbo β 24h Weather Forecast
|
| 14 |
+
|
| 15 |
+
Real-time deep-learning weather prediction for the Jumbo Statue at Tufts University.
|
| 16 |
+
|
| 17 |
+
## How It Works
|
| 18 |
+
|
| 19 |
+
1. **Fetches** the latest HRRR 3 km analysis data from NOAA (42 atmospheric channels, 450x449 grid covering the US Northeast)
|
| 20 |
+
2. **Runs** a trained CNN through the spatial snapshot
|
| 21 |
+
3. **Predicts** 6 weather variables 24 hours ahead at a single target point (Jumbo Statue, Medford MA)
|
| 22 |
+
|
| 23 |
+
## Models
|
| 24 |
+
|
| 25 |
+
| Model | Parameters | Architecture |
|
| 26 |
+
|-------|-----------|-------------|
|
| 27 |
+
| CNN Baseline | 11.3M | 6 residual blocks, progressive spatial downsampling |
|
| 28 |
+
| ResNet-18 | 11.2M | Modified torchvision ResNet-18 (42-channel input) |
|
| 29 |
+
|
| 30 |
+
## Input Channels (42)
|
| 31 |
+
|
| 32 |
+
Surface: 2m temperature, 2m humidity, 10m U/V wind, surface gust, solar radiation, 1hr precipitation.
|
| 33 |
+
Atmospheric: CAPE, dew point (5 levels), geopotential height (5 levels), temperature (5 levels),
|
| 34 |
+
U-wind (6 levels), V-wind (6 levels), cloud cover (4 layers), precipitable water, relative humidity, VIL.
|
| 35 |
+
|
| 36 |
+
## Output Variables
|
| 37 |
+
|
| 38 |
+
Temperature (K), Relative Humidity (%), U-Wind (m/s), V-Wind (m/s), Wind Gust (m/s), Precipitation (mm).
|
| 39 |
+
|
| 40 |
+
## Data Source
|
| 41 |
+
|
| 42 |
+
[HRRR (High-Resolution Rapid Refresh)](https://rapidrefresh.noaa.gov/hrrr/) β NOAA's 3 km hourly weather model, fetched in real-time from AWS S3 via [Herbie](https://herbie.readthedocs.io/).
|
| 43 |
+
|
| 44 |
+
## Course
|
| 45 |
+
|
| 46 |
+
Tufts CS 137 β Deep Neural Networks, Spring 2026
|
__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
__pycache__/hrrr_fetch.cpython-313.pyc
ADDED
|
Binary file (5.6 kB). View file
|
|
|
__pycache__/model_utils.cpython-313.pyc
ADDED
|
Binary file (6.83 kB). View file
|
|
|
__pycache__/var_mapping.cpython-313.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
__pycache__/visualization.cpython-313.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tufts Jumbo Weather Forecast β Deep Learning Demo
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
cd demo && python app.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import timedelta
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from hrrr_fetch import fetch_hrrr_input
|
| 14 |
+
from model_utils import run_forecast, load_model, AVAILABLE_MODELS
|
| 15 |
+
from visualization import (
|
| 16 |
+
get_static_maps,
|
| 17 |
+
plot_temperature,
|
| 18 |
+
plot_temperature_placeholder,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# ββ CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
|
| 26 |
+
CUSTOM_CSS = """
|
| 27 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
| 28 |
+
|
| 29 |
+
:root {
|
| 30 |
+
--font: -apple-system, BlinkMacSystemFont, "SF Pro Display",
|
| 31 |
+
"SF Pro Text", Inter, "Helvetica Neue", Arial, sans-serif;
|
| 32 |
+
--bg: #F2F2F7;
|
| 33 |
+
--card: #FFFFFF;
|
| 34 |
+
--border: #E5E5EA;
|
| 35 |
+
--text: #1D1D1F;
|
| 36 |
+
--muted: #86868B;
|
| 37 |
+
--accent: #0A84FF;
|
| 38 |
+
--dark: #1C1C1E;
|
| 39 |
+
}
|
| 40 |
+
* { font-family: var(--font) !important; }
|
| 41 |
+
|
| 42 |
+
.gradio-container {
|
| 43 |
+
max-width: 1320px !important;
|
| 44 |
+
margin: 0 auto !important;
|
| 45 |
+
background: var(--bg) !important;
|
| 46 |
+
padding-bottom: 24px !important;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/* ββ Top bar ββ */
|
| 50 |
+
.top-bar {
|
| 51 |
+
background: linear-gradient(135deg, #1C1C1E 0%, #2C2C2E 100%);
|
| 52 |
+
border-radius: 16px;
|
| 53 |
+
padding: 28px 36px;
|
| 54 |
+
margin-bottom: 16px;
|
| 55 |
+
display: flex;
|
| 56 |
+
justify-content: space-between;
|
| 57 |
+
align-items: center;
|
| 58 |
+
}
|
| 59 |
+
.top-bar .title {
|
| 60 |
+
font-size: 24px; font-weight: 700;
|
| 61 |
+
color: #F5F5F7; letter-spacing: -0.3px;
|
| 62 |
+
}
|
| 63 |
+
.top-bar .subtitle {
|
| 64 |
+
font-size: 13px; color: #98989D;
|
| 65 |
+
margin-top: 2px;
|
| 66 |
+
}
|
| 67 |
+
.top-bar .location {
|
| 68 |
+
text-align: right;
|
| 69 |
+
font-size: 13px; color: #98989D;
|
| 70 |
+
line-height: 1.6;
|
| 71 |
+
}
|
| 72 |
+
.top-bar .location b {
|
| 73 |
+
color: #F5F5F7; font-weight: 600;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* ββ Hero card ββ */
|
| 77 |
+
.hero-card {
|
| 78 |
+
background: var(--card);
|
| 79 |
+
border-radius: 16px;
|
| 80 |
+
border: 1px solid var(--border);
|
| 81 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.04);
|
| 82 |
+
padding: 32px 36px 28px;
|
| 83 |
+
margin-bottom: 16px;
|
| 84 |
+
}
|
| 85 |
+
.hero-main {
|
| 86 |
+
display: flex;
|
| 87 |
+
align-items: baseline;
|
| 88 |
+
gap: 20px;
|
| 89 |
+
margin-bottom: 4px;
|
| 90 |
+
}
|
| 91 |
+
.hero-temp {
|
| 92 |
+
font-size: 64px; font-weight: 300;
|
| 93 |
+
color: var(--text); letter-spacing: -2px;
|
| 94 |
+
line-height: 1;
|
| 95 |
+
}
|
| 96 |
+
.hero-temp-unit {
|
| 97 |
+
font-size: 28px; font-weight: 400;
|
| 98 |
+
color: var(--muted); margin-left: 2px;
|
| 99 |
+
}
|
| 100 |
+
.hero-status {
|
| 101 |
+
font-size: 20px; font-weight: 500;
|
| 102 |
+
color: var(--text); padding-left: 8px;
|
| 103 |
+
border-left: 3px solid var(--accent);
|
| 104 |
+
}
|
| 105 |
+
.hero-metrics {
|
| 106 |
+
display: flex;
|
| 107 |
+
gap: 12px;
|
| 108 |
+
margin: 20px 0 18px;
|
| 109 |
+
}
|
| 110 |
+
.metric-tile {
|
| 111 |
+
flex: 1;
|
| 112 |
+
background: var(--bg);
|
| 113 |
+
border-radius: 12px;
|
| 114 |
+
padding: 14px 16px;
|
| 115 |
+
text-align: center;
|
| 116 |
+
}
|
| 117 |
+
.metric-value {
|
| 118 |
+
font-size: 22px; font-weight: 600;
|
| 119 |
+
color: var(--text); line-height: 1.2;
|
| 120 |
+
}
|
| 121 |
+
.metric-label {
|
| 122 |
+
font-size: 12px; font-weight: 500;
|
| 123 |
+
color: var(--muted);
|
| 124 |
+
text-transform: uppercase;
|
| 125 |
+
letter-spacing: 0.5px;
|
| 126 |
+
margin-top: 4px;
|
| 127 |
+
}
|
| 128 |
+
.hero-meta {
|
| 129 |
+
font-size: 13px; color: var(--muted);
|
| 130 |
+
line-height: 1.6;
|
| 131 |
+
}
|
| 132 |
+
.hero-meta code {
|
| 133 |
+
background: var(--bg); padding: 2px 6px;
|
| 134 |
+
border-radius: 4px; font-size: 12px;
|
| 135 |
+
}
|
| 136 |
+
.hero-placeholder {
|
| 137 |
+
text-align: center;
|
| 138 |
+
padding: 36px 0;
|
| 139 |
+
color: var(--muted);
|
| 140 |
+
font-size: 16px; font-weight: 500;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/* ββ Map section ββ */
|
| 144 |
+
.maps-heading {
|
| 145 |
+
font-size: 11px; font-weight: 600;
|
| 146 |
+
text-transform: uppercase; letter-spacing: 0.8px;
|
| 147 |
+
color: var(--muted);
|
| 148 |
+
margin: 8px 0 8px 4px;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
.map-cell {
|
| 152 |
+
background: var(--card) !important;
|
| 153 |
+
border-radius: 14px !important;
|
| 154 |
+
border: 1px solid var(--border) !important;
|
| 155 |
+
box-shadow: 0 1px 4px rgba(0,0,0,0.04) !important;
|
| 156 |
+
overflow: hidden !important;
|
| 157 |
+
min-height: 380px !important;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/* ββ Controls inside hero ββ */
|
| 161 |
+
.controls-row {
|
| 162 |
+
display: flex; align-items: end; gap: 10px;
|
| 163 |
+
margin-top: 18px; padding-top: 16px;
|
| 164 |
+
border-top: 1px solid var(--border);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/* ββ Status ββ */
|
| 168 |
+
.status-text p, .status-text em {
|
| 169 |
+
font-size: 12px !important; color: var(--muted) !important;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/* ββ About ββ */
|
| 173 |
+
.about-section {
|
| 174 |
+
font-size: 13px !important; color: #6E6E73 !important;
|
| 175 |
+
line-height: 1.65 !important;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/* ββ Button ββ */
|
| 179 |
+
button.primary {
|
| 180 |
+
background: var(--accent) !important;
|
| 181 |
+
border: none !important; border-radius: 10px !important;
|
| 182 |
+
font-weight: 600 !important; font-size: 15px !important;
|
| 183 |
+
padding: 10px 28px !important;
|
| 184 |
+
}
|
| 185 |
+
button.primary:hover { background: #0A74E0 !important; }
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
+
|
| 190 |
+
model_choices = [
|
| 191 |
+
f"{v['display_name']} ({v['params']})" for v in AVAILABLE_MODELS.values()
|
| 192 |
+
]
|
| 193 |
+
model_keys = list(AVAILABLE_MODELS.keys())
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _resolve_model(display: str) -> str:
|
| 197 |
+
return model_keys[model_choices.index(display)]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _hero_placeholder() -> str:
|
| 201 |
+
return (
|
| 202 |
+
'<div class="hero-card">'
|
| 203 |
+
'<div class="hero-placeholder">'
|
| 204 |
+
"Click <b>Run Forecast</b> to fetch real-time HRRR data and generate a 24-hour prediction."
|
| 205 |
+
"</div></div>"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _hero_html(r: dict, cycle_str: str, forecast_str: str, model_label: str) -> str:
|
| 210 |
+
return (
|
| 211 |
+
'<div class="hero-card">'
|
| 212 |
+
# temperature + status
|
| 213 |
+
'<div class="hero-main">'
|
| 214 |
+
f'<div><span class="hero-temp">{r["temperature_c"]:.1f}</span>'
|
| 215 |
+
f'<span class="hero-temp-unit">Β°C</span></div>'
|
| 216 |
+
f'<div class="hero-status">{r["rain_status"]}</div>'
|
| 217 |
+
"</div>"
|
| 218 |
+
# metric tiles
|
| 219 |
+
'<div class="hero-metrics">'
|
| 220 |
+
f'<div class="metric-tile"><div class="metric-value">{r["temperature_f"]:.0f}Β°F</div>'
|
| 221 |
+
'<div class="metric-label">Temperature</div></div>'
|
| 222 |
+
f'<div class="metric-tile"><div class="metric-value">{r["humidity_pct"]:.0f}%</div>'
|
| 223 |
+
'<div class="metric-label">Humidity</div></div>'
|
| 224 |
+
f'<div class="metric-tile"><div class="metric-value">{r["wind_speed_ms"]:.1f}</div>'
|
| 225 |
+
f'<div class="metric-label">Wind m/s {r["wind_dir_str"]}</div></div>'
|
| 226 |
+
f'<div class="metric-tile"><div class="metric-value">{r["gust_ms"]:.1f}</div>'
|
| 227 |
+
'<div class="metric-label">Gust m/s</div></div>'
|
| 228 |
+
f'<div class="metric-tile"><div class="metric-value">{r["precipitation_mm"]:.2f}</div>'
|
| 229 |
+
'<div class="metric-label">Precip mm</div></div>'
|
| 230 |
+
"</div>"
|
| 231 |
+
# meta line
|
| 232 |
+
'<div class="hero-meta">'
|
| 233 |
+
f"Based on  <code>{cycle_str}</code>   "
|
| 234 |
+
f"Forecast valid  <b>{forecast_str}</b>   "
|
| 235 |
+
f"Model  <b>{model_label}</b>"
|
| 236 |
+
"</div>"
|
| 237 |
+
"</div>"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ββ Main callback ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 242 |
+
|
| 243 |
+
def do_forecast(model_display: str, progress=gr.Progress()):
|
| 244 |
+
model_name = _resolve_model(model_display)
|
| 245 |
+
|
| 246 |
+
progress(0.02, desc="Finding latest HRRR cycle...")
|
| 247 |
+
try:
|
| 248 |
+
input_array, cycle_time = fetch_hrrr_input(
|
| 249 |
+
progress_callback=lambda f, m: progress(f, desc=m),
|
| 250 |
+
)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
raise gr.Error(f"HRRR fetch failed: {e}")
|
| 253 |
+
|
| 254 |
+
cycle_str = cycle_time.strftime("%Y-%m-%d %H:%M UTC")
|
| 255 |
+
forecast_time = cycle_time + timedelta(hours=24)
|
| 256 |
+
forecast_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
|
| 257 |
+
|
| 258 |
+
progress(0.95, desc="Running model inference...")
|
| 259 |
+
try:
|
| 260 |
+
r = run_forecast(model_name, input_array)
|
| 261 |
+
except Exception as e:
|
| 262 |
+
raise gr.Error(f"Inference failed: {e}")
|
| 263 |
+
|
| 264 |
+
model_label = model_display.split("(")[0].strip()
|
| 265 |
+
hero = _hero_html(r, cycle_str, forecast_str, model_label)
|
| 266 |
+
temp_fig = plot_temperature(input_array, r, cycle_str, forecast_str)
|
| 267 |
+
status = f"Forecast complete β HRRR cycle {cycle_str}"
|
| 268 |
+
|
| 269 |
+
return hero, temp_fig, status
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ββ Build UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 273 |
+
|
| 274 |
+
# Pre-render static maps at import time
|
| 275 |
+
logger.info("Rendering static basemaps...")
|
| 276 |
+
_sat_fig, _street_fig = get_static_maps()
|
| 277 |
+
_temp_placeholder = plot_temperature_placeholder()
|
| 278 |
+
logger.info("Basemaps ready.")
|
| 279 |
+
|
| 280 |
+
with gr.Blocks(title="Tufts Jumbo Weather Forecast") as demo:
|
| 281 |
+
|
| 282 |
+
# ββ Top bar βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
+
gr.HTML(
|
| 284 |
+
'<div class="top-bar">'
|
| 285 |
+
'<div>'
|
| 286 |
+
'<div class="title">Tufts Jumbo Weather</div>'
|
| 287 |
+
'<div class="subtitle">Real-time deep-learning forecast</div>'
|
| 288 |
+
"</div>"
|
| 289 |
+
'<div class="location">'
|
| 290 |
+
"<b>Medford, MA</b><br>"
|
| 291 |
+
"42.41Β°N   71.12Β°W"
|
| 292 |
+
"</div>"
|
| 293 |
+
"</div>"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# ββ Hero card βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
hero_html = gr.HTML(_hero_placeholder())
|
| 298 |
+
|
| 299 |
+
# ββ Controls ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 300 |
+
with gr.Row(elem_classes=["controls-row"]):
|
| 301 |
+
model_dd = gr.Dropdown(
|
| 302 |
+
choices=model_choices, value=model_choices[0],
|
| 303 |
+
label="Model", scale=3,
|
| 304 |
+
)
|
| 305 |
+
run_btn = gr.Button("Run Forecast", variant="primary", scale=1)
|
| 306 |
+
|
| 307 |
+
status_bar = gr.Markdown(
|
| 308 |
+
"_Ready β click **Run Forecast**._",
|
| 309 |
+
elem_classes=["status-text"],
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# ββ Maps ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 313 |
+
gr.HTML('<div class="maps-heading">Coverage Maps β 1 350 km Γ 1 350 km   3 km resolution</div>')
|
| 314 |
+
|
| 315 |
+
with gr.Row(equal_height=True):
|
| 316 |
+
sat_plot = gr.Plot(
|
| 317 |
+
value=_sat_fig, label="Satellite",
|
| 318 |
+
elem_classes=["map-cell"],
|
| 319 |
+
)
|
| 320 |
+
street_plot = gr.Plot(
|
| 321 |
+
value=_street_fig, label="Reference Map",
|
| 322 |
+
elem_classes=["map-cell"],
|
| 323 |
+
)
|
| 324 |
+
temp_plot = gr.Plot(
|
| 325 |
+
value=_temp_placeholder, label="Temperature",
|
| 326 |
+
elem_classes=["map-cell"],
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# ββ About βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 330 |
+
with gr.Accordion("About this demo", open=False):
|
| 331 |
+
gr.Markdown(
|
| 332 |
+
"**Data**   HRRR 3 km analysis from NOAA (AWS S3, via Herbie). "
|
| 333 |
+
"42 atmospheric channels covering the US Northeast.\n\n"
|
| 334 |
+
"**Models**   CNN Baseline (11.3 M params) Β· ResNet-18 (11.2 M params) β "
|
| 335 |
+
"predict 6 weather variables 24 h ahead for a single target point.\n\n"
|
| 336 |
+
"**Course**   Tufts CS 137 β Deep Neural Networks, Spring 2026",
|
| 337 |
+
elem_classes=["about-section"],
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# ββ Callbacks βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 341 |
+
run_btn.click(
|
| 342 |
+
fn=do_forecast,
|
| 343 |
+
inputs=[model_dd],
|
| 344 |
+
outputs=[hero_html, temp_plot, status_bar],
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
logger.info("Pre-loading default model...")
|
| 350 |
+
try:
|
| 351 |
+
load_model(model_keys[0])
|
| 352 |
+
logger.info("Model loaded.")
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.warning(f"Pre-load failed: {e}")
|
| 355 |
+
|
| 356 |
+
demo.launch(share=False, css=CUSTOM_CSS)
|
checkpoints/cnn_baseline.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9978778103ca02974eec82bd676a58639b0e13fa893656bd622153a44ec68603
|
| 3 |
+
size 45379961
|
checkpoints/resnet18.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d74d4559da02319248276d3d616cda17766f422350dfce61de041c03b9cfff0
|
| 3 |
+
size 45288199
|
hrrr_fetch.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fetch real-time HRRR analysis data from NOAA AWS S3 via Herbie.
|
| 3 |
+
|
| 4 |
+
Downloads 42 individual GRIB2 messages (one per input channel),
|
| 5 |
+
extracts the New England subgrid (450x449), and stacks them into
|
| 6 |
+
the model's expected input format.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from datetime import datetime, timedelta, timezone
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from var_mapping import HRRR_MAPPING, NE_Y_SLICE, NE_X_SLICE
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def find_latest_hrrr_cycle(max_lookback_hours: int = 6) -> datetime:
|
| 20 |
+
"""
|
| 21 |
+
Find the most recent HRRR cycle that is available on AWS S3.
|
| 22 |
+
HRRR data typically becomes available ~45-90 minutes after valid time.
|
| 23 |
+
|
| 24 |
+
Returns a tz-naive datetime in UTC (Herbie requirement).
|
| 25 |
+
"""
|
| 26 |
+
from herbie import Herbie
|
| 27 |
+
|
| 28 |
+
now = datetime.now(timezone.utc).replace(tzinfo=None) # tz-naive UTC
|
| 29 |
+
for hours_ago in range(2, max_lookback_hours + 1):
|
| 30 |
+
cycle_time = (now - timedelta(hours=hours_ago)).replace(
|
| 31 |
+
minute=0, second=0, microsecond=0
|
| 32 |
+
)
|
| 33 |
+
try:
|
| 34 |
+
H = Herbie(
|
| 35 |
+
cycle_time,
|
| 36 |
+
model="hrrr",
|
| 37 |
+
product="sfc",
|
| 38 |
+
fxx=0,
|
| 39 |
+
verbose=False,
|
| 40 |
+
)
|
| 41 |
+
# Check if the index file exists (fast check without downloading data)
|
| 42 |
+
if H.idx is not None:
|
| 43 |
+
logger.info(f"Found HRRR cycle: {cycle_time:%Y-%m-%d %H:%M UTC}")
|
| 44 |
+
return cycle_time
|
| 45 |
+
except Exception:
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
f"No HRRR data available in the last {max_lookback_hours} hours. "
|
| 50 |
+
"NOAA servers may be temporarily unavailable."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _fetch_single_variable(cycle_time: datetime, mapping: dict) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Fetch one variable from HRRR and extract the NE subgrid.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
np.ndarray of shape (450, 449)
|
| 60 |
+
"""
|
| 61 |
+
from herbie import Herbie
|
| 62 |
+
|
| 63 |
+
H = Herbie(
|
| 64 |
+
cycle_time,
|
| 65 |
+
model="hrrr",
|
| 66 |
+
product=mapping["product"],
|
| 67 |
+
fxx=mapping["fxx"],
|
| 68 |
+
verbose=False,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
ds = H.xarray(mapping["search"], remove_grib=True)
|
| 72 |
+
|
| 73 |
+
# xarray dataset may have different variable names depending on the GRIB message.
|
| 74 |
+
# Get the first data variable (excluding coordinates).
|
| 75 |
+
data_vars = [v for v in ds.data_vars if v not in ("latitude", "longitude", "gribfile_projection")]
|
| 76 |
+
if not data_vars:
|
| 77 |
+
raise ValueError(f"No data variable found for {mapping['name']}")
|
| 78 |
+
|
| 79 |
+
field = ds[data_vars[0]].values # Full CONUS grid
|
| 80 |
+
|
| 81 |
+
# Extract NE subgrid
|
| 82 |
+
subgrid = field[NE_Y_SLICE, NE_X_SLICE]
|
| 83 |
+
|
| 84 |
+
if subgrid.shape != (450, 449):
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"Unexpected shape {subgrid.shape} for {mapping['name']}, expected (450, 449)"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return subgrid.astype(np.float32)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fetch_hrrr_input(
|
| 93 |
+
cycle_time: datetime = None,
|
| 94 |
+
progress_callback=None,
|
| 95 |
+
) -> tuple[np.ndarray, datetime]:
|
| 96 |
+
"""
|
| 97 |
+
Fetch all 42 HRRR channels and stack into model input format.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
cycle_time: Specific HRRR cycle to fetch. If None, finds the latest.
|
| 101 |
+
progress_callback: Optional callable(fraction, description) for progress updates.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
(input_array, cycle_time) where input_array is (450, 449, 42) float32.
|
| 105 |
+
"""
|
| 106 |
+
if cycle_time is None:
|
| 107 |
+
if progress_callback:
|
| 108 |
+
progress_callback(0.05, "Finding latest HRRR cycle...")
|
| 109 |
+
cycle_time = find_latest_hrrr_cycle()
|
| 110 |
+
|
| 111 |
+
channels = []
|
| 112 |
+
n_channels = len(HRRR_MAPPING)
|
| 113 |
+
failed = []
|
| 114 |
+
|
| 115 |
+
for i, mapping in enumerate(HRRR_MAPPING):
|
| 116 |
+
if progress_callback:
|
| 117 |
+
frac = 0.1 + 0.85 * (i / n_channels)
|
| 118 |
+
progress_callback(frac, f"Fetching {mapping['name']} ({i+1}/{n_channels})...")
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
field = _fetch_single_variable(cycle_time, mapping)
|
| 122 |
+
channels.append(field)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.warning(f"Failed to fetch {mapping['name']}: {e}")
|
| 125 |
+
failed.append(mapping["name"])
|
| 126 |
+
# Fill with zeros as fallback for individual missing channels
|
| 127 |
+
channels.append(np.zeros((450, 449), dtype=np.float32))
|
| 128 |
+
|
| 129 |
+
if failed:
|
| 130 |
+
logger.warning(f"Failed channels ({len(failed)}/{n_channels}): {failed}")
|
| 131 |
+
if len(failed) > n_channels // 2:
|
| 132 |
+
raise RuntimeError(
|
| 133 |
+
f"Too many channels failed ({len(failed)}/{n_channels}). "
|
| 134 |
+
"HRRR data may be unavailable."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
input_array = np.stack(channels, axis=-1) # (450, 449, 42)
|
| 138 |
+
|
| 139 |
+
if progress_callback:
|
| 140 |
+
progress_callback(1.0, "Data fetch complete!")
|
| 141 |
+
|
| 142 |
+
return input_array, cycle_time
|
model_utils.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model loading and inference utilities for the weather forecast demo.
|
| 3 |
+
|
| 4 |
+
Wraps the existing inference/predict.py logic, adding user-friendly
|
| 5 |
+
post-processing (Celsius, wind speed/direction, rain likelihood).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# In HF Space, models/ is in the same directory as this file
|
| 16 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 17 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 18 |
+
|
| 19 |
+
from models import create_model, get_model_defaults
|
| 20 |
+
|
| 21 |
+
# ββ Model cache (loaded once, reused across requests) ββββββββββββββββββ
|
| 22 |
+
_model_cache: dict = {}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
TARGET_VARS = [
|
| 26 |
+
("TMP@2m_above_ground", "Temperature (2m)", "K"),
|
| 27 |
+
("RH@2m_above_ground", "Relative Humidity", "%"),
|
| 28 |
+
("UGRD@10m_above_ground", "U-Wind (10m)", "m/s"),
|
| 29 |
+
("VGRD@10m_above_ground", "V-Wind (10m)", "m/s"),
|
| 30 |
+
("GUST@surface", "Wind Gust", "m/s"),
|
| 31 |
+
("APCP_1hr_acc_fcst@surface", "Precipitation (1hr)", "mm"),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Available models with display info
|
| 35 |
+
AVAILABLE_MODELS = {
|
| 36 |
+
"cnn_baseline": {
|
| 37 |
+
"display_name": "CNN Baseline",
|
| 38 |
+
"checkpoint": "checkpoints/cnn_baseline.pt",
|
| 39 |
+
"params": "11.3M",
|
| 40 |
+
},
|
| 41 |
+
"resnet18": {
|
| 42 |
+
"display_name": "ResNet-18",
|
| 43 |
+
"checkpoint": "checkpoints/resnet18.pt",
|
| 44 |
+
"params": "11.2M",
|
| 45 |
+
},
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_model(model_name: str, device: str = "cpu"):
|
| 50 |
+
"""
|
| 51 |
+
Load a trained model from checkpoint. Caches in memory for reuse.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
(model, norm_stats) tuple
|
| 55 |
+
"""
|
| 56 |
+
if model_name in _model_cache:
|
| 57 |
+
return _model_cache[model_name]
|
| 58 |
+
|
| 59 |
+
ckpt_path = PROJECT_ROOT / AVAILABLE_MODELS[model_name]["checkpoint"]
|
| 60 |
+
if not ckpt_path.exists():
|
| 61 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 62 |
+
|
| 63 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 64 |
+
args = ckpt["args"]
|
| 65 |
+
ckpt_model_name = args["model"]
|
| 66 |
+
|
| 67 |
+
defaults = get_model_defaults(ckpt_model_name)
|
| 68 |
+
n_frames = args.get("n_frames") or defaults["n_frames"]
|
| 69 |
+
|
| 70 |
+
model_kwargs = {
|
| 71 |
+
"n_input_channels": 42,
|
| 72 |
+
"n_targets": 6,
|
| 73 |
+
"base_channels": args.get("base_channels", 64),
|
| 74 |
+
}
|
| 75 |
+
if n_frames > 1:
|
| 76 |
+
model_kwargs["n_frames"] = n_frames
|
| 77 |
+
|
| 78 |
+
model = create_model(ckpt_model_name, **model_kwargs)
|
| 79 |
+
model.load_state_dict(ckpt["model"])
|
| 80 |
+
model.to(device).eval()
|
| 81 |
+
|
| 82 |
+
norm_stats = ckpt.get("norm_stats")
|
| 83 |
+
|
| 84 |
+
_model_cache[model_name] = (model, norm_stats)
|
| 85 |
+
return model, norm_stats
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def predict_raw(model, norm_stats, input_array: np.ndarray, device: str = "cpu") -> np.ndarray:
|
| 89 |
+
"""
|
| 90 |
+
Run inference on a (450, 449, 42) input array.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
np.ndarray of shape (6,) with denormalized physical values.
|
| 94 |
+
"""
|
| 95 |
+
x = torch.from_numpy(input_array).float()
|
| 96 |
+
x = x.permute(2, 0, 1).unsqueeze(0) # (1, 42, 450, 449)
|
| 97 |
+
|
| 98 |
+
if norm_stats:
|
| 99 |
+
mean = norm_stats["input_mean"]
|
| 100 |
+
std = norm_stats["input_std"]
|
| 101 |
+
# Ensure correct device
|
| 102 |
+
if isinstance(mean, torch.Tensor):
|
| 103 |
+
mean = mean.float()
|
| 104 |
+
std = std.float()
|
| 105 |
+
x = (x - mean) / (std + 1e-7)
|
| 106 |
+
|
| 107 |
+
x = x.to(device)
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
pred = model(x).squeeze(0).cpu() # (6,)
|
| 110 |
+
|
| 111 |
+
if norm_stats:
|
| 112 |
+
target_mean = norm_stats["target_mean"]
|
| 113 |
+
target_std = norm_stats["target_std"]
|
| 114 |
+
if isinstance(target_mean, torch.Tensor):
|
| 115 |
+
target_mean = target_mean.float()
|
| 116 |
+
target_std = target_std.float()
|
| 117 |
+
pred = pred * target_std + target_mean
|
| 118 |
+
|
| 119 |
+
return pred.numpy()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _wind_direction_str(degrees: float) -> str:
|
| 123 |
+
"""Convert wind direction in degrees to compass string."""
|
| 124 |
+
dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE",
|
| 125 |
+
"S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"]
|
| 126 |
+
idx = round(degrees / 22.5) % 16
|
| 127 |
+
return dirs[idx]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def format_forecast(pred: np.ndarray) -> dict:
|
| 131 |
+
"""
|
| 132 |
+
Convert raw model output (6 physical values) into a user-friendly forecast dict.
|
| 133 |
+
"""
|
| 134 |
+
temp_k = float(pred[0])
|
| 135 |
+
rh = float(pred[1])
|
| 136 |
+
u_wind = float(pred[2])
|
| 137 |
+
v_wind = float(pred[3])
|
| 138 |
+
gust = float(pred[4])
|
| 139 |
+
apcp = float(pred[5])
|
| 140 |
+
|
| 141 |
+
# Derived quantities
|
| 142 |
+
temp_c = temp_k - 273.15
|
| 143 |
+
temp_f = temp_c * 9 / 5 + 32
|
| 144 |
+
wind_speed = math.sqrt(u_wind**2 + v_wind**2)
|
| 145 |
+
# Meteorological wind direction: direction FROM which wind blows
|
| 146 |
+
wind_dir_deg = (math.degrees(math.atan2(-u_wind, -v_wind)) + 360) % 360
|
| 147 |
+
wind_dir_str = _wind_direction_str(wind_dir_deg)
|
| 148 |
+
|
| 149 |
+
# Rain likelihood based on APCP threshold
|
| 150 |
+
apcp = max(apcp, 0.0) # Clamp negative predictions
|
| 151 |
+
if apcp > 5.0:
|
| 152 |
+
rain_str = "Heavy Rain Likely"
|
| 153 |
+
elif apcp > 2.0:
|
| 154 |
+
rain_str = "Rain Likely"
|
| 155 |
+
elif apcp > 0.5:
|
| 156 |
+
rain_str = "Light Rain Possible"
|
| 157 |
+
else:
|
| 158 |
+
rain_str = "No Rain Expected"
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"temperature_k": temp_k,
|
| 162 |
+
"temperature_c": temp_c,
|
| 163 |
+
"temperature_f": temp_f,
|
| 164 |
+
"humidity_pct": max(0.0, min(100.0, rh)),
|
| 165 |
+
"u_wind_ms": u_wind,
|
| 166 |
+
"v_wind_ms": v_wind,
|
| 167 |
+
"wind_speed_ms": wind_speed,
|
| 168 |
+
"wind_dir_deg": wind_dir_deg,
|
| 169 |
+
"wind_dir_str": wind_dir_str,
|
| 170 |
+
"gust_ms": max(gust, 0.0),
|
| 171 |
+
"precipitation_mm": apcp,
|
| 172 |
+
"rain_status": rain_str,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def run_forecast(model_name: str, input_array: np.ndarray, device: str = "cpu") -> dict:
|
| 177 |
+
"""Full pipeline: load model β predict β format results."""
|
| 178 |
+
model, norm_stats = load_model(model_name, device)
|
| 179 |
+
pred = predict_raw(model, norm_stats, input_array, device)
|
| 180 |
+
return format_forecast(pred)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model registry for weather forecasting architectures.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from models import create_model, MODEL_REGISTRY
|
| 6 |
+
model = create_model("cnn_baseline", n_input_channels=42)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .cnn_baseline import BaselineCNN
|
| 10 |
+
from .cnn_multi_frame import MultiFrameCNN
|
| 11 |
+
from .cnn_3d import CNN3D
|
| 12 |
+
from .vit import WeatherViT
|
| 13 |
+
from .resnet_baseline import ResNet18Baseline
|
| 14 |
+
from .convnext_baseline import ConvNeXtBaseline
|
| 15 |
+
|
| 16 |
+
MODEL_REGISTRY = {
|
| 17 |
+
"cnn_baseline": BaselineCNN,
|
| 18 |
+
"cnn_multi_frame": MultiFrameCNN,
|
| 19 |
+
"cnn_3d": CNN3D,
|
| 20 |
+
"vit": WeatherViT,
|
| 21 |
+
"resnet18": ResNet18Baseline,
|
| 22 |
+
"convnext_tiny": ConvNeXtBaseline,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# Default model-specific settings
|
| 26 |
+
MODEL_DEFAULTS = {
|
| 27 |
+
"cnn_baseline": {"n_frames": 1, "stack_mode": "channel"},
|
| 28 |
+
"cnn_multi_frame": {"n_frames": 4, "stack_mode": "channel"},
|
| 29 |
+
"cnn_3d": {"n_frames": 4, "stack_mode": "temporal"},
|
| 30 |
+
"vit": {"n_frames": 1, "stack_mode": "channel"},
|
| 31 |
+
"resnet18": {"n_frames": 1, "stack_mode": "channel"},
|
| 32 |
+
"convnext_tiny": {"n_frames": 1, "stack_mode": "channel"},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_model(name, **kwargs):
|
| 37 |
+
"""Instantiate a model by name with given kwargs."""
|
| 38 |
+
if name not in MODEL_REGISTRY:
|
| 39 |
+
raise ValueError(f"Unknown model: {name}. Available: {list(MODEL_REGISTRY.keys())}")
|
| 40 |
+
return MODEL_REGISTRY[name](**kwargs)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_model_defaults(name):
|
| 44 |
+
"""Return default n_frames and stack_mode for a model."""
|
| 45 |
+
return MODEL_DEFAULTS.get(name, {"n_frames": 1, "stack_mode": "channel"})
|
models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
models/__pycache__/cnn_3d.cpython-313.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
models/__pycache__/cnn_baseline.cpython-313.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
models/__pycache__/cnn_multi_frame.cpython-313.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
models/__pycache__/convnext_baseline.cpython-313.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
models/__pycache__/resnet_baseline.cpython-313.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
models/__pycache__/vit.cpython-313.pyc
ADDED
|
Binary file (6.42 kB). View file
|
|
|
models/cnn_3d.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
3D CNN for spatiotemporal weather forecasting.
|
| 3 |
+
|
| 4 |
+
Uses Conv3d to jointly model spatial and temporal patterns
|
| 5 |
+
from multiple consecutive weather snapshots.
|
| 6 |
+
|
| 7 |
+
Input: (B, k, C, H, W) β k frames, each with C channels
|
| 8 |
+
Output: (B, 6)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResBlock3D(nn.Module):
|
| 16 |
+
"""3D residual block with separate temporal and spatial convolutions."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, in_ch, out_ch, stride_spatial=1, stride_temporal=1):
|
| 19 |
+
super().__init__()
|
| 20 |
+
stride = (stride_temporal, stride_spatial, stride_spatial)
|
| 21 |
+
|
| 22 |
+
self.conv1 = nn.Conv3d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
|
| 23 |
+
self.bn1 = nn.BatchNorm3d(out_ch)
|
| 24 |
+
self.conv2 = nn.Conv3d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
|
| 25 |
+
self.bn2 = nn.BatchNorm3d(out_ch)
|
| 26 |
+
self.relu = nn.ReLU(inplace=True)
|
| 27 |
+
|
| 28 |
+
self.shortcut = nn.Identity()
|
| 29 |
+
if any(s != 1 for s in stride) or in_ch != out_ch:
|
| 30 |
+
self.shortcut = nn.Sequential(
|
| 31 |
+
nn.Conv3d(in_ch, out_ch, 1, stride=stride, bias=False),
|
| 32 |
+
nn.BatchNorm3d(out_ch),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 37 |
+
out = self.bn2(self.conv2(out))
|
| 38 |
+
out = self.relu(out + self.shortcut(x))
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CNN3D(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
3D CNN for spatiotemporal weather forecasting.
|
| 45 |
+
|
| 46 |
+
Input: (B, k, C, H, W) where k=n_frames, C=n_input_channels
|
| 47 |
+
Output: (B, 6)
|
| 48 |
+
|
| 49 |
+
The temporal dimension is collapsed early (by stride-2 temporal convolutions
|
| 50 |
+
and pooling), while spatial dimensions are progressively downsampled.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64):
|
| 54 |
+
super().__init__()
|
| 55 |
+
ch = base_channels
|
| 56 |
+
|
| 57 |
+
# (B, k, C, H, W) -> (B, C, k, H, W) is done in forward()
|
| 58 |
+
self.stem = nn.Sequential(
|
| 59 |
+
nn.Conv3d(n_input_channels, ch, kernel_size=(3, 7, 7),
|
| 60 |
+
stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
|
| 61 |
+
nn.BatchNorm3d(ch),
|
| 62 |
+
nn.ReLU(inplace=True),
|
| 63 |
+
)
|
| 64 |
+
# spatial: 450x449 -> 225x225, temporal: k -> k
|
| 65 |
+
|
| 66 |
+
self.layer1 = ResBlock3D(ch, ch, stride_spatial=1, stride_temporal=1)
|
| 67 |
+
# Collapse temporal dimension
|
| 68 |
+
self.layer2 = ResBlock3D(ch, ch * 2, stride_spatial=2, stride_temporal=2)
|
| 69 |
+
# temporal: k -> k//2, spatial: 225 -> 113
|
| 70 |
+
self.layer3 = ResBlock3D(ch * 2, ch * 4, stride_spatial=2, stride_temporal=2)
|
| 71 |
+
# temporal: k//2 -> 1 (for k=4), spatial: 113 -> 57
|
| 72 |
+
self.layer4 = ResBlock3D(ch * 4, ch * 4, stride_spatial=2, stride_temporal=1)
|
| 73 |
+
# spatial: 57 -> 29
|
| 74 |
+
self.layer5 = ResBlock3D(ch * 4, ch * 8, stride_spatial=2, stride_temporal=1)
|
| 75 |
+
# spatial: 29 -> 15
|
| 76 |
+
self.layer6 = ResBlock3D(ch * 8, ch * 8, stride_spatial=2, stride_temporal=1)
|
| 77 |
+
# spatial: 15 -> 8
|
| 78 |
+
|
| 79 |
+
self.pool = nn.AdaptiveAvgPool3d(1)
|
| 80 |
+
self.head = nn.Sequential(
|
| 81 |
+
nn.Flatten(),
|
| 82 |
+
nn.Linear(ch * 8, ch * 2),
|
| 83 |
+
nn.ReLU(inplace=True),
|
| 84 |
+
nn.Dropout(0.3),
|
| 85 |
+
nn.Linear(ch * 2, n_targets),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
# x: (B, k, C, H, W) -> (B, C, k, H, W) for Conv3d
|
| 90 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 91 |
+
x = self.stem(x)
|
| 92 |
+
x = self.layer1(x)
|
| 93 |
+
x = self.layer2(x)
|
| 94 |
+
x = self.layer3(x)
|
| 95 |
+
x = self.layer4(x)
|
| 96 |
+
x = self.layer5(x)
|
| 97 |
+
x = self.layer6(x)
|
| 98 |
+
x = self.pool(x)
|
| 99 |
+
return self.head(x)
|
models/cnn_baseline.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline 2D CNN for single-frame weather forecasting.
|
| 3 |
+
|
| 4 |
+
Takes a single spatial snapshot (C, H, W) and predicts 6 weather variables 24h ahead.
|
| 5 |
+
Architecture: progressive downsampling with residual blocks, global average pooling, FC head.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResBlock(nn.Module):
|
| 13 |
+
"""Residual block with two conv layers and optional downsampling."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_ch, out_ch, stride=1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(out_ch)
|
| 19 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
|
| 20 |
+
self.bn2 = nn.BatchNorm2d(out_ch)
|
| 21 |
+
self.relu = nn.ReLU(inplace=True)
|
| 22 |
+
|
| 23 |
+
self.shortcut = nn.Identity()
|
| 24 |
+
if stride != 1 or in_ch != out_ch:
|
| 25 |
+
self.shortcut = nn.Sequential(
|
| 26 |
+
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
|
| 27 |
+
nn.BatchNorm2d(out_ch),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 32 |
+
out = self.bn2(self.conv2(out))
|
| 33 |
+
out = self.relu(out + self.shortcut(x))
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BaselineCNN(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Single-frame 2D CNN encoder for weather forecasting.
|
| 40 |
+
|
| 41 |
+
Input: (B, C, 450, 449) where C = n_input_channels (42)
|
| 42 |
+
Output: (B, 6) β predicted weather variables
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, n_input_channels=42, n_targets=6, base_channels=64):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
ch = base_channels
|
| 49 |
+
self.stem = nn.Sequential(
|
| 50 |
+
nn.Conv2d(n_input_channels, ch, 7, stride=2, padding=3, bias=False),
|
| 51 |
+
nn.BatchNorm2d(ch),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
)
|
| 54 |
+
# 450x449 -> 225x225
|
| 55 |
+
|
| 56 |
+
self.layer1 = ResBlock(ch, ch, stride=1) # 225x225
|
| 57 |
+
self.layer2 = ResBlock(ch, ch * 2, stride=2) # 113x113
|
| 58 |
+
self.layer3 = ResBlock(ch * 2, ch * 4, stride=2) # 57x57
|
| 59 |
+
self.layer4 = ResBlock(ch * 4, ch * 4, stride=2) # 29x29
|
| 60 |
+
self.layer5 = ResBlock(ch * 4, ch * 8, stride=2) # 15x15
|
| 61 |
+
self.layer6 = ResBlock(ch * 8, ch * 8, stride=2) # 8x8
|
| 62 |
+
|
| 63 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 64 |
+
self.head = nn.Sequential(
|
| 65 |
+
nn.Flatten(),
|
| 66 |
+
nn.Linear(ch * 8, ch * 2),
|
| 67 |
+
nn.ReLU(inplace=True),
|
| 68 |
+
nn.Dropout(0.3),
|
| 69 |
+
nn.Linear(ch * 2, n_targets),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = self.stem(x)
|
| 74 |
+
x = self.layer1(x)
|
| 75 |
+
x = self.layer2(x)
|
| 76 |
+
x = self.layer3(x)
|
| 77 |
+
x = self.layer4(x)
|
| 78 |
+
x = self.layer5(x)
|
| 79 |
+
x = self.layer6(x)
|
| 80 |
+
x = self.pool(x)
|
| 81 |
+
return self.head(x)
|
models/cnn_multi_frame.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-frame 2D CNN for weather forecasting.
|
| 3 |
+
|
| 4 |
+
Stacks k consecutive spatial snapshots along the channel dimension,
|
| 5 |
+
allowing the model to learn temporal patterns with standard 2D convolutions.
|
| 6 |
+
|
| 7 |
+
Input: (B, k*C, H, W)
|
| 8 |
+
Output: (B, 6)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from .cnn_baseline import ResBlock
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiFrameCNN(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Multi-frame 2D CNN that concatenates consecutive frames along channels.
|
| 19 |
+
|
| 20 |
+
Identical backbone to BaselineCNN but with an adapted stem for k*C input channels,
|
| 21 |
+
plus a temporal mixing layer after the stem.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.n_frames = n_frames
|
| 27 |
+
in_ch = n_input_channels * n_frames
|
| 28 |
+
ch = base_channels
|
| 29 |
+
|
| 30 |
+
self.stem = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_ch, ch * 2, 7, stride=2, padding=3, bias=False),
|
| 32 |
+
nn.BatchNorm2d(ch * 2),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
nn.Conv2d(ch * 2, ch, 1, bias=False),
|
| 35 |
+
nn.BatchNorm2d(ch),
|
| 36 |
+
nn.ReLU(inplace=True),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.layer1 = ResBlock(ch, ch, stride=1)
|
| 40 |
+
self.layer2 = ResBlock(ch, ch * 2, stride=2)
|
| 41 |
+
self.layer3 = ResBlock(ch * 2, ch * 4, stride=2)
|
| 42 |
+
self.layer4 = ResBlock(ch * 4, ch * 4, stride=2)
|
| 43 |
+
self.layer5 = ResBlock(ch * 4, ch * 8, stride=2)
|
| 44 |
+
self.layer6 = ResBlock(ch * 8, ch * 8, stride=2)
|
| 45 |
+
|
| 46 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 47 |
+
self.head = nn.Sequential(
|
| 48 |
+
nn.Flatten(),
|
| 49 |
+
nn.Linear(ch * 8, ch * 2),
|
| 50 |
+
nn.ReLU(inplace=True),
|
| 51 |
+
nn.Dropout(0.3),
|
| 52 |
+
nn.Linear(ch * 2, n_targets),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = self.stem(x)
|
| 57 |
+
x = self.layer1(x)
|
| 58 |
+
x = self.layer2(x)
|
| 59 |
+
x = self.layer3(x)
|
| 60 |
+
x = self.layer4(x)
|
| 61 |
+
x = self.layer5(x)
|
| 62 |
+
x = self.layer6(x)
|
| 63 |
+
x = self.pool(x)
|
| 64 |
+
return self.head(x)
|
models/convnext_baseline.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ConvNeXt-Tiny baseline for weather forecasting.
|
| 3 |
+
|
| 4 |
+
Uses torchvision's ConvNeXt-Tiny with modified input/output layers
|
| 5 |
+
to accept 42-channel weather data and predict 6 target variables.
|
| 6 |
+
Trained from scratch (no pretrained weights).
|
| 7 |
+
|
| 8 |
+
Input: (B, 42, 450, 449)
|
| 9 |
+
Output: (B, 6)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torchvision.models import convnext_tiny
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ConvNeXtBaseline(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
ConvNeXt-Tiny adapted for weather forecasting.
|
| 19 |
+
|
| 20 |
+
Modifications from ImageNet ConvNeXt-Tiny:
|
| 21 |
+
- Stem conv: 3 β 42 input channels
|
| 22 |
+
- Classifier head: 1000 β n_targets outputs
|
| 23 |
+
- No pretrained weights (trained from scratch)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, n_input_channels=42, n_targets=6, **kwargs):
|
| 27 |
+
super().__init__()
|
| 28 |
+
model = convnext_tiny(weights=None)
|
| 29 |
+
|
| 30 |
+
# Replace stem conv: 3ch β 42ch
|
| 31 |
+
# ConvNeXt stem: features[0][0] is Conv2d(3, 96, 4, stride=4)
|
| 32 |
+
model.features[0][0] = nn.Conv2d(
|
| 33 |
+
n_input_channels, 96, kernel_size=4, stride=4
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Replace classifier head: 1000 β n_targets
|
| 37 |
+
# ConvNeXt head: classifier = Sequential(LayerNorm, Flatten, Linear(768, 1000))
|
| 38 |
+
model.classifier[2] = nn.Linear(768, n_targets)
|
| 39 |
+
|
| 40 |
+
self.model = model
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return self.model(x)
|
models/resnet_baseline.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ResNet-18 baseline for weather forecasting.
|
| 3 |
+
|
| 4 |
+
Uses torchvision's ResNet-18 with modified input/output layers
|
| 5 |
+
to accept 42-channel weather data and predict 6 target variables.
|
| 6 |
+
Trained from scratch (no pretrained weights).
|
| 7 |
+
|
| 8 |
+
Input: (B, 42, 450, 449)
|
| 9 |
+
Output: (B, 6)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torchvision.models import resnet18
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResNet18Baseline(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Standard ResNet-18 adapted for weather forecasting.
|
| 19 |
+
|
| 20 |
+
Modifications from ImageNet ResNet-18:
|
| 21 |
+
- conv1: 3 β 42 input channels
|
| 22 |
+
- fc: 1000 β n_targets output classes
|
| 23 |
+
- No pretrained weights (trained from scratch)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, n_input_channels=42, n_targets=6, **kwargs):
|
| 27 |
+
super().__init__()
|
| 28 |
+
model = resnet18(weights=None)
|
| 29 |
+
|
| 30 |
+
# Replace first conv: 3ch β 42ch
|
| 31 |
+
model.conv1 = nn.Conv2d(
|
| 32 |
+
n_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Replace classifier head: 1000 β n_targets
|
| 36 |
+
model.fc = nn.Linear(512, n_targets)
|
| 37 |
+
|
| 38 |
+
self.model = model
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return self.model(x)
|
models/vit.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vision Transformer (ViT) for weather forecasting.
|
| 3 |
+
|
| 4 |
+
Splits the spatial input into non-overlapping patches, projects each patch
|
| 5 |
+
into an embedding, and processes them through a Transformer encoder.
|
| 6 |
+
|
| 7 |
+
Input: (B, C, H, W) β single frame with C channels
|
| 8 |
+
Output: (B, 6)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PatchEmbedding(nn.Module):
|
| 17 |
+
"""Convert spatial input into a sequence of patch embeddings."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_channels, embed_dim, patch_size, img_h, img_w):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.patch_size = patch_size
|
| 22 |
+
self.n_patches_h = img_h // patch_size
|
| 23 |
+
self.n_patches_w = img_w // patch_size
|
| 24 |
+
self.n_patches = self.n_patches_h * self.n_patches_w
|
| 25 |
+
|
| 26 |
+
self.proj = nn.Conv2d(in_channels, embed_dim,
|
| 27 |
+
kernel_size=patch_size, stride=patch_size)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
# x: (B, C, H, W) -> (B, embed_dim, nH, nW) -> (B, n_patches, embed_dim)
|
| 31 |
+
x = self.proj(x)
|
| 32 |
+
x = x.flatten(2).transpose(1, 2)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TransformerBlock(nn.Module):
|
| 37 |
+
"""Standard Transformer encoder block with pre-norm."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, embed_dim, n_heads, mlp_ratio=4.0, dropout=0.1):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
| 42 |
+
self.attn = nn.MultiheadAttention(embed_dim, n_heads,
|
| 43 |
+
dropout=dropout, batch_first=True)
|
| 44 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
| 45 |
+
self.mlp = nn.Sequential(
|
| 46 |
+
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
|
| 47 |
+
nn.GELU(),
|
| 48 |
+
nn.Dropout(dropout),
|
| 49 |
+
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
|
| 50 |
+
nn.Dropout(dropout),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
h = self.norm1(x)
|
| 55 |
+
h, _ = self.attn(h, h, h)
|
| 56 |
+
x = x + h
|
| 57 |
+
x = x + self.mlp(self.norm2(x))
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class WeatherViT(nn.Module):
|
| 62 |
+
"""
|
| 63 |
+
Vision Transformer for weather forecasting.
|
| 64 |
+
|
| 65 |
+
Input: (B, C, 450, 449) β pads width to 450 internally
|
| 66 |
+
Output: (B, 6)
|
| 67 |
+
|
| 68 |
+
Patches the input into 15x15 patches (30x30 = 900 tokens),
|
| 69 |
+
adds CLS token and positional embeddings, runs through Transformer,
|
| 70 |
+
and predicts from the CLS token.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, n_input_channels=42, n_targets=6, patch_size=15,
|
| 74 |
+
embed_dim=256, n_layers=6, n_heads=8, mlp_ratio=4.0, dropout=0.1,
|
| 75 |
+
**kwargs):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.patch_size = patch_size
|
| 78 |
+
img_h, img_w = 450, 450 # pad to square
|
| 79 |
+
|
| 80 |
+
self.patch_embed = PatchEmbedding(n_input_channels, embed_dim,
|
| 81 |
+
patch_size, img_h, img_w)
|
| 82 |
+
n_patches = self.patch_embed.n_patches
|
| 83 |
+
|
| 84 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
|
| 85 |
+
self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim) * 0.02)
|
| 86 |
+
self.pos_drop = nn.Dropout(dropout)
|
| 87 |
+
|
| 88 |
+
self.blocks = nn.Sequential(*[
|
| 89 |
+
TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
|
| 90 |
+
for _ in range(n_layers)
|
| 91 |
+
])
|
| 92 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 93 |
+
|
| 94 |
+
self.head = nn.Sequential(
|
| 95 |
+
nn.Linear(embed_dim, embed_dim // 2),
|
| 96 |
+
nn.ReLU(inplace=True),
|
| 97 |
+
nn.Dropout(0.3),
|
| 98 |
+
nn.Linear(embed_dim // 2, n_targets),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
B, C, H, W = x.shape
|
| 103 |
+
# Pad width from 449 to 450 if needed
|
| 104 |
+
if W < 450:
|
| 105 |
+
x = nn.functional.pad(x, (0, 450 - W))
|
| 106 |
+
|
| 107 |
+
patches = self.patch_embed(x) # (B, n_patches, D)
|
| 108 |
+
cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
|
| 109 |
+
x = torch.cat([cls, patches], dim=1) # (B, n_patches+1, D)
|
| 110 |
+
x = self.pos_drop(x + self.pos_embed)
|
| 111 |
+
|
| 112 |
+
x = self.blocks(x)
|
| 113 |
+
x = self.norm(x)
|
| 114 |
+
|
| 115 |
+
cls_out = x[:, 0] # (B, D)
|
| 116 |
+
return self.head(cls_out)
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libeccodes-dev
|
| 2 |
+
libeccodes-tools
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
matplotlib
|
| 5 |
+
gradio>=4.0
|
| 6 |
+
herbie-data>=2024.1
|
| 7 |
+
cfgrib
|
| 8 |
+
xarray
|
| 9 |
+
eccodes
|
| 10 |
+
cartopy
|
var_mapping.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mapping from 42-channel VAR_LEVELS to HRRR GRIB2 Herbie search strings.
|
| 3 |
+
|
| 4 |
+
Each entry corresponds to one channel in the model input, following the exact
|
| 5 |
+
order defined in data_preparation/data_spec.py (VAR_LEVELS = TARGET_VARS + ATMOS_VARS).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# fmt: off
|
| 9 |
+
HRRR_MAPPING = [
|
| 10 |
+
# ββ Target / surface variables (channels 0-6) ββββββββββββββββββββββ
|
| 11 |
+
{"name": "TMP@2m_above_ground", "search": ":TMP:2 m above ground", "product": "sfc", "fxx": 0},
|
| 12 |
+
{"name": "RH@2m_above_ground", "search": ":RH:2 m above ground", "product": "sfc", "fxx": 0},
|
| 13 |
+
{"name": "UGRD@10m_above_ground", "search": ":UGRD:10 m above ground", "product": "sfc", "fxx": 0},
|
| 14 |
+
{"name": "VGRD@10m_above_ground", "search": ":VGRD:10 m above ground", "product": "sfc", "fxx": 0},
|
| 15 |
+
{"name": "GUST@surface", "search": ":GUST:surface", "product": "sfc", "fxx": 0},
|
| 16 |
+
{"name": "DSWRF@surface", "search": ":DSWRF:surface", "product": "sfc", "fxx": 0},
|
| 17 |
+
# APCP is accumulated; not in analysis (fxx=0). Use 1-hour forecast from same cycle.
|
| 18 |
+
{"name": "APCP_1hr_acc_fcst@surface", "search": ":APCP:surface:0-1 hour acc fcst", "product": "sfc", "fxx": 1},
|
| 19 |
+
|
| 20 |
+
# ββ Atmospheric variables (channels 7-41) ββββββββββββββββββββββββββ
|
| 21 |
+
# CAPE
|
| 22 |
+
{"name": "CAPE@surface", "search": ":CAPE:surface", "product": "sfc", "fxx": 0},
|
| 23 |
+
|
| 24 |
+
# Dew point temperature at pressure levels
|
| 25 |
+
{"name": "DPT@1000mb", "search": ":DPT:1000 mb", "product": "prs", "fxx": 0},
|
| 26 |
+
{"name": "DPT@500mb", "search": ":DPT:500 mb", "product": "prs", "fxx": 0},
|
| 27 |
+
{"name": "DPT@700mb", "search": ":DPT:700 mb", "product": "prs", "fxx": 0},
|
| 28 |
+
{"name": "DPT@850mb", "search": ":DPT:850 mb", "product": "prs", "fxx": 0},
|
| 29 |
+
{"name": "DPT@925mb", "search": ":DPT:925 mb", "product": "prs", "fxx": 0},
|
| 30 |
+
|
| 31 |
+
# Geopotential height at pressure levels + surface
|
| 32 |
+
{"name": "HGT@1000mb", "search": ":HGT:1000 mb", "product": "prs", "fxx": 0},
|
| 33 |
+
{"name": "HGT@500mb", "search": ":HGT:500 mb", "product": "prs", "fxx": 0},
|
| 34 |
+
{"name": "HGT@700mb", "search": ":HGT:700 mb", "product": "prs", "fxx": 0},
|
| 35 |
+
{"name": "HGT@850mb", "search": ":HGT:850 mb", "product": "prs", "fxx": 0},
|
| 36 |
+
{"name": "HGT@surface", "search": ":HGT:surface", "product": "sfc", "fxx": 0},
|
| 37 |
+
|
| 38 |
+
# Temperature at pressure levels
|
| 39 |
+
{"name": "TMP@1000mb", "search": ":TMP:1000 mb", "product": "prs", "fxx": 0},
|
| 40 |
+
{"name": "TMP@500mb", "search": ":TMP:500 mb", "product": "prs", "fxx": 0},
|
| 41 |
+
{"name": "TMP@700mb", "search": ":TMP:700 mb", "product": "prs", "fxx": 0},
|
| 42 |
+
{"name": "TMP@850mb", "search": ":TMP:850 mb", "product": "prs", "fxx": 0},
|
| 43 |
+
{"name": "TMP@925mb", "search": ":TMP:925 mb", "product": "prs", "fxx": 0},
|
| 44 |
+
|
| 45 |
+
# U-component wind at pressure levels
|
| 46 |
+
{"name": "UGRD@1000mb", "search": ":UGRD:1000 mb", "product": "prs", "fxx": 0},
|
| 47 |
+
{"name": "UGRD@250mb", "search": ":UGRD:250 mb", "product": "prs", "fxx": 0},
|
| 48 |
+
{"name": "UGRD@500mb", "search": ":UGRD:500 mb", "product": "prs", "fxx": 0},
|
| 49 |
+
{"name": "UGRD@700mb", "search": ":UGRD:700 mb", "product": "prs", "fxx": 0},
|
| 50 |
+
{"name": "UGRD@850mb", "search": ":UGRD:850 mb", "product": "prs", "fxx": 0},
|
| 51 |
+
{"name": "UGRD@925mb", "search": ":UGRD:925 mb", "product": "prs", "fxx": 0},
|
| 52 |
+
|
| 53 |
+
# V-component wind at pressure levels
|
| 54 |
+
{"name": "VGRD@1000mb", "search": ":VGRD:1000 mb", "product": "prs", "fxx": 0},
|
| 55 |
+
{"name": "VGRD@250mb", "search": ":VGRD:250 mb", "product": "prs", "fxx": 0},
|
| 56 |
+
{"name": "VGRD@500mb", "search": ":VGRD:500 mb", "product": "prs", "fxx": 0},
|
| 57 |
+
{"name": "VGRD@700mb", "search": ":VGRD:700 mb", "product": "prs", "fxx": 0},
|
| 58 |
+
{"name": "VGRD@850mb", "search": ":VGRD:850 mb", "product": "prs", "fxx": 0},
|
| 59 |
+
{"name": "VGRD@925mb", "search": ":VGRD:925 mb", "product": "prs", "fxx": 0},
|
| 60 |
+
|
| 61 |
+
# Cloud cover
|
| 62 |
+
{"name": "TCDC@entire_atmosphere", "search": ":TCDC:entire atmosphere", "product": "sfc", "fxx": 0},
|
| 63 |
+
{"name": "HCDC@high_cloud_layer", "search": ":HCDC:high cloud layer", "product": "sfc", "fxx": 0},
|
| 64 |
+
{"name": "MCDC@middle_cloud_layer", "search": ":MCDC:middle cloud layer", "product": "sfc", "fxx": 0},
|
| 65 |
+
{"name": "LCDC@low_cloud_layer", "search": ":LCDC:low cloud layer", "product": "sfc", "fxx": 0},
|
| 66 |
+
|
| 67 |
+
# Moisture
|
| 68 |
+
{"name": "PWAT@entire_atmosphere_single_layer", "search": ":PWAT:entire atmosphere", "product": "sfc", "fxx": 0},
|
| 69 |
+
{"name": "RHPW@entire_atmosphere", "search": ":RHPW:entire atmosphere", "product": "sfc", "fxx": 0},
|
| 70 |
+
{"name": "VIL@entire_atmosphere", "search": ":VIL:entire atmosphere", "product": "sfc", "fxx": 0},
|
| 71 |
+
]
|
| 72 |
+
# fmt: on
|
| 73 |
+
|
| 74 |
+
assert len(HRRR_MAPPING) == 42, f"Expected 42 channels, got {len(HRRR_MAPPING)}"
|
| 75 |
+
|
| 76 |
+
# New England subgrid slice in the full HRRR CONUS grid (from data_spec.py)
|
| 77 |
+
NE_Y_SLICE = slice(600, 1050) # 450 rows
|
| 78 |
+
NE_X_SLICE = slice(1350, 1799) # 449 columns
|
| 79 |
+
|
| 80 |
+
# Jumbo Statue grid location within the NE subgrid
|
| 81 |
+
JUMBO_ROW = 177
|
| 82 |
+
JUMBO_COL = 263
|
visualization.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Forecast visualization β satellite, street, and temperature maps.
|
| 3 |
+
|
| 4 |
+
Satellite and reference maps are static (rendered once at startup).
|
| 5 |
+
Temperature map updates each time a forecast is run.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib
|
| 11 |
+
matplotlib.use("Agg")
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from matplotlib.figure import Figure
|
| 14 |
+
|
| 15 |
+
import cartopy.crs as ccrs
|
| 16 |
+
import cartopy.feature as cfeature
|
| 17 |
+
import cartopy.io.img_tiles as cimgt
|
| 18 |
+
|
| 19 |
+
from var_mapping import JUMBO_ROW, JUMBO_COL
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# ββ Projection & coordinates (from data_spec.py) ββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
+
PROJ = ccrs.LambertConformal(
|
| 26 |
+
central_longitude=262.5,
|
| 27 |
+
central_latitude=38.5,
|
| 28 |
+
standard_parallels=(38.5, 38.5),
|
| 29 |
+
globe=ccrs.Globe(semimajor_axis=6371229, semiminor_axis=6371229),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_x = 1352479.8574780696 + np.arange(449) * 3000
|
| 33 |
+
_y = 212693.8474433364 + np.arange(450) * 3000
|
| 34 |
+
EXTENT = [_x[0], _x[-1], _y[0], _y[-1]]
|
| 35 |
+
|
| 36 |
+
JUMBO_LON, JUMBO_LAT = -71.1204, 42.4078
|
| 37 |
+
X_GRID, Y_GRID = np.meshgrid(_x, _y)
|
| 38 |
+
|
| 39 |
+
CITIES = [
|
| 40 |
+
("Boston", 42.36, -71.06),
|
| 41 |
+
("Providence", 41.82, -71.41),
|
| 42 |
+
("Hartford", 41.76, -72.68),
|
| 43 |
+
("Portland", 43.66, -70.26),
|
| 44 |
+
("Burlington", 44.48, -73.21),
|
| 45 |
+
("Concord", 43.21, -71.54),
|
| 46 |
+
("Albany", 42.65, -73.76),
|
| 47 |
+
("New York", 40.71, -74.01),
|
| 48 |
+
("Montreal", 45.50, -73.57),
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ββ Tile sources ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
|
| 54 |
+
class _EsriSatellite(cimgt.GoogleWTS):
|
| 55 |
+
def _image_url(self, tile):
|
| 56 |
+
x, y, z = tile
|
| 57 |
+
return (
|
| 58 |
+
"https://server.arcgisonline.com/ArcGIS/rest/services/"
|
| 59 |
+
f"World_Imagery/MapServer/tile/{z}/{y}/{x}"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class _EsriStreetMap(cimgt.GoogleWTS):
|
| 64 |
+
def _image_url(self, tile):
|
| 65 |
+
x, y, z = tile
|
| 66 |
+
return (
|
| 67 |
+
"https://server.arcgisonline.com/ArcGIS/rest/services/"
|
| 68 |
+
f"World_Street_Map/MapServer/tile/{z}/{y}/{x}"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ββ Shared style βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
_MARKER = dict(
|
| 75 |
+
marker="*", color="#FF3B30", markersize=14,
|
| 76 |
+
markeredgecolor="white", markeredgewidth=1.0,
|
| 77 |
+
transform=ccrs.PlateCarree(), zorder=20,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
_TAG = dict(
|
| 81 |
+
fontsize=8.5, fontweight="bold", fontfamily="sans-serif",
|
| 82 |
+
color="white", transform=ccrs.PlateCarree(), zorder=25,
|
| 83 |
+
bbox=dict(boxstyle="round,pad=0.25", fc="#1C1C1E", ec="none", alpha=0.80),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _make_ax(fig_or_ax=None, figsize=(7.2, 6.8)):
|
| 88 |
+
"""Create a single GeoAxes with consistent extent."""
|
| 89 |
+
if fig_or_ax is None:
|
| 90 |
+
fig, ax = plt.subplots(
|
| 91 |
+
figsize=figsize, subplot_kw={"projection": PROJ},
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
fig, ax = fig_or_ax.figure, fig_or_ax
|
| 95 |
+
ax.set_extent(EXTENT, crs=PROJ)
|
| 96 |
+
return fig, ax
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ββ Individual map renderers βββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
|
| 101 |
+
def plot_satellite() -> Figure:
|
| 102 |
+
"""Render satellite basemap (static, no weather data needed)."""
|
| 103 |
+
fig, ax = _make_ax()
|
| 104 |
+
try:
|
| 105 |
+
ax.add_image(_EsriSatellite(), 7)
|
| 106 |
+
except Exception:
|
| 107 |
+
ax.add_feature(cfeature.LAND, facecolor="#5C4A32")
|
| 108 |
+
ax.add_feature(cfeature.OCEAN, facecolor="#1B3A4B")
|
| 109 |
+
ax.add_feature(cfeature.COASTLINE, linewidth=0.6, color="white")
|
| 110 |
+
ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#aaa")
|
| 111 |
+
ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
|
| 112 |
+
ax.text(JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG)
|
| 113 |
+
ax.set_title(
|
| 114 |
+
"Satellite", fontsize=12, fontweight="600",
|
| 115 |
+
fontfamily="sans-serif", pad=8, color="#1D1D1F",
|
| 116 |
+
)
|
| 117 |
+
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
|
| 118 |
+
return fig
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def plot_street() -> Figure:
|
| 122 |
+
"""Render street / reference basemap (static)."""
|
| 123 |
+
fig, ax = _make_ax()
|
| 124 |
+
try:
|
| 125 |
+
ax.add_image(_EsriStreetMap(), 7)
|
| 126 |
+
except Exception:
|
| 127 |
+
ax.add_feature(cfeature.LAND, facecolor="#E8E4D8")
|
| 128 |
+
ax.add_feature(cfeature.OCEAN, facecolor="#AAD3DF")
|
| 129 |
+
ax.add_feature(cfeature.LAKES, facecolor="#AAD3DF", edgecolor="#888", linewidth=0.3)
|
| 130 |
+
ax.add_feature(cfeature.RIVERS, edgecolor="#AAD3DF", linewidth=0.4)
|
| 131 |
+
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
|
| 132 |
+
ax.add_feature(cfeature.BORDERS, linewidth=0.5, linestyle="--")
|
| 133 |
+
ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#888")
|
| 134 |
+
pc = ccrs.PlateCarree()
|
| 135 |
+
for name, lat, lon in CITIES:
|
| 136 |
+
ax.text(
|
| 137 |
+
lon, lat, name,
|
| 138 |
+
fontsize=7, fontfamily="sans-serif", fontweight="500",
|
| 139 |
+
color="#333", transform=pc, zorder=15,
|
| 140 |
+
bbox=dict(boxstyle="round,pad=0.15", fc="white", ec="none", alpha=0.7),
|
| 141 |
+
)
|
| 142 |
+
ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
|
| 143 |
+
ax.text(JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG)
|
| 144 |
+
ax.set_title(
|
| 145 |
+
"Reference Map", fontsize=12, fontweight="600",
|
| 146 |
+
fontfamily="sans-serif", pad=8, color="#1D1D1F",
|
| 147 |
+
)
|
| 148 |
+
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
|
| 149 |
+
return fig
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def plot_temperature(
|
| 153 |
+
input_array: np.ndarray,
|
| 154 |
+
forecast: dict,
|
| 155 |
+
cycle_str: str,
|
| 156 |
+
forecast_str: str,
|
| 157 |
+
) -> Figure:
|
| 158 |
+
"""Render 2 m temperature map with forecast annotation."""
|
| 159 |
+
fig, ax = _make_ax()
|
| 160 |
+
|
| 161 |
+
temp_field = input_array[:, :, 0] - 273.15
|
| 162 |
+
masked = np.ma.masked_invalid(temp_field)
|
| 163 |
+
|
| 164 |
+
im = ax.pcolormesh(
|
| 165 |
+
X_GRID, Y_GRID, masked,
|
| 166 |
+
cmap="RdYlBu_r", shading="auto", transform=PROJ, zorder=5,
|
| 167 |
+
)
|
| 168 |
+
ax.add_feature(cfeature.COASTLINE, linewidth=0.5, color="#444", zorder=10)
|
| 169 |
+
ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#666", zorder=10)
|
| 170 |
+
|
| 171 |
+
cbar = fig.colorbar(im, ax=ax, shrink=0.72, pad=0.03, aspect=28)
|
| 172 |
+
cbar.set_label("Β°C", fontsize=10, fontfamily="sans-serif")
|
| 173 |
+
cbar.ax.tick_params(labelsize=8)
|
| 174 |
+
|
| 175 |
+
ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
|
| 176 |
+
|
| 177 |
+
temp_c = forecast["temperature_c"]
|
| 178 |
+
temp_f = forecast["temperature_f"]
|
| 179 |
+
label = f"Forecast {forecast_str}\n{temp_c:+.1f} Β°C / {temp_f:.0f} Β°F"
|
| 180 |
+
ax.text(
|
| 181 |
+
JUMBO_LON + 0.45, JUMBO_LAT + 0.35, label,
|
| 182 |
+
fontsize=8.5, fontweight="bold", fontfamily="sans-serif",
|
| 183 |
+
color="white", transform=ccrs.PlateCarree(), zorder=25,
|
| 184 |
+
bbox=dict(boxstyle="round,pad=0.35", fc="#1C1C1E", ec="white", alpha=0.88, lw=0.8),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
ax.set_title(
|
| 188 |
+
f"2 m Temperature β {cycle_str}",
|
| 189 |
+
fontsize=12, fontweight="600",
|
| 190 |
+
fontfamily="sans-serif", pad=8, color="#1D1D1F",
|
| 191 |
+
)
|
| 192 |
+
fig.subplots_adjust(left=0.02, right=0.95, bottom=0.02, top=0.93)
|
| 193 |
+
return fig
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def plot_temperature_placeholder() -> Figure:
|
| 197 |
+
"""Empty temperature panel shown before first forecast."""
|
| 198 |
+
fig, ax = _make_ax()
|
| 199 |
+
ax.add_feature(cfeature.LAND, facecolor="#E8E4D8", zorder=1)
|
| 200 |
+
ax.add_feature(cfeature.OCEAN, facecolor="#D6EAF0", zorder=1)
|
| 201 |
+
ax.add_feature(cfeature.COASTLINE, linewidth=0.5, color="#999", zorder=2)
|
| 202 |
+
ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#bbb", zorder=2)
|
| 203 |
+
ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
|
| 204 |
+
ax.text(
|
| 205 |
+
JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG,
|
| 206 |
+
)
|
| 207 |
+
ax.text(
|
| 208 |
+
0.5, 0.50, "Click Run Forecast",
|
| 209 |
+
transform=ax.transAxes, ha="center", va="center",
|
| 210 |
+
fontsize=14, fontweight="600", fontfamily="sans-serif",
|
| 211 |
+
color="#86868B",
|
| 212 |
+
)
|
| 213 |
+
ax.set_title(
|
| 214 |
+
"2 m Temperature", fontsize=12, fontweight="600",
|
| 215 |
+
fontfamily="sans-serif", pad=8, color="#1D1D1F",
|
| 216 |
+
)
|
| 217 |
+
fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
|
| 218 |
+
return fig
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ββ Startup cache βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
|
| 223 |
+
_cache = {}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_static_maps() -> tuple[Figure, Figure]:
|
| 227 |
+
"""Return cached satellite and street map figures (rendered once)."""
|
| 228 |
+
if "satellite" not in _cache:
|
| 229 |
+
logger.info("Rendering satellite basemap...")
|
| 230 |
+
_cache["satellite"] = plot_satellite()
|
| 231 |
+
if "street" not in _cache:
|
| 232 |
+
logger.info("Rendering reference basemap...")
|
| 233 |
+
_cache["street"] = plot_street()
|
| 234 |
+
return _cache["satellite"], _cache["street"]
|