""" AgriFM PASTIS - Training visualization and analysis. Generates all graphs, confusion matrix, prediction maps, and comparison plots. Usage: python visualize_results.py \ --work_dir ./work_dirs/fold1_v3 \ --data_root /workspace/project/PASTIS \ --fold 1 """ import os import sys import json import argparse import numpy as np import torch from pathlib import Path from torch.utils.data import DataLoader from torch.amp import autocast sys.path.insert(0, str(Path(__file__).parent)) import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.gridspec as gridspec from matplotlib.colors import ListedColormap import matplotlib.ticker as mticker from sklearn.metrics import confusion_matrix import seaborn as sns from models.agrifm import build_agrifm_pastis_small, build_agrifm_pastis_tiny, build_agrifm_pastis from datasets.pastis_dataset import PASTISDataset, PASTIS_CLASSES, IGNORE_INDEX from losses.loss import CropCELoss from evaluation.metrics import SegmentationMetrics # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CLASS_NAMES = [PASTIS_CLASSES[i] for i in range(20)] SHORT_NAMES = [ 'BG', 'Meadow', 'S.Wheat', 'Corn', 'W.Barley', 'W.Rape', 'Sp.Barley', 'Sunflwr', 'Grapevn', 'Beet', 'W.Trit', 'W.Durum', 'Fruits', 'Potato', 'Leg.Fod', 'Soybeans', 'Orchard', 'Mixed', 'Sorghum', 'Void' ] # Distinct color for each class CLASS_COLORS = [ '#000000', # 0 Background - black '#3cb371', # 1 Meadow - medium sea green '#ffd700', # 2 Soft wtr wheat - gold '#ff8c00', # 3 Corn - dark orange '#8b4513', # 4 Winter barley - saddle brown '#ff1493', # 5 Winter rapeseed- deep pink '#adff2f', # 6 Spring barley - green yellow '#ffff00', # 7 Sunflower - yellow '#800080', # 8 Grapevine - purple '#dc143c', # 9 Beet - crimson '#00bfff', # 10 Winter trit - deep sky blue '#daa520', # 11 Winter durum - goldenrod '#32cd32', # 12 Fruits/veg - lime green '#a0522d', # 13 Potatoes - sienna '#90ee90', # 14 Leg fodder - light green '#006400', # 15 Soybeans - dark green '#ff7f50', # 16 Orchard - coral '#87ceeb', # 17 Mixed cereal - sky blue '#bc8f8f', # 18 Sorghum - rosy brown '#808080', # 19 Void - gray ] CMAP = ListedColormap(CLASS_COLORS) def get_args(): p = argparse.ArgumentParser() p.add_argument('--work_dir', default='./work_dirs/fold1_v3') p.add_argument('--data_root', default='/workspace/project/PASTIS') p.add_argument('--fold', type=int, default=1) p.add_argument('--model_size', default='small', choices=['small','tiny','base']) p.add_argument('--num_classes',type=int, default=20) p.add_argument('--num_frames', type=int, default=32) p.add_argument('--batch_size', type=int, default=8) p.add_argument('--num_workers',type=int, default=4) p.add_argument('--n_samples', type=int, default=6, help='Number of sample prediction maps to show') p.add_argument('--out_dir', default=None, help='Output dir for plots (default: work_dir/plots)') return p.parse_args() # --------------------------------------------------------------------------- # 1. Training curves # --------------------------------------------------------------------------- def plot_training_curves(log_data, out_dir): print(" Plotting training curves...") epochs = [d['epoch'] for d in log_data] train_loss = [d['train_loss'] for d in log_data] val_loss = [d.get('val_loss', None) for d in log_data] mfscore = [d.get('mFscore', None) for d in log_data] miou = [d.get('mIoU', None) for d in log_data] oa = [d.get('OA', None) for d in log_data] kappa = [d.get('Kappa', None) for d in log_data] prec = [d.get('mPrecision', None) for d in log_data] rec = [d.get('mRecall', None) for d in log_data] fig, axes = plt.subplots(2, 3, figsize=(18, 10)) fig.suptitle('AgriFM × PASTIS — Training History', fontsize=16, fontweight='bold') # Loss curves ax = axes[0, 0] ax.plot(epochs, train_loss, 'b-', linewidth=2, label='Train Loss') vl = [v for v in val_loss if v is not None] ve = [e for e, v in zip(epochs, val_loss) if v is not None] ax.plot(ve, vl, 'r-', linewidth=2, label='Val Loss') ax.set_title('Loss Curves', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.legend() ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # mFscore ax = axes[0, 1] mf = [v for v in mfscore if v is not None] me = [e for e, v in zip(epochs, mfscore) if v is not None] ax.plot(me, mf, 'g-', linewidth=2, label='mFscore') best_epoch = me[mf.index(max(mf))] best_val = max(mf) ax.axhline(best_val, color='g', linestyle='--', alpha=0.5) ax.annotate(f'Best: {best_val:.1f}%\n@ epoch {best_epoch}', xy=(best_epoch, best_val), xytext=(best_epoch + len(me)*0.05, best_val - 5), fontsize=9, color='green', arrowprops=dict(arrowstyle='->', color='green', lw=1.5)) ax.set_title('mFscore (F1)', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('mFscore (%)') ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # mIoU ax = axes[0, 2] mi = [v for v in miou if v is not None] ax.plot(me, mi, 'orange', linewidth=2, label='mIoU') ax.set_title('Mean IoU', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('mIoU (%)') ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # OA + Kappa ax = axes[1, 0] oa_v = [v for v in oa if v is not None] ka_v = [v for v in kappa if v is not None] ax.plot(me, oa_v, 'purple', linewidth=2, label='OA') ax.plot(me, ka_v, 'brown', linewidth=2, label='Kappa') ax.set_title('Overall Accuracy & Kappa', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('%') ax.legend() ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # Precision vs Recall ax = axes[1, 1] pr_v = [v for v in prec if v is not None] re_v = [v for v in rec if v is not None] ax.plot(me, pr_v, 'teal', linewidth=2, label='mPrecision') ax.plot(me, re_v, 'salmon', linewidth=2, label='mRecall') ax.set_title('Precision vs Recall', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('%') ax.legend() ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # All metrics together ax = axes[1, 2] for vals, name, color in [ (mf, 'mFscore', 'green'), (mi, 'mIoU', 'orange'), (oa_v, 'OA', 'purple'), (ka_v, 'Kappa', 'brown'), ]: ax.plot(me, vals, linewidth=2, label=name, color=color) ax.set_title('All Metrics', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('%') ax.legend(fontsize=8) ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') plt.tight_layout() path = os.path.join(out_dir, '1_training_curves.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 2. Per-class IoU bar chart # --------------------------------------------------------------------------- def plot_per_class_iou(test_results, out_dir): print(" Plotting per-class IoU...") per_cls = test_results['per_class_iou'] # Sort by IoU descending items = [(k, v) for k, v in per_cls.items()] items = sorted(items, key=lambda x: -x[1]) names = [x[0] for x in items] values = [x[1] for x in items] colors = [CLASS_COLORS[CLASS_NAMES.index(n)] if n in CLASS_NAMES else '#888888' for n in names] # Fix background color colors = ['#444444' if c == '#000000' else c for c in colors] fig, ax = plt.subplots(figsize=(14, 7)) bars = ax.barh(range(len(names)), values, color=colors, edgecolor='white', linewidth=0.5) # Value labels for i, (bar, val) in enumerate(zip(bars, values)): ax.text(val + 0.5, i, f'{val:.1f}%', va='center', fontsize=9, fontweight='bold') # Color zones ax.axvline(50, color='gray', linestyle='--', alpha=0.5, linewidth=1) ax.axvline(25, color='gray', linestyle=':', alpha=0.5, linewidth=1) ax.set_yticks(range(len(names))) ax.set_yticklabels(names, fontsize=10) ax.set_xlabel('IoU (%)', fontsize=12) ax.set_title('Per-Class IoU — Test Set\n' f'(mIoU = {test_results["test_metrics"]["mIoU"]:.2f}% ' f'mFscore = {test_results["test_metrics"]["mFscore"]:.2f}%)', fontsize=13, fontweight='bold') ax.set_xlim(0, 100) ax.grid(True, axis='x', alpha=0.3) ax.set_facecolor('#f8f8f8') # Legend for zones ax.text(26, -1.2, '25%', fontsize=8, color='gray', ha='center') ax.text(51, -1.2, '50%', fontsize=8, color='gray', ha='center') plt.tight_layout() path = os.path.join(out_dir, '2_per_class_iou.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 3. Metrics summary radar chart # --------------------------------------------------------------------------- def plot_radar(test_results, out_dir): print(" Plotting radar chart...") metrics = test_results['test_metrics'] keys = ['OA', 'mIoU', 'mFscore', 'mPrecision', 'mRecall', 'Kappa'] values = [metrics[k] for k in keys] angles = np.linspace(0, 2*np.pi, len(keys), endpoint=False).tolist() values_ = values + [values[0]] angles_ = angles + [angles[0]] labels = keys + [keys[0]] fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) ax.plot(angles_, values_, 'o-', linewidth=2, color='#2196F3', markersize=8) ax.fill(angles_, values_, alpha=0.25, color='#2196F3') ax.set_xticks(angles) ax.set_xticklabels(keys, fontsize=12, fontweight='bold') ax.set_ylim(0, 100) ax.set_yticks([20, 40, 60, 80, 100]) ax.set_yticklabels(['20%','40%','60%','80%','100%'], fontsize=8, color='gray') ax.grid(color='gray', linestyle='--', linewidth=0.5, alpha=0.7) # Add value labels for angle, val, key in zip(angles, values, keys): ax.annotate(f'{val:.1f}%', xy=(angle, val), xytext=(angle, val + 5), ha='center', fontsize=10, fontweight='bold', color='#1565C0') ax.set_title('Test Set Metrics Overview', fontsize=14, fontweight='bold', pad=20) plt.tight_layout() path = os.path.join(out_dir, '3_metrics_radar.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 4. Confusion matrix # --------------------------------------------------------------------------- def plot_confusion_matrix(model, loader, device, args, out_dir): print(" Computing confusion matrix...") model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for batch in loader: s2 = batch['S2'].to(device) label = batch['label'] with autocast('cuda', enabled=True): logits = model(s2) pred = logits.argmax(dim=1).cpu().numpy() lbl = label.numpy() mask = lbl != IGNORE_INDEX all_preds.append(pred[mask]) all_labels.append(lbl[mask]) all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) # Only classes that appear in test set present = sorted(set(all_labels.tolist()) | set(all_preds.tolist())) present = [c for c in present if c != IGNORE_INDEX] cm = confusion_matrix(all_labels, all_preds, labels=present) cm_norm= cm.astype(float) / (cm.sum(axis=1, keepdims=True) + 1e-8) short = [SHORT_NAMES[c] for c in present] n = len(present) figsize= max(12, n * 0.7) fig, axes = plt.subplots(1, 2, figsize=(figsize*2 + 2, figsize)) # Raw counts ax = axes[0] im = ax.imshow(cm, cmap='Blues') ax.set_xticks(range(n)); ax.set_xticklabels(short, rotation=45, ha='right', fontsize=8) ax.set_yticks(range(n)); ax.set_yticklabels(short, fontsize=8) ax.set_title('Confusion Matrix (counts)', fontweight='bold', fontsize=12) ax.set_xlabel('Predicted'); ax.set_ylabel('True') plt.colorbar(im, ax=ax, shrink=0.8) # Normalized ax = axes[1] im2= ax.imshow(cm_norm, cmap='Blues', vmin=0, vmax=1) ax.set_xticks(range(n)); ax.set_xticklabels(short, rotation=45, ha='right', fontsize=8) ax.set_yticks(range(n)); ax.set_yticklabels(short, fontsize=8) # Annotate cells for i in range(n): for j in range(n): val = cm_norm[i, j] if val > 0.05: color = 'white' if val > 0.5 else 'black' ax.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=6, color=color) ax.set_title('Confusion Matrix (normalized)', fontweight='bold', fontsize=12) ax.set_xlabel('Predicted'); ax.set_ylabel('True') plt.colorbar(im2, ax=ax, shrink=0.8) plt.suptitle('AgriFM × PASTIS — Test Set Confusion Matrix', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(out_dir, '4_confusion_matrix.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 5. Prediction maps — qualitative examples # --------------------------------------------------------------------------- def plot_prediction_maps(model, dataset, device, args, out_dir): print(f" Plotting {args.n_samples} prediction maps...") model.eval() # Pick diverse samples np.random.seed(42) indices = np.random.choice(len(dataset), args.n_samples, replace=False) fig, axes = plt.subplots(args.n_samples, 4, figsize=(20, 5 * args.n_samples)) if args.n_samples == 1: axes = axes[np.newaxis, :] fig.suptitle('AgriFM × PASTIS — Prediction Examples\n' '(RGB | Ground Truth | Prediction | Difference)', fontsize=14, fontweight='bold') for row, idx in enumerate(indices): sample = dataset[idx] s2 = sample['S2'] # (T, C, H, W) label = sample['label'].numpy() # Forward pass with torch.no_grad(): inp = s2.unsqueeze(0).to(device) with autocast('cuda', enabled=True): logits = model(inp) pred = logits.argmax(dim=1).squeeze(0).cpu().numpy() # RGB: use bands 3,2,1 (Red, Green, Blue) from middle frame mid = s2.shape[0] // 2 rgb = s2[mid, [2,1,0], :, :].numpy() # (3, H, W) # Denormalize roughly for display rgb = (rgb * 0.3 + 0.5).clip(0, 1) rgb = np.transpose(rgb, (1, 2, 0)) # Mask void pred_show = pred.copy() label_show = label.copy() pred_show[label == IGNORE_INDEX] = IGNORE_INDEX label_show[label == IGNORE_INDEX] = IGNORE_INDEX # Difference map diff = np.zeros_like(pred) diff[pred_show == label_show] = 1 # correct diff[pred_show != label_show] = 0 # wrong diff[label == IGNORE_INDEX] = 2 # void pid = dataset.ids[idx] # Col 0: RGB axes[row,0].imshow(rgb) axes[row,0].set_title(f'Patch {pid} — RGB (T={mid})', fontsize=9) axes[row,0].axis('off') # Col 1: Ground truth axes[row,1].imshow(label_show, cmap=CMAP, vmin=0, vmax=19, interpolation='nearest') axes[row,1].set_title('Ground Truth', fontsize=9) axes[row,1].axis('off') # Col 2: Prediction axes[row,2].imshow(pred_show, cmap=CMAP, vmin=0, vmax=19, interpolation='nearest') axes[row,2].set_title('Prediction', fontsize=9) axes[row,2].axis('off') # Col 3: Difference diff_cmap = ListedColormap(['#e74c3c','#2ecc71','#95a5a6']) axes[row,3].imshow(diff, cmap=diff_cmap, vmin=0, vmax=2, interpolation='nearest') # Compute accuracy for this patch valid = label != IGNORE_INDEX if valid.sum() > 0: acc = (pred[valid] == label[valid]).mean() * 100 axes[row,3].set_title(f'Diff (acc={acc:.1f}%)\n' f'■ Wrong ■ Correct ■ Void', fontsize=9) axes[row,3].axis('off') # Class legend patches = [] for c in range(19): col = CLASS_COLORS[c] if col == '#000000': col = '#333333' patches.append(mpatches.Patch(color=col, label=f'{c}: {CLASS_NAMES[c]}')) fig.legend(handles=patches, loc='lower center', ncol=7, fontsize=7, bbox_to_anchor=(0.5, -0.01), framealpha=0.9) plt.tight_layout(rect=[0, 0.04, 1, 1]) path = os.path.join(out_dir, '5_prediction_maps.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 6. Class frequency vs IoU scatter # --------------------------------------------------------------------------- def plot_freq_vs_iou(test_results, out_dir): print(" Plotting frequency vs IoU scatter...") per_cls = test_results['per_class_iou'] # Count from args if available fig, ax = plt.subplots(figsize=(11, 7)) ious = list(per_cls.values()) names = list(per_cls.keys()) colors = [CLASS_COLORS[CLASS_NAMES.index(n)] if n in CLASS_NAMES else '#888' for n in names] colors = ['#444444' if c == '#000000' else c for c in colors] scatter = ax.scatter(range(len(names)), ious, c=colors, s=200, zorder=5, edgecolors='white', linewidths=1.5) for i, (name, iou) in enumerate(zip(names, ious)): ax.annotate(f'{iou:.1f}%', xy=(i, iou), xytext=(i, iou + 1.5), ha='center', fontsize=8, fontweight='bold') ax.axhline(np.mean(ious), color='red', linestyle='--', linewidth=2, label=f'mIoU = {np.mean(ious):.1f}%') ax.axhline(50, color='gray', linestyle=':', alpha=0.7, label='50% threshold') ax.set_xticks(range(len(names))) ax.set_xticklabels(names, rotation=45, ha='right', fontsize=9) ax.set_ylabel('IoU (%)', fontsize=12) ax.set_title('Per-Class IoU Overview\n(sorted by class index)', fontsize=13, fontweight='bold') ax.legend(fontsize=10) ax.grid(True, axis='y', alpha=0.3) ax.set_facecolor('#f8f8f8') ax.set_ylim(0, 100) plt.tight_layout() path = os.path.join(out_dir, '6_class_iou_scatter.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 7. Loss gap analysis (overfitting monitor) # --------------------------------------------------------------------------- def plot_overfitting_analysis(log_data, out_dir): print(" Plotting overfitting analysis...") epochs = [d['epoch'] for d in log_data] train_loss = [d['train_loss'] for d in log_data] val_loss = [d.get('val_loss', None) for d in log_data] ve = [e for e, v in zip(epochs, val_loss) if v is not None] tl = [t for t, v in zip(train_loss, val_loss) if v is not None] vl = [v for v in val_loss if v is not None] gap= [v - t for t, v in zip(tl, vl)] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle('Overfitting Analysis', fontsize=14, fontweight='bold') # Loss curves ax = axes[0] ax.plot(ve, tl, 'b-', linewidth=2, label='Train Loss') ax.plot(ve, vl, 'r-', linewidth=2, label='Val Loss') ax.fill_between(ve, tl, vl, where=[v > t for t, v in zip(tl, vl)], alpha=0.15, color='red', label='Overfit gap') ax.set_title('Train vs Val Loss', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('Loss') ax.legend() ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') # Gap ax = axes[1] ax.plot(ve, gap, 'purple', linewidth=2) ax.fill_between(ve, 0, gap, where=[g > 0 for g in gap], alpha=0.3, color='red', label='Overfit') ax.fill_between(ve, 0, gap, where=[g <= 0 for g in gap], alpha=0.3, color='green', label='Underfit') ax.axhline(0, color='black', linewidth=1) ax.set_title('Val Loss − Train Loss (gap)', fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel('Loss Gap') ax.legend() ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') plt.tight_layout() path = os.path.join(out_dir, '7_overfitting_analysis.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # 8. Summary card # --------------------------------------------------------------------------- def plot_summary_card(test_results, log_data, out_dir): print(" Plotting summary card...") metrics = test_results['test_metrics'] args_d = test_results.get('args', {}) fig = plt.figure(figsize=(16, 9)) fig.patch.set_facecolor('#1a1a2e') gs = gridspec.GridSpec(3, 4, figure=fig, hspace=0.5, wspace=0.4) title_color = '#e0e0e0' val_color = '#00d4ff' bg_card_color = '#16213e' def add_card(ax, title, value, unit='%', color=val_color): ax.set_facecolor(bg_card_color) ax.set_xticks([]); ax.set_yticks([]) for spine in ax.spines.values(): spine.set_edgecolor('#0f3460') spine.set_linewidth(2) ax.text(0.5, 0.65, f'{value}{unit}', transform=ax.transAxes, ha='center', va='center', fontsize=22, fontweight='bold', color=color) ax.text(0.5, 0.25, title, transform=ax.transAxes, ha='center', va='center', fontsize=10, color='#a0a0a0') # Metric cards card_data = [ ('mFscore', f"{metrics['mFscore']:.1f}", '#00d4ff'), ('mIoU', f"{metrics['mIoU']:.1f}", '#00ff88'), ('OA', f"{metrics['OA']:.1f}", '#ffaa00'), ('Kappa', f"{metrics['Kappa']:.1f}", '#ff6b6b'), ('mPrecision', f"{metrics['mPrecision']:.1f}", '#c084fc'), ('mRecall', f"{metrics['mRecall']:.1f}", '#fb923c'), ] positions = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1)] for (r,c), (name, val, color) in zip(positions, card_data): ax = fig.add_subplot(gs[r, c]) add_card(ax, name, val, '%', color) # Mini training curve ax = fig.add_subplot(gs[1, 2:]) ax.set_facecolor(bg_card_color) for spine in ax.spines.values(): spine.set_edgecolor('#0f3460'); spine.set_linewidth(2) mf = [d.get('mFscore', None) for d in log_data] me = [d['epoch'] for d, v in zip(log_data, mf) if v is not None] mf = [v for v in mf if v is not None] tl = [d['train_loss'] for d in log_data] vl = [d.get('val_loss', None) for d in log_data] ve = [d['epoch'] for d, v in zip(log_data, vl) if v is not None] vl = [v for v in vl if v is not None] ax2 = ax.twinx() ax.plot(me, mf, color='#00d4ff', linewidth=2, label='mFscore') ax2.plot([d['epoch'] for d in log_data], tl, color='#ff6b6b', linewidth=1.5, alpha=0.7, label='Train Loss') ax2.plot(ve, vl, color='#ffaa00', linewidth=1.5, alpha=0.7, label='Val Loss') ax.set_ylabel('mFscore (%)', color='#00d4ff', fontsize=9) ax2.set_ylabel('Loss', color='#ffaa00', fontsize=9) ax.set_xlabel('Epoch', color=title_color, fontsize=9) ax.tick_params(colors=title_color) ax2.tick_params(colors=title_color) ax.set_title('Training Progress', color=title_color, fontsize=10, fontweight='bold') ax.set_facecolor(bg_card_color) ax.grid(True, alpha=0.2, color='white') # Per-class IoU mini bar ax = fig.add_subplot(gs[2, :]) ax.set_facecolor(bg_card_color) for spine in ax.spines.values(): spine.set_edgecolor('#0f3460'); spine.set_linewidth(2) per_cls = test_results['per_class_iou'] items = sorted(per_cls.items(), key=lambda x: -x[1]) names = [x[0][:10] for x in items] vals = [x[1] for x in items] cols = [CLASS_COLORS[CLASS_NAMES.index(x[0])] if x[0] in CLASS_NAMES else '#888' for x in items] cols = ['#444444' if c == '#000000' else c for c in cols] bars = ax.bar(range(len(names)), vals, color=cols, edgecolor='#1a1a2e', linewidth=0.5) ax.axhline(np.mean(vals), color='white', linestyle='--', alpha=0.7, linewidth=1) ax.set_xticks(range(len(names))) ax.set_xticklabels(names, rotation=45, ha='right', fontsize=7, color=title_color) ax.set_ylabel('IoU (%)', color=title_color, fontsize=9) ax.tick_params(colors=title_color) ax.set_title('Per-Class IoU (sorted)', color=title_color, fontsize=10, fontweight='bold') ax.set_facecolor(bg_card_color) ax.set_ylim(0, 100) ax.grid(True, axis='y', alpha=0.2, color='white') fig.text(0.5, 0.97, 'AgriFM × PASTIS — Training Summary', ha='center', fontsize=16, fontweight='bold', color=title_color) fig.text(0.5, 0.935, f'Model: small (39.6M params) | ' f'Fold: {args_d.get("fold","1")} | ' f'Epochs: {len(log_data)} | ' f'Batch: {args_d.get("batch_size",16)} | ' f'LR: {args_d.get("lr","5e-5")}', ha='center', fontsize=10, color='#a0a0a0') plt.tight_layout(rect=[0, 0, 1, 0.93]) path = os.path.join(out_dir, '0_summary_card.png') plt.savefig(path, dpi=150, bbox_inches='tight', facecolor='#1a1a2e') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): args = get_args() out_dir = args.out_dir or os.path.join(args.work_dir, 'plots') os.makedirs(out_dir, exist_ok=True) print(f"\nAgriFM PASTIS — Visualization") print(f"Work dir : {args.work_dir}") print(f"Out dir : {out_dir}") print(f"{'─'*50}") # Load training log log_path = os.path.join(args.work_dir, 'log.json') with open(log_path) as f: log_data = json.load(f) print(f"Loaded {len(log_data)} epochs from log.json") # Load test results res_path = os.path.join(args.work_dir, 'test_results.json') with open(res_path) as f: test_results = json.load(f) print(f"Loaded test results: mFscore={test_results['test_metrics']['mFscore']}%") # Install seaborn if needed try: import seaborn as sns sns.set_style("whitegrid") except ImportError: pass # Plots that don't need the model print("\nGenerating plots...") plot_summary_card(test_results, log_data, out_dir) plot_training_curves(log_data, out_dir) plot_per_class_iou(test_results, out_dir) plot_radar(test_results, out_dir) plot_freq_vs_iou(test_results, out_dir) plot_overfitting_analysis(log_data, out_dir) # Plots that need the model print("\nLoading model for prediction maps...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.model_size == 'small': model = build_agrifm_pastis_small(num_classes=args.num_classes) elif args.model_size == 'tiny': model = build_agrifm_pastis_tiny(num_classes=args.num_classes) else: model = build_agrifm_pastis(num_classes=args.num_classes) ckpt = torch.load( os.path.join(args.work_dir, 'best_model.pth'), map_location=device, weights_only=False ) model.load_state_dict(ckpt['model']) model = model.to(device) model.eval() print(f"Loaded best model (epoch {ckpt.get('epoch','?')}, " f"mFscore={ckpt.get('best_mfscore','?'):.2f}%)") # Test dataset test_ds = PASTISDataset( args.data_root, fold=args.fold, split='test', num_frames=args.num_frames, augment=False ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) plot_confusion_matrix(model, test_loader, device, args, out_dir) plot_prediction_maps(model, test_ds, device, args, out_dir) print(f"\n{'═'*50}") print(f"All plots saved to: {out_dir}") print(f"\nFiles created:") for f in sorted(os.listdir(out_dir)): if f.endswith('.png'): size = os.path.getsize(os.path.join(out_dir, f)) / 1024 print(f" {f} ({size:.0f} KB)") if __name__ == '__main__': main()