jeffliulab commited on
Commit
e22f65c
Β·
1 Parent(s): cc6a769

Initial deploy: real-time weather forecast demo

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,14 +1,46 @@
1
  ---
2
- title: Weather Predict
3
- emoji: 🏒
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: A real time weather prediction system with CNN based model.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tufts Jumbo Weather Forecast
3
+ emoji: "\U0001F324"
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "4.44.1"
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
+ # Tufts Jumbo β€” 24h Weather Forecast
14
+
15
+ Real-time deep-learning weather prediction for the Jumbo Statue at Tufts University.
16
+
17
+ ## How It Works
18
+
19
+ 1. **Fetches** the latest HRRR 3 km analysis data from NOAA (42 atmospheric channels, 450x449 grid covering the US Northeast)
20
+ 2. **Runs** a trained CNN through the spatial snapshot
21
+ 3. **Predicts** 6 weather variables 24 hours ahead at a single target point (Jumbo Statue, Medford MA)
22
+
23
+ ## Models
24
+
25
+ | Model | Parameters | Architecture |
26
+ |-------|-----------|-------------|
27
+ | CNN Baseline | 11.3M | 6 residual blocks, progressive spatial downsampling |
28
+ | ResNet-18 | 11.2M | Modified torchvision ResNet-18 (42-channel input) |
29
+
30
+ ## Input Channels (42)
31
+
32
+ Surface: 2m temperature, 2m humidity, 10m U/V wind, surface gust, solar radiation, 1hr precipitation.
33
+ Atmospheric: CAPE, dew point (5 levels), geopotential height (5 levels), temperature (5 levels),
34
+ U-wind (6 levels), V-wind (6 levels), cloud cover (4 layers), precipitable water, relative humidity, VIL.
35
+
36
+ ## Output Variables
37
+
38
+ Temperature (K), Relative Humidity (%), U-Wind (m/s), V-Wind (m/s), Wind Gust (m/s), Precipitation (mm).
39
+
40
+ ## Data Source
41
+
42
+ [HRRR (High-Resolution Rapid Refresh)](https://rapidrefresh.noaa.gov/hrrr/) β€” NOAA's 3 km hourly weather model, fetched in real-time from AWS S3 via [Herbie](https://herbie.readthedocs.io/).
43
+
44
+ ## Course
45
+
46
+ Tufts CS 137 β€” Deep Neural Networks, Spring 2026
__pycache__/app.cpython-313.pyc ADDED
Binary file (11.9 kB). View file
 
__pycache__/hrrr_fetch.cpython-313.pyc ADDED
Binary file (5.6 kB). View file
 
__pycache__/model_utils.cpython-313.pyc ADDED
Binary file (6.83 kB). View file
 
__pycache__/var_mapping.cpython-313.pyc ADDED
Binary file (3.68 kB). View file
 
__pycache__/visualization.cpython-313.pyc ADDED
Binary file (11.5 kB). View file
 
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tufts Jumbo Weather Forecast β€” Deep Learning Demo
3
+
4
+ Usage:
5
+ cd demo && python app.py
6
+ """
7
+
8
+ import logging
9
+ from datetime import timedelta
10
+
11
+ import gradio as gr
12
+
13
+ from hrrr_fetch import fetch_hrrr_input
14
+ from model_utils import run_forecast, load_model, AVAILABLE_MODELS
15
+ from visualization import (
16
+ get_static_maps,
17
+ plot_temperature,
18
+ plot_temperature_placeholder,
19
+ )
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ── CSS ───────────────────────────────────────────────────────────────
25
+
26
+ CUSTOM_CSS = """
27
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
28
+
29
+ :root {
30
+ --font: -apple-system, BlinkMacSystemFont, "SF Pro Display",
31
+ "SF Pro Text", Inter, "Helvetica Neue", Arial, sans-serif;
32
+ --bg: #F2F2F7;
33
+ --card: #FFFFFF;
34
+ --border: #E5E5EA;
35
+ --text: #1D1D1F;
36
+ --muted: #86868B;
37
+ --accent: #0A84FF;
38
+ --dark: #1C1C1E;
39
+ }
40
+ * { font-family: var(--font) !important; }
41
+
42
+ .gradio-container {
43
+ max-width: 1320px !important;
44
+ margin: 0 auto !important;
45
+ background: var(--bg) !important;
46
+ padding-bottom: 24px !important;
47
+ }
48
+
49
+ /* ── Top bar ── */
50
+ .top-bar {
51
+ background: linear-gradient(135deg, #1C1C1E 0%, #2C2C2E 100%);
52
+ border-radius: 16px;
53
+ padding: 28px 36px;
54
+ margin-bottom: 16px;
55
+ display: flex;
56
+ justify-content: space-between;
57
+ align-items: center;
58
+ }
59
+ .top-bar .title {
60
+ font-size: 24px; font-weight: 700;
61
+ color: #F5F5F7; letter-spacing: -0.3px;
62
+ }
63
+ .top-bar .subtitle {
64
+ font-size: 13px; color: #98989D;
65
+ margin-top: 2px;
66
+ }
67
+ .top-bar .location {
68
+ text-align: right;
69
+ font-size: 13px; color: #98989D;
70
+ line-height: 1.6;
71
+ }
72
+ .top-bar .location b {
73
+ color: #F5F5F7; font-weight: 600;
74
+ }
75
+
76
+ /* ── Hero card ── */
77
+ .hero-card {
78
+ background: var(--card);
79
+ border-radius: 16px;
80
+ border: 1px solid var(--border);
81
+ box-shadow: 0 2px 8px rgba(0,0,0,0.04);
82
+ padding: 32px 36px 28px;
83
+ margin-bottom: 16px;
84
+ }
85
+ .hero-main {
86
+ display: flex;
87
+ align-items: baseline;
88
+ gap: 20px;
89
+ margin-bottom: 4px;
90
+ }
91
+ .hero-temp {
92
+ font-size: 64px; font-weight: 300;
93
+ color: var(--text); letter-spacing: -2px;
94
+ line-height: 1;
95
+ }
96
+ .hero-temp-unit {
97
+ font-size: 28px; font-weight: 400;
98
+ color: var(--muted); margin-left: 2px;
99
+ }
100
+ .hero-status {
101
+ font-size: 20px; font-weight: 500;
102
+ color: var(--text); padding-left: 8px;
103
+ border-left: 3px solid var(--accent);
104
+ }
105
+ .hero-metrics {
106
+ display: flex;
107
+ gap: 12px;
108
+ margin: 20px 0 18px;
109
+ }
110
+ .metric-tile {
111
+ flex: 1;
112
+ background: var(--bg);
113
+ border-radius: 12px;
114
+ padding: 14px 16px;
115
+ text-align: center;
116
+ }
117
+ .metric-value {
118
+ font-size: 22px; font-weight: 600;
119
+ color: var(--text); line-height: 1.2;
120
+ }
121
+ .metric-label {
122
+ font-size: 12px; font-weight: 500;
123
+ color: var(--muted);
124
+ text-transform: uppercase;
125
+ letter-spacing: 0.5px;
126
+ margin-top: 4px;
127
+ }
128
+ .hero-meta {
129
+ font-size: 13px; color: var(--muted);
130
+ line-height: 1.6;
131
+ }
132
+ .hero-meta code {
133
+ background: var(--bg); padding: 2px 6px;
134
+ border-radius: 4px; font-size: 12px;
135
+ }
136
+ .hero-placeholder {
137
+ text-align: center;
138
+ padding: 36px 0;
139
+ color: var(--muted);
140
+ font-size: 16px; font-weight: 500;
141
+ }
142
+
143
+ /* ── Map section ── */
144
+ .maps-heading {
145
+ font-size: 11px; font-weight: 600;
146
+ text-transform: uppercase; letter-spacing: 0.8px;
147
+ color: var(--muted);
148
+ margin: 8px 0 8px 4px;
149
+ }
150
+
151
+ .map-cell {
152
+ background: var(--card) !important;
153
+ border-radius: 14px !important;
154
+ border: 1px solid var(--border) !important;
155
+ box-shadow: 0 1px 4px rgba(0,0,0,0.04) !important;
156
+ overflow: hidden !important;
157
+ min-height: 380px !important;
158
+ }
159
+
160
+ /* ── Controls inside hero ── */
161
+ .controls-row {
162
+ display: flex; align-items: end; gap: 10px;
163
+ margin-top: 18px; padding-top: 16px;
164
+ border-top: 1px solid var(--border);
165
+ }
166
+
167
+ /* ── Status ── */
168
+ .status-text p, .status-text em {
169
+ font-size: 12px !important; color: var(--muted) !important;
170
+ }
171
+
172
+ /* ── About ── */
173
+ .about-section {
174
+ font-size: 13px !important; color: #6E6E73 !important;
175
+ line-height: 1.65 !important;
176
+ }
177
+
178
+ /* ── Button ── */
179
+ button.primary {
180
+ background: var(--accent) !important;
181
+ border: none !important; border-radius: 10px !important;
182
+ font-weight: 600 !important; font-size: 15px !important;
183
+ padding: 10px 28px !important;
184
+ }
185
+ button.primary:hover { background: #0A74E0 !important; }
186
+ """
187
+
188
+ # ── Helpers ────────────────────────────────────────────────────────────
189
+
190
+ model_choices = [
191
+ f"{v['display_name']} ({v['params']})" for v in AVAILABLE_MODELS.values()
192
+ ]
193
+ model_keys = list(AVAILABLE_MODELS.keys())
194
+
195
+
196
+ def _resolve_model(display: str) -> str:
197
+ return model_keys[model_choices.index(display)]
198
+
199
+
200
+ def _hero_placeholder() -> str:
201
+ return (
202
+ '<div class="hero-card">'
203
+ '<div class="hero-placeholder">'
204
+ "Click <b>Run Forecast</b> to fetch real-time HRRR data and generate a 24-hour prediction."
205
+ "</div></div>"
206
+ )
207
+
208
+
209
+ def _hero_html(r: dict, cycle_str: str, forecast_str: str, model_label: str) -> str:
210
+ return (
211
+ '<div class="hero-card">'
212
+ # temperature + status
213
+ '<div class="hero-main">'
214
+ f'<div><span class="hero-temp">{r["temperature_c"]:.1f}</span>'
215
+ f'<span class="hero-temp-unit">Β°C</span></div>'
216
+ f'<div class="hero-status">{r["rain_status"]}</div>'
217
+ "</div>"
218
+ # metric tiles
219
+ '<div class="hero-metrics">'
220
+ f'<div class="metric-tile"><div class="metric-value">{r["temperature_f"]:.0f}Β°F</div>'
221
+ '<div class="metric-label">Temperature</div></div>'
222
+ f'<div class="metric-tile"><div class="metric-value">{r["humidity_pct"]:.0f}%</div>'
223
+ '<div class="metric-label">Humidity</div></div>'
224
+ f'<div class="metric-tile"><div class="metric-value">{r["wind_speed_ms"]:.1f}</div>'
225
+ f'<div class="metric-label">Wind m/s {r["wind_dir_str"]}</div></div>'
226
+ f'<div class="metric-tile"><div class="metric-value">{r["gust_ms"]:.1f}</div>'
227
+ '<div class="metric-label">Gust m/s</div></div>'
228
+ f'<div class="metric-tile"><div class="metric-value">{r["precipitation_mm"]:.2f}</div>'
229
+ '<div class="metric-label">Precip mm</div></div>'
230
+ "</div>"
231
+ # meta line
232
+ '<div class="hero-meta">'
233
+ f"Based on &ensp;<code>{cycle_str}</code> &ensp; "
234
+ f"Forecast valid &ensp;<b>{forecast_str}</b> &ensp; "
235
+ f"Model &ensp;<b>{model_label}</b>"
236
+ "</div>"
237
+ "</div>"
238
+ )
239
+
240
+
241
+ # ── Main callback ──────────────────────────────────────────────────────
242
+
243
+ def do_forecast(model_display: str, progress=gr.Progress()):
244
+ model_name = _resolve_model(model_display)
245
+
246
+ progress(0.02, desc="Finding latest HRRR cycle...")
247
+ try:
248
+ input_array, cycle_time = fetch_hrrr_input(
249
+ progress_callback=lambda f, m: progress(f, desc=m),
250
+ )
251
+ except Exception as e:
252
+ raise gr.Error(f"HRRR fetch failed: {e}")
253
+
254
+ cycle_str = cycle_time.strftime("%Y-%m-%d %H:%M UTC")
255
+ forecast_time = cycle_time + timedelta(hours=24)
256
+ forecast_str = forecast_time.strftime("%Y-%m-%d %H:%M UTC")
257
+
258
+ progress(0.95, desc="Running model inference...")
259
+ try:
260
+ r = run_forecast(model_name, input_array)
261
+ except Exception as e:
262
+ raise gr.Error(f"Inference failed: {e}")
263
+
264
+ model_label = model_display.split("(")[0].strip()
265
+ hero = _hero_html(r, cycle_str, forecast_str, model_label)
266
+ temp_fig = plot_temperature(input_array, r, cycle_str, forecast_str)
267
+ status = f"Forecast complete β€” HRRR cycle {cycle_str}"
268
+
269
+ return hero, temp_fig, status
270
+
271
+
272
+ # ── Build UI ──────────────────────────────────────────────────────────
273
+
274
+ # Pre-render static maps at import time
275
+ logger.info("Rendering static basemaps...")
276
+ _sat_fig, _street_fig = get_static_maps()
277
+ _temp_placeholder = plot_temperature_placeholder()
278
+ logger.info("Basemaps ready.")
279
+
280
+ with gr.Blocks(title="Tufts Jumbo Weather Forecast") as demo:
281
+
282
+ # ── Top bar ───────────────────────────────────────────────────
283
+ gr.HTML(
284
+ '<div class="top-bar">'
285
+ '<div>'
286
+ '<div class="title">Tufts Jumbo Weather</div>'
287
+ '<div class="subtitle">Real-time deep-learning forecast</div>'
288
+ "</div>"
289
+ '<div class="location">'
290
+ "<b>Medford, MA</b><br>"
291
+ "42.41Β°N &ensp; 71.12Β°W"
292
+ "</div>"
293
+ "</div>"
294
+ )
295
+
296
+ # ── Hero card ─────────────────────────────────────────────────
297
+ hero_html = gr.HTML(_hero_placeholder())
298
+
299
+ # ── Controls ──────────────────────────────────────────────────
300
+ with gr.Row(elem_classes=["controls-row"]):
301
+ model_dd = gr.Dropdown(
302
+ choices=model_choices, value=model_choices[0],
303
+ label="Model", scale=3,
304
+ )
305
+ run_btn = gr.Button("Run Forecast", variant="primary", scale=1)
306
+
307
+ status_bar = gr.Markdown(
308
+ "_Ready β€” click **Run Forecast**._",
309
+ elem_classes=["status-text"],
310
+ )
311
+
312
+ # ── Maps ──────────────────────────────────────────────────────
313
+ gr.HTML('<div class="maps-heading">Coverage Maps β€” 1 350 km Γ— 1 350 km &ensp; 3 km resolution</div>')
314
+
315
+ with gr.Row(equal_height=True):
316
+ sat_plot = gr.Plot(
317
+ value=_sat_fig, label="Satellite",
318
+ elem_classes=["map-cell"],
319
+ )
320
+ street_plot = gr.Plot(
321
+ value=_street_fig, label="Reference Map",
322
+ elem_classes=["map-cell"],
323
+ )
324
+ temp_plot = gr.Plot(
325
+ value=_temp_placeholder, label="Temperature",
326
+ elem_classes=["map-cell"],
327
+ )
328
+
329
+ # ── About ─────────────────────────────────────────────────────
330
+ with gr.Accordion("About this demo", open=False):
331
+ gr.Markdown(
332
+ "**Data** &ensp; HRRR 3 km analysis from NOAA (AWS S3, via Herbie). "
333
+ "42 atmospheric channels covering the US Northeast.\n\n"
334
+ "**Models** &ensp; CNN Baseline (11.3 M params) Β· ResNet-18 (11.2 M params) β€” "
335
+ "predict 6 weather variables 24 h ahead for a single target point.\n\n"
336
+ "**Course** &ensp; Tufts CS 137 β€” Deep Neural Networks, Spring 2026",
337
+ elem_classes=["about-section"],
338
+ )
339
+
340
+ # ── Callbacks ─────────────────────────────────────────────────
341
+ run_btn.click(
342
+ fn=do_forecast,
343
+ inputs=[model_dd],
344
+ outputs=[hero_html, temp_plot, status_bar],
345
+ )
346
+
347
+
348
+ if __name__ == "__main__":
349
+ logger.info("Pre-loading default model...")
350
+ try:
351
+ load_model(model_keys[0])
352
+ logger.info("Model loaded.")
353
+ except Exception as e:
354
+ logger.warning(f"Pre-load failed: {e}")
355
+
356
+ demo.launch(share=False, css=CUSTOM_CSS)
checkpoints/cnn_baseline.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9978778103ca02974eec82bd676a58639b0e13fa893656bd622153a44ec68603
3
+ size 45379961
checkpoints/resnet18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d74d4559da02319248276d3d616cda17766f422350dfce61de041c03b9cfff0
3
+ size 45288199
hrrr_fetch.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fetch real-time HRRR analysis data from NOAA AWS S3 via Herbie.
3
+
4
+ Downloads 42 individual GRIB2 messages (one per input channel),
5
+ extracts the New England subgrid (450x449), and stacks them into
6
+ the model's expected input format.
7
+ """
8
+
9
+ import logging
10
+ from datetime import datetime, timedelta, timezone
11
+
12
+ import numpy as np
13
+
14
+ from var_mapping import HRRR_MAPPING, NE_Y_SLICE, NE_X_SLICE
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def find_latest_hrrr_cycle(max_lookback_hours: int = 6) -> datetime:
20
+ """
21
+ Find the most recent HRRR cycle that is available on AWS S3.
22
+ HRRR data typically becomes available ~45-90 minutes after valid time.
23
+
24
+ Returns a tz-naive datetime in UTC (Herbie requirement).
25
+ """
26
+ from herbie import Herbie
27
+
28
+ now = datetime.now(timezone.utc).replace(tzinfo=None) # tz-naive UTC
29
+ for hours_ago in range(2, max_lookback_hours + 1):
30
+ cycle_time = (now - timedelta(hours=hours_ago)).replace(
31
+ minute=0, second=0, microsecond=0
32
+ )
33
+ try:
34
+ H = Herbie(
35
+ cycle_time,
36
+ model="hrrr",
37
+ product="sfc",
38
+ fxx=0,
39
+ verbose=False,
40
+ )
41
+ # Check if the index file exists (fast check without downloading data)
42
+ if H.idx is not None:
43
+ logger.info(f"Found HRRR cycle: {cycle_time:%Y-%m-%d %H:%M UTC}")
44
+ return cycle_time
45
+ except Exception:
46
+ continue
47
+
48
+ raise RuntimeError(
49
+ f"No HRRR data available in the last {max_lookback_hours} hours. "
50
+ "NOAA servers may be temporarily unavailable."
51
+ )
52
+
53
+
54
+ def _fetch_single_variable(cycle_time: datetime, mapping: dict) -> np.ndarray:
55
+ """
56
+ Fetch one variable from HRRR and extract the NE subgrid.
57
+
58
+ Returns:
59
+ np.ndarray of shape (450, 449)
60
+ """
61
+ from herbie import Herbie
62
+
63
+ H = Herbie(
64
+ cycle_time,
65
+ model="hrrr",
66
+ product=mapping["product"],
67
+ fxx=mapping["fxx"],
68
+ verbose=False,
69
+ )
70
+
71
+ ds = H.xarray(mapping["search"], remove_grib=True)
72
+
73
+ # xarray dataset may have different variable names depending on the GRIB message.
74
+ # Get the first data variable (excluding coordinates).
75
+ data_vars = [v for v in ds.data_vars if v not in ("latitude", "longitude", "gribfile_projection")]
76
+ if not data_vars:
77
+ raise ValueError(f"No data variable found for {mapping['name']}")
78
+
79
+ field = ds[data_vars[0]].values # Full CONUS grid
80
+
81
+ # Extract NE subgrid
82
+ subgrid = field[NE_Y_SLICE, NE_X_SLICE]
83
+
84
+ if subgrid.shape != (450, 449):
85
+ raise ValueError(
86
+ f"Unexpected shape {subgrid.shape} for {mapping['name']}, expected (450, 449)"
87
+ )
88
+
89
+ return subgrid.astype(np.float32)
90
+
91
+
92
+ def fetch_hrrr_input(
93
+ cycle_time: datetime = None,
94
+ progress_callback=None,
95
+ ) -> tuple[np.ndarray, datetime]:
96
+ """
97
+ Fetch all 42 HRRR channels and stack into model input format.
98
+
99
+ Args:
100
+ cycle_time: Specific HRRR cycle to fetch. If None, finds the latest.
101
+ progress_callback: Optional callable(fraction, description) for progress updates.
102
+
103
+ Returns:
104
+ (input_array, cycle_time) where input_array is (450, 449, 42) float32.
105
+ """
106
+ if cycle_time is None:
107
+ if progress_callback:
108
+ progress_callback(0.05, "Finding latest HRRR cycle...")
109
+ cycle_time = find_latest_hrrr_cycle()
110
+
111
+ channels = []
112
+ n_channels = len(HRRR_MAPPING)
113
+ failed = []
114
+
115
+ for i, mapping in enumerate(HRRR_MAPPING):
116
+ if progress_callback:
117
+ frac = 0.1 + 0.85 * (i / n_channels)
118
+ progress_callback(frac, f"Fetching {mapping['name']} ({i+1}/{n_channels})...")
119
+
120
+ try:
121
+ field = _fetch_single_variable(cycle_time, mapping)
122
+ channels.append(field)
123
+ except Exception as e:
124
+ logger.warning(f"Failed to fetch {mapping['name']}: {e}")
125
+ failed.append(mapping["name"])
126
+ # Fill with zeros as fallback for individual missing channels
127
+ channels.append(np.zeros((450, 449), dtype=np.float32))
128
+
129
+ if failed:
130
+ logger.warning(f"Failed channels ({len(failed)}/{n_channels}): {failed}")
131
+ if len(failed) > n_channels // 2:
132
+ raise RuntimeError(
133
+ f"Too many channels failed ({len(failed)}/{n_channels}). "
134
+ "HRRR data may be unavailable."
135
+ )
136
+
137
+ input_array = np.stack(channels, axis=-1) # (450, 449, 42)
138
+
139
+ if progress_callback:
140
+ progress_callback(1.0, "Data fetch complete!")
141
+
142
+ return input_array, cycle_time
model_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and inference utilities for the weather forecast demo.
3
+
4
+ Wraps the existing inference/predict.py logic, adding user-friendly
5
+ post-processing (Celsius, wind speed/direction, rain likelihood).
6
+ """
7
+
8
+ import math
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ # In HF Space, models/ is in the same directory as this file
16
+ PROJECT_ROOT = Path(__file__).resolve().parent
17
+ sys.path.insert(0, str(PROJECT_ROOT))
18
+
19
+ from models import create_model, get_model_defaults
20
+
21
+ # ── Model cache (loaded once, reused across requests) ──────────────────
22
+ _model_cache: dict = {}
23
+
24
+
25
+ TARGET_VARS = [
26
+ ("TMP@2m_above_ground", "Temperature (2m)", "K"),
27
+ ("RH@2m_above_ground", "Relative Humidity", "%"),
28
+ ("UGRD@10m_above_ground", "U-Wind (10m)", "m/s"),
29
+ ("VGRD@10m_above_ground", "V-Wind (10m)", "m/s"),
30
+ ("GUST@surface", "Wind Gust", "m/s"),
31
+ ("APCP_1hr_acc_fcst@surface", "Precipitation (1hr)", "mm"),
32
+ ]
33
+
34
+ # Available models with display info
35
+ AVAILABLE_MODELS = {
36
+ "cnn_baseline": {
37
+ "display_name": "CNN Baseline",
38
+ "checkpoint": "checkpoints/cnn_baseline.pt",
39
+ "params": "11.3M",
40
+ },
41
+ "resnet18": {
42
+ "display_name": "ResNet-18",
43
+ "checkpoint": "checkpoints/resnet18.pt",
44
+ "params": "11.2M",
45
+ },
46
+ }
47
+
48
+
49
+ def load_model(model_name: str, device: str = "cpu"):
50
+ """
51
+ Load a trained model from checkpoint. Caches in memory for reuse.
52
+
53
+ Returns:
54
+ (model, norm_stats) tuple
55
+ """
56
+ if model_name in _model_cache:
57
+ return _model_cache[model_name]
58
+
59
+ ckpt_path = PROJECT_ROOT / AVAILABLE_MODELS[model_name]["checkpoint"]
60
+ if not ckpt_path.exists():
61
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
62
+
63
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
64
+ args = ckpt["args"]
65
+ ckpt_model_name = args["model"]
66
+
67
+ defaults = get_model_defaults(ckpt_model_name)
68
+ n_frames = args.get("n_frames") or defaults["n_frames"]
69
+
70
+ model_kwargs = {
71
+ "n_input_channels": 42,
72
+ "n_targets": 6,
73
+ "base_channels": args.get("base_channels", 64),
74
+ }
75
+ if n_frames > 1:
76
+ model_kwargs["n_frames"] = n_frames
77
+
78
+ model = create_model(ckpt_model_name, **model_kwargs)
79
+ model.load_state_dict(ckpt["model"])
80
+ model.to(device).eval()
81
+
82
+ norm_stats = ckpt.get("norm_stats")
83
+
84
+ _model_cache[model_name] = (model, norm_stats)
85
+ return model, norm_stats
86
+
87
+
88
+ def predict_raw(model, norm_stats, input_array: np.ndarray, device: str = "cpu") -> np.ndarray:
89
+ """
90
+ Run inference on a (450, 449, 42) input array.
91
+
92
+ Returns:
93
+ np.ndarray of shape (6,) with denormalized physical values.
94
+ """
95
+ x = torch.from_numpy(input_array).float()
96
+ x = x.permute(2, 0, 1).unsqueeze(0) # (1, 42, 450, 449)
97
+
98
+ if norm_stats:
99
+ mean = norm_stats["input_mean"]
100
+ std = norm_stats["input_std"]
101
+ # Ensure correct device
102
+ if isinstance(mean, torch.Tensor):
103
+ mean = mean.float()
104
+ std = std.float()
105
+ x = (x - mean) / (std + 1e-7)
106
+
107
+ x = x.to(device)
108
+ with torch.no_grad():
109
+ pred = model(x).squeeze(0).cpu() # (6,)
110
+
111
+ if norm_stats:
112
+ target_mean = norm_stats["target_mean"]
113
+ target_std = norm_stats["target_std"]
114
+ if isinstance(target_mean, torch.Tensor):
115
+ target_mean = target_mean.float()
116
+ target_std = target_std.float()
117
+ pred = pred * target_std + target_mean
118
+
119
+ return pred.numpy()
120
+
121
+
122
+ def _wind_direction_str(degrees: float) -> str:
123
+ """Convert wind direction in degrees to compass string."""
124
+ dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE",
125
+ "S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"]
126
+ idx = round(degrees / 22.5) % 16
127
+ return dirs[idx]
128
+
129
+
130
+ def format_forecast(pred: np.ndarray) -> dict:
131
+ """
132
+ Convert raw model output (6 physical values) into a user-friendly forecast dict.
133
+ """
134
+ temp_k = float(pred[0])
135
+ rh = float(pred[1])
136
+ u_wind = float(pred[2])
137
+ v_wind = float(pred[3])
138
+ gust = float(pred[4])
139
+ apcp = float(pred[5])
140
+
141
+ # Derived quantities
142
+ temp_c = temp_k - 273.15
143
+ temp_f = temp_c * 9 / 5 + 32
144
+ wind_speed = math.sqrt(u_wind**2 + v_wind**2)
145
+ # Meteorological wind direction: direction FROM which wind blows
146
+ wind_dir_deg = (math.degrees(math.atan2(-u_wind, -v_wind)) + 360) % 360
147
+ wind_dir_str = _wind_direction_str(wind_dir_deg)
148
+
149
+ # Rain likelihood based on APCP threshold
150
+ apcp = max(apcp, 0.0) # Clamp negative predictions
151
+ if apcp > 5.0:
152
+ rain_str = "Heavy Rain Likely"
153
+ elif apcp > 2.0:
154
+ rain_str = "Rain Likely"
155
+ elif apcp > 0.5:
156
+ rain_str = "Light Rain Possible"
157
+ else:
158
+ rain_str = "No Rain Expected"
159
+
160
+ return {
161
+ "temperature_k": temp_k,
162
+ "temperature_c": temp_c,
163
+ "temperature_f": temp_f,
164
+ "humidity_pct": max(0.0, min(100.0, rh)),
165
+ "u_wind_ms": u_wind,
166
+ "v_wind_ms": v_wind,
167
+ "wind_speed_ms": wind_speed,
168
+ "wind_dir_deg": wind_dir_deg,
169
+ "wind_dir_str": wind_dir_str,
170
+ "gust_ms": max(gust, 0.0),
171
+ "precipitation_mm": apcp,
172
+ "rain_status": rain_str,
173
+ }
174
+
175
+
176
+ def run_forecast(model_name: str, input_array: np.ndarray, device: str = "cpu") -> dict:
177
+ """Full pipeline: load model β†’ predict β†’ format results."""
178
+ model, norm_stats = load_model(model_name, device)
179
+ pred = predict_raw(model, norm_stats, input_array, device)
180
+ return format_forecast(pred)
models/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model registry for weather forecasting architectures.
3
+
4
+ Usage:
5
+ from models import create_model, MODEL_REGISTRY
6
+ model = create_model("cnn_baseline", n_input_channels=42)
7
+ """
8
+
9
+ from .cnn_baseline import BaselineCNN
10
+ from .cnn_multi_frame import MultiFrameCNN
11
+ from .cnn_3d import CNN3D
12
+ from .vit import WeatherViT
13
+ from .resnet_baseline import ResNet18Baseline
14
+ from .convnext_baseline import ConvNeXtBaseline
15
+
16
+ MODEL_REGISTRY = {
17
+ "cnn_baseline": BaselineCNN,
18
+ "cnn_multi_frame": MultiFrameCNN,
19
+ "cnn_3d": CNN3D,
20
+ "vit": WeatherViT,
21
+ "resnet18": ResNet18Baseline,
22
+ "convnext_tiny": ConvNeXtBaseline,
23
+ }
24
+
25
+ # Default model-specific settings
26
+ MODEL_DEFAULTS = {
27
+ "cnn_baseline": {"n_frames": 1, "stack_mode": "channel"},
28
+ "cnn_multi_frame": {"n_frames": 4, "stack_mode": "channel"},
29
+ "cnn_3d": {"n_frames": 4, "stack_mode": "temporal"},
30
+ "vit": {"n_frames": 1, "stack_mode": "channel"},
31
+ "resnet18": {"n_frames": 1, "stack_mode": "channel"},
32
+ "convnext_tiny": {"n_frames": 1, "stack_mode": "channel"},
33
+ }
34
+
35
+
36
+ def create_model(name, **kwargs):
37
+ """Instantiate a model by name with given kwargs."""
38
+ if name not in MODEL_REGISTRY:
39
+ raise ValueError(f"Unknown model: {name}. Available: {list(MODEL_REGISTRY.keys())}")
40
+ return MODEL_REGISTRY[name](**kwargs)
41
+
42
+
43
+ def get_model_defaults(name):
44
+ """Return default n_frames and stack_mode for a model."""
45
+ return MODEL_DEFAULTS.get(name, {"n_frames": 1, "stack_mode": "channel"})
models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.63 kB). View file
 
models/__pycache__/cnn_3d.cpython-313.pyc ADDED
Binary file (5.48 kB). View file
 
models/__pycache__/cnn_baseline.cpython-313.pyc ADDED
Binary file (4.87 kB). View file
 
models/__pycache__/cnn_multi_frame.cpython-313.pyc ADDED
Binary file (3.51 kB). View file
 
models/__pycache__/convnext_baseline.cpython-313.pyc ADDED
Binary file (1.91 kB). View file
 
models/__pycache__/resnet_baseline.cpython-313.pyc ADDED
Binary file (1.84 kB). View file
 
models/__pycache__/vit.cpython-313.pyc ADDED
Binary file (6.42 kB). View file
 
models/cnn_3d.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D CNN for spatiotemporal weather forecasting.
3
+
4
+ Uses Conv3d to jointly model spatial and temporal patterns
5
+ from multiple consecutive weather snapshots.
6
+
7
+ Input: (B, k, C, H, W) β€” k frames, each with C channels
8
+ Output: (B, 6)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class ResBlock3D(nn.Module):
16
+ """3D residual block with separate temporal and spatial convolutions."""
17
+
18
+ def __init__(self, in_ch, out_ch, stride_spatial=1, stride_temporal=1):
19
+ super().__init__()
20
+ stride = (stride_temporal, stride_spatial, stride_spatial)
21
+
22
+ self.conv1 = nn.Conv3d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
23
+ self.bn1 = nn.BatchNorm3d(out_ch)
24
+ self.conv2 = nn.Conv3d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
25
+ self.bn2 = nn.BatchNorm3d(out_ch)
26
+ self.relu = nn.ReLU(inplace=True)
27
+
28
+ self.shortcut = nn.Identity()
29
+ if any(s != 1 for s in stride) or in_ch != out_ch:
30
+ self.shortcut = nn.Sequential(
31
+ nn.Conv3d(in_ch, out_ch, 1, stride=stride, bias=False),
32
+ nn.BatchNorm3d(out_ch),
33
+ )
34
+
35
+ def forward(self, x):
36
+ out = self.relu(self.bn1(self.conv1(x)))
37
+ out = self.bn2(self.conv2(out))
38
+ out = self.relu(out + self.shortcut(x))
39
+ return out
40
+
41
+
42
+ class CNN3D(nn.Module):
43
+ """
44
+ 3D CNN for spatiotemporal weather forecasting.
45
+
46
+ Input: (B, k, C, H, W) where k=n_frames, C=n_input_channels
47
+ Output: (B, 6)
48
+
49
+ The temporal dimension is collapsed early (by stride-2 temporal convolutions
50
+ and pooling), while spatial dimensions are progressively downsampled.
51
+ """
52
+
53
+ def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64):
54
+ super().__init__()
55
+ ch = base_channels
56
+
57
+ # (B, k, C, H, W) -> (B, C, k, H, W) is done in forward()
58
+ self.stem = nn.Sequential(
59
+ nn.Conv3d(n_input_channels, ch, kernel_size=(3, 7, 7),
60
+ stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
61
+ nn.BatchNorm3d(ch),
62
+ nn.ReLU(inplace=True),
63
+ )
64
+ # spatial: 450x449 -> 225x225, temporal: k -> k
65
+
66
+ self.layer1 = ResBlock3D(ch, ch, stride_spatial=1, stride_temporal=1)
67
+ # Collapse temporal dimension
68
+ self.layer2 = ResBlock3D(ch, ch * 2, stride_spatial=2, stride_temporal=2)
69
+ # temporal: k -> k//2, spatial: 225 -> 113
70
+ self.layer3 = ResBlock3D(ch * 2, ch * 4, stride_spatial=2, stride_temporal=2)
71
+ # temporal: k//2 -> 1 (for k=4), spatial: 113 -> 57
72
+ self.layer4 = ResBlock3D(ch * 4, ch * 4, stride_spatial=2, stride_temporal=1)
73
+ # spatial: 57 -> 29
74
+ self.layer5 = ResBlock3D(ch * 4, ch * 8, stride_spatial=2, stride_temporal=1)
75
+ # spatial: 29 -> 15
76
+ self.layer6 = ResBlock3D(ch * 8, ch * 8, stride_spatial=2, stride_temporal=1)
77
+ # spatial: 15 -> 8
78
+
79
+ self.pool = nn.AdaptiveAvgPool3d(1)
80
+ self.head = nn.Sequential(
81
+ nn.Flatten(),
82
+ nn.Linear(ch * 8, ch * 2),
83
+ nn.ReLU(inplace=True),
84
+ nn.Dropout(0.3),
85
+ nn.Linear(ch * 2, n_targets),
86
+ )
87
+
88
+ def forward(self, x):
89
+ # x: (B, k, C, H, W) -> (B, C, k, H, W) for Conv3d
90
+ x = x.permute(0, 2, 1, 3, 4)
91
+ x = self.stem(x)
92
+ x = self.layer1(x)
93
+ x = self.layer2(x)
94
+ x = self.layer3(x)
95
+ x = self.layer4(x)
96
+ x = self.layer5(x)
97
+ x = self.layer6(x)
98
+ x = self.pool(x)
99
+ return self.head(x)
models/cnn_baseline.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline 2D CNN for single-frame weather forecasting.
3
+
4
+ Takes a single spatial snapshot (C, H, W) and predicts 6 weather variables 24h ahead.
5
+ Architecture: progressive downsampling with residual blocks, global average pooling, FC head.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class ResBlock(nn.Module):
13
+ """Residual block with two conv layers and optional downsampling."""
14
+
15
+ def __init__(self, in_ch, out_ch, stride=1):
16
+ super().__init__()
17
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(out_ch)
19
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(out_ch)
21
+ self.relu = nn.ReLU(inplace=True)
22
+
23
+ self.shortcut = nn.Identity()
24
+ if stride != 1 or in_ch != out_ch:
25
+ self.shortcut = nn.Sequential(
26
+ nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
27
+ nn.BatchNorm2d(out_ch),
28
+ )
29
+
30
+ def forward(self, x):
31
+ out = self.relu(self.bn1(self.conv1(x)))
32
+ out = self.bn2(self.conv2(out))
33
+ out = self.relu(out + self.shortcut(x))
34
+ return out
35
+
36
+
37
+ class BaselineCNN(nn.Module):
38
+ """
39
+ Single-frame 2D CNN encoder for weather forecasting.
40
+
41
+ Input: (B, C, 450, 449) where C = n_input_channels (42)
42
+ Output: (B, 6) β€” predicted weather variables
43
+ """
44
+
45
+ def __init__(self, n_input_channels=42, n_targets=6, base_channels=64):
46
+ super().__init__()
47
+
48
+ ch = base_channels
49
+ self.stem = nn.Sequential(
50
+ nn.Conv2d(n_input_channels, ch, 7, stride=2, padding=3, bias=False),
51
+ nn.BatchNorm2d(ch),
52
+ nn.ReLU(inplace=True),
53
+ )
54
+ # 450x449 -> 225x225
55
+
56
+ self.layer1 = ResBlock(ch, ch, stride=1) # 225x225
57
+ self.layer2 = ResBlock(ch, ch * 2, stride=2) # 113x113
58
+ self.layer3 = ResBlock(ch * 2, ch * 4, stride=2) # 57x57
59
+ self.layer4 = ResBlock(ch * 4, ch * 4, stride=2) # 29x29
60
+ self.layer5 = ResBlock(ch * 4, ch * 8, stride=2) # 15x15
61
+ self.layer6 = ResBlock(ch * 8, ch * 8, stride=2) # 8x8
62
+
63
+ self.pool = nn.AdaptiveAvgPool2d(1)
64
+ self.head = nn.Sequential(
65
+ nn.Flatten(),
66
+ nn.Linear(ch * 8, ch * 2),
67
+ nn.ReLU(inplace=True),
68
+ nn.Dropout(0.3),
69
+ nn.Linear(ch * 2, n_targets),
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = self.stem(x)
74
+ x = self.layer1(x)
75
+ x = self.layer2(x)
76
+ x = self.layer3(x)
77
+ x = self.layer4(x)
78
+ x = self.layer5(x)
79
+ x = self.layer6(x)
80
+ x = self.pool(x)
81
+ return self.head(x)
models/cnn_multi_frame.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-frame 2D CNN for weather forecasting.
3
+
4
+ Stacks k consecutive spatial snapshots along the channel dimension,
5
+ allowing the model to learn temporal patterns with standard 2D convolutions.
6
+
7
+ Input: (B, k*C, H, W)
8
+ Output: (B, 6)
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from .cnn_baseline import ResBlock
14
+
15
+
16
+ class MultiFrameCNN(nn.Module):
17
+ """
18
+ Multi-frame 2D CNN that concatenates consecutive frames along channels.
19
+
20
+ Identical backbone to BaselineCNN but with an adapted stem for k*C input channels,
21
+ plus a temporal mixing layer after the stem.
22
+ """
23
+
24
+ def __init__(self, n_input_channels=42, n_targets=6, n_frames=4, base_channels=64):
25
+ super().__init__()
26
+ self.n_frames = n_frames
27
+ in_ch = n_input_channels * n_frames
28
+ ch = base_channels
29
+
30
+ self.stem = nn.Sequential(
31
+ nn.Conv2d(in_ch, ch * 2, 7, stride=2, padding=3, bias=False),
32
+ nn.BatchNorm2d(ch * 2),
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(ch * 2, ch, 1, bias=False),
35
+ nn.BatchNorm2d(ch),
36
+ nn.ReLU(inplace=True),
37
+ )
38
+
39
+ self.layer1 = ResBlock(ch, ch, stride=1)
40
+ self.layer2 = ResBlock(ch, ch * 2, stride=2)
41
+ self.layer3 = ResBlock(ch * 2, ch * 4, stride=2)
42
+ self.layer4 = ResBlock(ch * 4, ch * 4, stride=2)
43
+ self.layer5 = ResBlock(ch * 4, ch * 8, stride=2)
44
+ self.layer6 = ResBlock(ch * 8, ch * 8, stride=2)
45
+
46
+ self.pool = nn.AdaptiveAvgPool2d(1)
47
+ self.head = nn.Sequential(
48
+ nn.Flatten(),
49
+ nn.Linear(ch * 8, ch * 2),
50
+ nn.ReLU(inplace=True),
51
+ nn.Dropout(0.3),
52
+ nn.Linear(ch * 2, n_targets),
53
+ )
54
+
55
+ def forward(self, x):
56
+ x = self.stem(x)
57
+ x = self.layer1(x)
58
+ x = self.layer2(x)
59
+ x = self.layer3(x)
60
+ x = self.layer4(x)
61
+ x = self.layer5(x)
62
+ x = self.layer6(x)
63
+ x = self.pool(x)
64
+ return self.head(x)
models/convnext_baseline.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ConvNeXt-Tiny baseline for weather forecasting.
3
+
4
+ Uses torchvision's ConvNeXt-Tiny with modified input/output layers
5
+ to accept 42-channel weather data and predict 6 target variables.
6
+ Trained from scratch (no pretrained weights).
7
+
8
+ Input: (B, 42, 450, 449)
9
+ Output: (B, 6)
10
+ """
11
+
12
+ import torch.nn as nn
13
+ from torchvision.models import convnext_tiny
14
+
15
+
16
+ class ConvNeXtBaseline(nn.Module):
17
+ """
18
+ ConvNeXt-Tiny adapted for weather forecasting.
19
+
20
+ Modifications from ImageNet ConvNeXt-Tiny:
21
+ - Stem conv: 3 β†’ 42 input channels
22
+ - Classifier head: 1000 β†’ n_targets outputs
23
+ - No pretrained weights (trained from scratch)
24
+ """
25
+
26
+ def __init__(self, n_input_channels=42, n_targets=6, **kwargs):
27
+ super().__init__()
28
+ model = convnext_tiny(weights=None)
29
+
30
+ # Replace stem conv: 3ch β†’ 42ch
31
+ # ConvNeXt stem: features[0][0] is Conv2d(3, 96, 4, stride=4)
32
+ model.features[0][0] = nn.Conv2d(
33
+ n_input_channels, 96, kernel_size=4, stride=4
34
+ )
35
+
36
+ # Replace classifier head: 1000 β†’ n_targets
37
+ # ConvNeXt head: classifier = Sequential(LayerNorm, Flatten, Linear(768, 1000))
38
+ model.classifier[2] = nn.Linear(768, n_targets)
39
+
40
+ self.model = model
41
+
42
+ def forward(self, x):
43
+ return self.model(x)
models/resnet_baseline.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResNet-18 baseline for weather forecasting.
3
+
4
+ Uses torchvision's ResNet-18 with modified input/output layers
5
+ to accept 42-channel weather data and predict 6 target variables.
6
+ Trained from scratch (no pretrained weights).
7
+
8
+ Input: (B, 42, 450, 449)
9
+ Output: (B, 6)
10
+ """
11
+
12
+ import torch.nn as nn
13
+ from torchvision.models import resnet18
14
+
15
+
16
+ class ResNet18Baseline(nn.Module):
17
+ """
18
+ Standard ResNet-18 adapted for weather forecasting.
19
+
20
+ Modifications from ImageNet ResNet-18:
21
+ - conv1: 3 β†’ 42 input channels
22
+ - fc: 1000 β†’ n_targets output classes
23
+ - No pretrained weights (trained from scratch)
24
+ """
25
+
26
+ def __init__(self, n_input_channels=42, n_targets=6, **kwargs):
27
+ super().__init__()
28
+ model = resnet18(weights=None)
29
+
30
+ # Replace first conv: 3ch β†’ 42ch
31
+ model.conv1 = nn.Conv2d(
32
+ n_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
33
+ )
34
+
35
+ # Replace classifier head: 1000 β†’ n_targets
36
+ model.fc = nn.Linear(512, n_targets)
37
+
38
+ self.model = model
39
+
40
+ def forward(self, x):
41
+ return self.model(x)
models/vit.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision Transformer (ViT) for weather forecasting.
3
+
4
+ Splits the spatial input into non-overlapping patches, projects each patch
5
+ into an embedding, and processes them through a Transformer encoder.
6
+
7
+ Input: (B, C, H, W) β€” single frame with C channels
8
+ Output: (B, 6)
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+ class PatchEmbedding(nn.Module):
17
+ """Convert spatial input into a sequence of patch embeddings."""
18
+
19
+ def __init__(self, in_channels, embed_dim, patch_size, img_h, img_w):
20
+ super().__init__()
21
+ self.patch_size = patch_size
22
+ self.n_patches_h = img_h // patch_size
23
+ self.n_patches_w = img_w // patch_size
24
+ self.n_patches = self.n_patches_h * self.n_patches_w
25
+
26
+ self.proj = nn.Conv2d(in_channels, embed_dim,
27
+ kernel_size=patch_size, stride=patch_size)
28
+
29
+ def forward(self, x):
30
+ # x: (B, C, H, W) -> (B, embed_dim, nH, nW) -> (B, n_patches, embed_dim)
31
+ x = self.proj(x)
32
+ x = x.flatten(2).transpose(1, 2)
33
+ return x
34
+
35
+
36
+ class TransformerBlock(nn.Module):
37
+ """Standard Transformer encoder block with pre-norm."""
38
+
39
+ def __init__(self, embed_dim, n_heads, mlp_ratio=4.0, dropout=0.1):
40
+ super().__init__()
41
+ self.norm1 = nn.LayerNorm(embed_dim)
42
+ self.attn = nn.MultiheadAttention(embed_dim, n_heads,
43
+ dropout=dropout, batch_first=True)
44
+ self.norm2 = nn.LayerNorm(embed_dim)
45
+ self.mlp = nn.Sequential(
46
+ nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
47
+ nn.GELU(),
48
+ nn.Dropout(dropout),
49
+ nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
50
+ nn.Dropout(dropout),
51
+ )
52
+
53
+ def forward(self, x):
54
+ h = self.norm1(x)
55
+ h, _ = self.attn(h, h, h)
56
+ x = x + h
57
+ x = x + self.mlp(self.norm2(x))
58
+ return x
59
+
60
+
61
+ class WeatherViT(nn.Module):
62
+ """
63
+ Vision Transformer for weather forecasting.
64
+
65
+ Input: (B, C, 450, 449) β€” pads width to 450 internally
66
+ Output: (B, 6)
67
+
68
+ Patches the input into 15x15 patches (30x30 = 900 tokens),
69
+ adds CLS token and positional embeddings, runs through Transformer,
70
+ and predicts from the CLS token.
71
+ """
72
+
73
+ def __init__(self, n_input_channels=42, n_targets=6, patch_size=15,
74
+ embed_dim=256, n_layers=6, n_heads=8, mlp_ratio=4.0, dropout=0.1,
75
+ **kwargs):
76
+ super().__init__()
77
+ self.patch_size = patch_size
78
+ img_h, img_w = 450, 450 # pad to square
79
+
80
+ self.patch_embed = PatchEmbedding(n_input_channels, embed_dim,
81
+ patch_size, img_h, img_w)
82
+ n_patches = self.patch_embed.n_patches
83
+
84
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
85
+ self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim) * 0.02)
86
+ self.pos_drop = nn.Dropout(dropout)
87
+
88
+ self.blocks = nn.Sequential(*[
89
+ TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
90
+ for _ in range(n_layers)
91
+ ])
92
+ self.norm = nn.LayerNorm(embed_dim)
93
+
94
+ self.head = nn.Sequential(
95
+ nn.Linear(embed_dim, embed_dim // 2),
96
+ nn.ReLU(inplace=True),
97
+ nn.Dropout(0.3),
98
+ nn.Linear(embed_dim // 2, n_targets),
99
+ )
100
+
101
+ def forward(self, x):
102
+ B, C, H, W = x.shape
103
+ # Pad width from 449 to 450 if needed
104
+ if W < 450:
105
+ x = nn.functional.pad(x, (0, 450 - W))
106
+
107
+ patches = self.patch_embed(x) # (B, n_patches, D)
108
+ cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
109
+ x = torch.cat([cls, patches], dim=1) # (B, n_patches+1, D)
110
+ x = self.pos_drop(x + self.pos_embed)
111
+
112
+ x = self.blocks(x)
113
+ x = self.norm(x)
114
+
115
+ cls_out = x[:, 0] # (B, D)
116
+ return self.head(cls_out)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libeccodes-dev
2
+ libeccodes-tools
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision
3
+ numpy
4
+ matplotlib
5
+ gradio>=4.0
6
+ herbie-data>=2024.1
7
+ cfgrib
8
+ xarray
9
+ eccodes
10
+ cartopy
var_mapping.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mapping from 42-channel VAR_LEVELS to HRRR GRIB2 Herbie search strings.
3
+
4
+ Each entry corresponds to one channel in the model input, following the exact
5
+ order defined in data_preparation/data_spec.py (VAR_LEVELS = TARGET_VARS + ATMOS_VARS).
6
+ """
7
+
8
+ # fmt: off
9
+ HRRR_MAPPING = [
10
+ # ── Target / surface variables (channels 0-6) ──────────────────────
11
+ {"name": "TMP@2m_above_ground", "search": ":TMP:2 m above ground", "product": "sfc", "fxx": 0},
12
+ {"name": "RH@2m_above_ground", "search": ":RH:2 m above ground", "product": "sfc", "fxx": 0},
13
+ {"name": "UGRD@10m_above_ground", "search": ":UGRD:10 m above ground", "product": "sfc", "fxx": 0},
14
+ {"name": "VGRD@10m_above_ground", "search": ":VGRD:10 m above ground", "product": "sfc", "fxx": 0},
15
+ {"name": "GUST@surface", "search": ":GUST:surface", "product": "sfc", "fxx": 0},
16
+ {"name": "DSWRF@surface", "search": ":DSWRF:surface", "product": "sfc", "fxx": 0},
17
+ # APCP is accumulated; not in analysis (fxx=0). Use 1-hour forecast from same cycle.
18
+ {"name": "APCP_1hr_acc_fcst@surface", "search": ":APCP:surface:0-1 hour acc fcst", "product": "sfc", "fxx": 1},
19
+
20
+ # ── Atmospheric variables (channels 7-41) ──────────────────────────
21
+ # CAPE
22
+ {"name": "CAPE@surface", "search": ":CAPE:surface", "product": "sfc", "fxx": 0},
23
+
24
+ # Dew point temperature at pressure levels
25
+ {"name": "DPT@1000mb", "search": ":DPT:1000 mb", "product": "prs", "fxx": 0},
26
+ {"name": "DPT@500mb", "search": ":DPT:500 mb", "product": "prs", "fxx": 0},
27
+ {"name": "DPT@700mb", "search": ":DPT:700 mb", "product": "prs", "fxx": 0},
28
+ {"name": "DPT@850mb", "search": ":DPT:850 mb", "product": "prs", "fxx": 0},
29
+ {"name": "DPT@925mb", "search": ":DPT:925 mb", "product": "prs", "fxx": 0},
30
+
31
+ # Geopotential height at pressure levels + surface
32
+ {"name": "HGT@1000mb", "search": ":HGT:1000 mb", "product": "prs", "fxx": 0},
33
+ {"name": "HGT@500mb", "search": ":HGT:500 mb", "product": "prs", "fxx": 0},
34
+ {"name": "HGT@700mb", "search": ":HGT:700 mb", "product": "prs", "fxx": 0},
35
+ {"name": "HGT@850mb", "search": ":HGT:850 mb", "product": "prs", "fxx": 0},
36
+ {"name": "HGT@surface", "search": ":HGT:surface", "product": "sfc", "fxx": 0},
37
+
38
+ # Temperature at pressure levels
39
+ {"name": "TMP@1000mb", "search": ":TMP:1000 mb", "product": "prs", "fxx": 0},
40
+ {"name": "TMP@500mb", "search": ":TMP:500 mb", "product": "prs", "fxx": 0},
41
+ {"name": "TMP@700mb", "search": ":TMP:700 mb", "product": "prs", "fxx": 0},
42
+ {"name": "TMP@850mb", "search": ":TMP:850 mb", "product": "prs", "fxx": 0},
43
+ {"name": "TMP@925mb", "search": ":TMP:925 mb", "product": "prs", "fxx": 0},
44
+
45
+ # U-component wind at pressure levels
46
+ {"name": "UGRD@1000mb", "search": ":UGRD:1000 mb", "product": "prs", "fxx": 0},
47
+ {"name": "UGRD@250mb", "search": ":UGRD:250 mb", "product": "prs", "fxx": 0},
48
+ {"name": "UGRD@500mb", "search": ":UGRD:500 mb", "product": "prs", "fxx": 0},
49
+ {"name": "UGRD@700mb", "search": ":UGRD:700 mb", "product": "prs", "fxx": 0},
50
+ {"name": "UGRD@850mb", "search": ":UGRD:850 mb", "product": "prs", "fxx": 0},
51
+ {"name": "UGRD@925mb", "search": ":UGRD:925 mb", "product": "prs", "fxx": 0},
52
+
53
+ # V-component wind at pressure levels
54
+ {"name": "VGRD@1000mb", "search": ":VGRD:1000 mb", "product": "prs", "fxx": 0},
55
+ {"name": "VGRD@250mb", "search": ":VGRD:250 mb", "product": "prs", "fxx": 0},
56
+ {"name": "VGRD@500mb", "search": ":VGRD:500 mb", "product": "prs", "fxx": 0},
57
+ {"name": "VGRD@700mb", "search": ":VGRD:700 mb", "product": "prs", "fxx": 0},
58
+ {"name": "VGRD@850mb", "search": ":VGRD:850 mb", "product": "prs", "fxx": 0},
59
+ {"name": "VGRD@925mb", "search": ":VGRD:925 mb", "product": "prs", "fxx": 0},
60
+
61
+ # Cloud cover
62
+ {"name": "TCDC@entire_atmosphere", "search": ":TCDC:entire atmosphere", "product": "sfc", "fxx": 0},
63
+ {"name": "HCDC@high_cloud_layer", "search": ":HCDC:high cloud layer", "product": "sfc", "fxx": 0},
64
+ {"name": "MCDC@middle_cloud_layer", "search": ":MCDC:middle cloud layer", "product": "sfc", "fxx": 0},
65
+ {"name": "LCDC@low_cloud_layer", "search": ":LCDC:low cloud layer", "product": "sfc", "fxx": 0},
66
+
67
+ # Moisture
68
+ {"name": "PWAT@entire_atmosphere_single_layer", "search": ":PWAT:entire atmosphere", "product": "sfc", "fxx": 0},
69
+ {"name": "RHPW@entire_atmosphere", "search": ":RHPW:entire atmosphere", "product": "sfc", "fxx": 0},
70
+ {"name": "VIL@entire_atmosphere", "search": ":VIL:entire atmosphere", "product": "sfc", "fxx": 0},
71
+ ]
72
+ # fmt: on
73
+
74
+ assert len(HRRR_MAPPING) == 42, f"Expected 42 channels, got {len(HRRR_MAPPING)}"
75
+
76
+ # New England subgrid slice in the full HRRR CONUS grid (from data_spec.py)
77
+ NE_Y_SLICE = slice(600, 1050) # 450 rows
78
+ NE_X_SLICE = slice(1350, 1799) # 449 columns
79
+
80
+ # Jumbo Statue grid location within the NE subgrid
81
+ JUMBO_ROW = 177
82
+ JUMBO_COL = 263
visualization.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Forecast visualization β€” satellite, street, and temperature maps.
3
+
4
+ Satellite and reference maps are static (rendered once at startup).
5
+ Temperature map updates each time a forecast is run.
6
+ """
7
+
8
+ import logging
9
+ import numpy as np
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib.figure import Figure
14
+
15
+ import cartopy.crs as ccrs
16
+ import cartopy.feature as cfeature
17
+ import cartopy.io.img_tiles as cimgt
18
+
19
+ from var_mapping import JUMBO_ROW, JUMBO_COL
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ── Projection & coordinates (from data_spec.py) ──────────────────────
24
+
25
+ PROJ = ccrs.LambertConformal(
26
+ central_longitude=262.5,
27
+ central_latitude=38.5,
28
+ standard_parallels=(38.5, 38.5),
29
+ globe=ccrs.Globe(semimajor_axis=6371229, semiminor_axis=6371229),
30
+ )
31
+
32
+ _x = 1352479.8574780696 + np.arange(449) * 3000
33
+ _y = 212693.8474433364 + np.arange(450) * 3000
34
+ EXTENT = [_x[0], _x[-1], _y[0], _y[-1]]
35
+
36
+ JUMBO_LON, JUMBO_LAT = -71.1204, 42.4078
37
+ X_GRID, Y_GRID = np.meshgrid(_x, _y)
38
+
39
+ CITIES = [
40
+ ("Boston", 42.36, -71.06),
41
+ ("Providence", 41.82, -71.41),
42
+ ("Hartford", 41.76, -72.68),
43
+ ("Portland", 43.66, -70.26),
44
+ ("Burlington", 44.48, -73.21),
45
+ ("Concord", 43.21, -71.54),
46
+ ("Albany", 42.65, -73.76),
47
+ ("New York", 40.71, -74.01),
48
+ ("Montreal", 45.50, -73.57),
49
+ ]
50
+
51
+
52
+ # ── Tile sources ──────────────────────────────────────────────────────
53
+
54
+ class _EsriSatellite(cimgt.GoogleWTS):
55
+ def _image_url(self, tile):
56
+ x, y, z = tile
57
+ return (
58
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
59
+ f"World_Imagery/MapServer/tile/{z}/{y}/{x}"
60
+ )
61
+
62
+
63
+ class _EsriStreetMap(cimgt.GoogleWTS):
64
+ def _image_url(self, tile):
65
+ x, y, z = tile
66
+ return (
67
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
68
+ f"World_Street_Map/MapServer/tile/{z}/{y}/{x}"
69
+ )
70
+
71
+
72
+ # ── Shared style ─────────────────────────────────────────────────────
73
+
74
+ _MARKER = dict(
75
+ marker="*", color="#FF3B30", markersize=14,
76
+ markeredgecolor="white", markeredgewidth=1.0,
77
+ transform=ccrs.PlateCarree(), zorder=20,
78
+ )
79
+
80
+ _TAG = dict(
81
+ fontsize=8.5, fontweight="bold", fontfamily="sans-serif",
82
+ color="white", transform=ccrs.PlateCarree(), zorder=25,
83
+ bbox=dict(boxstyle="round,pad=0.25", fc="#1C1C1E", ec="none", alpha=0.80),
84
+ )
85
+
86
+
87
+ def _make_ax(fig_or_ax=None, figsize=(7.2, 6.8)):
88
+ """Create a single GeoAxes with consistent extent."""
89
+ if fig_or_ax is None:
90
+ fig, ax = plt.subplots(
91
+ figsize=figsize, subplot_kw={"projection": PROJ},
92
+ )
93
+ else:
94
+ fig, ax = fig_or_ax.figure, fig_or_ax
95
+ ax.set_extent(EXTENT, crs=PROJ)
96
+ return fig, ax
97
+
98
+
99
+ # ── Individual map renderers ─────────────────────────────────────────
100
+
101
+ def plot_satellite() -> Figure:
102
+ """Render satellite basemap (static, no weather data needed)."""
103
+ fig, ax = _make_ax()
104
+ try:
105
+ ax.add_image(_EsriSatellite(), 7)
106
+ except Exception:
107
+ ax.add_feature(cfeature.LAND, facecolor="#5C4A32")
108
+ ax.add_feature(cfeature.OCEAN, facecolor="#1B3A4B")
109
+ ax.add_feature(cfeature.COASTLINE, linewidth=0.6, color="white")
110
+ ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#aaa")
111
+ ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
112
+ ax.text(JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG)
113
+ ax.set_title(
114
+ "Satellite", fontsize=12, fontweight="600",
115
+ fontfamily="sans-serif", pad=8, color="#1D1D1F",
116
+ )
117
+ fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
118
+ return fig
119
+
120
+
121
+ def plot_street() -> Figure:
122
+ """Render street / reference basemap (static)."""
123
+ fig, ax = _make_ax()
124
+ try:
125
+ ax.add_image(_EsriStreetMap(), 7)
126
+ except Exception:
127
+ ax.add_feature(cfeature.LAND, facecolor="#E8E4D8")
128
+ ax.add_feature(cfeature.OCEAN, facecolor="#AAD3DF")
129
+ ax.add_feature(cfeature.LAKES, facecolor="#AAD3DF", edgecolor="#888", linewidth=0.3)
130
+ ax.add_feature(cfeature.RIVERS, edgecolor="#AAD3DF", linewidth=0.4)
131
+ ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
132
+ ax.add_feature(cfeature.BORDERS, linewidth=0.5, linestyle="--")
133
+ ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#888")
134
+ pc = ccrs.PlateCarree()
135
+ for name, lat, lon in CITIES:
136
+ ax.text(
137
+ lon, lat, name,
138
+ fontsize=7, fontfamily="sans-serif", fontweight="500",
139
+ color="#333", transform=pc, zorder=15,
140
+ bbox=dict(boxstyle="round,pad=0.15", fc="white", ec="none", alpha=0.7),
141
+ )
142
+ ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
143
+ ax.text(JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG)
144
+ ax.set_title(
145
+ "Reference Map", fontsize=12, fontweight="600",
146
+ fontfamily="sans-serif", pad=8, color="#1D1D1F",
147
+ )
148
+ fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
149
+ return fig
150
+
151
+
152
+ def plot_temperature(
153
+ input_array: np.ndarray,
154
+ forecast: dict,
155
+ cycle_str: str,
156
+ forecast_str: str,
157
+ ) -> Figure:
158
+ """Render 2 m temperature map with forecast annotation."""
159
+ fig, ax = _make_ax()
160
+
161
+ temp_field = input_array[:, :, 0] - 273.15
162
+ masked = np.ma.masked_invalid(temp_field)
163
+
164
+ im = ax.pcolormesh(
165
+ X_GRID, Y_GRID, masked,
166
+ cmap="RdYlBu_r", shading="auto", transform=PROJ, zorder=5,
167
+ )
168
+ ax.add_feature(cfeature.COASTLINE, linewidth=0.5, color="#444", zorder=10)
169
+ ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#666", zorder=10)
170
+
171
+ cbar = fig.colorbar(im, ax=ax, shrink=0.72, pad=0.03, aspect=28)
172
+ cbar.set_label("Β°C", fontsize=10, fontfamily="sans-serif")
173
+ cbar.ax.tick_params(labelsize=8)
174
+
175
+ ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
176
+
177
+ temp_c = forecast["temperature_c"]
178
+ temp_f = forecast["temperature_f"]
179
+ label = f"Forecast {forecast_str}\n{temp_c:+.1f} Β°C / {temp_f:.0f} Β°F"
180
+ ax.text(
181
+ JUMBO_LON + 0.45, JUMBO_LAT + 0.35, label,
182
+ fontsize=8.5, fontweight="bold", fontfamily="sans-serif",
183
+ color="white", transform=ccrs.PlateCarree(), zorder=25,
184
+ bbox=dict(boxstyle="round,pad=0.35", fc="#1C1C1E", ec="white", alpha=0.88, lw=0.8),
185
+ )
186
+
187
+ ax.set_title(
188
+ f"2 m Temperature β€” {cycle_str}",
189
+ fontsize=12, fontweight="600",
190
+ fontfamily="sans-serif", pad=8, color="#1D1D1F",
191
+ )
192
+ fig.subplots_adjust(left=0.02, right=0.95, bottom=0.02, top=0.93)
193
+ return fig
194
+
195
+
196
+ def plot_temperature_placeholder() -> Figure:
197
+ """Empty temperature panel shown before first forecast."""
198
+ fig, ax = _make_ax()
199
+ ax.add_feature(cfeature.LAND, facecolor="#E8E4D8", zorder=1)
200
+ ax.add_feature(cfeature.OCEAN, facecolor="#D6EAF0", zorder=1)
201
+ ax.add_feature(cfeature.COASTLINE, linewidth=0.5, color="#999", zorder=2)
202
+ ax.add_feature(cfeature.STATES, linewidth=0.3, edgecolor="#bbb", zorder=2)
203
+ ax.plot(JUMBO_LON, JUMBO_LAT, **_MARKER)
204
+ ax.text(
205
+ JUMBO_LON + 0.35, JUMBO_LAT + 0.25, "Jumbo", **_TAG,
206
+ )
207
+ ax.text(
208
+ 0.5, 0.50, "Click Run Forecast",
209
+ transform=ax.transAxes, ha="center", va="center",
210
+ fontsize=14, fontweight="600", fontfamily="sans-serif",
211
+ color="#86868B",
212
+ )
213
+ ax.set_title(
214
+ "2 m Temperature", fontsize=12, fontweight="600",
215
+ fontfamily="sans-serif", pad=8, color="#1D1D1F",
216
+ )
217
+ fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.93)
218
+ return fig
219
+
220
+
221
+ # ── Startup cache ─────────────────────────────────────────────────────
222
+
223
+ _cache = {}
224
+
225
+
226
+ def get_static_maps() -> tuple[Figure, Figure]:
227
+ """Return cached satellite and street map figures (rendered once)."""
228
+ if "satellite" not in _cache:
229
+ logger.info("Rendering satellite basemap...")
230
+ _cache["satellite"] = plot_satellite()
231
+ if "street" not in _cache:
232
+ logger.info("Rendering reference basemap...")
233
+ _cache["street"] = plot_street()
234
+ return _cache["satellite"], _cache["street"]