#!/usr/bin/env python3 """ RetinaSense v3.0 — Phase 4A: Knowledge Distillation + ONNX Export ================================================================== Distills ViT-Base (86M, 331MB) → ViT-Tiny (5.7M, ~23MB) Then exports to ONNX and quantizes to INT8 (~6MB) Target: Student retains >95% of teacher performance Expected: 55x size reduction, ~4x inference speedup on CPU """ import os, json, time, warnings, numpy as np, pandas as pd warnings.filterwarnings('ignore') import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from torchvision import transforms import timm from sklearn.metrics import f1_score, classification_report, roc_auc_score from sklearn.preprocessing import label_binarize OUTPUT_DIR = './outputs_v3/compressed' os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') IMG_SIZE = 224 N_CLASSES = 5 CLASS_NAMES= ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] BATCH_SIZE = 64 N_EPOCHS = 50 PATIENCE = 10 KD_ALPHA = 0.3 # weight for hard label CE loss KD_TEMP = 4.0 # temperature for soft targets BASE_LR = 3e-4 NUM_WORKERS= 4 with open('./data/fundus_norm_stats.json') as f: ns = json.load(f) NORM_MEAN, NORM_STD = ns['mean_rgb'], ns['std_rgb'] with open('./outputs_v3/temperature.json') as f: T_OPT = json.load(f)['temperature'] print('=' * 65) print(' RetinaSense v3.0 — Knowledge Distillation') print('=' * 65) print(f' Teacher : ViT-Base/16 (86M params, 331MB)') print(f' Student : ViT-Tiny/16 (5.7M params, ~23MB)') print(f' KD temp : {KD_TEMP} alpha: {KD_ALPHA}') print(f' Device : {DEVICE}') # ================================================================ # TEACHER MODEL # ================================================================ class MultiTaskViT(nn.Module): def __init__(self): super().__init__() self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=0) self.drop = nn.Dropout(0.3) self.disease_head = nn.Sequential( nn.Linear(768, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, N_CLASSES)) self.severity_head = nn.Sequential( nn.Linear(768, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 5)) def forward(self, x): f = self.backbone(x); f = self.drop(f) return self.disease_head(f), self.severity_head(f) teacher = MultiTaskViT().to(DEVICE) ckpt = torch.load('./outputs_v3/best_model.pth', map_location=DEVICE, weights_only=False) teacher.load_state_dict(ckpt['model_state_dict']) teacher.eval() for p in teacher.parameters(): p.requires_grad_(False) print(f'\n Teacher loaded (epoch {ckpt["epoch"]+1}, F1={ckpt["macro_f1"]:.4f})') # ================================================================ # STUDENT MODEL (ViT-Tiny) # ================================================================ class StudentViT(nn.Module): def __init__(self): super().__init__() self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=0) feat_dim = 192 # ViT-Tiny hidden dim self.drop = nn.Dropout(0.2) self.disease_head = nn.Sequential( nn.Linear(feat_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, N_CLASSES)) def forward(self, x): f = self.backbone(x) f = self.drop(f) return self.disease_head(f) student = StudentViT().to(DEVICE) student_params = sum(p.numel() for p in student.parameters()) teacher_params = sum(p.numel() for p in teacher.parameters()) print(f' Teacher params: {teacher_params:,}') print(f' Student params: {student_params:,} ({100*student_params/teacher_params:.1f}% of teacher)') # ================================================================ # DATA # ================================================================ def _cache_key(path): stem = os.path.splitext(os.path.basename(path))[0] return os.path.join('./preprocessed_cache_v3', f'{stem}_{IMG_SIZE}.npy') def make_transforms(phase): norm = transforms.Normalize(NORM_MEAN, NORM_STD) if phase == 'train': return transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(p=0.3), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), norm]) return transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), norm]) class RetinalDataset(Dataset): def __init__(self, df, transform): self.df = df.reset_index(drop=True) self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] cache_fp = row.get('cache_path', _cache_key(row['image_path'])) try: img = np.load(cache_fp) except Exception: img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) return (self.transform(img), torch.tensor(int(row['disease_label']), dtype=torch.long)) train_df = pd.read_csv('./data/train_split.csv') calib_df = pd.read_csv('./data/calib_split.csv') test_df = pd.read_csv('./data/test_split.csv') # Use train+calib for distillation kd_df = pd.concat([train_df, calib_df], ignore_index=True) labels = kd_df['disease_label'].values cnts = np.bincount(labels, minlength=N_CLASSES).astype(float) cnts = np.where(cnts == 0, 1.0, cnts) w = 1.0 / cnts[labels] sampler = WeightedRandomSampler(torch.DoubleTensor(w), len(w), replacement=True) train_ds = RetinalDataset(kd_df, make_transforms('train')) test_ds = RetinalDataset(test_df, make_transforms('val')) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True, prefetch_factor=2) test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=True) print(f'\n KD train: {len(kd_df)} | Test: {len(test_df)}') # ================================================================ # KNOWLEDGE DISTILLATION TRAINING # ================================================================ def kd_loss(student_logits, teacher_logits, true_labels, alpha=KD_ALPHA, T=KD_TEMP): """ KD loss = alpha * CE(student, hard_labels) + (1-alpha) * KL(student/T, teacher/T) The KL term encourages student to match teacher's soft probability distribution. """ ce_hard = F.cross_entropy(student_logits, true_labels) student_soft = F.log_softmax(student_logits / T, dim=1) teacher_soft = F.softmax(teacher_logits / T, dim=1) kl = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (T ** 2) return alpha * ce_hard + (1 - alpha) * kl optimizer = torch.optim.AdamW(student.parameters(), lr=BASE_LR, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=BASE_LR, steps_per_epoch=len(train_loader), epochs=N_EPOCHS, pct_start=0.1, anneal_strategy='cos', div_factor=10.0, final_div_factor=100.0) scaler = GradScaler() best_f1, best_state, patience_ctr = 0.0, None, 0 print(f'\n Training student for {N_EPOCHS} epochs...\n') for epoch in range(N_EPOCHS): student.train() run_loss = correct = total = 0 for imgs, labels in train_loader: imgs = imgs.to(DEVICE, non_blocking=True) labels = labels.to(DEVICE, non_blocking=True) with torch.no_grad(): with autocast('cuda'): t_logits, _ = teacher(imgs) with autocast('cuda'): s_logits = student(imgs) loss = kd_loss(s_logits, t_logits, labels) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) scaler.step(optimizer); scaler.update() scheduler.step(); optimizer.zero_grad(set_to_none=True) run_loss += loss.item() correct += (s_logits.argmax(1) == labels).sum().item() total += labels.size(0) # Evaluate student.eval() all_preds, all_labels, all_probs = [], [], [] with torch.no_grad(): for imgs, lbl in test_loader: imgs = imgs.to(DEVICE) with autocast('cuda'): logits = student(imgs) probs = torch.softmax(logits.float() / T_OPT, dim=1) all_preds.extend(logits.argmax(1).cpu().numpy()) all_labels.extend(lbl.numpy()) all_probs.extend(probs.cpu().numpy()) preds = np.array(all_preds) targets = np.array(all_labels) mf1 = f1_score(targets, preds, average='macro') acc = 100 * (preds == targets).mean() tag = '' if mf1 > best_f1 + 0.001: best_f1 = mf1; patience_ctr = 0 best_state = {k: v.cpu().clone() for k, v in student.state_dict().items()} tag = f' * BEST' else: patience_ctr += 1 if (epoch + 1) % 5 == 0 or tag: print(f' E{epoch+1:02d} | Loss {run_loss/len(train_loader):.3f} | ' f'Acc {acc:.1f}% | mF1 {mf1:.4f}{tag}') if patience_ctr >= PATIENCE: print(f' Early stop at epoch {epoch+1}') break # Save best student student_path = os.path.join(OUTPUT_DIR, 'student_vit_tiny.pth') torch.save({'model_state_dict': best_state, 'macro_f1': best_f1, 'architecture': 'vit_tiny_patch16_224'}, student_path) print(f'\n Student saved -> {student_path} ({os.path.getsize(student_path)/1e6:.1f} MB)') # ================================================================ # COMPARE TEACHER VS STUDENT # ================================================================ print('\n Comparing Teacher vs Student on test set...') student.load_state_dict(best_state) student.eval() def get_preds(mdl, loader, is_teacher=False): all_p, all_l, all_pr = [], [], [] with torch.no_grad(): for imgs, lbl in loader: imgs = imgs.to(DEVICE) with autocast('cuda'): out = mdl(imgs) logits = out[0] if is_teacher else out probs = torch.softmax(logits.float() / T_OPT, dim=1) all_p.extend(logits.argmax(1).cpu().numpy()) all_l.extend(lbl.numpy()) all_pr.extend(probs.cpu().numpy()) return np.array(all_p), np.array(all_l), np.array(all_pr) t_preds, t_labels, t_probs = get_preds(teacher, test_loader, is_teacher=True) s_preds, s_labels, s_probs = get_preds(student, test_loader, is_teacher=False) def metrics(preds, labels, probs, name): acc = 100 * (preds == labels).mean() mf1 = f1_score(labels, preds, average='macro') try: mauc = roc_auc_score(label_binarize(labels, classes=range(N_CLASSES)), probs, multi_class='ovr', average='macro') except: mauc = 0.0 print(f'\n [{name}]') print(f' Accuracy: {acc:.2f}% | Macro F1: {mf1:.4f} | AUC: {mauc:.4f}') print(classification_report(labels, preds, target_names=CLASS_NAMES, digits=3)) return {'accuracy': acc, 'macro_f1': mf1, 'macro_auc': mauc} m_teacher = metrics(t_preds, t_labels, t_probs, 'Teacher (ViT-Base)') m_student = metrics(s_preds, s_labels, s_probs, 'Student (ViT-Tiny)') retention = m_student['macro_f1'] / m_teacher['macro_f1'] * 100 print(f'\n Performance retention: {retention:.1f}%') print(f' Size reduction: {os.path.getsize("./outputs_v3/best_model.pth")/1e6:.0f}MB → {os.path.getsize(student_path)/1e6:.0f}MB') # ================================================================ # ONNX EXPORT # ================================================================ print('\n Exporting student to ONNX...') student.eval() dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE) onnx_path = os.path.join(OUTPUT_DIR, 'retinasense_student.onnx') torch.onnx.export( student, dummy, onnx_path, export_params=True, opset_version=17, do_constant_folding=True, input_names=['fundus_image'], output_names=['disease_logits'], dynamic_axes={ 'fundus_image': {0: 'batch_size'}, 'disease_logits': {0: 'batch_size'}, } ) print(f' ONNX saved -> {onnx_path} ({os.path.getsize(onnx_path)/1e6:.1f} MB)') # Verify ONNX try: import onnx model_onnx = onnx.load(onnx_path) onnx.checker.check_model(model_onnx) print(' ONNX verification: PASSED') except Exception as e: print(f' ONNX verification warning: {e}') # INT8 quantization try: from onnxruntime.quantization import quantize_dynamic, QuantType int8_path = os.path.join(OUTPUT_DIR, 'retinasense_student_int8.onnx') quantize_dynamic(onnx_path, int8_path, weight_type=QuantType.QInt8) print(f' INT8 model -> {int8_path} ({os.path.getsize(int8_path)/1e6:.1f} MB)') except Exception as e: print(f' INT8 quantization error: {e}') # ONNX runtime benchmark try: import onnxruntime as ort sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) dummy_np = np.random.randn(1, 3, IMG_SIZE, IMG_SIZE).astype(np.float32) # Warmup for _ in range(5): sess.run(None, {'fundus_image': dummy_np}) # Benchmark n = 50 t0 = time.time() for _ in range(n): sess.run(None, {'fundus_image': dummy_np}) cpu_ms = (time.time() - t0) / n * 1000 print(f'\n ONNX CPU inference: {cpu_ms:.1f} ms/image') except Exception as e: print(f' ONNX benchmark error: {e}') # Save summary results = { 'teacher': {'architecture': 'vit_base_patch16_224', 'params': teacher_params, 'size_mb': os.path.getsize('./outputs_v3/best_model.pth')/1e6, **m_teacher}, 'student': {'architecture': 'vit_tiny_patch16_224', 'params': student_params, 'size_mb': os.path.getsize(student_path)/1e6, **m_student}, 'onnx_size_mb': os.path.getsize(onnx_path)/1e6, 'performance_retention_pct': retention, 'size_reduction_factor': os.path.getsize('./outputs_v3/best_model.pth') / os.path.getsize(student_path), } with open(os.path.join(OUTPUT_DIR, 'compression_results.json'), 'w') as f: json.dump(results, f, indent=2) print('\n' + '=' * 65) print(' DISTILLATION + COMPRESSION COMPLETE') print(f' Teacher F1 : {m_teacher["macro_f1"]:.4f}') print(f' Student F1 : {m_student["macro_f1"]:.4f} ({retention:.1f}% retention)') print(f' Size: {results["teacher"]["size_mb"]:.0f}MB → {results["student"]["size_mb"]:.0f}MB → {results["onnx_size_mb"]:.0f}MB (ONNX)') print('=' * 65)