retinasense-vit / retinasense_v2_extended.py
tanishq74's picture
Add retinasense_v2_extended.py
eb2c80b verified
#!/usr/bin/env python3
"""
RetinaSense v2 Extended β€” 50 Epoch Training Pipeline
====================================================
Fixes from v1:
1. Focal Loss (handles class imbalance far better than weighted CE)
2. Stratified batch sampler (every batch sees all classes)
3. LR warmup + cosine decay (stable optimisation)
4. Gradient accumulation (effective batch 128, actual batch 32)
5. Early stopping on Macro F1 (not accuracy β€” misleading with imbalance)
6. Per-class metrics tracked every epoch
7. Pre-cached preprocessing (GPU efficiency)
8. Proper NaN handling in mixed precision
9. Comprehensive plots after training
"""
import os, sys, time, warnings, json
import numpy as np
import pandas as pd
import cv2
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from collections import Counter
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
classification_report, confusion_matrix,
roc_auc_score, f1_score, roc_curve, auc
)
from sklearn.preprocessing import label_binarize
# ═══════════════════════════════════════════════════════════
# CONFIG
# ═══════════════════════════════════════════════════════════
SAVE_DIR = './outputs_v2_extended'
CACHE_DIR = './preprocessed_cache'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
EPOCHS = 50
WARMUP_EPOCHS = 3 # heads-only warmup
LR_WARMUP_STEPS = 3 # linear warmup epochs after unfreeze
BATCH_SIZE = 32 # actual batch size (stable)
ACCUM_STEPS = 2 # gradient accumulation β†’ effective batch 64
NUM_WORKERS = 8
PATIENCE = 12 # early stopping on macro-F1 (extended)
FOCAL_GAMMA = 1.0 # reduced from 2.0 β€” less aggressive
IMG_SIZE = 300 # EfficientNet-B3 optimal input
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CLASS_NAMES = ['Normal', 'Diabetes/DR', 'Glaucoma', 'Cataract', 'AMD']
NUM_CLASSES = len(CLASS_NAMES)
print('='*65)
print(' RetinaSense v2 Extended β€” 50 Epochs Pipeline')
print('='*65)
if torch.cuda.is_available():
print(f' GPU : {torch.cuda.get_device_name(0)}')
print(f' VRAM : {round(torch.cuda.get_device_properties(0).total_memory/1e9,1)} GB')
print(f' Epochs : {EPOCHS}')
print(f' Batch : {BATCH_SIZE} (effective {BATCH_SIZE*ACCUM_STEPS} via grad accum)')
print(f' Image Size : {IMG_SIZE}')
print(f' Focal Loss Ξ³: {FOCAL_GAMMA} (mild β€” avoids over-correction)')
print(f' Early Stop : patience={PATIENCE} on macro-F1')
print('='*65)
# ═══════════════════════════════════════════════════════════
# 1 METADATA
# ═══════════════════════════════════════════════════════════
print('\n[1/7] Building metadata...')
BASE = './'
disease_cols = ['N','D','G','C','A']
label_map = {'N':0,'D':1,'G':2,'C':3,'A':4}
df_odir = pd.read_csv(f'{BASE}/odir/full_df.csv')
df_odir['disease_count'] = df_odir[disease_cols].sum(axis=1)
df_odir = df_odir[df_odir['disease_count']==1].copy()
def get_label(row):
for d in disease_cols:
if row[d]==1: return label_map[d]
df_odir['disease_label'] = df_odir.apply(get_label, axis=1)
img_col = next(c for c in df_odir.columns
if any(k in c.lower() for k in ['filename','fundus','image']))
odir_meta = pd.DataFrame({
'image_path': f'{BASE}/odir/preprocessed_images/'+df_odir[img_col].astype(str),
'dataset': 'ODIR',
'disease_label': df_odir['disease_label'],
'severity_label':-1
})
df_aptos = pd.read_csv(f'{BASE}/aptos/train.csv')
aptos_meta = pd.DataFrame({
'image_path': f'{BASE}/aptos/train_images/'+df_aptos['id_code']+'.png',
'dataset': 'APTOS',
'disease_label': 1,
'severity_label':df_aptos['diagnosis']
})
meta = pd.concat([odir_meta, aptos_meta], ignore_index=True)
meta = meta[meta['image_path'].apply(os.path.exists)].reset_index(drop=True)
print(f' Total samples: {len(meta)}')
dist = meta['disease_label'].value_counts().sort_index()
for i,cnt in dist.items():
print(f' {CLASS_NAMES[i]:15s}: {cnt:4d} ({100*cnt/len(meta):.1f}%)')
# ═══════════════════════════════════════════════════════════
# 2 PRE-CACHE
# ═══════════════════════════════════════════════════════════
print(f'\n[2/7] Pre-caching @ {IMG_SIZE}Γ—{IMG_SIZE}...')
def ben_graham(path, sz=IMG_SIZE, sigma=10):
img = cv2.imread(path)
if img is None:
img = np.array(Image.open(path).convert('RGB'))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (sz, sz))
img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img,(0,0),sigma), -4, 128)
mask = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.circle(mask, (sz//2, sz//2), int(sz*0.48), 255, -1)
return cv2.bitwise_and(img, img, mask=mask)
cache_paths = []
cached = 0
for _, row in tqdm(meta.iterrows(), total=len(meta), desc='Caching'):
stem = os.path.splitext(os.path.basename(row['image_path']))[0]
fp = f'{CACHE_DIR}/{stem}_{IMG_SIZE}.npy'
if not os.path.exists(fp):
try:
np.save(fp, ben_graham(row['image_path']))
except Exception:
np.save(fp, np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8))
cached += 1
cache_paths.append(fp)
meta['cache_path'] = cache_paths
print(f' Newly cached: {cached} | Already cached: {len(meta)-cached}')
# ═══════════════════════════════════════════════════════════
# 3 DATASET + LOADERS
# ═══════════════════════════════════════════════════════════
print('\n[3/7] Creating data loaders...')
train_df, val_df = train_test_split(
meta, test_size=0.2, stratify=meta['disease_label'], random_state=42)
def make_transforms(phase):
if phase == 'train':
return transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(20),
transforms.RandomAffine(degrees=0, translate=(0.05,0.05), scale=(0.95,1.05)),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
transforms.RandomErasing(p=0.2),
])
return transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
class RetDS(Dataset):
def __init__(self, df, tfm):
self.df = df.reset_index(drop=True)
self.tfm = tfm
def __len__(self): return len(self.df)
def __getitem__(self, i):
r = self.df.iloc[i]
try: img = np.load(r['cache_path'])
except: img = np.zeros((IMG_SIZE,IMG_SIZE,3), dtype=np.uint8)
return (self.tfm(img),
torch.tensor(int(r['disease_label']), dtype=torch.long),
torch.tensor(int(r['severity_label']), dtype=torch.long))
train_ds = RetDS(train_df, make_transforms('train'))
val_ds = RetDS(val_df, make_transforms('val'))
# Use shuffle (not WeightedRandomSampler β€” that over-corrects with focal loss)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, pin_memory=True,
persistent_workers=True, prefetch_factor=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True,
persistent_workers=True)
print(f' Train : {len(train_ds):5d} ({len(train_loader):3d} batches)')
print(f' Val : {len(val_ds):5d} ({len(val_loader):3d} batches)')
print(f' ⚑ Focal Loss + class weights handle imbalance (no oversampling)')
# ═══════════════════════════════════════════════════════════
# 4 MODEL + FOCAL LOSS
# ═══════════════════════════════════════════════════════════
print('\n[4/7] Building model...')
class FocalLoss(nn.Module):
"""Focal Loss β€” down-weights easy examples, focuses on hard ones."""
def __init__(self, alpha=None, gamma=2.0):
super().__init__()
self.gamma = gamma
if alpha is not None:
self.register_buffer('alpha', alpha)
else:
self.alpha = None
def forward(self, logits, targets):
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
focal = ((1 - pt) ** self.gamma) * ce
if self.alpha is not None:
at = self.alpha.gather(0, targets)
focal = at * focal
return focal.mean()
class MultiTaskModel(nn.Module):
def __init__(self, n_disease=5, n_severity=5, drop=0.4):
super().__init__()
bb = models.efficientnet_b3(weights='IMAGENET1K_V1')
self.backbone = nn.Sequential(*list(bb.children())[:-1])
feat = 1536
self.drop = nn.Dropout(drop)
self.disease_head = nn.Sequential(
nn.Linear(feat, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, n_disease))
self.severity_head = nn.Sequential(
nn.Linear(feat, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(256, n_severity))
def forward(self, x):
f = self.backbone(x).flatten(1)
f = self.drop(f)
return self.disease_head(f), self.severity_head(f)
model = MultiTaskModel().to(device)
# class-weight alpha for focal loss
cw = compute_class_weight('balanced', classes=np.arange(5), y=train_df['disease_label'].values)
alpha = torch.tensor(cw, dtype=torch.float32).to(device)
alpha = alpha / alpha.sum() * NUM_CLASSES # normalize
print(f' Focal Ξ±: {[f"{a:.2f}" for a in alpha.tolist()]}')
criterion_d = FocalLoss(alpha=alpha, gamma=FOCAL_GAMMA)
criterion_s = nn.CrossEntropyLoss(ignore_index=-1)
total_p = sum(p.numel() for p in model.parameters())
print(f' Params: {total_p:,}')
# ═══════════════════════════════════════════════════════════
# 5 TRAINING LOOP
# ═══════════════════════════════════════════════════════════
print('\n[5/7] Training...')
# freeze backbone first
for p in model.backbone.parameters():
p.requires_grad = False
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=3e-4, weight_decay=1e-3)
scaler = GradScaler()
def get_scheduler(opt, warmup_steps, total_steps):
"""Linear warmup then cosine decay."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / max(1, warmup_steps)
progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
return max(0.05, 0.5 * (1.0 + np.cos(np.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
CHECKPOINT = f'{SAVE_DIR}/best_model.pth'
history = {k:[] for k in [
'train_loss','val_loss','train_acc','val_acc',
'macro_f1','weighted_f1','lr',
*(f'f1_{c}' for c in CLASS_NAMES)
]}
best_f1 = 0.0
patience_ctr = 0
total_steps = EPOCHS * len(train_loader) // ACCUM_STEPS
sched = get_scheduler(optimizer, warmup_steps=len(train_loader)//ACCUM_STEPS, total_steps=total_steps)
print('='*65)
for epoch in range(EPOCHS):
t0 = time.time()
# ── Unfreeze backbone after warmup ──
if epoch == WARMUP_EPOCHS:
print('\n πŸ”“ Unfreezing backbone with LR warmup')
for p in model.backbone.parameters():
p.requires_grad = True
# new optimizer for full model with lower LR for backbone
optimizer = torch.optim.AdamW([
{'params': model.backbone.parameters(), 'lr': 1e-5},
{'params': model.disease_head.parameters(), 'lr': 1e-4},
{'params': model.severity_head.parameters(), 'lr': 1e-4},
], weight_decay=1e-3)
remaining = (EPOCHS - WARMUP_EPOCHS) * len(train_loader) // ACCUM_STEPS
sched = get_scheduler(optimizer,
warmup_steps=LR_WARMUP_STEPS * len(train_loader) // ACCUM_STEPS,
total_steps=remaining)
scaler = GradScaler()
# ── TRAIN ──
model.train()
run_loss = 0.0
correct = 0
total = 0
optimizer.zero_grad(set_to_none=True)
pbar = tqdm(train_loader, desc=f'E{epoch+1:02d}/{EPOCHS} train', leave=False)
for step, (imgs, d_lbl, s_lbl) in enumerate(pbar):
imgs = imgs.to(device, non_blocking=True)
d_lbl = d_lbl.to(device, non_blocking=True)
s_lbl = s_lbl.to(device, non_blocking=True)
with autocast('cuda'):
d_out, s_out = model(imgs)
loss_d = criterion_d(d_out, d_lbl)
loss_s = criterion_s(s_out, s_lbl)
loss = (loss_d + 0.2 * loss_s) / ACCUM_STEPS
# check for NaN
if torch.isnan(loss) or torch.isinf(loss):
optimizer.zero_grad(set_to_none=True)
continue
scaler.scale(loss).backward()
if (step + 1) % ACCUM_STEPS == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
sched.step()
run_loss += loss.item() * ACCUM_STEPS
preds = d_out.argmax(1)
correct += (preds == d_lbl).sum().item()
total += d_lbl.size(0)
pbar.set_postfix(loss=f'{loss.item()*ACCUM_STEPS:.3f}',
acc=f'{100*correct/total:.1f}%')
train_loss = run_loss / len(train_loader)
train_acc = 100 * correct / total
# ── VALIDATE ──
model.eval()
vl = 0.0
all_p, all_t, all_prob = [], [], []
with torch.no_grad():
for imgs, d_lbl, s_lbl in tqdm(val_loader, desc=f'E{epoch+1:02d}/{EPOCHS} val ', leave=False):
imgs = imgs.to(device, non_blocking=True)
d_lbl = d_lbl.to(device, non_blocking=True)
s_lbl = s_lbl.to(device, non_blocking=True)
with autocast('cuda'):
d_out, s_out = model(imgs)
ld = criterion_d(d_out, d_lbl)
ls = criterion_s(s_out, s_lbl)
loss = ld + 0.2 * ls
if not (torch.isnan(loss) or torch.isinf(loss)):
vl += loss.item()
probs = torch.softmax(d_out.float(), dim=1)
all_p.extend(d_out.argmax(1).cpu().numpy())
all_t.extend(d_lbl.cpu().numpy())
all_prob.extend(probs.cpu().numpy())
val_loss = vl / len(val_loader)
all_p, all_t, all_prob = np.array(all_p), np.array(all_t), np.array(all_prob)
val_acc = 100 * (all_p == all_t).mean()
mf1 = f1_score(all_t, all_p, average='macro')
wf1 = f1_score(all_t, all_p, average='weighted')
per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
lr = optimizer.param_groups[0]['lr']
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
history['macro_f1'].append(mf1)
history['weighted_f1'].append(wf1)
history['lr'].append(lr)
for ci, cn in enumerate(CLASS_NAMES):
history[f'f1_{cn}'].append(per_f1[ci])
elapsed = time.time() - t0
tag = ''
if mf1 > best_f1:
best_f1 = mf1
patience_ctr = 0
torch.save({
'epoch': epoch, 'model_state_dict': model.state_dict(),
'val_acc': val_acc, 'macro_f1': mf1, 'history': history
}, CHECKPOINT)
tag = f' β˜… NEW BEST (macro-F1={mf1:.4f})'
else:
patience_ctr += 1
cls_str = ' | '.join(f'{cn[:3]}:{per_f1[ci]:.2f}' for ci,cn in enumerate(CLASS_NAMES))
print(f'E{epoch+1:02d} | {elapsed:.0f}s | LR {lr:.1e} | '
f'TrL {train_loss:.3f} TrA {train_acc:.1f}% | '
f'VL {val_loss:.3f} VA {val_acc:.1f}% | '
f'mF1 {mf1:.3f} wF1 {wf1:.3f}{tag}')
print(f' {cls_str}')
if patience_ctr >= PATIENCE:
print(f'\n ⏹ Early stopping β€” no improvement for {PATIENCE} epochs')
break
print(f'\nβœ… Training done. Best macro-F1: {best_f1:.4f}')
# ═══════════════════════════════════════════════════════════
# 6 EVALUATION + PLOTS
# ═══════════════════════════════════════════════════════════
print('\n[6/7] Full evaluation...')
ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
history = ckpt['history']
all_p, all_t, all_prob = [], [], []
with torch.no_grad():
for imgs, d_lbl, _ in tqdm(val_loader, desc='Evaluating'):
imgs = imgs.to(device)
d_out, _ = model(imgs)
all_p.extend(d_out.argmax(1).cpu().numpy())
all_t.extend(d_lbl.numpy())
all_prob.extend(torch.softmax(d_out.float(), dim=1).cpu().numpy())
all_p = np.array(all_p)
all_t = np.array(all_t)
all_prob = np.array(all_prob)
print('\n' + '='*65)
print(' CLASSIFICATION REPORT')
print('='*65)
report = classification_report(all_t, all_p, target_names=CLASS_NAMES, digits=4)
print(report)
mf1 = f1_score(all_t, all_p, average='macro')
wf1 = f1_score(all_t, all_p, average='weighted')
try: mauc = roc_auc_score(all_t, all_prob, multi_class='ovr', average='macro')
except: mauc = 0.0
print(f'Macro F1 : {mf1:.4f}')
print(f'Weighted F1 : {wf1:.4f}')
print(f'Macro AUC : {mauc:.4f}')
# ═══════════════════════════════════════════════════════════
# 7 COMPREHENSIVE PLOTS
# ═══════════════════════════════════════════════════════════
print('\n[7/7] Generating plots...')
ep = range(1, len(history['train_loss'])+1)
colors = ['#2ecc71','#3498db','#e74c3c','#f39c12','#9b59b6']
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
# ── 1. Loss ──
axes[0,0].plot(ep, history['train_loss'], 'b-o', ms=4, label='Train')
axes[0,0].plot(ep, history['val_loss'], 'r-o', ms=4, label='Val')
axes[0,0].set_title('Loss', fontweight='bold')
axes[0,0].legend(); axes[0,0].grid(alpha=.3)
# ── 2. Accuracy ──
axes[0,1].plot(ep, history['train_acc'], 'b-o', ms=4, label='Train')
axes[0,1].plot(ep, history['val_acc'], 'r-o', ms=4, label='Val')
axes[0,1].set_title('Accuracy (%)', fontweight='bold')
axes[0,1].legend(); axes[0,1].grid(alpha=.3)
# ── 3. Macro / Weighted F1 ──
axes[0,2].plot(ep, history['macro_f1'], 'g-o', ms=4, label='Macro F1')
axes[0,2].plot(ep, history['weighted_f1'], 'm-o', ms=4, label='Weighted F1')
axes[0,2].set_title('F1 Scores', fontweight='bold')
axes[0,2].legend(); axes[0,2].grid(alpha=.3)
# ── 4. Per-class F1 ──
for ci, cn in enumerate(CLASS_NAMES):
axes[1,0].plot(ep, history[f'f1_{cn}'], '-o', ms=3, color=colors[ci], label=cn)
axes[1,0].set_title('Per-Class F1', fontweight='bold')
axes[1,0].legend(); axes[1,0].grid(alpha=.3)
# ── 5. Confusion Matrix ──
cm = confusion_matrix(all_t, all_p)
cm_n = cm.astype(float) / cm.sum(axis=1, keepdims=True)
sns.heatmap(cm_n, annot=True, fmt='.2f', cmap='Blues', ax=axes[1,1],
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
axes[1,1].set_title('Confusion Matrix (norm)', fontweight='bold')
axes[1,1].set_ylabel('True'); axes[1,1].set_xlabel('Pred')
# ── 6. ROC ──
y_bin = label_binarize(all_t, classes=list(range(NUM_CLASSES)))
for ci, (cn, col) in enumerate(zip(CLASS_NAMES, colors)):
fpr, tpr, _ = roc_curve(y_bin[:,ci], all_prob[:,ci])
axes[1,2].plot(fpr, tpr, color=col, lw=2, label=f'{cn} ({auc(fpr,tpr):.3f})')
axes[1,2].plot([0,1],[0,1],'k--',lw=1)
axes[1,2].set_title('ROC Curves', fontweight='bold')
axes[1,2].legend(loc='lower right', fontsize=8)
axes[1,2].grid(alpha=.3)
plt.suptitle(f'RetinaSense v2 Extended β€” Macro F1={mf1:.3f} | AUC={mauc:.3f} | Val Acc={100*(all_p==all_t).mean():.1f}%',
fontsize=15, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/dashboard.png', dpi=150, bbox_inches='tight')
plt.close()
# LR schedule plot
fig, ax = plt.subplots(figsize=(8,3))
ax.plot(ep, history['lr'], 'b-o', ms=3)
ax.set_title('Learning Rate Schedule', fontweight='bold')
ax.set_xlabel('Epoch'); ax.set_ylabel('LR')
ax.grid(alpha=.3)
plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/lr_schedule.png', dpi=150)
plt.close()
# Save metrics
pd.DataFrame([{
'val_accuracy': 100*(all_p==all_t).mean(),
'macro_f1': mf1, 'weighted_f1': wf1, 'macro_auc': mauc,
**{f'f1_{cn}': f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES))[ci]
for ci,cn in enumerate(CLASS_NAMES)}
}]).to_csv(f'{SAVE_DIR}/metrics.csv', index=False)
# Save history
with open(f'{SAVE_DIR}/history.json','w') as f:
json.dump({k:[float(v) for v in vs] for k,vs in history.items()}, f, indent=2)
print(f'\n{"="*65}')
print(f' RETINASENSE v2 EXTENDED β€” FINAL RESULTS')
print(f'{"="*65}')
print(f' Best Macro F1 : {best_f1:.4f}')
print(f' Val Accuracy : {100*(all_p==all_t).mean():.2f}%')
print(f' Macro AUC : {mauc:.4f}')
per_f1 = f1_score(all_t, all_p, average=None, labels=range(NUM_CLASSES), zero_division=0)
for ci, cn in enumerate(CLASS_NAMES):
print(f' {cn:15s}: F1={per_f1[ci]:.3f}')
print(f'{"="*65}')
print(f'\nπŸ“ {SAVE_DIR}/')
print(f' β”œβ”€β”€ best_model.pth')
print(f' β”œβ”€β”€ dashboard.png')
print(f' β”œβ”€β”€ lr_schedule.png')
print(f' β”œβ”€β”€ metrics.csv')
print(f' └── history.json')