| |
| """ |
| RetinaSense v3.0 — Phase 4A: Knowledge Distillation + ONNX Export |
| ================================================================== |
| Distills ViT-Base (86M, 331MB) → ViT-Tiny (5.7M, ~23MB) |
| Then exports to ONNX and quantizes to INT8 (~6MB) |
| |
| Target: Student retains >95% of teacher performance |
| Expected: 55x size reduction, ~4x inference speedup on CPU |
| """ |
|
|
| import os, json, time, warnings, numpy as np, pandas as pd |
| warnings.filterwarnings('ignore') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.amp import autocast, GradScaler |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
| from torchvision import transforms |
| import timm |
| from sklearn.metrics import f1_score, classification_report, roc_auc_score |
| from sklearn.preprocessing import label_binarize |
|
|
| OUTPUT_DIR = './outputs_v3/compressed' |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| IMG_SIZE = 224 |
| N_CLASSES = 5 |
| CLASS_NAMES= ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD'] |
| BATCH_SIZE = 64 |
| N_EPOCHS = 50 |
| PATIENCE = 10 |
| KD_ALPHA = 0.3 |
| KD_TEMP = 4.0 |
| BASE_LR = 3e-4 |
| NUM_WORKERS= 4 |
|
|
| with open('./data/fundus_norm_stats.json') as f: |
| ns = json.load(f) |
| NORM_MEAN, NORM_STD = ns['mean_rgb'], ns['std_rgb'] |
|
|
| with open('./outputs_v3/temperature.json') as f: |
| T_OPT = json.load(f)['temperature'] |
|
|
| print('=' * 65) |
| print(' RetinaSense v3.0 — Knowledge Distillation') |
| print('=' * 65) |
| print(f' Teacher : ViT-Base/16 (86M params, 331MB)') |
| print(f' Student : ViT-Tiny/16 (5.7M params, ~23MB)') |
| print(f' KD temp : {KD_TEMP} alpha: {KD_ALPHA}') |
| print(f' Device : {DEVICE}') |
|
|
|
|
| |
| |
| |
| class MultiTaskViT(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=0) |
| self.drop = nn.Dropout(0.3) |
| self.disease_head = nn.Sequential( |
| nn.Linear(768, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), |
| nn.Linear(256, N_CLASSES)) |
| self.severity_head = nn.Sequential( |
| nn.Linear(768, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), |
| nn.Linear(256, 5)) |
| def forward(self, x): |
| f = self.backbone(x); f = self.drop(f) |
| return self.disease_head(f), self.severity_head(f) |
|
|
| teacher = MultiTaskViT().to(DEVICE) |
| ckpt = torch.load('./outputs_v3/best_model.pth', map_location=DEVICE, weights_only=False) |
| teacher.load_state_dict(ckpt['model_state_dict']) |
| teacher.eval() |
| for p in teacher.parameters(): |
| p.requires_grad_(False) |
| print(f'\n Teacher loaded (epoch {ckpt["epoch"]+1}, F1={ckpt["macro_f1"]:.4f})') |
|
|
|
|
| |
| |
| |
| class StudentViT(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=0) |
| feat_dim = 192 |
| self.drop = nn.Dropout(0.2) |
| self.disease_head = nn.Sequential( |
| nn.Linear(feat_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2), |
| nn.Linear(256, N_CLASSES)) |
| def forward(self, x): |
| f = self.backbone(x) |
| f = self.drop(f) |
| return self.disease_head(f) |
|
|
| student = StudentViT().to(DEVICE) |
| student_params = sum(p.numel() for p in student.parameters()) |
| teacher_params = sum(p.numel() for p in teacher.parameters()) |
| print(f' Teacher params: {teacher_params:,}') |
| print(f' Student params: {student_params:,} ({100*student_params/teacher_params:.1f}% of teacher)') |
|
|
|
|
| |
| |
| |
| def _cache_key(path): |
| stem = os.path.splitext(os.path.basename(path))[0] |
| return os.path.join('./preprocessed_cache_v3', f'{stem}_{IMG_SIZE}.npy') |
|
|
| def make_transforms(phase): |
| norm = transforms.Normalize(NORM_MEAN, NORM_STD) |
| if phase == 'train': |
| return transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomVerticalFlip(p=0.3), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), |
| transforms.ToTensor(), norm]) |
| return transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), norm]) |
|
|
| class RetinalDataset(Dataset): |
| def __init__(self, df, transform): |
| self.df = df.reset_index(drop=True) |
| self.transform = transform |
| def __len__(self): return len(self.df) |
| def __getitem__(self, idx): |
| row = self.df.iloc[idx] |
| cache_fp = row.get('cache_path', _cache_key(row['image_path'])) |
| try: |
| img = np.load(cache_fp) |
| except Exception: |
| img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8) |
| return (self.transform(img), |
| torch.tensor(int(row['disease_label']), dtype=torch.long)) |
|
|
| train_df = pd.read_csv('./data/train_split.csv') |
| calib_df = pd.read_csv('./data/calib_split.csv') |
| test_df = pd.read_csv('./data/test_split.csv') |
|
|
| |
| kd_df = pd.concat([train_df, calib_df], ignore_index=True) |
|
|
| labels = kd_df['disease_label'].values |
| cnts = np.bincount(labels, minlength=N_CLASSES).astype(float) |
| cnts = np.where(cnts == 0, 1.0, cnts) |
| w = 1.0 / cnts[labels] |
| sampler = WeightedRandomSampler(torch.DoubleTensor(w), len(w), replacement=True) |
|
|
| train_ds = RetinalDataset(kd_df, make_transforms('train')) |
| test_ds = RetinalDataset(test_df, make_transforms('val')) |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, |
| num_workers=NUM_WORKERS, pin_memory=True, |
| persistent_workers=True, prefetch_factor=2) |
| test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=NUM_WORKERS, pin_memory=True, |
| persistent_workers=True) |
| print(f'\n KD train: {len(kd_df)} | Test: {len(test_df)}') |
|
|
|
|
| |
| |
| |
| def kd_loss(student_logits, teacher_logits, true_labels, alpha=KD_ALPHA, T=KD_TEMP): |
| """ |
| KD loss = alpha * CE(student, hard_labels) + (1-alpha) * KL(student/T, teacher/T) |
| The KL term encourages student to match teacher's soft probability distribution. |
| """ |
| ce_hard = F.cross_entropy(student_logits, true_labels) |
| student_soft = F.log_softmax(student_logits / T, dim=1) |
| teacher_soft = F.softmax(teacher_logits / T, dim=1) |
| kl = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (T ** 2) |
| return alpha * ce_hard + (1 - alpha) * kl |
|
|
| optimizer = torch.optim.AdamW(student.parameters(), lr=BASE_LR, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, max_lr=BASE_LR, steps_per_epoch=len(train_loader), |
| epochs=N_EPOCHS, pct_start=0.1, anneal_strategy='cos', |
| div_factor=10.0, final_div_factor=100.0) |
| scaler = GradScaler() |
|
|
| best_f1, best_state, patience_ctr = 0.0, None, 0 |
| print(f'\n Training student for {N_EPOCHS} epochs...\n') |
|
|
| for epoch in range(N_EPOCHS): |
| student.train() |
| run_loss = correct = total = 0 |
|
|
| for imgs, labels in train_loader: |
| imgs = imgs.to(DEVICE, non_blocking=True) |
| labels = labels.to(DEVICE, non_blocking=True) |
|
|
| with torch.no_grad(): |
| with autocast('cuda'): |
| t_logits, _ = teacher(imgs) |
|
|
| with autocast('cuda'): |
| s_logits = student(imgs) |
| loss = kd_loss(s_logits, t_logits, labels) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) |
| scaler.step(optimizer); scaler.update() |
| scheduler.step(); optimizer.zero_grad(set_to_none=True) |
|
|
| run_loss += loss.item() |
| correct += (s_logits.argmax(1) == labels).sum().item() |
| total += labels.size(0) |
|
|
| |
| student.eval() |
| all_preds, all_labels, all_probs = [], [], [] |
| with torch.no_grad(): |
| for imgs, lbl in test_loader: |
| imgs = imgs.to(DEVICE) |
| with autocast('cuda'): |
| logits = student(imgs) |
| probs = torch.softmax(logits.float() / T_OPT, dim=1) |
| all_preds.extend(logits.argmax(1).cpu().numpy()) |
| all_labels.extend(lbl.numpy()) |
| all_probs.extend(probs.cpu().numpy()) |
|
|
| preds = np.array(all_preds) |
| targets = np.array(all_labels) |
| mf1 = f1_score(targets, preds, average='macro') |
| acc = 100 * (preds == targets).mean() |
|
|
| tag = '' |
| if mf1 > best_f1 + 0.001: |
| best_f1 = mf1; patience_ctr = 0 |
| best_state = {k: v.cpu().clone() for k, v in student.state_dict().items()} |
| tag = f' * BEST' |
| else: |
| patience_ctr += 1 |
|
|
| if (epoch + 1) % 5 == 0 or tag: |
| print(f' E{epoch+1:02d} | Loss {run_loss/len(train_loader):.3f} | ' |
| f'Acc {acc:.1f}% | mF1 {mf1:.4f}{tag}') |
|
|
| if patience_ctr >= PATIENCE: |
| print(f' Early stop at epoch {epoch+1}') |
| break |
|
|
| |
| student_path = os.path.join(OUTPUT_DIR, 'student_vit_tiny.pth') |
| torch.save({'model_state_dict': best_state, 'macro_f1': best_f1, |
| 'architecture': 'vit_tiny_patch16_224'}, student_path) |
| print(f'\n Student saved -> {student_path} ({os.path.getsize(student_path)/1e6:.1f} MB)') |
|
|
|
|
| |
| |
| |
| print('\n Comparing Teacher vs Student on test set...') |
|
|
| student.load_state_dict(best_state) |
| student.eval() |
|
|
| def get_preds(mdl, loader, is_teacher=False): |
| all_p, all_l, all_pr = [], [], [] |
| with torch.no_grad(): |
| for imgs, lbl in loader: |
| imgs = imgs.to(DEVICE) |
| with autocast('cuda'): |
| out = mdl(imgs) |
| logits = out[0] if is_teacher else out |
| probs = torch.softmax(logits.float() / T_OPT, dim=1) |
| all_p.extend(logits.argmax(1).cpu().numpy()) |
| all_l.extend(lbl.numpy()) |
| all_pr.extend(probs.cpu().numpy()) |
| return np.array(all_p), np.array(all_l), np.array(all_pr) |
|
|
| t_preds, t_labels, t_probs = get_preds(teacher, test_loader, is_teacher=True) |
| s_preds, s_labels, s_probs = get_preds(student, test_loader, is_teacher=False) |
|
|
| def metrics(preds, labels, probs, name): |
| acc = 100 * (preds == labels).mean() |
| mf1 = f1_score(labels, preds, average='macro') |
| try: |
| mauc = roc_auc_score(label_binarize(labels, classes=range(N_CLASSES)), |
| probs, multi_class='ovr', average='macro') |
| except: |
| mauc = 0.0 |
| print(f'\n [{name}]') |
| print(f' Accuracy: {acc:.2f}% | Macro F1: {mf1:.4f} | AUC: {mauc:.4f}') |
| print(classification_report(labels, preds, target_names=CLASS_NAMES, digits=3)) |
| return {'accuracy': acc, 'macro_f1': mf1, 'macro_auc': mauc} |
|
|
| m_teacher = metrics(t_preds, t_labels, t_probs, 'Teacher (ViT-Base)') |
| m_student = metrics(s_preds, s_labels, s_probs, 'Student (ViT-Tiny)') |
| retention = m_student['macro_f1'] / m_teacher['macro_f1'] * 100 |
| print(f'\n Performance retention: {retention:.1f}%') |
| print(f' Size reduction: {os.path.getsize("./outputs_v3/best_model.pth")/1e6:.0f}MB → {os.path.getsize(student_path)/1e6:.0f}MB') |
|
|
|
|
| |
| |
| |
| print('\n Exporting student to ONNX...') |
| student.eval() |
| dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE) |
|
|
| onnx_path = os.path.join(OUTPUT_DIR, 'retinasense_student.onnx') |
| torch.onnx.export( |
| student, dummy, onnx_path, |
| export_params=True, |
| opset_version=17, |
| do_constant_folding=True, |
| input_names=['fundus_image'], |
| output_names=['disease_logits'], |
| dynamic_axes={ |
| 'fundus_image': {0: 'batch_size'}, |
| 'disease_logits': {0: 'batch_size'}, |
| } |
| ) |
| print(f' ONNX saved -> {onnx_path} ({os.path.getsize(onnx_path)/1e6:.1f} MB)') |
|
|
| |
| try: |
| import onnx |
| model_onnx = onnx.load(onnx_path) |
| onnx.checker.check_model(model_onnx) |
| print(' ONNX verification: PASSED') |
| except Exception as e: |
| print(f' ONNX verification warning: {e}') |
|
|
| |
| try: |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
| int8_path = os.path.join(OUTPUT_DIR, 'retinasense_student_int8.onnx') |
| quantize_dynamic(onnx_path, int8_path, weight_type=QuantType.QInt8) |
| print(f' INT8 model -> {int8_path} ({os.path.getsize(int8_path)/1e6:.1f} MB)') |
| except Exception as e: |
| print(f' INT8 quantization error: {e}') |
|
|
| |
| try: |
| import onnxruntime as ort |
| sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) |
| dummy_np = np.random.randn(1, 3, IMG_SIZE, IMG_SIZE).astype(np.float32) |
|
|
| |
| for _ in range(5): |
| sess.run(None, {'fundus_image': dummy_np}) |
|
|
| |
| n = 50 |
| t0 = time.time() |
| for _ in range(n): |
| sess.run(None, {'fundus_image': dummy_np}) |
| cpu_ms = (time.time() - t0) / n * 1000 |
| print(f'\n ONNX CPU inference: {cpu_ms:.1f} ms/image') |
| except Exception as e: |
| print(f' ONNX benchmark error: {e}') |
|
|
| |
| results = { |
| 'teacher': {'architecture': 'vit_base_patch16_224', 'params': teacher_params, |
| 'size_mb': os.path.getsize('./outputs_v3/best_model.pth')/1e6, |
| **m_teacher}, |
| 'student': {'architecture': 'vit_tiny_patch16_224', 'params': student_params, |
| 'size_mb': os.path.getsize(student_path)/1e6, **m_student}, |
| 'onnx_size_mb': os.path.getsize(onnx_path)/1e6, |
| 'performance_retention_pct': retention, |
| 'size_reduction_factor': os.path.getsize('./outputs_v3/best_model.pth') / os.path.getsize(student_path), |
| } |
| with open(os.path.join(OUTPUT_DIR, 'compression_results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
|
|
| print('\n' + '=' * 65) |
| print(' DISTILLATION + COMPRESSION COMPLETE') |
| print(f' Teacher F1 : {m_teacher["macro_f1"]:.4f}') |
| print(f' Student F1 : {m_student["macro_f1"]:.4f} ({retention:.1f}% retention)') |
| print(f' Size: {results["teacher"]["size_mb"]:.0f}MB → {results["student"]["size_mb"]:.0f}MB → {results["onnx_size_mb"]:.0f}MB (ONNX)') |
| print('=' * 65) |
|
|