import gradio as gr import plotly.graph_objects as go import os import json import math import numpy as np import pandas as pd from pathlib import Path from PIL import Image from huggingface_hub import snapshot_download # Import your data loader from plot_interactive_loader import load_experiment_data # ----------------------------------------- # DATASET INITIALISATION # ----------------------------------------- # Loads a HuggingFace dataset required for the demo. # Ensures local folder is created and mirrors the HF snapshot. # Requires HF_TOKEN environment variable for authentication. hf_token = os.environ["HF_TOKEN"] os.makedirs("data/gardensgp", exist_ok=True) snapshot_download( repo_id="medwa126/gardensgp", repo_type="dataset", local_dir="data/gardensgp", token=hf_token ) print("Dataset loaded") # ----------------------------------------- # CONFIGURATION CONSTANTS # ----------------------------------------- RESULTS_DIR = "experiment_output" # Mapping of TPR classification → visual styling in the plot TPR_MAPPING = { "TP": {'color': 'green', 'marker': 'circle', 'legend': 'True Positive (Correct)'}, "FP": {'color': 'blue', 'marker': 'triangle-up', 'legend': 'False Positive (Wrong Match)'}, "FN": {'color': 'red', 'marker': 'x', 'legend': 'False Negative (No Match)'}, "TN": {'color': 'gray', 'marker': 'square', 'legend': 'True Negative (Correctly Ignored)'}, "N/A": {'color': 'black', 'marker': 'circle', 'legend': 'N/A'} } # ----------------------------------------- # GPS HELPERS # ----------------------------------------- # Haversine distance between two GPS coordinates. # Used to convert lat/lon to metre-scale distances. def calculate_gps_distance_meters(lat1, lon1, lat2, lon2): R = 6371000 lat1_rad = math.radians(lat1) lon1_rad = math.radians(lon1) lat2_rad = math.radians(lat2) lon2_rad = math.radians(lon2) dlat = lat2_rad - lat1_rad dlon = lon2_rad - lon1_rad a = math.sin(dlat/2)**2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon/2)**2 c = 2 * math.asin(math.sqrt(a)) return R * c # Convert full GPS tracks to local x/y metre coordinates. def to_meters_coords(coords_list, origin_lat, origin_lon): x_meters, y_meters = [], [] for lat, lon in coords_list: y = calculate_gps_distance_meters(origin_lat, origin_lon, lat, origin_lon) x = calculate_gps_distance_meters(origin_lat, origin_lon, origin_lat, lon) if lat < origin_lat: y *= -1 if lon < origin_lon: x *= -1 x_meters.append(x) y_meters.append(y) return x_meters, y_meters # ----------------------------------------- # IMAGE SAFETY WRAPPER # ----------------------------------------- # Ensures missing or corrupted images don’t break the UI. def load_image_safe(image_path): if not image_path or not os.path.exists(image_path): return None try: return Image.open(image_path) except Exception as e: print(f"[WARNING] Could not load image {image_path}: {e}") return None # ----------------------------------------- # MAIN GPS PLOT BUILDER # ----------------------------------------- # Loads experiment data for a selected method and constructs # an interactive Plotly map with: # - Query / database GPS paths # - Per-query markers # - Hover cards showing error + metadata # - TPR-based styling # Returns: (plotly_figure, predictions_dict) def create_gps_plot(selected_method, predictions_store): print(f"\n[DEBUG] Creating plot for method: {selected_method}") # Load all experiment output (predictions, config, tolerance etc.) data_dict = load_experiment_data(RESULTS_DIR) # Handle missing/invalid experiment data gracefully if data_dict is None: return go.Figure().add_annotation( text="❌ Failed to load experiment data", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color="red") ), {} methods_data = data_dict.get('method_data', []) if not methods_data: return go.Figure().add_annotation( text="❌ No method data found", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color="red") ), {} tolerance = data_dict.get('tolerance', 10) distance_unit = data_dict.get('distance_unit', 'meters') # ------------------------------------------------------------- # Extract all GPS coordinates across all methods for plotting # ------------------------------------------------------------- all_db_coords = set() all_query_coords = set() for method_data in methods_data: predictions = method_data.get('predictions', []) for pred in predictions: gps = pred.get('gps_coordinates', {}) if gps: all_query_coords.add((gps['query_lat'], gps['query_lon'])) all_db_coords.add((gps['predicted_lat'], gps['predicted_lon'])) if 'ground_truth_lat' in gps: all_db_coords.add((gps['ground_truth_lat'], gps['ground_truth_lon'])) if not all_query_coords: return go.Figure().add_annotation( text="❌ No GPS coordinates found", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color="red") ), {} # Convert global GPS to local metre grid origin_lat, origin_lon = sorted(list(all_query_coords))[0] query_coords_m = to_meters_coords(sorted(list(all_query_coords)), origin_lat, origin_lon) db_coords_m = to_meters_coords(sorted(list(all_db_coords)), origin_lat, origin_lon) # ------------------------------------------------------------- # Identify selected method's prediction output # ------------------------------------------------------------- selected_method_data = None for method_data in methods_data: method_name = method_data['config']['description'].split(' ')[0] if method_name == selected_method: selected_method_data = method_data break if not selected_method_data: return go.Figure().add_annotation( text=f"❌ Method '{selected_method}' not found", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color="red") ), {} predictions = selected_method_data.get('predictions', []) if not predictions: return go.Figure().add_annotation( text=f"❌ No predictions for {selected_method}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=20, color="red") ), {} # ------------------------------------------------------------- # Build dataframe for plotting + prediction lookup dictionary # ------------------------------------------------------------- predictions_dict = {} plot_data = [] for pred in predictions: gps = pred['gps_coordinates'] plot_row = pred.copy() gt_lat = gps.get('ground_truth_lat', gps['predicted_lat']) gt_lon = gps.get('ground_truth_lon', gps['predicted_lon']) # Add fields for DataFrame plot_row.update({ 'Query_Longitude': gps['query_lon'], 'Query_Latitude': gps['query_lat'], 'Predicted_Longitude': gps['predicted_lon'], 'Predicted_Latitude': gps['predicted_lat'], 'GroundTruth_Longitude': gt_lon, 'GroundTruth_Latitude': gt_lon }) plot_data.append(plot_row) # Build prediction lookup store used when showing clicked/selected images query_idx = pred['query_index'] predictions_dict[str(query_idx)] = { 'query_image': pred.get('query_image_path', pred.get('query_image', '')), 'gt_image': pred.get('gt_image_path', pred.get('ground_truth_image', '')), 'pred_image': pred.get('predicted_image_path', pred.get('predicted_image', '')), 'query_idx': query_idx, 'predicted_idx': pred['predicted_index'], 'ground_truth_idx': pred.get('ground_truth_index', -1), 'distance_error': pred['distance_error'], 'is_correct': pred['distance_error'] < tolerance, 'tpr_classification': pred.get('tpr_classification', 'N/A') } plot_df = pd.DataFrame(plot_data) # ------------------------------------------------------------- # Convert prediction GPS to metre coordinates for plotting # ------------------------------------------------------------- coords_to_convert = plot_df[['Query_Latitude', 'Query_Longitude']].values.tolist() plot_df['Query_X_Meters'], plot_df['Query_Y_Meters'] = to_meters_coords( coords_to_convert, origin_lat, origin_lon ) # ------------------------------------------------------------- # Assign correctness categories + styling strings # ------------------------------------------------------------- unit_suffix = "m" if distance_unit == "meters" else "f" tolerance_format = f"{tolerance:.1f}" if distance_unit == "meters" else f"{tolerance:.0f}" plot_df['Accuracy_Category'] = np.select( [plot_df['distance_error'] < tolerance], [f'Correct (< {tolerance_format}{unit_suffix})'], default=f'Incorrect (≥ {tolerance_format}{unit_suffix})' ) plot_df['TPR_Classification'] = plot_df.apply( lambda row: row.get('tpr_classification', 'N/A'), axis=1 ) # ----------------------------------------- # CREATE MAIN PLOTLY FIGURE # ----------------------------------------- fig = go.Figure() # Background database path fig.add_trace(go.Scatter( x=db_coords_m[0], y=db_coords_m[1], mode='markers', name='Database Points', marker=dict(color='blue', size=4), opacity=0.3, hoverinfo='skip' )) # Query path fig.add_trace(go.Scatter( x=query_coords_m[0], y=query_coords_m[1], mode='markers', name='Query Path', marker=dict(size=4, color='red'), opacity=0.3, hoverinfo='skip' )) # ------------------------------------------------------------- # Add prediction markers with TPR shapes + correctness colour # ------------------------------------------------------------- hover_text = [] marker_symbols = [] marker_colors = [] custom_data = [] color_map = { f'Correct (< {tolerance_format}{unit_suffix})': 'green', f'Incorrect (≥ {tolerance_format}{unit_suffix})': 'red' } tpr_symbol_map = { 'TP': 'circle', 'FP': 'triangle-up', 'FN': 'x', 'TN': 'square', 'N/A': 'circle' } for _, row in plot_df.iterrows(): query_idx = row['query_index'] is_correct = row['distance_error'] < tolerance status_symbol = "✓" if is_correct else "✗" tpr_class = row['TPR_Classification'] hover = ( f"Query Index: {query_idx}
" f"Method: {selected_method}
" f"Predicted Index: {row['predicted_index']}
" f"Error: {row['distance_error']:.1f}{unit_suffix}
" f"Status: {status_symbol} {row['Accuracy_Category']}
" f"TPR: {TPR_MAPPING.get(tpr_class, {}).get('legend', tpr_class)}" ) hover_text.append(hover) marker_symbols.append(tpr_symbol_map.get(tpr_class, 'circle')) marker_colors.append(color_map[row['Accuracy_Category']]) custom_data.append([query_idx]) # Add main prediction trace fig.add_trace(go.Scatter( x=plot_df['Query_X_Meters'], y=plot_df['Query_Y_Meters'], mode='markers', name='Predictions', marker=dict( size=10, color=marker_colors, symbol=marker_symbols, line=dict(width=1, color='Black') ), text=hover_text, hoverinfo='text', customdata=custom_data, showlegend=False )) # ------------------------------------------------------------- # Add correctness legend entries # ------------------------------------------------------------- for category, color in color_map.items(): count = len(plot_df[plot_df['Accuracy_Category'] == category]) if count > 0: fig.add_trace(go.Scatter( x=[None], y=[None], mode='markers', marker=dict(size=10, color=color, symbol='circle', line=dict(width=1, color='Black')), name=f'{category} ({count})', hoverinfo='none' )) # Add TPR legend entries for tpr_class, tpr_info in TPR_MAPPING.items(): tpr_count = len(plot_df[plot_df['TPR_Classification'] == tpr_class]) if tpr_count > 0: fig.add_trace(go.Scatter( x=[None], y=[None], mode='markers', marker=dict(size=10, color=tpr_info['color'], symbol=tpr_info['marker'], line=dict(width=1, color='Black')), name=f'{tpr_info["legend"]} ({tpr_count})', hoverinfo='none' )) # Layout styling fig.update_layout( title=f'GPS Accuracy Heatmap - {selected_method}
' f'Tolerance: {tolerance}{unit_suffix}', xaxis_title='East-West Distance (Meters)', yaxis_title='North-South Distance (Meters)', yaxis=dict(scaleanchor="x", scaleratio=1), legend_title="Accuracy / TPR", hoverlabel=dict(bgcolor="white", font_size=12), height=700, template="plotly_white", hovermode='closest' ) print(f"[DEBUG] Plot created with {len(predictions_dict)} clickable points") return fig, predictions_dict # ----------------------------------------- # CLICK HANDLER # ----------------------------------------- # Experimental: tries to detect clicked plot points. # Currently uses fallback logic to display any valid prediction. def handle_plot_click(evt: gr.SelectData, predictions_store): print(f"[DEBUG] Plot clicked! Event data: {evt}") if predictions_store is None or not predictions_store: return None, None, None, "No data loaded. Please select a method first.", gr.update(visible=False), 0 # Fallback strategy due to limited click metadata from Plotly available_indices = sorted([int(k) for k in predictions_store.keys()]) if not available_indices: return None, None, None, "No predictions available", gr.update(visible=False), 0 query_idx = available_indices[0] print(f"[DEBUG] Showing query index: {query_idx}") pred_data = predictions_store[str(query_idx)] query_img = load_image_safe(pred_data['query_image']) gt_img = load_image_safe(pred_data['gt_image']) pred_img = load_image_safe(pred_data['pred_image']) status = "✓ CORRECT" if pred_data['is_correct'] else "✗ INCORRECT" status_color = "green" if pred_data['is_correct'] else "red" info_html = f"""

