Add retinasense_vit.py
Browse files- retinasense_vit.py +581 -0
retinasense_vit.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RetinaSense ViT — Vision Transformer Variant
|
| 4 |
+
=============================================
|
| 5 |
+
Based on RetinaSense v2 pipeline, replacing EfficientNet-B3 with
|
| 6 |
+
ViT-Base-Patch16-224 from the timm library.
|
| 7 |
+
|
| 8 |
+
Key changes from v2:
|
| 9 |
+
- Backbone: ViT-Base-Patch16-224 (timm) instead of EfficientNet-B3
|
| 10 |
+
- Feature dimension: 768 (ViT) instead of 1536 (EfficientNet-B3)
|
| 11 |
+
- Image size: 224x224 (ViT native) instead of 300x300
|
| 12 |
+
- EPOCHS: 30, PATIENCE: 10
|
| 13 |
+
- Output directory: ./outputs_vit
|
| 14 |
+
|
| 15 |
+
Everything else preserved from v2:
|
| 16 |
+
- Focal Loss with class weights
|
| 17 |
+
- Multi-task architecture (disease + severity heads)
|
| 18 |
+
- LR warmup + cosine decay
|
| 19 |
+
- Gradient accumulation
|
| 20 |
+
- Early stopping on Macro F1
|
| 21 |
+
- Pre-cached preprocessing
|
| 22 |
+
- Comprehensive plots
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os, sys, time, warnings, json
|
| 26 |
+
import numpy as np
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import cv2
|
| 29 |
+
import matplotlib
|
| 30 |
+
matplotlib.use('Agg')
|
| 31 |
+
import matplotlib.pyplot as plt
|
| 32 |
+
import seaborn as sns
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
from collections import Counter
|
| 36 |
+
warnings.filterwarnings('ignore')
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn as nn
|
| 40 |
+
import torch.nn.functional as F
|
| 41 |
+
from torch.amp import autocast, GradScaler
|
| 42 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 43 |
+
from torchvision import transforms
|
| 44 |
+
|
| 45 |
+
import timm
|
| 46 |
+
|
| 47 |
+
from sklearn.model_selection import train_test_split
|
| 48 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 49 |
+
from sklearn.metrics import (
|
| 50 |
+
classification_report, confusion_matrix,
|
| 51 |
+
roc_auc_score, f1_score, roc_curve, auc
|
| 52 |
+
)
|
| 53 |
+
from sklearn.preprocessing import label_binarize
|
| 54 |
+
|
| 55 |
+
# ===============================================================
|
| 56 |
+
# CONFIG
|
| 57 |
+
# ===============================================================
|
| 58 |
+
SAVE_DIR = './outputs_vit'
|
| 59 |
+
CACHE_DIR = './preprocessed_cache_vit'
|
| 60 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 61 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
EPOCHS = 30
|
| 64 |
+
WARMUP_EPOCHS = 3 # heads-only warmup
|
| 65 |
+
LR_WARMUP_STEPS = 3 # linear warmup epochs after unfreeze
|
| 66 |
+
BATCH_SIZE = 32 # actual batch size
|
| 67 |
+
ACCUM_STEPS = 2 # gradient accumulation -> effective batch 64
|
| 68 |
+
NUM_WORKERS = 8
|
| 69 |
+
PATIENCE = 10 # early stopping on macro-F1
|
| 70 |
+
FOCAL_GAMMA = 1.0 # reduced from 2.0 -- less aggressive
|
| 71 |
+
IMG_SIZE = 224 # ViT native resolution
|
| 72 |
+
|
| 73 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 74 |
+
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
|
| 75 |
+
NUM_CLASSES = len(CLASS_NAMES)
|
| 76 |
+
|
| 77 |
+
print('='*65)
|
| 78 |
+
print(' RetinaSense ViT -- Vision Transformer Pipeline')
|
| 79 |
+
print('='*65)
|
| 80 |
+
if torch.cuda.is_available():
|
| 81 |
+
print(f' GPU : {torch.cuda.get_device_name(0)}')
|
| 82 |
+
print(f' VRAM : {round(torch.cuda.get_device_properties(0).total_memory/1e9,1)} GB')
|
| 83 |
+
print(f' Backbone : ViT-Base-Patch16-224 (timm)')
|
| 84 |
+
print(f' Epochs : {EPOCHS}')
|
| 85 |
+
print(f' Batch : {BATCH_SIZE} (effective {BATCH_SIZE*ACCUM_STEPS} via grad accum)')
|
| 86 |
+
print(f' Image Size : {IMG_SIZE}')
|
| 87 |
+
print(f' Focal Loss g: {FOCAL_GAMMA} (mild -- avoids over-correction)')
|
| 88 |
+
print(f' Early Stop : patience={PATIENCE} on macro-F1')
|
| 89 |
+
print('='*65)
|
| 90 |
+
|
| 91 |
+
# ===============================================================
|
| 92 |
+
# 1 METADATA
|
| 93 |
+
# ===============================================================
|
| 94 |
+
print('\n[1/7] Building metadata...')
|
| 95 |
+
BASE = './'
|
| 96 |
+
disease_cols = ['N','D','G','C','A']
|
| 97 |
+
label_map = {'N':0,'D':1,'G':2,'C':3,'A':4}
|
| 98 |
+
|
| 99 |
+
df_odir = pd.read_csv(f'{BASE}/odir/full_df.csv')
|
| 100 |
+
df_odir['disease_count'] = df_odir[disease_cols].sum(axis=1)
|
| 101 |
+
df_odir = df_odir[df_odir['disease_count']==1].copy()
|
| 102 |
+
def get_label(row):
|
| 103 |
+
for d in disease_cols:
|
| 104 |
+
if row[d]==1: return label_map[d]
|
| 105 |
+
df_odir['disease_label'] = df_odir.apply(get_label, axis=1)
|
| 106 |
+
|
| 107 |
+
img_col = next(c for c in df_odir.columns
|
| 108 |
+
if any(k in c.lower() for k in ['filename','fundus','image']))
|
| 109 |
+
|
| 110 |
+
odir_meta = pd.DataFrame({
|
| 111 |
+
'image_path': f'{BASE}/odir/preprocessed_images/'+df_odir[img_col].astype(str),
|
| 112 |
+
'dataset': 'ODIR',
|
| 113 |
+
'disease_label': df_odir['disease_label'],
|
| 114 |
+
'severity_label':-1
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
df_aptos = pd.read_csv(f'{BASE}/aptos/train.csv')
|
| 118 |
+
aptos_meta = pd.DataFrame({
|
| 119 |
+
'image_path': f'{BASE}/aptos/train_images/'+df_aptos['id_code']+'.png',
|
| 120 |
+
'dataset': 'APTOS',
|
| 121 |
+
'disease_label': 1,
|
| 122 |
+
'severity_label':df_aptos['diagnosis']
|
| 123 |
+
})
|
| 124 |
+
|
| 125 |
+
meta = pd.concat([odir_meta, aptos_meta], ignore_index=True)
|
| 126 |
+
meta = meta[meta['image_path'].apply(os.path.exists)].reset_index(drop=True)
|
| 127 |
+
print(f' Total samples: {len(meta)}')
|
| 128 |
+
dist = meta['disease_label'].value_counts().sort_index()
|
| 129 |
+
for i,cnt in dist.items():
|
| 130 |
+
print(f' {CLASS_NAMES[i]:15s}: {cnt:4d} ({100*cnt/len(meta):.1f}%)')
|
| 131 |
+
|
| 132 |
+
# ===============================================================
|
| 133 |
+
# 2 PRE-CACHE
|
| 134 |
+
# ===============================================================
|
| 135 |
+
print(f'\n[2/7] Pre-caching @ {IMG_SIZE}x{IMG_SIZE}...')
|
| 136 |
+
|
| 137 |
+
def ben_graham(path, sz=IMG_SIZE, sigma=10):
|
| 138 |
+
img = cv2.imread(path)
|
| 139 |
+
if img is None:
|
| 140 |
+
img = np.array(Image.open(path).convert('RGB'))
|
| 141 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 142 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 143 |
+
img = cv2.resize(img, (sz, sz))
|
| 144 |
+
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img,(0,0),sigma), -4, 128)
|
| 145 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
| 146 |
+
cv2.circle(mask, (sz//2, sz//2), int(sz*0.48), 255, -1)
|
| 147 |
+
return cv2.bitwise_and(img, img, mask=mask)
|
| 148 |
+
|
| 149 |
+
cache_paths = []
|
| 150 |
+
cached = 0
|
| 151 |
+
for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Caching'):
|
| 152 |
+
stem = os.path.splitext(os.path.basename(row['image_path']))[0]
|
| 153 |
+
fp = f'{CACHE_DIR}/{stem}_{IMG_SIZE}.npy'
|
| 154 |
+
if not os.path.exists(fp):
|
| 155 |
+
try:
|
| 156 |
+
np.save(fp, ben_graham(row['image_path']))
|
| 157 |
+
except Exception:
|
| 158 |
+
np.save(fp, np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8))
|
| 159 |
+
cached += 1
|
| 160 |
+
cache_paths.append(fp)
|
| 161 |
+
meta['cache_path'] = cache_paths
|
| 162 |
+
print(f' Newly cached: {cached} | Already cached: {len(meta)-cached}')
|
| 163 |
+
|
| 164 |
+
# ===============================================================
|
| 165 |
+
# 3 DATASET + LOADERS
|
| 166 |
+
# ===============================================================
|
| 167 |
+
print('\n[3/7] Creating data loaders...')
|
| 168 |
+
|
| 169 |
+
train_df, val_df = train_test_split(
|
| 170 |
+
meta, test_size=0.2, stratify=meta['disease_label'], random_state=42)
|
| 171 |
+
|
| 172 |
+
def make_transforms(phase):
|
| 173 |
+
if phase == 'train':
|
| 174 |
+
return transforms.Compose([
|
| 175 |
+
transforms.ToPILImage(),
|
| 176 |
+
transforms.RandomHorizontalFlip(),
|
| 177 |
+
transforms.RandomVerticalFlip(p=0.3),
|
| 178 |
+
transforms.RandomRotation(20),
|
| 179 |
+
transforms.RandomAffine(degrees=0, translate=(0.05,0.05), scale=(0.95,1.05)),
|
| 180 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
|
| 181 |
+
transforms.ToTensor(),
|
| 182 |
+
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
|
| 183 |
+
transforms.RandomErasing(p=0.2),
|
| 184 |
+
])
|
| 185 |
+
return transforms.Compose([
|
| 186 |
+
transforms.ToPILImage(),
|
| 187 |
+
transforms.ToTensor(),
|
| 188 |
+
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
|
| 189 |
+
])
|
| 190 |
+
|
| 191 |
+
class RetDS(Dataset):
|
| 192 |
+
def __init__(self, df, tfm):
|
| 193 |
+
self.df = df.reset_index(drop=True)
|
| 194 |
+
self.tfm = tfm
|
| 195 |
+
def __len__(self): return len(self.df)
|
| 196 |
+
def __getitem__(self, i):
|
| 197 |
+
r = self.df.iloc[i]
|
| 198 |
+
try: img = np.load(r['cache_path'])
|
| 199 |
+
except: img = np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8)
|
| 200 |
+
return (self.tfm(img),
|
| 201 |
+
torch.tensor(int(r['disease_label']), dtype=torch.long),
|
| 202 |
+
torch.tensor(int(r['severity_label']), dtype=torch.long))
|
| 203 |
+
|
| 204 |
+
train_ds = RetDS(train_df, make_transforms('train'))
|
| 205 |
+
val_ds = RetDS(val_df, make_transforms('val'))
|
| 206 |
+
|
| 207 |
+
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
|
| 208 |
+
num_workers=NUM_WORKERS, pin_memory=True,
|
| 209 |
+
persistent_workers=True, prefetch_factor=2)
|
| 210 |
+
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
|
| 211 |
+
num_workers=NUM_WORKERS, pin_memory=True,
|
| 212 |
+
persistent_workers=True)
|
| 213 |
+
|
| 214 |
+
print(f' Train : {len(train_ds):5d} ({len(train_loader):3d} batches)')
|
| 215 |
+
print(f' Val : {len(val_ds):5d} ({len(val_loader):3d} batches)')
|
| 216 |
+
print(f' Focal Loss + class weights handle imbalance (no oversampling)')
|
| 217 |
+
|
| 218 |
+
# ===============================================================
|
| 219 |
+
# 4 MODEL + FOCAL LOSS
|
| 220 |
+
# ===============================================================
|
| 221 |
+
print('\n[4/7] Building ViT model...')
|
| 222 |
+
|
| 223 |
+
class FocalLoss(nn.Module):
|
| 224 |
+
"""Focal Loss -- down-weights easy examples, focuses on hard ones."""
|
| 225 |
+
def __init__(self, alpha=None, gamma=2.0):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.gamma = gamma
|
| 228 |
+
if alpha is not None:
|
| 229 |
+
self.register_buffer('alpha', alpha)
|
| 230 |
+
else:
|
| 231 |
+
self.alpha = None
|
| 232 |
+
|
| 233 |
+
def forward(self, logits, targets):
|
| 234 |
+
ce = F.cross_entropy(logits, targets, reduction='none')
|
| 235 |
+
pt = torch.exp(-ce)
|
| 236 |
+
focal = ((1 - pt) ** self.gamma) * ce
|
| 237 |
+
if self.alpha is not None:
|
| 238 |
+
at = self.alpha.gather(0, targets)
|
| 239 |
+
focal = at * focal
|
| 240 |
+
return focal.mean()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class MultiTaskViT(nn.Module):
|
| 244 |
+
def __init__(self, n_disease=5, n_severity=5, drop=0.4):
|
| 245 |
+
super().__init__()
|
| 246 |
+
# ViT-Base-Patch16-224: outputs 768-dim features
|
| 247 |
+
self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
|
| 248 |
+
feat = 768
|
| 249 |
+
self.drop = nn.Dropout(drop)
|
| 250 |
+
self.disease_head = nn.Sequential(
|
| 251 |
+
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
|
| 252 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
|
| 253 |
+
nn.Linear(256, n_disease))
|
| 254 |
+
self.severity_head = nn.Sequential(
|
| 255 |
+
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
|
| 256 |
+
nn.Linear(256, n_severity))
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
f = self.backbone(x) # timm ViT with num_classes=0 returns pooled features
|
| 260 |
+
f = self.drop(f)
|
| 261 |
+
return self.disease_head(f), self.severity_head(f)
|
| 262 |
+
|
| 263 |
+
model = MultiTaskViT().to(device)
|
| 264 |
+
|
| 265 |
+
# class-weight alpha for focal loss
|
| 266 |
+
cw = compute_class_weight('balanced', classes=np.arange(5), y=train_df['disease_label'].values)
|
| 267 |
+
alpha = torch.tensor(cw, dtype=torch.float32).to(device)
|
| 268 |
+
alpha = alpha / alpha.sum() * NUM_CLASSES # normalize
|
| 269 |
+
print(f' Focal a: {[f"{a:.2f}" for a in alpha.tolist()]}')
|
| 270 |
+
|
| 271 |
+
criterion_d = FocalLoss(alpha=alpha, gamma=FOCAL_GAMMA)
|
| 272 |
+
criterion_s = nn.CrossEntropyLoss(ignore_index=-1)
|
| 273 |
+
|
| 274 |
+
total_p = sum(p.numel() for p in model.parameters())
|
| 275 |
+
print(f' Params: {total_p:,}')
|
| 276 |
+
|
| 277 |
+
# ===============================================================
|
| 278 |
+
# 5 TRAINING LOOP
|
| 279 |
+
# ===============================================================
|
| 280 |
+
print('\n[5/7] Training...')
|
| 281 |
+
|
| 282 |
+
# freeze backbone first
|
| 283 |
+
for p in model.backbone.parameters():
|
| 284 |
+
p.requires_grad = False
|
| 285 |
+
|
| 286 |
+
optimizer = torch.optim.AdamW(
|
| 287 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 288 |
+
lr=3e-4, weight_decay=1e-3)
|
| 289 |
+
scaler = GradScaler()
|
| 290 |
+
|
| 291 |
+
def get_scheduler(opt, warmup_steps, total_steps):
|
| 292 |
+
"""Linear warmup then cosine decay."""
|
| 293 |
+
def lr_lambda(step):
|
| 294 |
+
if step < warmup_steps:
|
| 295 |
+
return float(step) / max(1, warmup_steps)
|
| 296 |
+
progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 297 |
+
return max(0.05, 0.5 * (1.0 + np.cos(np.pi * progress)))
|
| 298 |
+
return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
|
| 299 |
+
|
| 300 |
+
CHECKPOINT = f'{SAVE_DIR}/best_model.pth'
|
| 301 |
+
|
| 302 |
+
history = {k:[] for k in [
|
| 303 |
+
'train_loss','val_loss','train_acc','val_acc',
|
| 304 |
+
'macro_f1','weighted_f1','lr',
|
| 305 |
+
*(f'f1_{c}' for c in CLASS_NAMES)
|
| 306 |
+
]}
|
| 307 |
+
|
| 308 |
+
best_f1 = 0.0
|
| 309 |
+
patience_ctr = 0
|
| 310 |
+
total_steps = EPOCHS * len(train_loader) // ACCUM_STEPS
|
| 311 |
+
sched = get_scheduler(optimizer, warmup_steps=len(train_loader)//ACCUM_STEPS, total_steps=total_steps)
|
| 312 |
+
|
| 313 |
+
t_start = time.time()
|
| 314 |
+
print('='*65)
|
| 315 |
+
|
| 316 |
+
for epoch in range(EPOCHS):
|
| 317 |
+
t0 = time.time()
|
| 318 |
+
|
| 319 |
+
# -- Unfreeze backbone after warmup --
|
| 320 |
+
if epoch == WARMUP_EPOCHS:
|
| 321 |
+
print('\n Unfreezing ViT backbone with LR warmup')
|
| 322 |
+
for p in model.backbone.parameters():
|
| 323 |
+
p.requires_grad = True
|
| 324 |
+
# new optimizer for full model with lower LR for backbone
|
| 325 |
+
optimizer = torch.optim.AdamW([
|
| 326 |
+
{'params': model.backbone.parameters(), 'lr': 1e-5},
|
| 327 |
+
{'params': model.disease_head.parameters(), 'lr': 1e-4},
|
| 328 |
+
{'params': model.severity_head.parameters(), 'lr': 1e-4},
|
| 329 |
+
], weight_decay=1e-3)
|
| 330 |
+
remaining = (EPOCHS - WARMUP_EPOCHS) * len(train_loader) // ACCUM_STEPS
|
| 331 |
+
sched = get_scheduler(optimizer,
|
| 332 |
+
warmup_steps=LR_WARMUP_STEPS * len(train_loader) // ACCUM_STEPS,
|
| 333 |
+
total_steps=remaining)
|
| 334 |
+
scaler = GradScaler()
|
| 335 |
+
|
| 336 |
+
# -- TRAIN --
|
| 337 |
+
model.train()
|
| 338 |
+
run_loss = 0.0
|
| 339 |
+
correct = 0
|
| 340 |
+
total = 0
|
| 341 |
+
optimizer.zero_grad(set_to_none=True)
|
| 342 |
+
|
| 343 |
+
pbar = tqdm(train_loader, desc=f'E{epoch+1:02d}/{EPOCHS} train', leave=False)
|
| 344 |
+
for step, (imgs, d_lbl, s_lbl) in enumerate(pbar):
|
| 345 |
+
imgs = imgs.to(device, non_blocking=True)
|
| 346 |
+
d_lbl = d_lbl.to(device, non_blocking=True)
|
| 347 |
+
s_lbl = s_lbl.to(device, non_blocking=True)
|
| 348 |
+
|
| 349 |
+
with autocast('cuda'):
|
| 350 |
+
d_out, s_out = model(imgs)
|
| 351 |
+
loss_d = criterion_d(d_out, d_lbl)
|
| 352 |
+
loss_s = criterion_s(s_out, s_lbl)
|
| 353 |
+
loss = (loss_d + 0.2 * loss_s) / ACCUM_STEPS
|
| 354 |
+
|
| 355 |
+
# check for NaN
|
| 356 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 357 |
+
optimizer.zero_grad(set_to_none=True)
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
scaler.scale(loss).backward()
|
| 361 |
+
|
| 362 |
+
if (step + 1) % ACCUM_STEPS == 0:
|
| 363 |
+
scaler.unscale_(optimizer)
|
| 364 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 365 |
+
scaler.step(optimizer)
|
| 366 |
+
scaler.update()
|
| 367 |
+
optimizer.zero_grad(set_to_none=True)
|
| 368 |
+
sched.step()
|
| 369 |
+
|
| 370 |
+
run_loss += loss.item() * ACCUM_STEPS
|
| 371 |
+
preds = d_out.argmax(1)
|
| 372 |
+
correct += (preds == d_lbl).sum().item()
|
| 373 |
+
total += d_lbl.size(0)
|
| 374 |
+
pbar.set_postfix(loss=f'{loss.item()*ACCUM_STEPS:.3f}',
|
| 375 |
+
acc=f'{100*correct/total:.1f}%')
|
| 376 |
+
|
| 377 |
+
train_loss = run_loss / len(train_loader)
|
| 378 |
+
train_acc = 100 * correct / total
|
| 379 |
+
|
| 380 |
+
# -- VALIDATE --
|
| 381 |
+
model.eval()
|
| 382 |
+
vl = 0.0
|
| 383 |
+
all_p, all_t, all_prob = [], [], []
|
| 384 |
+
with torch.no_grad():
|
| 385 |
+
for imgs, d_lbl, s_lbl in tqdm(val_loader, desc=f'E{epoch+1:02d}/{EPOCHS} val ', leave=False):
|
| 386 |
+
imgs = imgs.to(device, non_blocking=True)
|
| 387 |
+
d_lbl = d_lbl.to(device, non_blocking=True)
|
| 388 |
+
s_lbl = s_lbl.to(device, non_blocking=True)
|
| 389 |
+
with autocast('cuda'):
|
| 390 |
+
d_out, s_out = model(imgs)
|
| 391 |
+
ld = criterion_d(d_out, d_lbl)
|
| 392 |
+
ls = criterion_s(s_out, s_lbl)
|
| 393 |
+
loss = ld + 0.2 * ls
|
| 394 |
+
if not (torch.isnan(loss) or torch.isinf(loss)):
|
| 395 |
+
vl += loss.item()
|
| 396 |
+
probs = torch.softmax(d_out.float(), dim=1)
|
| 397 |
+
all_p.extend(d_out.argmax(1).cpu().numpy())
|
| 398 |
+
all_t.extend(d_lbl.cpu().numpy())
|
| 399 |
+
all_prob.extend(probs.cpu().numpy())
|
| 400 |
+
|
| 401 |
+
val_loss = vl / len(val_loader)
|
| 402 |
+
all_p, all_t, all_prob = np.array(all_p), np.array(all_t), np.array(all_prob)
|
| 403 |
+
val_acc = 100 * (all_p == all_t).mean()
|
| 404 |
+
|
| 405 |
+
mf1 = f1_score(all_t, all_p, average='macro')
|
| 406 |
+
wf1 = f1_score(all_t, all_p, average='weighted')
|
| 407 |
+
per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
|
| 408 |
+
|
| 409 |
+
lr = optimizer.param_groups[0]['lr']
|
| 410 |
+
|
| 411 |
+
history['train_loss'].append(train_loss)
|
| 412 |
+
history['val_loss'].append(val_loss)
|
| 413 |
+
history['train_acc'].append(train_acc)
|
| 414 |
+
history['val_acc'].append(val_acc)
|
| 415 |
+
history['macro_f1'].append(mf1)
|
| 416 |
+
history['weighted_f1'].append(wf1)
|
| 417 |
+
history['lr'].append(lr)
|
| 418 |
+
for ci, cn in enumerate(CLASS_NAMES):
|
| 419 |
+
history[f'f1_{cn}'].append(per_f1[ci])
|
| 420 |
+
|
| 421 |
+
elapsed = time.time() - t0
|
| 422 |
+
|
| 423 |
+
tag = ''
|
| 424 |
+
if mf1 > best_f1:
|
| 425 |
+
best_f1 = mf1
|
| 426 |
+
patience_ctr = 0
|
| 427 |
+
torch.save({
|
| 428 |
+
'epoch': epoch, 'model_state_dict': model.state_dict(),
|
| 429 |
+
'val_acc': val_acc, 'macro_f1': mf1, 'history': history
|
| 430 |
+
}, CHECKPOINT)
|
| 431 |
+
tag = f' * NEW BEST (macro-F1={mf1:.4f})'
|
| 432 |
+
else:
|
| 433 |
+
patience_ctr += 1
|
| 434 |
+
|
| 435 |
+
cls_str = ' | '.join(f'{cn[:3]}:{per_f1[ci]:.2f}' for ci,cn in enumerate(CLASS_NAMES))
|
| 436 |
+
print(f'E{epoch+1:02d} | {elapsed:.0f}s | LR {lr:.1e} | '
|
| 437 |
+
f'TrL {train_loss:.3f} TrA {train_acc:.1f}% | '
|
| 438 |
+
f'VL {val_loss:.3f} VA {val_acc:.1f}% | '
|
| 439 |
+
f'mF1 {mf1:.3f} wF1 {wf1:.3f}{tag}')
|
| 440 |
+
print(f' {cls_str}')
|
| 441 |
+
|
| 442 |
+
if patience_ctr >= PATIENCE:
|
| 443 |
+
print(f'\n Early stopping -- no improvement for {PATIENCE} epochs')
|
| 444 |
+
break
|
| 445 |
+
|
| 446 |
+
total_train_time = time.time() - t_start
|
| 447 |
+
print(f'\nTraining done. Best macro-F1: {best_f1:.4f}')
|
| 448 |
+
print(f'Total training time: {total_train_time/60:.1f} minutes')
|
| 449 |
+
|
| 450 |
+
# ===============================================================
|
| 451 |
+
# 6 EVALUATION + PLOTS
|
| 452 |
+
# ===============================================================
|
| 453 |
+
print('\n[6/7] Full evaluation...')
|
| 454 |
+
|
| 455 |
+
ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=False)
|
| 456 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 457 |
+
model.eval()
|
| 458 |
+
history = ckpt['history']
|
| 459 |
+
|
| 460 |
+
all_p, all_t, all_prob = [], [], []
|
| 461 |
+
with torch.no_grad():
|
| 462 |
+
for imgs, d_lbl, _ in tqdm(val_loader, desc='Evaluating'):
|
| 463 |
+
imgs = imgs.to(device)
|
| 464 |
+
d_out, _ = model(imgs)
|
| 465 |
+
all_p.extend(d_out.argmax(1).cpu().numpy())
|
| 466 |
+
all_t.extend(d_lbl.numpy())
|
| 467 |
+
all_prob.extend(torch.softmax(d_out.float(), dim=1).cpu().numpy())
|
| 468 |
+
|
| 469 |
+
all_p = np.array(all_p)
|
| 470 |
+
all_t = np.array(all_t)
|
| 471 |
+
all_prob = np.array(all_prob)
|
| 472 |
+
|
| 473 |
+
print('\n' + '='*65)
|
| 474 |
+
print(' CLASSIFICATION REPORT')
|
| 475 |
+
print('='*65)
|
| 476 |
+
report = classification_report(all_t, all_p, target_names=CLASS_NAMES, digits=4)
|
| 477 |
+
print(report)
|
| 478 |
+
mf1 = f1_score(all_t, all_p, average='macro')
|
| 479 |
+
wf1 = f1_score(all_t, all_p, average='weighted')
|
| 480 |
+
try: mauc = roc_auc_score(all_t, all_prob, multi_class='ovr', average='macro')
|
| 481 |
+
except: mauc = 0.0
|
| 482 |
+
print(f'Macro F1 : {mf1:.4f}')
|
| 483 |
+
print(f'Weighted F1 : {wf1:.4f}')
|
| 484 |
+
print(f'Macro AUC : {mauc:.4f}')
|
| 485 |
+
|
| 486 |
+
# ===============================================================
|
| 487 |
+
# 7 COMPREHENSIVE PLOTS
|
| 488 |
+
# ===============================================================
|
| 489 |
+
print('\n[7/7] Generating plots...')
|
| 490 |
+
ep = range(1, len(history['train_loss'])+1)
|
| 491 |
+
colors = ['#2ecc71','#3498db','#e74c3c','#f39c12','#9b59b6']
|
| 492 |
+
|
| 493 |
+
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
|
| 494 |
+
|
| 495 |
+
# -- 1. Loss --
|
| 496 |
+
axes[0,0].plot(ep, history['train_loss'], 'b-o', ms=4, label='Train')
|
| 497 |
+
axes[0,0].plot(ep, history['val_loss'], 'r-o', ms=4, label='Val')
|
| 498 |
+
axes[0,0].set_title('Loss', fontweight='bold')
|
| 499 |
+
axes[0,0].legend(); axes[0,0].grid(alpha=.3)
|
| 500 |
+
|
| 501 |
+
# -- 2. Accuracy --
|
| 502 |
+
axes[0,1].plot(ep, history['train_acc'], 'b-o', ms=4, label='Train')
|
| 503 |
+
axes[0,1].plot(ep, history['val_acc'], 'r-o', ms=4, label='Val')
|
| 504 |
+
axes[0,1].set_title('Accuracy (%)', fontweight='bold')
|
| 505 |
+
axes[0,1].legend(); axes[0,1].grid(alpha=.3)
|
| 506 |
+
|
| 507 |
+
# -- 3. Macro / Weighted F1 --
|
| 508 |
+
axes[0,2].plot(ep, history['macro_f1'], 'g-o', ms=4, label='Macro F1')
|
| 509 |
+
axes[0,2].plot(ep, history['weighted_f1'], 'm-o', ms=4, label='Weighted F1')
|
| 510 |
+
axes[0,2].set_title('F1 Scores', fontweight='bold')
|
| 511 |
+
axes[0,2].legend(); axes[0,2].grid(alpha=.3)
|
| 512 |
+
|
| 513 |
+
# -- 4. Per-class F1 --
|
| 514 |
+
for ci, cn in enumerate(CLASS_NAMES):
|
| 515 |
+
axes[1,0].plot(ep, history[f'f1_{cn}'], '-o', ms=3, color=colors[ci], label=cn)
|
| 516 |
+
axes[1,0].set_title('Per-Class F1', fontweight='bold')
|
| 517 |
+
axes[1,0].legend(); axes[1,0].grid(alpha=.3)
|
| 518 |
+
|
| 519 |
+
# -- 5. Confusion Matrix --
|
| 520 |
+
cm = confusion_matrix(all_t, all_p)
|
| 521 |
+
cm_n = cm.astype(float) / cm.sum(axis=1, keepdims=True)
|
| 522 |
+
sns.heatmap(cm_n, annot=True, fmt='.2f', cmap='Blues', ax=axes[1,1],
|
| 523 |
+
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
|
| 524 |
+
axes[1,1].set_title('Confusion Matrix (norm)', fontweight='bold')
|
| 525 |
+
axes[1,1].set_ylabel('True'); axes[1,1].set_xlabel('Pred')
|
| 526 |
+
|
| 527 |
+
# -- 6. ROC --
|
| 528 |
+
y_bin = label_binarize(all_t, classes=list(range(NUM_CLASSES)))
|
| 529 |
+
for ci, (cn, col) in enumerate(zip(CLASS_NAMES, colors)):
|
| 530 |
+
fpr, tpr, _ = roc_curve(y_bin[:,ci], all_prob[:,ci])
|
| 531 |
+
axes[1,2].plot(fpr, tpr, color=col, lw=2, label=f'{cn} ({auc(fpr,tpr):.3f})')
|
| 532 |
+
axes[1,2].plot([0,1],[0,1],'k--',lw=1)
|
| 533 |
+
axes[1,2].set_title('ROC Curves', fontweight='bold')
|
| 534 |
+
axes[1,2].legend(loc='lower right', fontsize=8)
|
| 535 |
+
axes[1,2].grid(alpha=.3)
|
| 536 |
+
|
| 537 |
+
plt.suptitle(f'RetinaSense ViT -- Macro F1={mf1:.3f} | AUC={mauc:.3f} | Val Acc={100*(all_p==all_t).mean():.1f}%',
|
| 538 |
+
fontsize=15, fontweight='bold', y=1.01)
|
| 539 |
+
plt.tight_layout()
|
| 540 |
+
plt.savefig(f'{SAVE_DIR}/dashboard.png', dpi=150, bbox_inches='tight')
|
| 541 |
+
plt.close()
|
| 542 |
+
|
| 543 |
+
# LR schedule plot
|
| 544 |
+
fig, ax = plt.subplots(figsize=(8,3))
|
| 545 |
+
ax.plot(ep, history['lr'], 'b-o', ms=3)
|
| 546 |
+
ax.set_title('Learning Rate Schedule', fontweight='bold')
|
| 547 |
+
ax.set_xlabel('Epoch'); ax.set_ylabel('LR')
|
| 548 |
+
ax.grid(alpha=.3)
|
| 549 |
+
plt.tight_layout()
|
| 550 |
+
plt.savefig(f'{SAVE_DIR}/lr_schedule.png', dpi=150)
|
| 551 |
+
plt.close()
|
| 552 |
+
|
| 553 |
+
# Save metrics
|
| 554 |
+
pd.DataFrame([{
|
| 555 |
+
'val_accuracy': 100*(all_p==all_t).mean(),
|
| 556 |
+
'macro_f1': mf1, 'weighted_f1': wf1, 'macro_auc': mauc,
|
| 557 |
+
**{f'f1_{cn}': f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES))[ci]
|
| 558 |
+
for ci,cn in enumerate(CLASS_NAMES)}
|
| 559 |
+
}]).to_csv(f'{SAVE_DIR}/metrics.csv', index=False)
|
| 560 |
+
|
| 561 |
+
# Save history
|
| 562 |
+
with open(f'{SAVE_DIR}/history.json','w') as f:
|
| 563 |
+
json.dump({k:[float(v) for v in vs] for k,vs in history.items()}, f, indent=2)
|
| 564 |
+
|
| 565 |
+
print(f'\n{"="*65}')
|
| 566 |
+
print(f' RETINASENSE ViT -- FINAL RESULTS')
|
| 567 |
+
print(f'{"="*65}')
|
| 568 |
+
print(f' Best Macro F1 : {best_f1:.4f}')
|
| 569 |
+
print(f' Val Accuracy : {100*(all_p==all_t).mean():.2f}%')
|
| 570 |
+
print(f' Macro AUC : {mauc:.4f}')
|
| 571 |
+
print(f' Training Time : {total_train_time/60:.1f} minutes')
|
| 572 |
+
per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
|
| 573 |
+
for ci, cn in enumerate(CLASS_NAMES):
|
| 574 |
+
print(f' {cn:15s}: F1={per_f1[ci]:.3f}')
|
| 575 |
+
print(f'{"="*65}')
|
| 576 |
+
print(f'\n {SAVE_DIR}/')
|
| 577 |
+
print(f' -- best_model.pth')
|
| 578 |
+
print(f' -- dashboard.png')
|
| 579 |
+
print(f' -- lr_schedule.png')
|
| 580 |
+
print(f' -- metrics.csv')
|
| 581 |
+
print(f' -- history.json')
|