tanishq74 commited on
Commit
458ff46
·
verified ·
1 Parent(s): c3e6a8f

Add retinasense_vit.py

Browse files
Files changed (1) hide show
  1. retinasense_vit.py +581 -0
retinasense_vit.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RetinaSense ViT — Vision Transformer Variant
4
+ =============================================
5
+ Based on RetinaSense v2 pipeline, replacing EfficientNet-B3 with
6
+ ViT-Base-Patch16-224 from the timm library.
7
+
8
+ Key changes from v2:
9
+ - Backbone: ViT-Base-Patch16-224 (timm) instead of EfficientNet-B3
10
+ - Feature dimension: 768 (ViT) instead of 1536 (EfficientNet-B3)
11
+ - Image size: 224x224 (ViT native) instead of 300x300
12
+ - EPOCHS: 30, PATIENCE: 10
13
+ - Output directory: ./outputs_vit
14
+
15
+ Everything else preserved from v2:
16
+ - Focal Loss with class weights
17
+ - Multi-task architecture (disease + severity heads)
18
+ - LR warmup + cosine decay
19
+ - Gradient accumulation
20
+ - Early stopping on Macro F1
21
+ - Pre-cached preprocessing
22
+ - Comprehensive plots
23
+ """
24
+
25
+ import os, sys, time, warnings, json
26
+ import numpy as np
27
+ import pandas as pd
28
+ import cv2
29
+ import matplotlib
30
+ matplotlib.use('Agg')
31
+ import matplotlib.pyplot as plt
32
+ import seaborn as sns
33
+ from PIL import Image
34
+ from tqdm import tqdm
35
+ from collections import Counter
36
+ warnings.filterwarnings('ignore')
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+ import torch.nn.functional as F
41
+ from torch.amp import autocast, GradScaler
42
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
43
+ from torchvision import transforms
44
+
45
+ import timm
46
+
47
+ from sklearn.model_selection import train_test_split
48
+ from sklearn.utils.class_weight import compute_class_weight
49
+ from sklearn.metrics import (
50
+ classification_report, confusion_matrix,
51
+ roc_auc_score, f1_score, roc_curve, auc
52
+ )
53
+ from sklearn.preprocessing import label_binarize
54
+
55
+ # ===============================================================
56
+ # CONFIG
57
+ # ===============================================================
58
+ SAVE_DIR = './outputs_vit'
59
+ CACHE_DIR = './preprocessed_cache_vit'
60
+ os.makedirs(SAVE_DIR, exist_ok=True)
61
+ os.makedirs(CACHE_DIR, exist_ok=True)
62
+
63
+ EPOCHS = 30
64
+ WARMUP_EPOCHS = 3 # heads-only warmup
65
+ LR_WARMUP_STEPS = 3 # linear warmup epochs after unfreeze
66
+ BATCH_SIZE = 32 # actual batch size
67
+ ACCUM_STEPS = 2 # gradient accumulation -> effective batch 64
68
+ NUM_WORKERS = 8
69
+ PATIENCE = 10 # early stopping on macro-F1
70
+ FOCAL_GAMMA = 1.0 # reduced from 2.0 -- less aggressive
71
+ IMG_SIZE = 224 # ViT native resolution
72
+
73
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
74
+ CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
75
+ NUM_CLASSES = len(CLASS_NAMES)
76
+
77
+ print('='*65)
78
+ print(' RetinaSense ViT -- Vision Transformer Pipeline')
79
+ print('='*65)
80
+ if torch.cuda.is_available():
81
+ print(f' GPU : {torch.cuda.get_device_name(0)}')
82
+ print(f' VRAM : {round(torch.cuda.get_device_properties(0).total_memory/1e9,1)} GB')
83
+ print(f' Backbone : ViT-Base-Patch16-224 (timm)')
84
+ print(f' Epochs : {EPOCHS}')
85
+ print(f' Batch : {BATCH_SIZE} (effective {BATCH_SIZE*ACCUM_STEPS} via grad accum)')
86
+ print(f' Image Size : {IMG_SIZE}')
87
+ print(f' Focal Loss g: {FOCAL_GAMMA} (mild -- avoids over-correction)')
88
+ print(f' Early Stop : patience={PATIENCE} on macro-F1')
89
+ print('='*65)
90
+
91
+ # ===============================================================
92
+ # 1 METADATA
93
+ # ===============================================================
94
+ print('\n[1/7] Building metadata...')
95
+ BASE = './'
96
+ disease_cols = ['N','D','G','C','A']
97
+ label_map = {'N':0,'D':1,'G':2,'C':3,'A':4}
98
+
99
+ df_odir = pd.read_csv(f'{BASE}/odir/full_df.csv')
100
+ df_odir['disease_count'] = df_odir[disease_cols].sum(axis=1)
101
+ df_odir = df_odir[df_odir['disease_count']==1].copy()
102
+ def get_label(row):
103
+ for d in disease_cols:
104
+ if row[d]==1: return label_map[d]
105
+ df_odir['disease_label'] = df_odir.apply(get_label, axis=1)
106
+
107
+ img_col = next(c for c in df_odir.columns
108
+ if any(k in c.lower() for k in ['filename','fundus','image']))
109
+
110
+ odir_meta = pd.DataFrame({
111
+ 'image_path': f'{BASE}/odir/preprocessed_images/'+df_odir[img_col].astype(str),
112
+ 'dataset': 'ODIR',
113
+ 'disease_label': df_odir['disease_label'],
114
+ 'severity_label':-1
115
+ })
116
+
117
+ df_aptos = pd.read_csv(f'{BASE}/aptos/train.csv')
118
+ aptos_meta = pd.DataFrame({
119
+ 'image_path': f'{BASE}/aptos/train_images/'+df_aptos['id_code']+'.png',
120
+ 'dataset': 'APTOS',
121
+ 'disease_label': 1,
122
+ 'severity_label':df_aptos['diagnosis']
123
+ })
124
+
125
+ meta = pd.concat([odir_meta, aptos_meta], ignore_index=True)
126
+ meta = meta[meta['image_path'].apply(os.path.exists)].reset_index(drop=True)
127
+ print(f' Total samples: {len(meta)}')
128
+ dist = meta['disease_label'].value_counts().sort_index()
129
+ for i,cnt in dist.items():
130
+ print(f' {CLASS_NAMES[i]:15s}: {cnt:4d} ({100*cnt/len(meta):.1f}%)')
131
+
132
+ # ===============================================================
133
+ # 2 PRE-CACHE
134
+ # ===============================================================
135
+ print(f'\n[2/7] Pre-caching @ {IMG_SIZE}x{IMG_SIZE}...')
136
+
137
+ def ben_graham(path, sz=IMG_SIZE, sigma=10):
138
+ img = cv2.imread(path)
139
+ if img is None:
140
+ img = np.array(Image.open(path).convert('RGB'))
141
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
142
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
143
+ img = cv2.resize(img, (sz, sz))
144
+ img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img,(0,0),sigma), -4, 128)
145
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
146
+ cv2.circle(mask, (sz//2, sz//2), int(sz*0.48), 255, -1)
147
+ return cv2.bitwise_and(img, img, mask=mask)
148
+
149
+ cache_paths = []
150
+ cached = 0
151
+ for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Caching'):
152
+ stem = os.path.splitext(os.path.basename(row['image_path']))[0]
153
+ fp = f'{CACHE_DIR}/{stem}_{IMG_SIZE}.npy'
154
+ if not os.path.exists(fp):
155
+ try:
156
+ np.save(fp, ben_graham(row['image_path']))
157
+ except Exception:
158
+ np.save(fp, np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8))
159
+ cached += 1
160
+ cache_paths.append(fp)
161
+ meta['cache_path'] = cache_paths
162
+ print(f' Newly cached: {cached} | Already cached: {len(meta)-cached}')
163
+
164
+ # ===============================================================
165
+ # 3 DATASET + LOADERS
166
+ # ===============================================================
167
+ print('\n[3/7] Creating data loaders...')
168
+
169
+ train_df, val_df = train_test_split(
170
+ meta, test_size=0.2, stratify=meta['disease_label'], random_state=42)
171
+
172
+ def make_transforms(phase):
173
+ if phase == 'train':
174
+ return transforms.Compose([
175
+ transforms.ToPILImage(),
176
+ transforms.RandomHorizontalFlip(),
177
+ transforms.RandomVerticalFlip(p=0.3),
178
+ transforms.RandomRotation(20),
179
+ transforms.RandomAffine(degrees=0, translate=(0.05,0.05), scale=(0.95,1.05)),
180
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
181
+ transforms.ToTensor(),
182
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
183
+ transforms.RandomErasing(p=0.2),
184
+ ])
185
+ return transforms.Compose([
186
+ transforms.ToPILImage(),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
189
+ ])
190
+
191
+ class RetDS(Dataset):
192
+ def __init__(self, df, tfm):
193
+ self.df = df.reset_index(drop=True)
194
+ self.tfm = tfm
195
+ def __len__(self): return len(self.df)
196
+ def __getitem__(self, i):
197
+ r = self.df.iloc[i]
198
+ try: img = np.load(r['cache_path'])
199
+ except: img = np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8)
200
+ return (self.tfm(img),
201
+ torch.tensor(int(r['disease_label']), dtype=torch.long),
202
+ torch.tensor(int(r['severity_label']), dtype=torch.long))
203
+
204
+ train_ds = RetDS(train_df, make_transforms('train'))
205
+ val_ds = RetDS(val_df, make_transforms('val'))
206
+
207
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
208
+ num_workers=NUM_WORKERS, pin_memory=True,
209
+ persistent_workers=True, prefetch_factor=2)
210
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
211
+ num_workers=NUM_WORKERS, pin_memory=True,
212
+ persistent_workers=True)
213
+
214
+ print(f' Train : {len(train_ds):5d} ({len(train_loader):3d} batches)')
215
+ print(f' Val : {len(val_ds):5d} ({len(val_loader):3d} batches)')
216
+ print(f' Focal Loss + class weights handle imbalance (no oversampling)')
217
+
218
+ # ===============================================================
219
+ # 4 MODEL + FOCAL LOSS
220
+ # ===============================================================
221
+ print('\n[4/7] Building ViT model...')
222
+
223
+ class FocalLoss(nn.Module):
224
+ """Focal Loss -- down-weights easy examples, focuses on hard ones."""
225
+ def __init__(self, alpha=None, gamma=2.0):
226
+ super().__init__()
227
+ self.gamma = gamma
228
+ if alpha is not None:
229
+ self.register_buffer('alpha', alpha)
230
+ else:
231
+ self.alpha = None
232
+
233
+ def forward(self, logits, targets):
234
+ ce = F.cross_entropy(logits, targets, reduction='none')
235
+ pt = torch.exp(-ce)
236
+ focal = ((1 - pt) ** self.gamma) * ce
237
+ if self.alpha is not None:
238
+ at = self.alpha.gather(0, targets)
239
+ focal = at * focal
240
+ return focal.mean()
241
+
242
+
243
+ class MultiTaskViT(nn.Module):
244
+ def __init__(self, n_disease=5, n_severity=5, drop=0.4):
245
+ super().__init__()
246
+ # ViT-Base-Patch16-224: outputs 768-dim features
247
+ self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
248
+ feat = 768
249
+ self.drop = nn.Dropout(drop)
250
+ self.disease_head = nn.Sequential(
251
+ nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
252
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
253
+ nn.Linear(256, n_disease))
254
+ self.severity_head = nn.Sequential(
255
+ nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
256
+ nn.Linear(256, n_severity))
257
+
258
+ def forward(self, x):
259
+ f = self.backbone(x) # timm ViT with num_classes=0 returns pooled features
260
+ f = self.drop(f)
261
+ return self.disease_head(f), self.severity_head(f)
262
+
263
+ model = MultiTaskViT().to(device)
264
+
265
+ # class-weight alpha for focal loss
266
+ cw = compute_class_weight('balanced', classes=np.arange(5), y=train_df['disease_label'].values)
267
+ alpha = torch.tensor(cw, dtype=torch.float32).to(device)
268
+ alpha = alpha / alpha.sum() * NUM_CLASSES # normalize
269
+ print(f' Focal a: {[f"{a:.2f}" for a in alpha.tolist()]}')
270
+
271
+ criterion_d = FocalLoss(alpha=alpha, gamma=FOCAL_GAMMA)
272
+ criterion_s = nn.CrossEntropyLoss(ignore_index=-1)
273
+
274
+ total_p = sum(p.numel() for p in model.parameters())
275
+ print(f' Params: {total_p:,}')
276
+
277
+ # ===============================================================
278
+ # 5 TRAINING LOOP
279
+ # ===============================================================
280
+ print('\n[5/7] Training...')
281
+
282
+ # freeze backbone first
283
+ for p in model.backbone.parameters():
284
+ p.requires_grad = False
285
+
286
+ optimizer = torch.optim.AdamW(
287
+ filter(lambda p: p.requires_grad, model.parameters()),
288
+ lr=3e-4, weight_decay=1e-3)
289
+ scaler = GradScaler()
290
+
291
+ def get_scheduler(opt, warmup_steps, total_steps):
292
+ """Linear warmup then cosine decay."""
293
+ def lr_lambda(step):
294
+ if step < warmup_steps:
295
+ return float(step) / max(1, warmup_steps)
296
+ progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
297
+ return max(0.05, 0.5 * (1.0 + np.cos(np.pi * progress)))
298
+ return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
299
+
300
+ CHECKPOINT = f'{SAVE_DIR}/best_model.pth'
301
+
302
+ history = {k:[] for k in [
303
+ 'train_loss','val_loss','train_acc','val_acc',
304
+ 'macro_f1','weighted_f1','lr',
305
+ *(f'f1_{c}' for c in CLASS_NAMES)
306
+ ]}
307
+
308
+ best_f1 = 0.0
309
+ patience_ctr = 0
310
+ total_steps = EPOCHS * len(train_loader) // ACCUM_STEPS
311
+ sched = get_scheduler(optimizer, warmup_steps=len(train_loader)//ACCUM_STEPS, total_steps=total_steps)
312
+
313
+ t_start = time.time()
314
+ print('='*65)
315
+
316
+ for epoch in range(EPOCHS):
317
+ t0 = time.time()
318
+
319
+ # -- Unfreeze backbone after warmup --
320
+ if epoch == WARMUP_EPOCHS:
321
+ print('\n Unfreezing ViT backbone with LR warmup')
322
+ for p in model.backbone.parameters():
323
+ p.requires_grad = True
324
+ # new optimizer for full model with lower LR for backbone
325
+ optimizer = torch.optim.AdamW([
326
+ {'params': model.backbone.parameters(), 'lr': 1e-5},
327
+ {'params': model.disease_head.parameters(), 'lr': 1e-4},
328
+ {'params': model.severity_head.parameters(), 'lr': 1e-4},
329
+ ], weight_decay=1e-3)
330
+ remaining = (EPOCHS - WARMUP_EPOCHS) * len(train_loader) // ACCUM_STEPS
331
+ sched = get_scheduler(optimizer,
332
+ warmup_steps=LR_WARMUP_STEPS * len(train_loader) // ACCUM_STEPS,
333
+ total_steps=remaining)
334
+ scaler = GradScaler()
335
+
336
+ # -- TRAIN --
337
+ model.train()
338
+ run_loss = 0.0
339
+ correct = 0
340
+ total = 0
341
+ optimizer.zero_grad(set_to_none=True)
342
+
343
+ pbar = tqdm(train_loader, desc=f'E{epoch+1:02d}/{EPOCHS} train', leave=False)
344
+ for step, (imgs, d_lbl, s_lbl) in enumerate(pbar):
345
+ imgs = imgs.to(device, non_blocking=True)
346
+ d_lbl = d_lbl.to(device, non_blocking=True)
347
+ s_lbl = s_lbl.to(device, non_blocking=True)
348
+
349
+ with autocast('cuda'):
350
+ d_out, s_out = model(imgs)
351
+ loss_d = criterion_d(d_out, d_lbl)
352
+ loss_s = criterion_s(s_out, s_lbl)
353
+ loss = (loss_d + 0.2 * loss_s) / ACCUM_STEPS
354
+
355
+ # check for NaN
356
+ if torch.isnan(loss) or torch.isinf(loss):
357
+ optimizer.zero_grad(set_to_none=True)
358
+ continue
359
+
360
+ scaler.scale(loss).backward()
361
+
362
+ if (step + 1) % ACCUM_STEPS == 0:
363
+ scaler.unscale_(optimizer)
364
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
365
+ scaler.step(optimizer)
366
+ scaler.update()
367
+ optimizer.zero_grad(set_to_none=True)
368
+ sched.step()
369
+
370
+ run_loss += loss.item() * ACCUM_STEPS
371
+ preds = d_out.argmax(1)
372
+ correct += (preds == d_lbl).sum().item()
373
+ total += d_lbl.size(0)
374
+ pbar.set_postfix(loss=f'{loss.item()*ACCUM_STEPS:.3f}',
375
+ acc=f'{100*correct/total:.1f}%')
376
+
377
+ train_loss = run_loss / len(train_loader)
378
+ train_acc = 100 * correct / total
379
+
380
+ # -- VALIDATE --
381
+ model.eval()
382
+ vl = 0.0
383
+ all_p, all_t, all_prob = [], [], []
384
+ with torch.no_grad():
385
+ for imgs, d_lbl, s_lbl in tqdm(val_loader, desc=f'E{epoch+1:02d}/{EPOCHS} val ', leave=False):
386
+ imgs = imgs.to(device, non_blocking=True)
387
+ d_lbl = d_lbl.to(device, non_blocking=True)
388
+ s_lbl = s_lbl.to(device, non_blocking=True)
389
+ with autocast('cuda'):
390
+ d_out, s_out = model(imgs)
391
+ ld = criterion_d(d_out, d_lbl)
392
+ ls = criterion_s(s_out, s_lbl)
393
+ loss = ld + 0.2 * ls
394
+ if not (torch.isnan(loss) or torch.isinf(loss)):
395
+ vl += loss.item()
396
+ probs = torch.softmax(d_out.float(), dim=1)
397
+ all_p.extend(d_out.argmax(1).cpu().numpy())
398
+ all_t.extend(d_lbl.cpu().numpy())
399
+ all_prob.extend(probs.cpu().numpy())
400
+
401
+ val_loss = vl / len(val_loader)
402
+ all_p, all_t, all_prob = np.array(all_p), np.array(all_t), np.array(all_prob)
403
+ val_acc = 100 * (all_p == all_t).mean()
404
+
405
+ mf1 = f1_score(all_t, all_p, average='macro')
406
+ wf1 = f1_score(all_t, all_p, average='weighted')
407
+ per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
408
+
409
+ lr = optimizer.param_groups[0]['lr']
410
+
411
+ history['train_loss'].append(train_loss)
412
+ history['val_loss'].append(val_loss)
413
+ history['train_acc'].append(train_acc)
414
+ history['val_acc'].append(val_acc)
415
+ history['macro_f1'].append(mf1)
416
+ history['weighted_f1'].append(wf1)
417
+ history['lr'].append(lr)
418
+ for ci, cn in enumerate(CLASS_NAMES):
419
+ history[f'f1_{cn}'].append(per_f1[ci])
420
+
421
+ elapsed = time.time() - t0
422
+
423
+ tag = ''
424
+ if mf1 > best_f1:
425
+ best_f1 = mf1
426
+ patience_ctr = 0
427
+ torch.save({
428
+ 'epoch': epoch, 'model_state_dict': model.state_dict(),
429
+ 'val_acc': val_acc, 'macro_f1': mf1, 'history': history
430
+ }, CHECKPOINT)
431
+ tag = f' * NEW BEST (macro-F1={mf1:.4f})'
432
+ else:
433
+ patience_ctr += 1
434
+
435
+ cls_str = ' | '.join(f'{cn[:3]}:{per_f1[ci]:.2f}' for ci,cn in enumerate(CLASS_NAMES))
436
+ print(f'E{epoch+1:02d} | {elapsed:.0f}s | LR {lr:.1e} | '
437
+ f'TrL {train_loss:.3f} TrA {train_acc:.1f}% | '
438
+ f'VL {val_loss:.3f} VA {val_acc:.1f}% | '
439
+ f'mF1 {mf1:.3f} wF1 {wf1:.3f}{tag}')
440
+ print(f' {cls_str}')
441
+
442
+ if patience_ctr >= PATIENCE:
443
+ print(f'\n Early stopping -- no improvement for {PATIENCE} epochs')
444
+ break
445
+
446
+ total_train_time = time.time() - t_start
447
+ print(f'\nTraining done. Best macro-F1: {best_f1:.4f}')
448
+ print(f'Total training time: {total_train_time/60:.1f} minutes')
449
+
450
+ # ===============================================================
451
+ # 6 EVALUATION + PLOTS
452
+ # ===============================================================
453
+ print('\n[6/7] Full evaluation...')
454
+
455
+ ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=False)
456
+ model.load_state_dict(ckpt['model_state_dict'])
457
+ model.eval()
458
+ history = ckpt['history']
459
+
460
+ all_p, all_t, all_prob = [], [], []
461
+ with torch.no_grad():
462
+ for imgs, d_lbl, _ in tqdm(val_loader, desc='Evaluating'):
463
+ imgs = imgs.to(device)
464
+ d_out, _ = model(imgs)
465
+ all_p.extend(d_out.argmax(1).cpu().numpy())
466
+ all_t.extend(d_lbl.numpy())
467
+ all_prob.extend(torch.softmax(d_out.float(), dim=1).cpu().numpy())
468
+
469
+ all_p = np.array(all_p)
470
+ all_t = np.array(all_t)
471
+ all_prob = np.array(all_prob)
472
+
473
+ print('\n' + '='*65)
474
+ print(' CLASSIFICATION REPORT')
475
+ print('='*65)
476
+ report = classification_report(all_t, all_p, target_names=CLASS_NAMES, digits=4)
477
+ print(report)
478
+ mf1 = f1_score(all_t, all_p, average='macro')
479
+ wf1 = f1_score(all_t, all_p, average='weighted')
480
+ try: mauc = roc_auc_score(all_t, all_prob, multi_class='ovr', average='macro')
481
+ except: mauc = 0.0
482
+ print(f'Macro F1 : {mf1:.4f}')
483
+ print(f'Weighted F1 : {wf1:.4f}')
484
+ print(f'Macro AUC : {mauc:.4f}')
485
+
486
+ # ===============================================================
487
+ # 7 COMPREHENSIVE PLOTS
488
+ # ===============================================================
489
+ print('\n[7/7] Generating plots...')
490
+ ep = range(1, len(history['train_loss'])+1)
491
+ colors = ['#2ecc71','#3498db','#e74c3c','#f39c12','#9b59b6']
492
+
493
+ fig, axes = plt.subplots(2, 3, figsize=(20, 12))
494
+
495
+ # -- 1. Loss --
496
+ axes[0,0].plot(ep, history['train_loss'], 'b-o', ms=4, label='Train')
497
+ axes[0,0].plot(ep, history['val_loss'], 'r-o', ms=4, label='Val')
498
+ axes[0,0].set_title('Loss', fontweight='bold')
499
+ axes[0,0].legend(); axes[0,0].grid(alpha=.3)
500
+
501
+ # -- 2. Accuracy --
502
+ axes[0,1].plot(ep, history['train_acc'], 'b-o', ms=4, label='Train')
503
+ axes[0,1].plot(ep, history['val_acc'], 'r-o', ms=4, label='Val')
504
+ axes[0,1].set_title('Accuracy (%)', fontweight='bold')
505
+ axes[0,1].legend(); axes[0,1].grid(alpha=.3)
506
+
507
+ # -- 3. Macro / Weighted F1 --
508
+ axes[0,2].plot(ep, history['macro_f1'], 'g-o', ms=4, label='Macro F1')
509
+ axes[0,2].plot(ep, history['weighted_f1'], 'm-o', ms=4, label='Weighted F1')
510
+ axes[0,2].set_title('F1 Scores', fontweight='bold')
511
+ axes[0,2].legend(); axes[0,2].grid(alpha=.3)
512
+
513
+ # -- 4. Per-class F1 --
514
+ for ci, cn in enumerate(CLASS_NAMES):
515
+ axes[1,0].plot(ep, history[f'f1_{cn}'], '-o', ms=3, color=colors[ci], label=cn)
516
+ axes[1,0].set_title('Per-Class F1', fontweight='bold')
517
+ axes[1,0].legend(); axes[1,0].grid(alpha=.3)
518
+
519
+ # -- 5. Confusion Matrix --
520
+ cm = confusion_matrix(all_t, all_p)
521
+ cm_n = cm.astype(float) / cm.sum(axis=1, keepdims=True)
522
+ sns.heatmap(cm_n, annot=True, fmt='.2f', cmap='Blues', ax=axes[1,1],
523
+ xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
524
+ axes[1,1].set_title('Confusion Matrix (norm)', fontweight='bold')
525
+ axes[1,1].set_ylabel('True'); axes[1,1].set_xlabel('Pred')
526
+
527
+ # -- 6. ROC --
528
+ y_bin = label_binarize(all_t, classes=list(range(NUM_CLASSES)))
529
+ for ci, (cn, col) in enumerate(zip(CLASS_NAMES, colors)):
530
+ fpr, tpr, _ = roc_curve(y_bin[:,ci], all_prob[:,ci])
531
+ axes[1,2].plot(fpr, tpr, color=col, lw=2, label=f'{cn} ({auc(fpr,tpr):.3f})')
532
+ axes[1,2].plot([0,1],[0,1],'k--',lw=1)
533
+ axes[1,2].set_title('ROC Curves', fontweight='bold')
534
+ axes[1,2].legend(loc='lower right', fontsize=8)
535
+ axes[1,2].grid(alpha=.3)
536
+
537
+ plt.suptitle(f'RetinaSense ViT -- Macro F1={mf1:.3f} | AUC={mauc:.3f} | Val Acc={100*(all_p==all_t).mean():.1f}%',
538
+ fontsize=15, fontweight='bold', y=1.01)
539
+ plt.tight_layout()
540
+ plt.savefig(f'{SAVE_DIR}/dashboard.png', dpi=150, bbox_inches='tight')
541
+ plt.close()
542
+
543
+ # LR schedule plot
544
+ fig, ax = plt.subplots(figsize=(8,3))
545
+ ax.plot(ep, history['lr'], 'b-o', ms=3)
546
+ ax.set_title('Learning Rate Schedule', fontweight='bold')
547
+ ax.set_xlabel('Epoch'); ax.set_ylabel('LR')
548
+ ax.grid(alpha=.3)
549
+ plt.tight_layout()
550
+ plt.savefig(f'{SAVE_DIR}/lr_schedule.png', dpi=150)
551
+ plt.close()
552
+
553
+ # Save metrics
554
+ pd.DataFrame([{
555
+ 'val_accuracy': 100*(all_p==all_t).mean(),
556
+ 'macro_f1': mf1, 'weighted_f1': wf1, 'macro_auc': mauc,
557
+ **{f'f1_{cn}': f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES))[ci]
558
+ for ci,cn in enumerate(CLASS_NAMES)}
559
+ }]).to_csv(f'{SAVE_DIR}/metrics.csv', index=False)
560
+
561
+ # Save history
562
+ with open(f'{SAVE_DIR}/history.json','w') as f:
563
+ json.dump({k:[float(v) for v in vs] for k,vs in history.items()}, f, indent=2)
564
+
565
+ print(f'\n{"="*65}')
566
+ print(f' RETINASENSE ViT -- FINAL RESULTS')
567
+ print(f'{"="*65}')
568
+ print(f' Best Macro F1 : {best_f1:.4f}')
569
+ print(f' Val Accuracy : {100*(all_p==all_t).mean():.2f}%')
570
+ print(f' Macro AUC : {mauc:.4f}')
571
+ print(f' Training Time : {total_train_time/60:.1f} minutes')
572
+ per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
573
+ for ci, cn in enumerate(CLASS_NAMES):
574
+ print(f' {cn:15s}: F1={per_f1[ci]:.3f}')
575
+ print(f'{"="*65}')
576
+ print(f'\n {SAVE_DIR}/')
577
+ print(f' -- best_model.pth')
578
+ print(f' -- dashboard.png')
579
+ print(f' -- lr_schedule.png')
580
+ print(f' -- metrics.csv')
581
+ print(f' -- history.json')