Image Comparison - Query #{query_idx}

Status: {status}

Error: {pred_data['distance_error']:.1f}m

Note: Click detection is experimental.

""" return query_img, gt_img, pred_img, info_html, gr.update(visible=True), query_idx # ----------------------------------------- # MANUAL QUERY SELECTION HANDLER # ----------------------------------------- # This is the primary reliable way users inspect images. def handle_manual_selection(query_idx, predictions_store): if predictions_store is None or not predictions_store: return None, None, None, "No data loaded. Please select a method first.", gr.update(visible=False) try: query_idx = int(query_idx) print(f"[DEBUG] Viewing query index: {query_idx}") pred_data = predictions_store.get(str(query_idx)) if not pred_data: available_indices = sorted([int(k) for k in predictions_store.keys()]) min_idx = min(available_indices) max_idx = max(available_indices) return None, None, None, ( f"❌ Query index {query_idx} not found. " f"Valid range: {min_idx}-{max_idx}" ), gr.update(visible=False) query_img = load_image_safe(pred_data['query_image']) gt_img = load_image_safe(pred_data['gt_image']) pred_img = load_image_safe(pred_data['pred_image']) status = "✓ CORRECT" if pred_data['is_correct'] else "✗ INCORRECT" status_color = "green" if pred_data['is_correct'] else "red" tpr_class = pred_data['tpr_classification'] info_html = f"""
Query #{pred_data['query_idx']} {status}
GT: #{pred_data['ground_truth_idx']}
Pred: #{pred_data['predicted_idx']}
Error: {pred_data['distance_error']:.1f}m
TPR: {tpr_class}
""" return query_img, gt_img, pred_img, info_html, gr.update(visible=True) except Exception as e: print(f"[ERROR] Error loading images: {e}") return None, None, None, f"❌ Error: {str(e)}", gr.update(visible=False) # ----------------------------------------- # NAVIGATION HELPERS # ----------------------------------------- # Move to next/previous query index. def navigate_query(current_idx, direction, predictions_store): if not predictions_store: return current_idx try: current_idx = int(current_idx) available_indices = sorted([int(k) for k in predictions_store.keys()]) if direction == "next": next_indices = [idx for idx in available_indices if idx > current_idx] return next_indices[0] if next_indices else available_indices[-1] else: prev_indices = [idx for idx in available_indices if idx < current_idx] return prev_indices[-1] if prev_indices else available_indices[0] except Exception: return current_idx # ----------------------------------------- # DISCOVER AVAILABLE METHODS # ----------------------------------------- def get_available_methods(): data_dict = load_experiment_data(RESULTS_DIR) if data_dict and data_dict.get('method_data'): return [m['config']['description'].split(' ')[0] for m in data_dict['method_data']] return ["No methods available"] # ----------------------------------------- # GRADIO INTERFACE # ----------------------------------------- # Builds full interactive dashboard layout: # - Method selector # - GPS plot # - Query navigator # - Image comparison panel with gr.Blocks(title="VPR Interactive GPS Analysis", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧭 VPR Interactive GPS Analysis\nExplore GPS accuracy heatmaps and view image comparisons interactively.") predictions_state = gr.State({}) # Method selector + refresh button with gr.Row(): method_dropdown = gr.Dropdown( choices=get_available_methods(), value=get_available_methods()[0], label="🔽 Select Method", interactive=True, scale=9 ) refresh_btn = gr.Button("🔄", size="sm", scale=1, min_width=50) # Main UI layout: plot left, image viewer right with gr.Row(): with gr.Column(scale=3): plot = gr.Plot(label="📊 GPS Accuracy Heatmap") # Query navigation controls with gr.Row(): with gr.Column(scale=2): query_idx_input = gr.Number( label="🔢 Query Index", value=0, precision=0, minimum=0, info="Hover over markers to see their index" ) with gr.Column(scale=1): with gr.Row(): prev_btn = gr.Button("⬅️", size="sm") next_btn = gr.Button("➡️", size="sm") view_btn = gr.Button("👁️ View Images", variant="primary", size="lg") # Right panel: image viewer with gr.Column(scale=2): image_info = gr.HTML("

👈 Select query index

") with gr.Column(visible=False) as image_row: query_image = gr.Image(label="🔍 Query", type="pil", height=200) gt_image = gr.Image(label="✅ Ground Truth", type="pil", height=200) pred_image = gr.Image(label="🎯 Prediction", type="pil", height=200) # Helpful instructions block gr.Markdown(""" --- ### 💡 How to Use 1. Select a method to load its predictions 2. Hover markers to read metadata 3. Enter index or use ⬅️ ➡️ 4. Click View Images to compare """) # ------------------------------------------------------------- # CALLBACK BINDINGS # ------------------------------------------------------------- def refresh_methods(): methods = get_available_methods() return gr.Dropdown(choices=methods, value=methods[0]) def update_plot_and_state(method): return create_gps_plot(method, {}) method_dropdown.change( fn=update_plot_and_state, inputs=[method_dropdown], outputs=[plot, predictions_state] ) refresh_btn.click( fn=refresh_methods, inputs=[], outputs=[method_dropdown] ) view_btn.click( fn=lambda idx, store: handle_manual_selection(int(idx), store), inputs=[query_idx_input, predictions_state], outputs=[query_image, gt_image, pred_image, image_info, image_row] ) def navigate_and_view(current_idx, direction, store): new_idx = navigate_query(current_idx, direction, store) results = handle_manual_selection(new_idx, store) return (new_idx,) + results prev_btn.click( fn=lambda idx, store: navigate_and_view(idx, "prev", store), inputs=[query_idx_input, predictions_state], outputs=[query_idx_input, query_image, gt_image, pred_image, image_info, image_row] ) next_btn.click( fn=lambda idx, store: navigate_and_view(idx, "next", store), inputs=[query_idx_input, predictions_state], outputs=[query_idx_input, query_image, gt_image, pred_image, image_info, image_row] ) demo.load( fn=update_plot_and_state, inputs=[method_dropdown], outputs=[plot, predictions_state] ) # ----------------------------------------- # RUN APP # ----------------------------------------- if __name__ == "__main__": print("🚀 Starting VPR GPS Analysis App...") print(f"📁 Working directory: {os.getcwd()}") print(f"📈 Results directory: {os.path.abspath(RESULTS_DIR)}") demo.launch(server_name="0.0.0.0", server_port=7860)