""" Visualize results for all folds + cross-validation summary. Usage: python visualize_all_folds.py """ import os import sys import json import statistics import numpy as np from pathlib import Path import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import ListedColormap import torch from torch.utils.data import DataLoader from torch.amp import autocast sys.path.insert(0, str(Path(__file__).parent)) from models.agrifm import build_agrifm_pastis_small from datasets.pastis_dataset import PASTISDataset, PASTIS_CLASSES, IGNORE_INDEX from evaluation.metrics import SegmentationMetrics import numpy as np import torch.serialization torch.serialization.add_safe_globals([np.core.multiarray.scalar]) # --------------------------------------------------------------------------- FOLD_DIRS = { 1: './work_dirs/fold1_v3', 2: './work_dirs/fold2_small', 3: './work_dirs/fold3_small', 4: './work_dirs/fold4_small', 5: './work_dirs/fold5_small', } DATA_ROOT = '/workspace/project/PASTIS' OUT_DIR = './work_dirs/all_folds_summary' NUM_CLASSES = 20 NUM_FRAMES = 32 CLASS_NAMES = [PASTIS_CLASSES[i] for i in range(20)] SHORT_NAMES = ['BG','Meadow','S.Wheat','Corn','W.Barley','W.Rape', 'Sp.Barley','Sunflwr','Grapevn','Beet','W.Trit', 'W.Durum','Fruits','Potato','Leg.Fod','Soybeans', 'Orchard','Mixed','Sorghum','Void'] CLASS_COLORS = [ '#333333','#3cb371','#ffd700','#ff8c00','#8b4513','#ff1493', '#adff2f','#ffff00','#800080','#dc143c','#00bfff','#daa520', '#32cd32','#a0522d','#90ee90','#006400','#ff7f50','#87ceeb', '#bc8f8f','#808080', ] METRICS_KEYS = ['OA','mIoU','mFscore','mPrecision','mRecall','Kappa'] os.makedirs(OUT_DIR, exist_ok=True) # --------------------------------------------------------------------------- # Load all fold results # --------------------------------------------------------------------------- def load_results(): results = {} for fold, d in FOLD_DIRS.items(): path = os.path.join(d, 'test_results.json') log_path = os.path.join(d, 'log.json') if os.path.exists(path): with open(path) as f: results[fold] = json.load(f) if os.path.exists(log_path): with open(log_path) as f: results[fold]['log'] = json.load(f) print(f" Fold {fold}: mFscore={results[fold]['test_metrics']['mFscore']:.2f}%") else: print(f" Fold {fold}: NOT FOUND — skipping") return results # --------------------------------------------------------------------------- # Plot 1: Cross-validation metrics bar chart # --------------------------------------------------------------------------- def plot_cv_metrics(results, out_dir): print(" Plotting CV metrics comparison...") folds = sorted(results.keys()) x = np.arange(len(METRICS_KEYS)) width = 0.15 colors= ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0','#795548'] fig, ax = plt.subplots(figsize=(14, 7)) for i, fold in enumerate(folds): vals = [results[fold]['test_metrics'][k] for k in METRICS_KEYS] bars = ax.bar(x + i*width, vals, width, label=f'Fold {fold}', color=colors[i], alpha=0.85, edgecolor='white') # Mean line means = [] for k in METRICS_KEYS: v = [results[f]['test_metrics'][k] for f in folds] means.append(sum(v)/len(v)) ax.plot(x + (len(folds)-1)*width/2, means, 'k--o', linewidth=2, markersize=8, label='Mean', zorder=5) for xi, mv in zip(x + (len(folds)-1)*width/2, means): ax.annotate(f'{mv:.1f}%', xy=(xi, mv), xytext=(xi, mv+1.5), ha='center', fontsize=9, fontweight='bold') ax.set_xticks(x + (len(folds)-1)*width/2) ax.set_xticklabels(METRICS_KEYS, fontsize=12, fontweight='bold') ax.set_ylabel('Score (%)', fontsize=12) ax.set_ylim(0, 100) ax.set_title('AgriFM × PASTIS — Cross-Validation Results (All Folds)', fontsize=14, fontweight='bold') ax.legend(fontsize=10, ncol=len(folds)+1) ax.grid(True, axis='y', alpha=0.3) ax.set_facecolor('#f8f8f8') path = os.path.join(out_dir, 'CV1_metrics_all_folds.png') plt.tight_layout() plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Plot 2: Per-class IoU heatmap across folds # --------------------------------------------------------------------------- def plot_per_class_heatmap(results, out_dir): print(" Plotting per-class IoU heatmap...") folds = sorted(results.keys()) classes = [c for c in range(NUM_CLASSES) if c != IGNORE_INDEX] names = [CLASS_NAMES[c] for c in classes] data = np.zeros((len(folds), len(classes))) for i, fold in enumerate(folds): per_cls = results[fold]['per_class_iou'] for j, c in enumerate(classes): name = CLASS_NAMES[c] data[i, j] = per_cls.get(name, 0.) fig, ax = plt.subplots(figsize=(18, len(folds)*1.2 + 2)) im = ax.imshow(data, cmap='RdYlGn', vmin=0, vmax=100, aspect='auto') for i in range(len(folds)): for j in range(len(classes)): val = data[i, j] color = 'white' if val < 20 or val > 80 else 'black' ax.text(j, i, f'{val:.0f}', ha='center', va='center', fontsize=8, color=color, fontweight='bold') ax.set_xticks(range(len(classes))) ax.set_xticklabels(names, rotation=45, ha='right', fontsize=9) ax.set_yticks(range(len(folds))) ax.set_yticklabels([f'Fold {f}' for f in folds], fontsize=11) ax.set_title('Per-Class IoU Heatmap Across All Folds (%)', fontsize=13, fontweight='bold') plt.colorbar(im, ax=ax, shrink=0.8, label='IoU (%)') # Mean row means = data.mean(axis=0) ax2 = fig.add_axes([ax.get_position().x0, ax.get_position().y0 - 0.08, ax.get_position().width, 0.06]) im2 = ax2.imshow(means[np.newaxis, :], cmap='RdYlGn', vmin=0, vmax=100, aspect='auto') for j, mv in enumerate(means): color = 'white' if mv < 20 or mv > 80 else 'black' ax2.text(j, 0, f'{mv:.0f}', ha='center', va='center', fontsize=8, color=color, fontweight='bold') ax2.set_xticks([]) ax2.set_yticks([0]) ax2.set_yticklabels(['Mean'], fontsize=11) plt.tight_layout() path = os.path.join(out_dir, 'CV2_per_class_heatmap.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Plot 3: Training curves overlay (all folds on same plot) # --------------------------------------------------------------------------- def plot_training_overlay(results, out_dir): print(" Plotting training curves overlay...") folds = sorted(results.keys()) colors = ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0'] fig, axes = plt.subplots(1, 3, figsize=(18, 5)) fig.suptitle('Training Curves — All Folds Overlay', fontsize=14, fontweight='bold') for i, fold in enumerate(folds): if 'log' not in results[fold]: continue log = results[fold]['log'] epochs = [d['epoch'] for d in log] tl = [d['train_loss'] for d in log] vl = [d.get('val_loss', None) for d in log] mf = [d.get('mFscore', None) for d in log] mi = [d.get('mIoU', None) for d in log] ve = [e for e,v in zip(epochs,vl) if v is not None] vl = [v for v in vl if v is not None] me = [e for e,v in zip(epochs,mf) if v is not None] mf = [v for v in mf if v is not None] mi = [v for v in mi if v is not None] c = colors[i] axes[0].plot(epochs, tl, color=c, linewidth=1.5, label=f'Fold {fold} train') axes[0].plot(ve, vl, color=c, linewidth=1.5, linestyle='--', alpha=0.6) axes[1].plot(me, mf, color=c, linewidth=2, label=f'Fold {fold}') axes[2].plot(me, mi, color=c, linewidth=2, label=f'Fold {fold}') for ax, title, ylabel in zip( axes, ['Loss (solid=train, dashed=val)', 'mFscore (%)', 'mIoU (%)'], ['Loss', 'mFscore (%)', 'mIoU (%)'] ): ax.set_title(title, fontweight='bold') ax.set_xlabel('Epoch') ax.set_ylabel(ylabel) ax.legend(fontsize=8) ax.grid(True, alpha=0.3) ax.set_facecolor('#f8f8f8') plt.tight_layout() path = os.path.join(out_dir, 'CV3_training_curves_overlay.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Plot 4: Final summary table (mean ± std) # --------------------------------------------------------------------------- def plot_summary_table(results, out_dir): print(" Plotting summary table...") folds = sorted(results.keys()) fig, ax = plt.subplots(figsize=(14, len(folds)*0.6 + 3)) ax.axis('off') # Build table data col_labels = ['Fold'] + METRICS_KEYS table_data = [] for fold in folds: m = results[fold]['test_metrics'] row = [f'Fold {fold}'] + [f"{m[k]:.2f}%" for k in METRICS_KEYS] table_data.append(row) # Mean and std rows means = [] stds = [] for k in METRICS_KEYS: v = [results[f]['test_metrics'][k] for f in folds] means.append(f"{sum(v)/len(v):.2f}%") stds.append(f"{statistics.stdev(v):.2f}%" if len(v)>1 else "0.00%") table_data.append(['Mean'] + means) table_data.append(['Std'] + stds) table = ax.table( cellText=table_data, colLabels=col_labels, cellLoc='center', loc='center', ) table.auto_set_font_size(False) table.set_fontsize(11) table.scale(1.2, 2.0) # Style header for j in range(len(col_labels)): table[0, j].set_facecolor('#2E75B6') table[0, j].set_text_props(color='white', fontweight='bold') # Style mean/std rows for j in range(len(col_labels)): table[len(folds)+1, j].set_facecolor('#E2EFDA') table[len(folds)+1, j].set_text_props(fontweight='bold') table[len(folds)+2, j].set_facecolor('#FFF2CC') # Alternating row colors for i in range(1, len(folds)+1): bg = '#F8F8F8' if i % 2 == 0 else '#FFFFFF' for j in range(len(col_labels)): table[i, j].set_facecolor(bg) ax.set_title('AgriFM × PASTIS — 5-Fold Cross-Validation Summary', fontsize=14, fontweight='bold', pad=20) path = os.path.join(out_dir, 'CV4_summary_table.png') plt.tight_layout() plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Plot 5: Per-class mean IoU bar + std error bars # --------------------------------------------------------------------------- def plot_mean_per_class(results, out_dir): print(" Plotting mean per-class IoU with std...") folds = sorted(results.keys()) classes = [c for c in range(NUM_CLASSES) if c != IGNORE_INDEX] names = [CLASS_NAMES[c] for c in classes] means = [] stds = [] for c in classes: name = CLASS_NAMES[c] vals = [results[f]['per_class_iou'].get(name, 0.) for f in folds] means.append(sum(vals)/len(vals)) stds.append(statistics.stdev(vals) if len(vals)>1 else 0.) # Sort by mean IoU descending sorted_idx = np.argsort(means)[::-1] names_s = [names[i] for i in sorted_idx] means_s = [means[i] for i in sorted_idx] stds_s = [stds[i] for i in sorted_idx] colors_s = [CLASS_COLORS[classes[i]] for i in sorted_idx] colors_s = ['#444444' if c=='#000000' else c for c in colors_s] fig, ax = plt.subplots(figsize=(14, 7)) bars = ax.bar(range(len(names_s)), means_s, color=colors_s, edgecolor='white', linewidth=0.5, alpha=0.85) ax.errorbar(range(len(names_s)), means_s, yerr=stds_s, fmt='none', color='black', capsize=4, linewidth=1.5, label='±1 std across folds') mean_all = sum(means) / len(means) ax.axhline(mean_all, color='red', linestyle='--', linewidth=2, label=f'Mean mIoU = {mean_all:.1f}%') ax.axhline(50, color='gray', linestyle=':', alpha=0.5) for i, (mv, sv) in enumerate(zip(means_s, stds_s)): ax.text(i, mv + sv + 1, f'{mv:.1f}', ha='center', fontsize=8, fontweight='bold') ax.set_xticks(range(len(names_s))) ax.set_xticklabels(names_s, rotation=45, ha='right', fontsize=9) ax.set_ylabel('Mean IoU (%)', fontsize=12) ax.set_ylim(0, 105) ax.set_title('Mean Per-Class IoU with Standard Deviation (5-Fold CV)', fontsize=13, fontweight='bold') ax.legend(fontsize=10) ax.grid(True, axis='y', alpha=0.3) ax.set_facecolor('#f8f8f8') path = os.path.join(out_dir, 'CV5_mean_per_class_iou.png') plt.tight_layout() plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Plot 6: Box plots of metrics across folds # --------------------------------------------------------------------------- def plot_boxplots(results, out_dir): print(" Plotting metric boxplots...") folds = sorted(results.keys()) data = [[results[f]['test_metrics'][k] for f in folds] for k in METRICS_KEYS] fig, ax = plt.subplots(figsize=(12, 6)) bp = ax.boxplot(data, labels=METRICS_KEYS, patch_artist=True, medianprops=dict(color='black', linewidth=2)) colors = ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0','#795548'] for patch, color in zip(bp['boxes'], colors): patch.set_facecolor(color) patch.set_alpha(0.7) # Overlay individual fold points for i, vals in enumerate(data): x = np.random.normal(i+1, 0.04, size=len(vals)) ax.scatter(x, vals, color='black', s=40, zorder=5, alpha=0.8) for j, (xi, v) in enumerate(zip(x, vals)): ax.annotate(f'F{folds[j]}', (xi, v), textcoords='offset points', xytext=(5, 0), fontsize=7) ax.set_ylabel('Score (%)', fontsize=12) ax.set_title('Metric Distribution Across 5 Folds', fontsize=13, fontweight='bold') ax.grid(True, axis='y', alpha=0.3) ax.set_facecolor('#f8f8f8') path = os.path.join(out_dir, 'CV6_metric_boxplots.png') plt.tight_layout() plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f" Saved: {path}") # --------------------------------------------------------------------------- # Run individual fold visualizations too # --------------------------------------------------------------------------- def run_fold_visualizations(): print("\nRunning individual fold visualizations...") import subprocess for fold, d in FOLD_DIRS.items(): if not os.path.exists(os.path.join(d, 'test_results.json')): continue plots_dir = os.path.join(d, 'plots') if os.path.exists(plots_dir) and len(os.listdir(plots_dir)) >= 7: print(f" Fold {fold}: plots already exist, skipping") continue print(f" Running visualize_results.py for fold {fold}...") cmd = [ 'python', 'visualize_results.py', '--work_dir', d, '--data_root', DATA_ROOT, '--fold', str(fold), '--model_size', 'small', '--num_classes', '20', '--num_frames', '32', '--n_samples', '6', '--out_dir', plots_dir, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode == 0: print(f" Fold {fold} plots done") else: print(f" Fold {fold} error: {result.stderr[-200:]}") # --------------------------------------------------------------------------- # Print CV table to console # --------------------------------------------------------------------------- def print_cv_table(results): folds = sorted(results.keys()) print(f"\n{'='*75}") print("CROSS-VALIDATION RESULTS SUMMARY") print(f"{'='*75}") print(f"{'Fold':<6} {'OA':>7} {'mIoU':>7} {'mFscore':>8} " f"{'Prec':>7} {'Recall':>8} {'Kappa':>7}") print("─"*60) all_vals = {k: [] for k in METRICS_KEYS} for fold in folds: m = results[fold]['test_metrics'] for k in METRICS_KEYS: all_vals[k].append(m[k]) print(f" {fold} " f"{m['OA']:>7.2f} " f"{m['mIoU']:>7.2f} " f"{m['mFscore']:>8.2f} " f"{m['mPrecision']:>7.2f} " f"{m['mRecall']:>8.2f} " f"{m['Kappa']:>7.2f}") print("─"*60) means = [sum(all_vals[k])/len(all_vals[k]) for k in METRICS_KEYS] stds = [statistics.stdev(all_vals[k]) if len(all_vals[k])>1 else 0 for k in METRICS_KEYS] print(f" Mean " + "".join(f" {v:>7.2f}" for v in means)) print(f" Std " + "".join(f" {v:>7.2f}" for v in stds)) print(f"{'='*75}") # Save to JSON summary = { 'per_fold': {f: results[f]['test_metrics'] for f in folds}, 'mean': {k: round(sum(all_vals[k])/len(all_vals[k]),2) for k in METRICS_KEYS}, 'std': {k: round(statistics.stdev(all_vals[k]),2) if len(all_vals[k])>1 else 0 for k in METRICS_KEYS}, 'num_folds': len(folds), } with open(os.path.join(OUT_DIR, 'cv_summary.json'), 'w') as f: json.dump(summary, f, indent=2) print(f"\nSaved CV summary to {OUT_DIR}/cv_summary.json") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == '__main__': print(f"\nAgriFM PASTIS — All Folds Visualization") print(f"Output dir: {OUT_DIR}") print(f"{'─'*50}") print("Loading fold results...") results = load_results() if len(results) == 0: print("No fold results found!") exit(1) print_cv_table(results) print("\nGenerating cross-validation plots...") plot_cv_metrics(results, OUT_DIR) plot_per_class_heatmap(results, OUT_DIR) plot_training_overlay(results, OUT_DIR) plot_summary_table(results, OUT_DIR) plot_mean_per_class(results, OUT_DIR) plot_boxplots(results, OUT_DIR) # Individual fold plots run_fold_visualizations() # Copy everything to outputs import shutil for fold, d in FOLD_DIRS.items(): plots_d = os.path.join(d, 'plots') if os.path.exists(plots_d): for f in os.listdir(plots_d): if f.endswith('.png'): shutil.copy( os.path.join(plots_d, f), os.path.join(OUT_DIR, f'fold{fold}_{f}') ) print(f"\n{'='*50}") print(f"All plots saved to: {OUT_DIR}") print(f"Files created:") for f in sorted(os.listdir(OUT_DIR)): if f.endswith('.png') or f.endswith('.json'): size = os.path.getsize(os.path.join(OUT_DIR, f)) / 1024 print(f" {f} ({size:.0f} KB)")