tanishq74 commited on
Commit
232b144
Β·
verified Β·
1 Parent(s): 449bcc1

Add retinasense_v2.py

Browse files
Files changed (1) hide show
  1. retinasense_v2.py +567 -0
retinasense_v2.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense v2 β€” Production-Grade Training Pipeline
4
+ ====================================================
5
+ Fixes from v1:
6
+ 1. Focal Loss (handles class imbalance far better than weighted CE)
7
+ 2. Stratified batch sampler (every batch sees all classes)
8
+ 3. LR warmup + cosine decay (stable optimisation)
9
+ 4. Gradient accumulation (effective batch 128, actual batch 32)
10
+ 5. Early stopping on Macro F1 (not accuracy β€” misleading with imbalance)
11
+ 6. Per-class metrics tracked every epoch
12
+ 7. Pre-cached preprocessing (GPU efficiency)
13
+ 8. Proper NaN handling in mixed precision
14
+ 9. Comprehensive plots after training
15
+ """
16
+
17
+ import os, sys, time, warnings, json
18
+ import numpy as np
19
+ import pandas as pd
20
+ import cv2
21
+ import matplotlib
22
+ matplotlib.use('Agg')
23
+ import matplotlib.pyplot as plt
24
+ import seaborn as sns
25
+ from PIL import Image
26
+ from tqdm import tqdm
27
+ from collections import Counter
28
+ warnings.filterwarnings('ignore')
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from torch.amp import autocast, GradScaler
34
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
35
+ from torchvision import models, transforms
36
+
37
+ from sklearn.model_selection import train_test_split
38
+ from sklearn.utils.class_weight import compute_class_weight
39
+ from sklearn.metrics import (
40
+ classification_report, confusion_matrix,
41
+ roc_auc_score, f1_score, roc_curve, auc
42
+ )
43
+ from sklearn.preprocessing import label_binarize
44
+
45
+ # ═══════════════════════════════════════════════════════════
46
+ # CONFIG
47
+ # ═══════════════════════════════════════════════════════════
48
+ SAVE_DIR = './outputs_v2'
49
+ CACHE_DIR = './preprocessed_cache'
50
+ os.makedirs(SAVE_DIR, exist_ok=True)
51
+ os.makedirs(CACHE_DIR, exist_ok=True)
52
+
53
+ EPOCHS = 20
54
+ WARMUP_EPOCHS = 3 # heads-only warmup
55
+ LR_WARMUP_STEPS = 3 # linear warmup epochs after unfreeze
56
+ BATCH_SIZE = 32 # actual batch size (stable)
57
+ ACCUM_STEPS = 2 # gradient accumulation β†’ effective batch 64
58
+ NUM_WORKERS = 8
59
+ PATIENCE = 7 # early stopping on macro-F1
60
+ FOCAL_GAMMA = 1.0 # reduced from 2.0 β€” less aggressive
61
+ IMG_SIZE = 300 # EfficientNet-B3 optimal input
62
+
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
65
+ NUM_CLASSES = len(CLASS_NAMES)
66
+
67
+ print('='*65)
68
+ print(' RetinaSense v2 β€” Production Pipeline')
69
+ print('='*65)
70
+ if torch.cuda.is_available():
71
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
72
+ print(f' VRAM : {round(torch.cuda.get_device_properties(0).total_memory/1e9,1)} GB')
73
+ print(f' Epochs : {EPOCHS}')
74
+ print(f' Batch : {BATCH_SIZE} (effective {BATCH_SIZE*ACCUM_STEPS} via grad accum)')
75
+ print(f' Image Size : {IMG_SIZE}')
76
+ print(f' Focal Loss Ξ³: {FOCAL_GAMMA} (mild β€” avoids over-correction)')
77
+ print(f' Early Stop : patience={PATIENCE} on macro-F1')
78
+ print('='*65)
79
+
80
+ # ═══════════════════════════════════════════════════════════
81
+ # 1 METADATA
82
+ # ═══════════════════════════════════════════════════════════
83
+ print('\n[1/7] Building metadata...')
84
+ BASE = './'
85
+ disease_cols = ['N','D','G','C','A']
86
+ label_map = {'N':0,'D':1,'G':2,'C':3,'A':4}
87
+
88
+ df_odir = pd.read_csv(f'{BASE}/odir/full_df.csv')
89
+ df_odir['disease_count'] = df_odir[disease_cols].sum(axis=1)
90
+ df_odir = df_odir[df_odir['disease_count']==1].copy()
91
+ def get_label(row):
92
+ for d in disease_cols:
93
+ if row[d]==1: return label_map[d]
94
+ df_odir['disease_label'] = df_odir.apply(get_label, axis=1)
95
+
96
+ img_col = next(c for c in df_odir.columns
97
+ if any(k in c.lower() for k in ['filename','fundus','image']))
98
+
99
+ odir_meta = pd.DataFrame({
100
+ 'image_path': f'{BASE}/odir/preprocessed_images/'+df_odir[img_col].astype(str),
101
+ 'dataset': 'ODIR',
102
+ 'disease_label': df_odir['disease_label'],
103
+ 'severity_label':-1
104
+ })
105
+
106
+ df_aptos = pd.read_csv(f'{BASE}/aptos/train.csv')
107
+ aptos_meta = pd.DataFrame({
108
+ 'image_path': f'{BASE}/aptos/train_images/'+df_aptos['id_code']+'.png',
109
+ 'dataset': 'APTOS',
110
+ 'disease_label': 1,
111
+ 'severity_label':df_aptos['diagnosis']
112
+ })
113
+
114
+ meta = pd.concat([odir_meta, aptos_meta], ignore_index=True)
115
+ meta = meta[meta['image_path'].apply(os.path.exists)].reset_index(drop=True)
116
+ print(f' Total samples: {len(meta)}')
117
+ dist = meta['disease_label'].value_counts().sort_index()
118
+ for i,cnt in dist.items():
119
+ print(f' {CLASS_NAMES[i]:15s}: {cnt:4d} ({100*cnt/len(meta):.1f}%)')
120
+
121
+ # ═══════════════════════════════════════════════════════════
122
+ # 2 PRE-CACHE
123
+ # ═══════════════════════════════════════════════════════════
124
+ print(f'\n[2/7] Pre-caching @ {IMG_SIZE}Γ—{IMG_SIZE}...')
125
+
126
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
127
+ img = cv2.imread(path)
128
+ if img is None:
129
+ img = np.array(Image.open(path).convert('RGB'))
130
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
131
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
132
+ img = cv2.resize(img, (sz, sz))
133
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img,(0,0),sigma), -4, 128)
134
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
135
+ cv2.circle(mask, (sz//2, sz//2), int(sz*0.48), 255, -1)
136
+ return cv2.bitwise_and(img, img, mask=mask)
137
+
138
+ cache_paths = []
139
+ cached = 0
140
+ for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Caching'):
141
+ stem = os.path.splitext(os.path.basename(row['image_path']))[0]
142
+ fp = f'{CACHE_DIR}/{stem}_{IMG_SIZE}.npy'
143
+ if not os.path.exists(fp):
144
+ try:
145
+ np.save(fp, ben_graham(row['image_path']))
146
+ except Exception:
147
+ np.save(fp, np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8))
148
+ cached += 1
149
+ cache_paths.append(fp)
150
+ meta['cache_path'] = cache_paths
151
+ print(f' Newly cached: {cached} | Already cached: {len(meta)-cached}')
152
+
153
+ # ═══════════════════════════════════════════════════════════
154
+ # 3 DATASET + LOADERS
155
+ # ═══════════════════════════════════════════════════════════
156
+ print('\n[3/7] Creating data loaders...')
157
+
158
+ train_df, val_df = train_test_split(
159
+ meta, test_size=0.2, stratify=meta['disease_label'], random_state=42)
160
+
161
+ def make_transforms(phase):
162
+ if phase == 'train':
163
+ return transforms.Compose([
164
+ transforms.ToPILImage(),
165
+ transforms.RandomHorizontalFlip(),
166
+ transforms.RandomVerticalFlip(p=0.3),
167
+ transforms.RandomRotation(20),
168
+ transforms.RandomAffine(degrees=0, translate=(0.05,0.05), scale=(0.95,1.05)),
169
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
170
+ transforms.ToTensor(),
171
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
172
+ transforms.RandomErasing(p=0.2),
173
+ ])
174
+ return transforms.Compose([
175
+ transforms.ToPILImage(),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
178
+ ])
179
+
180
+ class RetDS(Dataset):
181
+ def __init__(self, df, tfm):
182
+ self.df = df.reset_index(drop=True)
183
+ self.tfm = tfm
184
+ def __len__(self): return len(self.df)
185
+ def __getitem__(self, i):
186
+ r = self.df.iloc[i]
187
+ try: img = np.load(r['cache_path'])
188
+ except: img = np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8)
189
+ return (self.tfm(img),
190
+ torch.tensor(int(r['disease_label']), dtype=torch.long),
191
+ torch.tensor(int(r['severity_label']), dtype=torch.long))
192
+
193
+ train_ds = RetDS(train_df, make_transforms('train'))
194
+ val_ds = RetDS(val_df, make_transforms('val'))
195
+
196
+ # Use shuffle (not WeightedRandomSampler β€” that over-corrects with focal loss)
197
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
198
+ num_workers=NUM_WORKERS, pin_memory=True,
199
+ persistent_workers=True, prefetch_factor=2)
200
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
201
+ num_workers=NUM_WORKERS, pin_memory=True,
202
+ persistent_workers=True)
203
+
204
+ print(f' Train : {len(train_ds):5d} ({len(train_loader):3d} batches)')
205
+ print(f' Val : {len(val_ds):5d} ({len(val_loader):3d} batches)')
206
+ print(f' ⚑ Focal Loss + class weights handle imbalance (no oversampling)')
207
+
208
+ # ═══════════════════════════════════════════════════════════
209
+ # 4 MODEL + FOCAL LOSS
210
+ # ═══════════════════════════════════════════════════════════
211
+ print('\n[4/7] Building model...')
212
+
213
+ class FocalLoss(nn.Module):
214
+ """Focal Loss β€” down-weights easy examples, focuses on hard ones."""
215
+ def __init__(self, alpha=None, gamma=2.0):
216
+ super().__init__()
217
+ self.gamma = gamma
218
+ if alpha is not None:
219
+ self.register_buffer('alpha', alpha)
220
+ else:
221
+ self.alpha = None
222
+
223
+ def forward(self, logits, targets):
224
+ ce = F.cross_entropy(logits, targets, reduction='none')
225
+ pt = torch.exp(-ce)
226
+ focal = ((1 - pt) ** self.gamma) * ce
227
+ if self.alpha is not None:
228
+ at = self.alpha.gather(0, targets)
229
+ focal = at * focal
230
+ return focal.mean()
231
+
232
+
233
+ class MultiTaskModel(nn.Module):
234
+ def __init__(self, n_disease=5, n_severity=5, drop=0.4):
235
+ super().__init__()
236
+ bb = models.efficientnet_b3(weights='IMAGENET1K_V1')
237
+ self.backbone = nn.Sequential(*list(bb.children())[:-1])
238
+ feat = 1536
239
+ self.drop = nn.Dropout(drop)
240
+ self.disease_head = nn.Sequential(
241
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
242
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
243
+ nn.Linear(256, n_disease))
244
+ self.severity_head = nn.Sequential(
245
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
246
+ nn.Linear(256, n_severity))
247
+
248
+ def forward(self, x):
249
+ f = self.backbone(x).flatten(1)
250
+ f = self.drop(f)
251
+ return self.disease_head(f), self.severity_head(f)
252
+
253
+ model = MultiTaskModel().to(device)
254
+
255
+ # class-weight alpha for focal loss
256
+ cw = compute_class_weight('balanced', classes=np.arange(5), y=train_df['disease_label'].values)
257
+ alpha = torch.tensor(cw, dtype=torch.float32).to(device)
258
+ alpha = alpha / alpha.sum() * NUM_CLASSES # normalize
259
+ print(f' Focal Ξ±: {[f"{a:.2f}" for a in alpha.tolist()]}')
260
+
261
+ criterion_d = FocalLoss(alpha=alpha, gamma=FOCAL_GAMMA)
262
+ criterion_s = nn.CrossEntropyLoss(ignore_index=-1)
263
+
264
+ total_p = sum(p.numel() for p in model.parameters())
265
+ print(f' Params: {total_p:,}')
266
+
267
+ # ═══════════════════════════════════════════════════════════
268
+ # 5 TRAINING LOOP
269
+ # ═══════════════════════════════════════════════════════════
270
+ print('\n[5/7] Training...')
271
+
272
+ # freeze backbone first
273
+ for p in model.backbone.parameters():
274
+ p.requires_grad = False
275
+
276
+ optimizer = torch.optim.AdamW(
277
+ filter(lambda p: p.requires_grad, model.parameters()),
278
+ lr=3e-4, weight_decay=1e-3)
279
+ scaler = GradScaler()
280
+
281
+ def get_scheduler(opt, warmup_steps, total_steps):
282
+ """Linear warmup then cosine decay."""
283
+ def lr_lambda(step):
284
+ if step < warmup_steps:
285
+ return float(step) / max(1, warmup_steps)
286
+ progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
287
+ return max(0.05, 0.5 * (1.0 + np.cos(np.pi * progress)))
288
+ return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
289
+
290
+ CHECKPOINT = f'{SAVE_DIR}/best_model.pth'
291
+
292
+ history = {k:[] for k in [
293
+ 'train_loss','val_loss','train_acc','val_acc',
294
+ 'macro_f1','weighted_f1','lr',
295
+ *(f'f1_{c}' for c in CLASS_NAMES)
296
+ ]}
297
+
298
+ best_f1 = 0.0
299
+ patience_ctr = 0
300
+ total_steps = EPOCHS * len(train_loader) // ACCUM_STEPS
301
+ sched = get_scheduler(optimizer, warmup_steps=len(train_loader)//ACCUM_STEPS, total_steps=total_steps)
302
+
303
+ print('='*65)
304
+
305
+ for epoch in range(EPOCHS):
306
+ t0 = time.time()
307
+
308
+ # ── Unfreeze backbone after warmup ──
309
+ if epoch == WARMUP_EPOCHS:
310
+ print('\n πŸ”“ Unfreezing backbone with LR warmup')
311
+ for p in model.backbone.parameters():
312
+ p.requires_grad = True
313
+ # new optimizer for full model with lower LR for backbone
314
+ optimizer = torch.optim.AdamW([
315
+ {'params': model.backbone.parameters(), 'lr': 1e-5},
316
+ {'params': model.disease_head.parameters(), 'lr': 1e-4},
317
+ {'params': model.severity_head.parameters(), 'lr': 1e-4},
318
+ ], weight_decay=1e-3)
319
+ remaining = (EPOCHS - WARMUP_EPOCHS) * len(train_loader) // ACCUM_STEPS
320
+ sched = get_scheduler(optimizer,
321
+ warmup_steps=LR_WARMUP_STEPS * len(train_loader) // ACCUM_STEPS,
322
+ total_steps=remaining)
323
+ scaler = GradScaler()
324
+
325
+ # ── TRAIN ──
326
+ model.train()
327
+ run_loss = 0.0
328
+ correct = 0
329
+ total = 0
330
+ optimizer.zero_grad(set_to_none=True)
331
+
332
+ pbar = tqdm(train_loader, desc=f'E{epoch+1:02d}/{EPOCHS} train', leave=False)
333
+ for step, (imgs, d_lbl, s_lbl) in enumerate(pbar):
334
+ imgs = imgs.to(device, non_blocking=True)
335
+ d_lbl = d_lbl.to(device, non_blocking=True)
336
+ s_lbl = s_lbl.to(device, non_blocking=True)
337
+
338
+ with autocast('cuda'):
339
+ d_out, s_out = model(imgs)
340
+ loss_d = criterion_d(d_out, d_lbl)
341
+ loss_s = criterion_s(s_out, s_lbl)
342
+ loss = (loss_d + 0.2 * loss_s) / ACCUM_STEPS
343
+
344
+ # check for NaN
345
+ if torch.isnan(loss) or torch.isinf(loss):
346
+ optimizer.zero_grad(set_to_none=True)
347
+ continue
348
+
349
+ scaler.scale(loss).backward()
350
+
351
+ if (step + 1) % ACCUM_STEPS == 0:
352
+ scaler.unscale_(optimizer)
353
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
354
+ scaler.step(optimizer)
355
+ scaler.update()
356
+ optimizer.zero_grad(set_to_none=True)
357
+ sched.step()
358
+
359
+ run_loss += loss.item() * ACCUM_STEPS
360
+ preds = d_out.argmax(1)
361
+ correct += (preds == d_lbl).sum().item()
362
+ total += d_lbl.size(0)
363
+ pbar.set_postfix(loss=f'{loss.item()*ACCUM_STEPS:.3f}',
364
+ acc=f'{100*correct/total:.1f}%')
365
+
366
+ train_loss = run_loss / len(train_loader)
367
+ train_acc = 100 * correct / total
368
+
369
+ # ── VALIDATE ──
370
+ model.eval()
371
+ vl = 0.0
372
+ all_p, all_t, all_prob = [], [], []
373
+ with torch.no_grad():
374
+ for imgs, d_lbl, s_lbl in tqdm(val_loader, desc=f'E{epoch+1:02d}/{EPOCHS} val ', leave=False):
375
+ imgs = imgs.to(device, non_blocking=True)
376
+ d_lbl = d_lbl.to(device, non_blocking=True)
377
+ s_lbl = s_lbl.to(device, non_blocking=True)
378
+ with autocast('cuda'):
379
+ d_out, s_out = model(imgs)
380
+ ld = criterion_d(d_out, d_lbl)
381
+ ls = criterion_s(s_out, s_lbl)
382
+ loss = ld + 0.2 * ls
383
+ if not (torch.isnan(loss) or torch.isinf(loss)):
384
+ vl += loss.item()
385
+ probs = torch.softmax(d_out.float(), dim=1)
386
+ all_p.extend(d_out.argmax(1).cpu().numpy())
387
+ all_t.extend(d_lbl.cpu().numpy())
388
+ all_prob.extend(probs.cpu().numpy())
389
+
390
+ val_loss = vl / len(val_loader)
391
+ all_p, all_t, all_prob = np.array(all_p), np.array(all_t), np.array(all_prob)
392
+ val_acc = 100 * (all_p == all_t).mean()
393
+
394
+ mf1 = f1_score(all_t, all_p, average='macro')
395
+ wf1 = f1_score(all_t, all_p, average='weighted')
396
+ per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
397
+
398
+ lr = optimizer.param_groups[0]['lr']
399
+
400
+ history['train_loss'].append(train_loss)
401
+ history['val_loss'].append(val_loss)
402
+ history['train_acc'].append(train_acc)
403
+ history['val_acc'].append(val_acc)
404
+ history['macro_f1'].append(mf1)
405
+ history['weighted_f1'].append(wf1)
406
+ history['lr'].append(lr)
407
+ for ci, cn in enumerate(CLASS_NAMES):
408
+ history[f'f1_{cn}'].append(per_f1[ci])
409
+
410
+ elapsed = time.time() - t0
411
+
412
+ tag = ''
413
+ if mf1 > best_f1:
414
+ best_f1 = mf1
415
+ patience_ctr = 0
416
+ torch.save({
417
+ 'epoch': epoch, 'model_state_dict': model.state_dict(),
418
+ 'val_acc': val_acc, 'macro_f1': mf1, 'history': history
419
+ }, CHECKPOINT)
420
+ tag = f' β˜… NEW BEST (macro-F1={mf1:.4f})'
421
+ else:
422
+ patience_ctr += 1
423
+
424
+ cls_str = ' | '.join(f'{cn[:3]}:{per_f1[ci]:.2f}' for ci,cn in enumerate(CLASS_NAMES))
425
+ print(f'E{epoch+1:02d} | {elapsed:.0f}s | LR {lr:.1e} | '
426
+ f'TrL {train_loss:.3f} TrA {train_acc:.1f}% | '
427
+ f'VL {val_loss:.3f} VA {val_acc:.1f}% | '
428
+ f'mF1 {mf1:.3f} wF1 {wf1:.3f}{tag}')
429
+ print(f' {cls_str}')
430
+
431
+ if patience_ctr >= PATIENCE:
432
+ print(f'\n ⏹ Early stopping β€” no improvement for {PATIENCE} epochs')
433
+ break
434
+
435
+ print(f'\nβœ… Training done. Best macro-F1: {best_f1:.4f}')
436
+
437
+ # ═══════════════════════════════════════════════════════════
438
+ # 6 EVALUATION + PLOTS
439
+ # ═══════════════════════════════════════════════════════════
440
+ print('\n[6/7] Full evaluation...')
441
+
442
+ ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=False)
443
+ model.load_state_dict(ckpt['model_state_dict'])
444
+ model.eval()
445
+ history = ckpt['history']
446
+
447
+ all_p, all_t, all_prob = [], [], []
448
+ with torch.no_grad():
449
+ for imgs, d_lbl, _ in tqdm(val_loader, desc='Evaluating'):
450
+ imgs = imgs.to(device)
451
+ d_out, _ = model(imgs)
452
+ all_p.extend(d_out.argmax(1).cpu().numpy())
453
+ all_t.extend(d_lbl.numpy())
454
+ all_prob.extend(torch.softmax(d_out.float(), dim=1).cpu().numpy())
455
+
456
+ all_p = np.array(all_p)
457
+ all_t = np.array(all_t)
458
+ all_prob = np.array(all_prob)
459
+
460
+ print('\n' + '='*65)
461
+ print(' CLASSIFICATION REPORT')
462
+ print('='*65)
463
+ report = classification_report(all_t, all_p, target_names=CLASS_NAMES, digits=4)
464
+ print(report)
465
+ mf1 = f1_score(all_t, all_p, average='macro')
466
+ wf1 = f1_score(all_t, all_p, average='weighted')
467
+ try: mauc = roc_auc_score(all_t, all_prob, multi_class='ovr', average='macro')
468
+ except: mauc = 0.0
469
+ print(f'Macro F1 : {mf1:.4f}')
470
+ print(f'Weighted F1 : {wf1:.4f}')
471
+ print(f'Macro AUC : {mauc:.4f}')
472
+
473
+ # ═══════════════════════════════════════════════════════════
474
+ # 7 COMPREHENSIVE PLOTS
475
+ # ═══════════════════════════════════════════════════════════
476
+ print('\n[7/7] Generating plots...')
477
+ ep = range(1, len(history['train_loss'])+1)
478
+ colors = ['#2ecc71','#3498db','#e74c3c','#f39c12','#9b59b6']
479
+
480
+ fig, axes = plt.subplots(2, 3, figsize=(20, 12))
481
+
482
+ # ── 1. Loss ──
483
+ axes[0,0].plot(ep, history['train_loss'], 'b-o', ms=4, label='Train')
484
+ axes[0,0].plot(ep, history['val_loss'], 'r-o', ms=4, label='Val')
485
+ axes[0,0].set_title('Loss', fontweight='bold')
486
+ axes[0,0].legend(); axes[0,0].grid(alpha=.3)
487
+
488
+ # ── 2. Accuracy ──
489
+ axes[0,1].plot(ep, history['train_acc'], 'b-o', ms=4, label='Train')
490
+ axes[0,1].plot(ep, history['val_acc'], 'r-o', ms=4, label='Val')
491
+ axes[0,1].set_title('Accuracy (%)', fontweight='bold')
492
+ axes[0,1].legend(); axes[0,1].grid(alpha=.3)
493
+
494
+ # ── 3. Macro / Weighted F1 ──
495
+ axes[0,2].plot(ep, history['macro_f1'], 'g-o', ms=4, label='Macro F1')
496
+ axes[0,2].plot(ep, history['weighted_f1'], 'm-o', ms=4, label='Weighted F1')
497
+ axes[0,2].set_title('F1 Scores', fontweight='bold')
498
+ axes[0,2].legend(); axes[0,2].grid(alpha=.3)
499
+
500
+ # ── 4. Per-class F1 ──
501
+ for ci, cn in enumerate(CLASS_NAMES):
502
+ axes[1,0].plot(ep, history[f'f1_{cn}'], '-o', ms=3, color=colors[ci], label=cn)
503
+ axes[1,0].set_title('Per-Class F1', fontweight='bold')
504
+ axes[1,0].legend(); axes[1,0].grid(alpha=.3)
505
+
506
+ # ── 5. Confusion Matrix ──
507
+ cm = confusion_matrix(all_t, all_p)
508
+ cm_n = cm.astype(float) / cm.sum(axis=1, keepdims=True)
509
+ sns.heatmap(cm_n, annot=True, fmt='.2f', cmap='Blues', ax=axes[1,1],
510
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
511
+ axes[1,1].set_title('Confusion Matrix (norm)', fontweight='bold')
512
+ axes[1,1].set_ylabel('True'); axes[1,1].set_xlabel('Pred')
513
+
514
+ # ── 6. ROC ──
515
+ y_bin = label_binarize(all_t, classes=list(range(NUM_CLASSES)))
516
+ for ci, (cn, col) in enumerate(zip(CLASS_NAMES, colors)):
517
+ fpr, tpr, _ = roc_curve(y_bin[:,ci], all_prob[:,ci])
518
+ axes[1,2].plot(fpr, tpr, color=col, lw=2, label=f'{cn} ({auc(fpr,tpr):.3f})')
519
+ axes[1,2].plot([0,1],[0,1],'k--',lw=1)
520
+ axes[1,2].set_title('ROC Curves', fontweight='bold')
521
+ axes[1,2].legend(loc='lower right', fontsize=8)
522
+ axes[1,2].grid(alpha=.3)
523
+
524
+ plt.suptitle(f'RetinaSense v2 β€” Macro F1={mf1:.3f} | AUC={mauc:.3f} | Val Acc={100*(all_p==all_t).mean():.1f}%',
525
+ fontsize=15, fontweight='bold', y=1.01)
526
+ plt.tight_layout()
527
+ plt.savefig(f'{SAVE_DIR}/dashboard.png', dpi=150, bbox_inches='tight')
528
+ plt.close()
529
+
530
+ # LR schedule plot
531
+ fig, ax = plt.subplots(figsize=(8,3))
532
+ ax.plot(ep, history['lr'], 'b-o', ms=3)
533
+ ax.set_title('Learning Rate Schedule', fontweight='bold')
534
+ ax.set_xlabel('Epoch'); ax.set_ylabel('LR')
535
+ ax.grid(alpha=.3)
536
+ plt.tight_layout()
537
+ plt.savefig(f'{SAVE_DIR}/lr_schedule.png', dpi=150)
538
+ plt.close()
539
+
540
+ # Save metrics
541
+ pd.DataFrame([{
542
+ 'val_accuracy': 100*(all_p==all_t).mean(),
543
+ 'macro_f1': mf1, 'weighted_f1': wf1, 'macro_auc': mauc,
544
+ **{f'f1_{cn}': f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES))[ci]
545
+ for ci,cn in enumerate(CLASS_NAMES)}
546
+ }]).to_csv(f'{SAVE_DIR}/metrics.csv', index=False)
547
+
548
+ # Save history
549
+ with open(f'{SAVE_DIR}/history.json','w') as f:
550
+ json.dump({k:[float(v) for v in vs] for k,vs in history.items()}, f, indent=2)
551
+
552
+ print(f'\n{"="*65}')
553
+ print(f' RETINASENSE v2 β€” FINAL RESULTS')
554
+ print(f'{"="*65}')
555
+ print(f' Best Macro F1 : {best_f1:.4f}')
556
+ print(f' Val Accuracy : {100*(all_p==all_t).mean():.2f}%')
557
+ print(f' Macro AUC : {mauc:.4f}')
558
+ per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
559
+ for ci, cn in enumerate(CLASS_NAMES):
560
+ print(f' {cn:15s}: F1={per_f1[ci]:.3f}')
561
+ print(f'{"="*65}')
562
+ print(f'\nπŸ“ {SAVE_DIR}/')
563
+ print(f' β”œβ”€β”€ best_model.pth')
564
+ print(f' β”œβ”€β”€ dashboard.png')
565
+ print(f' β”œβ”€β”€ lr_schedule.png')
566
+ print(f' β”œβ”€β”€ metrics.csv')
567
+ print(f' └── history.json')