tanishq74 commited on
Commit
f78cecd
·
verified ·
1 Parent(s): 649e096

Add threshold_optimization.py

Browse files
Files changed (1) hide show
  1. threshold_optimization.py +523 -0
threshold_optimization.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Threshold Optimization for RetinaSense v2
4
+ ==========================================
5
+
6
+ Optimizes classification thresholds per class to maximize F1 scores.
7
+ Current model has AUC=0.91 but uses fixed argmax decision.
8
+ With class imbalance, per-class thresholds can significantly improve performance.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchvision.models as models
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import numpy as np
16
+ import pandas as pd
17
+ from pathlib import Path
18
+ import json
19
+ from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, roc_auc_score
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ from tqdm import tqdm
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
+
26
+ # Device
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ print(f"🔧 Using device: {device}")
29
+
30
+ # Paths
31
+ DATA_DIR = Path('./data')
32
+ CACHE_DIR = Path('./preprocessed_cache')
33
+ MODEL_PATH = Path('./outputs_v2/best_model.pth')
34
+ OUTPUT_DIR = Path('./outputs_v2')
35
+ OUTPUT_DIR.mkdir(exist_ok=True)
36
+
37
+ # Config
38
+ BATCH_SIZE = 64
39
+ NUM_WORKERS = 8
40
+ IMG_SIZE = 300
41
+
42
+ # Class names
43
+ DISEASE_CLASSES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
44
+
45
+
46
+ class CachedDataset(Dataset):
47
+ """Dataset that loads pre-cached preprocessed images"""
48
+ def __init__(self, csv_path, cache_dir, mode='train'):
49
+ self.cache_dir = Path(cache_dir)
50
+ self.mode = mode
51
+
52
+ # Load CSV
53
+ df = pd.read_csv(csv_path)
54
+
55
+ # Split train/val
56
+ val_size = int(0.15 * len(df))
57
+ if mode == 'train':
58
+ self.df = df.iloc[val_size:].reset_index(drop=True)
59
+ else:
60
+ self.df = df.iloc[:val_size].reset_index(drop=True)
61
+
62
+ print(f"📊 {mode.upper()} set: {len(self.df)} samples")
63
+
64
+ def __len__(self):
65
+ return len(self.df)
66
+
67
+ def __getitem__(self, idx):
68
+ row = self.df.iloc[idx]
69
+ img_id = row['image_id']
70
+
71
+ # Load cached image
72
+ cache_path = self.cache_dir / f"{img_id}.npy"
73
+ img = np.load(cache_path)
74
+
75
+ # Convert to tensor
76
+ img = torch.from_numpy(img).float()
77
+
78
+ # Labels
79
+ disease = int(row['disease_label'])
80
+ severity = int(row['severity_label']) if 'severity_label' in row else 0
81
+
82
+ return img, disease, severity, img_id
83
+
84
+
85
+ class MultiTaskModel(nn.Module):
86
+ """Multi-task model for disease classification + severity grading"""
87
+ def __init__(self, num_disease_classes=5, num_severity_classes=5, dropout=0.4):
88
+ super().__init__()
89
+
90
+ # Load EfficientNet-B3 backbone
91
+ backbone = models.efficientnet_b3(weights='IMAGENET1K_V1')
92
+ self.backbone = nn.Sequential(*list(backbone.children())[:-1])
93
+
94
+ # Feature dimension
95
+ self.feature_dim = 1536
96
+
97
+ # Global pooling and dropout
98
+ self.pool = nn.AdaptiveAvgPool2d(1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ # Disease classification head
102
+ self.disease_head = nn.Sequential(
103
+ nn.Linear(1536, 512),
104
+ nn.BatchNorm1d(512),
105
+ nn.ReLU(),
106
+ nn.Dropout(0.3),
107
+ nn.Linear(512, 256),
108
+ nn.BatchNorm1d(256),
109
+ nn.ReLU(),
110
+ nn.Dropout(0.2),
111
+ nn.Linear(256, num_disease_classes)
112
+ )
113
+
114
+ # Severity grading head (simpler than disease head)
115
+ self.severity_head = nn.Sequential(
116
+ nn.Linear(1536, 256),
117
+ nn.BatchNorm1d(256),
118
+ nn.ReLU(),
119
+ nn.Dropout(0.3),
120
+ nn.Linear(256, num_severity_classes)
121
+ )
122
+
123
+ def forward(self, x):
124
+ # Extract features
125
+ features = self.backbone(x)
126
+ features = self.pool(features)
127
+ features = features.flatten(1)
128
+ features = self.dropout(features)
129
+
130
+ # Predictions
131
+ disease_logits = self.disease_head(features)
132
+ severity_logits = self.severity_head(features)
133
+
134
+ return disease_logits, severity_logits
135
+
136
+
137
+ def load_model():
138
+ """Load trained model from checkpoint"""
139
+ print(f"📥 Loading model from {MODEL_PATH}")
140
+
141
+ model = MultiTaskModel(num_disease_classes=5, num_severity_classes=5, dropout=0.4)
142
+ checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
143
+ model.load_state_dict(checkpoint['model_state_dict'])
144
+ model = model.to(device)
145
+ model.eval()
146
+
147
+ epoch = checkpoint.get('epoch', 'unknown')
148
+ val_acc = checkpoint.get('val_acc', 0)
149
+ val_f1 = checkpoint.get('val_macro_f1', checkpoint.get('val_f1', 0))
150
+
151
+ print(f"✅ Loaded model from epoch {epoch}")
152
+ if val_acc > 0:
153
+ print(f" Val Acc: {val_acc:.2f}%, Macro F1: {val_f1:.3f}")
154
+
155
+ return model
156
+
157
+
158
+ def get_predictions(model, dataloader):
159
+ """Get all predictions and ground truth labels"""
160
+ print("🔮 Getting predictions on validation set...")
161
+
162
+ all_probs = []
163
+ all_labels = []
164
+ all_ids = []
165
+
166
+ with torch.no_grad():
167
+ for imgs, diseases, severities, img_ids in tqdm(dataloader, desc="Predicting"):
168
+ imgs = imgs.to(device, non_blocking=True)
169
+
170
+ # Get predictions
171
+ disease_logits, _ = model(imgs)
172
+ probs = torch.softmax(disease_logits, dim=1)
173
+
174
+ all_probs.append(probs.cpu().numpy())
175
+ all_labels.append(diseases.numpy())
176
+ all_ids.extend(img_ids)
177
+
178
+ all_probs = np.vstack(all_probs)
179
+ all_labels = np.concatenate(all_labels)
180
+
181
+ print(f"✅ Got predictions for {len(all_labels)} samples")
182
+ print(f" Probability shape: {all_probs.shape}")
183
+
184
+ return all_probs, all_labels, all_ids
185
+
186
+
187
+ def find_optimal_threshold_ovr(y_true, y_probs, class_idx):
188
+ """
189
+ Find optimal threshold for one-vs-rest using Youden's J statistic
190
+
191
+ Args:
192
+ y_true: Ground truth labels (n_samples,)
193
+ y_probs: Predicted probabilities for this class (n_samples,)
194
+ class_idx: Index of the class
195
+
196
+ Returns:
197
+ best_threshold, best_f1
198
+ """
199
+ # Convert to binary (one-vs-rest)
200
+ y_binary = (y_true == class_idx).astype(int)
201
+
202
+ # Try thresholds from 0.1 to 0.9
203
+ thresholds = np.arange(0.1, 0.91, 0.01)
204
+ best_f1 = 0
205
+ best_threshold = 0.5
206
+
207
+ for thresh in thresholds:
208
+ y_pred = (y_probs >= thresh).astype(int)
209
+
210
+ # Calculate F1 (handle zero division)
211
+ try:
212
+ f1 = f1_score(y_binary, y_pred, zero_division=0)
213
+ if f1 > best_f1:
214
+ best_f1 = f1
215
+ best_threshold = thresh
216
+ except:
217
+ continue
218
+
219
+ return best_threshold, best_f1
220
+
221
+
222
+ def optimize_thresholds(y_true, y_probs):
223
+ """
224
+ Optimize thresholds for all classes using one-vs-rest approach
225
+
226
+ Returns:
227
+ optimal_thresholds: dict mapping class_idx -> threshold
228
+ """
229
+ print("🎯 Optimizing thresholds per class...")
230
+
231
+ optimal_thresholds = {}
232
+
233
+ for class_idx in range(5):
234
+ class_name = DISEASE_CLASSES[class_idx]
235
+ class_probs = y_probs[:, class_idx]
236
+
237
+ # Find optimal threshold
238
+ best_thresh, best_f1 = find_optimal_threshold_ovr(y_true, class_probs, class_idx)
239
+
240
+ optimal_thresholds[class_idx] = best_thresh
241
+
242
+ # Count samples
243
+ n_samples = (y_true == class_idx).sum()
244
+
245
+ print(f" {class_name:15s}: threshold={best_thresh:.3f}, F1={best_f1:.3f}, n={n_samples}")
246
+
247
+ return optimal_thresholds
248
+
249
+
250
+ def predict_with_thresholds(y_probs, thresholds):
251
+ """
252
+ Make predictions using optimized thresholds
253
+
254
+ Strategy: For each sample, take the class with highest probability
255
+ if it exceeds its threshold. Otherwise, predict the most likely class.
256
+ """
257
+ n_samples = y_probs.shape[0]
258
+ predictions = np.zeros(n_samples, dtype=int)
259
+
260
+ for i in range(n_samples):
261
+ probs = y_probs[i]
262
+
263
+ # Get class with max probability
264
+ max_class = np.argmax(probs)
265
+ max_prob = probs[max_class]
266
+
267
+ # Check if it exceeds threshold
268
+ if max_prob >= thresholds[max_class]:
269
+ predictions[i] = max_class
270
+ else:
271
+ # Try other classes in order of probability
272
+ sorted_classes = np.argsort(probs)[::-1]
273
+ assigned = False
274
+ for cls in sorted_classes:
275
+ if probs[cls] >= thresholds[cls]:
276
+ predictions[i] = cls
277
+ assigned = True
278
+ break
279
+
280
+ # If no class exceeds threshold, fall back to max probability
281
+ if not assigned:
282
+ predictions[i] = max_class
283
+
284
+ return predictions
285
+
286
+
287
+ def evaluate(y_true, y_pred, y_probs, title="Evaluation"):
288
+ """Comprehensive evaluation with all metrics"""
289
+ print(f"\n{'='*50}")
290
+ print(f"{title}")
291
+ print(f"{'='*50}")
292
+
293
+ # Overall metrics
294
+ accuracy = (y_true == y_pred).mean() * 100
295
+ macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
296
+ weighted_f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
297
+
298
+ print(f"Accuracy: {accuracy:.2f}%")
299
+ print(f"Macro F1: {macro_f1:.3f}")
300
+ print(f"Weighted F1: {weighted_f1:.3f}")
301
+
302
+ # AUC-ROC
303
+ try:
304
+ auc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
305
+ print(f"Macro AUC-ROC: {auc:.3f}")
306
+ except:
307
+ auc = 0.0
308
+ print("AUC-ROC: N/A")
309
+
310
+ # Per-class metrics
311
+ print(f"\n{'Class':<15} {'F1':>6} {'Prec':>6} {'Rec':>6} {'Supp':>6}")
312
+ print("-" * 50)
313
+
314
+ f1_scores = f1_score(y_true, y_pred, average=None, zero_division=0)
315
+ precisions = precision_score(y_true, y_pred, average=None, zero_division=0)
316
+ recalls = recall_score(y_true, y_pred, average=None, zero_division=0)
317
+
318
+ per_class_results = {}
319
+
320
+ for i, class_name in enumerate(DISEASE_CLASSES):
321
+ support = (y_true == i).sum()
322
+ per_class_results[class_name] = {
323
+ 'f1': f1_scores[i],
324
+ 'precision': precisions[i],
325
+ 'recall': recalls[i],
326
+ 'support': int(support)
327
+ }
328
+ print(f"{class_name:<15} {f1_scores[i]:>6.3f} {precisions[i]:>6.3f} {recalls[i]:>6.3f} {support:>6d}")
329
+
330
+ return {
331
+ 'accuracy': accuracy,
332
+ 'macro_f1': macro_f1,
333
+ 'weighted_f1': weighted_f1,
334
+ 'auc': auc,
335
+ 'per_class': per_class_results,
336
+ 'confusion_matrix': confusion_matrix(y_true, y_pred).tolist()
337
+ }
338
+
339
+
340
+ def plot_comparison(results_baseline, results_optimized, optimal_thresholds, output_path):
341
+ """Plot before/after comparison"""
342
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
343
+
344
+ # F1 scores comparison
345
+ ax = axes[0, 0]
346
+ classes = DISEASE_CLASSES
347
+ baseline_f1 = [results_baseline['per_class'][c]['f1'] for c in classes]
348
+ optimized_f1 = [results_optimized['per_class'][c]['f1'] for c in classes]
349
+
350
+ x = np.arange(len(classes))
351
+ width = 0.35
352
+
353
+ ax.bar(x - width/2, baseline_f1, width, label='Baseline (argmax)', alpha=0.8)
354
+ ax.bar(x + width/2, optimized_f1, width, label='Optimized thresholds', alpha=0.8)
355
+
356
+ ax.set_ylabel('F1 Score')
357
+ ax.set_title('Per-Class F1 Score Comparison')
358
+ ax.set_xticks(x)
359
+ ax.set_xticklabels(classes, rotation=45, ha='right')
360
+ ax.legend()
361
+ ax.grid(axis='y', alpha=0.3)
362
+
363
+ # Overall metrics comparison
364
+ ax = axes[0, 1]
365
+ metrics = ['Accuracy', 'Macro F1', 'Weighted F1', 'AUC-ROC']
366
+ baseline_vals = [
367
+ results_baseline['accuracy']/100,
368
+ results_baseline['macro_f1'],
369
+ results_baseline['weighted_f1'],
370
+ results_baseline['auc']
371
+ ]
372
+ optimized_vals = [
373
+ results_optimized['accuracy']/100,
374
+ results_optimized['macro_f1'],
375
+ results_optimized['weighted_f1'],
376
+ results_optimized['auc']
377
+ ]
378
+
379
+ x = np.arange(len(metrics))
380
+ ax.bar(x - width/2, baseline_vals, width, label='Baseline', alpha=0.8)
381
+ ax.bar(x + width/2, optimized_vals, width, label='Optimized', alpha=0.8)
382
+
383
+ ax.set_ylabel('Score')
384
+ ax.set_title('Overall Metrics Comparison')
385
+ ax.set_xticks(x)
386
+ ax.set_xticklabels(metrics, rotation=45, ha='right')
387
+ ax.legend()
388
+ ax.set_ylim([0, 1])
389
+ ax.grid(axis='y', alpha=0.3)
390
+
391
+ # Optimal thresholds
392
+ ax = axes[1, 0]
393
+ thresholds_list = [optimal_thresholds[i] for i in range(5)]
394
+ bars = ax.bar(classes, thresholds_list, alpha=0.8, color='steelblue')
395
+
396
+ # Add default threshold line
397
+ ax.axhline(y=0.5, color='red', linestyle='--', label='Default (0.5)', alpha=0.5)
398
+
399
+ ax.set_ylabel('Optimal Threshold')
400
+ ax.set_title('Optimized Thresholds per Class')
401
+ ax.set_xticklabels(classes, rotation=45, ha='right')
402
+ ax.legend()
403
+ ax.set_ylim([0, 1])
404
+ ax.grid(axis='y', alpha=0.3)
405
+
406
+ # Add threshold values on bars
407
+ for bar, thresh in zip(bars, thresholds_list):
408
+ height = bar.get_height()
409
+ ax.text(bar.get_x() + bar.get_width()/2., height,
410
+ f'{thresh:.2f}',
411
+ ha='center', va='bottom', fontsize=9)
412
+
413
+ # Improvement heatmap
414
+ ax = axes[1, 1]
415
+ improvements = []
416
+ for class_name in classes:
417
+ baseline = results_baseline['per_class'][class_name]['f1']
418
+ optimized = results_optimized['per_class'][class_name]['f1']
419
+ improvement = optimized - baseline
420
+ improvements.append(improvement)
421
+
422
+ colors = ['red' if x < 0 else 'green' for x in improvements]
423
+ bars = ax.barh(classes, improvements, color=colors, alpha=0.7)
424
+
425
+ ax.axvline(x=0, color='black', linestyle='-', linewidth=0.8)
426
+ ax.set_xlabel('F1 Score Change')
427
+ ax.set_title('Per-Class F1 Improvement')
428
+ ax.grid(axis='x', alpha=0.3)
429
+
430
+ # Add values
431
+ for i, (bar, val) in enumerate(zip(bars, improvements)):
432
+ x_pos = val + (0.01 if val > 0 else -0.01)
433
+ ha = 'left' if val > 0 else 'right'
434
+ ax.text(x_pos, i, f'{val:+.3f}', va='center', ha=ha, fontsize=9)
435
+
436
+ plt.tight_layout()
437
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
438
+ print(f"📊 Comparison plot saved to {output_path}")
439
+
440
+
441
+ def main():
442
+ print("🎯 Threshold Optimization for RetinaSense v2")
443
+ print("=" * 50)
444
+
445
+ # Load model
446
+ model = load_model()
447
+
448
+ # Load validation data
449
+ val_dataset = CachedDataset(
450
+ csv_path=DATA_DIR / 'train_processed.csv',
451
+ cache_dir=CACHE_DIR,
452
+ mode='val'
453
+ )
454
+
455
+ val_loader = DataLoader(
456
+ val_dataset,
457
+ batch_size=BATCH_SIZE,
458
+ shuffle=False,
459
+ num_workers=NUM_WORKERS,
460
+ pin_memory=True,
461
+ persistent_workers=True
462
+ )
463
+
464
+ # Get predictions
465
+ y_probs, y_true, img_ids = get_predictions(model, val_loader)
466
+
467
+ # Baseline: argmax predictions
468
+ y_pred_baseline = np.argmax(y_probs, axis=1)
469
+
470
+ # Evaluate baseline
471
+ print("\n" + "="*50)
472
+ print("BASELINE EVALUATION (argmax)")
473
+ print("="*50)
474
+ results_baseline = evaluate(y_true, y_pred_baseline, y_probs, "Baseline")
475
+
476
+ # Optimize thresholds
477
+ print("\n" + "="*50)
478
+ print("THRESHOLD OPTIMIZATION")
479
+ print("="*50)
480
+ optimal_thresholds = optimize_thresholds(y_true, y_probs)
481
+
482
+ # Predict with optimized thresholds
483
+ y_pred_optimized = predict_with_thresholds(y_probs, optimal_thresholds)
484
+
485
+ # Evaluate optimized
486
+ print("\n" + "="*50)
487
+ print("OPTIMIZED EVALUATION")
488
+ print("="*50)
489
+ results_optimized = evaluate(y_true, y_pred_optimized, y_probs, "Optimized")
490
+
491
+ # Save results
492
+ results = {
493
+ 'optimal_thresholds': optimal_thresholds,
494
+ 'baseline': results_baseline,
495
+ 'optimized': results_optimized
496
+ }
497
+
498
+ output_json = OUTPUT_DIR / 'threshold_optimization_results.json'
499
+ with open(output_json, 'w') as f:
500
+ json.dump(results, f, indent=2)
501
+ print(f"\n✅ Results saved to {output_json}")
502
+
503
+ # Plot comparison
504
+ plot_path = OUTPUT_DIR / 'threshold_comparison.png'
505
+ plot_comparison(results_baseline, results_optimized, optimal_thresholds, plot_path)
506
+
507
+ # Summary
508
+ print("\n" + "="*50)
509
+ print("SUMMARY")
510
+ print("="*50)
511
+ print(f"Baseline Macro F1: {results_baseline['macro_f1']:.3f}")
512
+ print(f"Optimized Macro F1: {results_optimized['macro_f1']:.3f}")
513
+ print(f"Improvement: {results_optimized['macro_f1'] - results_baseline['macro_f1']:+.3f}")
514
+ print(f"\nBaseline Accuracy: {results_baseline['accuracy']:.2f}%")
515
+ print(f"Optimized Accuracy: {results_optimized['accuracy']:.2f}%")
516
+ print(f"Improvement: {results_optimized['accuracy'] - results_baseline['accuracy']:+.2f}%")
517
+
518
+ print("\n✅ Threshold optimization complete!")
519
+ print(f"📁 Results saved to {OUTPUT_DIR}/")
520
+
521
+
522
+ if __name__ == '__main__':
523
+ main()