import gradio as gr
import plotly.graph_objects as go
import os
import json
import math
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from huggingface_hub import snapshot_download
# Import your data loader
from plot_interactive_loader import load_experiment_data
# -----------------------------------------
# DATASET INITIALISATION
# -----------------------------------------
# Loads a HuggingFace dataset required for the demo.
# Ensures local folder is created and mirrors the HF snapshot.
# Requires HF_TOKEN environment variable for authentication.
hf_token = os.environ["HF_TOKEN"]
os.makedirs("data/gardensgp", exist_ok=True)
snapshot_download(
repo_id="medwa126/gardensgp",
repo_type="dataset",
local_dir="data/gardensgp",
token=hf_token
)
print("Dataset loaded")
# -----------------------------------------
# CONFIGURATION CONSTANTS
# -----------------------------------------
RESULTS_DIR = "experiment_output"
# Mapping of TPR classification → visual styling in the plot
TPR_MAPPING = {
"TP": {'color': 'green', 'marker': 'circle', 'legend': 'True Positive (Correct)'},
"FP": {'color': 'blue', 'marker': 'triangle-up', 'legend': 'False Positive (Wrong Match)'},
"FN": {'color': 'red', 'marker': 'x', 'legend': 'False Negative (No Match)'},
"TN": {'color': 'gray', 'marker': 'square', 'legend': 'True Negative (Correctly Ignored)'},
"N/A": {'color': 'black', 'marker': 'circle', 'legend': 'N/A'}
}
# -----------------------------------------
# GPS HELPERS
# -----------------------------------------
# Haversine distance between two GPS coordinates.
# Used to convert lat/lon to metre-scale distances.
def calculate_gps_distance_meters(lat1, lon1, lat2, lon2):
R = 6371000
lat1_rad = math.radians(lat1)
lon1_rad = math.radians(lon1)
lat2_rad = math.radians(lat2)
lon2_rad = math.radians(lon2)
dlat = lat2_rad - lat1_rad
dlon = lon2_rad - lon1_rad
a = math.sin(dlat/2)**2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon/2)**2
c = 2 * math.asin(math.sqrt(a))
return R * c
# Convert full GPS tracks to local x/y metre coordinates.
def to_meters_coords(coords_list, origin_lat, origin_lon):
x_meters, y_meters = [], []
for lat, lon in coords_list:
y = calculate_gps_distance_meters(origin_lat, origin_lon, lat, origin_lon)
x = calculate_gps_distance_meters(origin_lat, origin_lon, origin_lat, lon)
if lat < origin_lat: y *= -1
if lon < origin_lon: x *= -1
x_meters.append(x)
y_meters.append(y)
return x_meters, y_meters
# -----------------------------------------
# IMAGE SAFETY WRAPPER
# -----------------------------------------
# Ensures missing or corrupted images don’t break the UI.
def load_image_safe(image_path):
if not image_path or not os.path.exists(image_path):
return None
try:
return Image.open(image_path)
except Exception as e:
print(f"[WARNING] Could not load image {image_path}: {e}")
return None
# -----------------------------------------
# MAIN GPS PLOT BUILDER
# -----------------------------------------
# Loads experiment data for a selected method and constructs
# an interactive Plotly map with:
# - Query / database GPS paths
# - Per-query markers
# - Hover cards showing error + metadata
# - TPR-based styling
# Returns: (plotly_figure, predictions_dict)
def create_gps_plot(selected_method, predictions_store):
print(f"\n[DEBUG] Creating plot for method: {selected_method}")
# Load all experiment output (predictions, config, tolerance etc.)
data_dict = load_experiment_data(RESULTS_DIR)
# Handle missing/invalid experiment data gracefully
if data_dict is None:
return go.Figure().add_annotation(
text="❌ Failed to load experiment data",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=20, color="red")
), {}
methods_data = data_dict.get('method_data', [])
if not methods_data:
return go.Figure().add_annotation(
text="❌ No method data found",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=20, color="red")
), {}
tolerance = data_dict.get('tolerance', 10)
distance_unit = data_dict.get('distance_unit', 'meters')
# -------------------------------------------------------------
# Extract all GPS coordinates across all methods for plotting
# -------------------------------------------------------------
all_db_coords = set()
all_query_coords = set()
for method_data in methods_data:
predictions = method_data.get('predictions', [])
for pred in predictions:
gps = pred.get('gps_coordinates', {})
if gps:
all_query_coords.add((gps['query_lat'], gps['query_lon']))
all_db_coords.add((gps['predicted_lat'], gps['predicted_lon']))
if 'ground_truth_lat' in gps:
all_db_coords.add((gps['ground_truth_lat'], gps['ground_truth_lon']))
if not all_query_coords:
return go.Figure().add_annotation(
text="❌ No GPS coordinates found",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=20, color="red")
), {}
# Convert global GPS to local metre grid
origin_lat, origin_lon = sorted(list(all_query_coords))[0]
query_coords_m = to_meters_coords(sorted(list(all_query_coords)), origin_lat, origin_lon)
db_coords_m = to_meters_coords(sorted(list(all_db_coords)), origin_lat, origin_lon)
# -------------------------------------------------------------
# Identify selected method's prediction output
# -------------------------------------------------------------
selected_method_data = None
for method_data in methods_data:
method_name = method_data['config']['description'].split(' ')[0]
if method_name == selected_method:
selected_method_data = method_data
break
if not selected_method_data:
return go.Figure().add_annotation(
text=f"❌ Method '{selected_method}' not found",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=20, color="red")
), {}
predictions = selected_method_data.get('predictions', [])
if not predictions:
return go.Figure().add_annotation(
text=f"❌ No predictions for {selected_method}",
xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
font=dict(size=20, color="red")
), {}
# -------------------------------------------------------------
# Build dataframe for plotting + prediction lookup dictionary
# -------------------------------------------------------------
predictions_dict = {}
plot_data = []
for pred in predictions:
gps = pred['gps_coordinates']
plot_row = pred.copy()
gt_lat = gps.get('ground_truth_lat', gps['predicted_lat'])
gt_lon = gps.get('ground_truth_lon', gps['predicted_lon'])
# Add fields for DataFrame
plot_row.update({
'Query_Longitude': gps['query_lon'],
'Query_Latitude': gps['query_lat'],
'Predicted_Longitude': gps['predicted_lon'],
'Predicted_Latitude': gps['predicted_lat'],
'GroundTruth_Longitude': gt_lon,
'GroundTruth_Latitude': gt_lon
})
plot_data.append(plot_row)
# Build prediction lookup store used when showing clicked/selected images
query_idx = pred['query_index']
predictions_dict[str(query_idx)] = {
'query_image': pred.get('query_image_path', pred.get('query_image', '')),
'gt_image': pred.get('gt_image_path', pred.get('ground_truth_image', '')),
'pred_image': pred.get('predicted_image_path', pred.get('predicted_image', '')),
'query_idx': query_idx,
'predicted_idx': pred['predicted_index'],
'ground_truth_idx': pred.get('ground_truth_index', -1),
'distance_error': pred['distance_error'],
'is_correct': pred['distance_error'] < tolerance,
'tpr_classification': pred.get('tpr_classification', 'N/A')
}
plot_df = pd.DataFrame(plot_data)
# -------------------------------------------------------------
# Convert prediction GPS to metre coordinates for plotting
# -------------------------------------------------------------
coords_to_convert = plot_df[['Query_Latitude', 'Query_Longitude']].values.tolist()
plot_df['Query_X_Meters'], plot_df['Query_Y_Meters'] = to_meters_coords(
coords_to_convert, origin_lat, origin_lon
)
# -------------------------------------------------------------
# Assign correctness categories + styling strings
# -------------------------------------------------------------
unit_suffix = "m" if distance_unit == "meters" else "f"
tolerance_format = f"{tolerance:.1f}" if distance_unit == "meters" else f"{tolerance:.0f}"
plot_df['Accuracy_Category'] = np.select(
[plot_df['distance_error'] < tolerance],
[f'Correct (< {tolerance_format}{unit_suffix})'],
default=f'Incorrect (≥ {tolerance_format}{unit_suffix})'
)
plot_df['TPR_Classification'] = plot_df.apply(
lambda row: row.get('tpr_classification', 'N/A'),
axis=1
)
# -----------------------------------------
# CREATE MAIN PLOTLY FIGURE
# -----------------------------------------
fig = go.Figure()
# Background database path
fig.add_trace(go.Scatter(
x=db_coords_m[0], y=db_coords_m[1],
mode='markers', name='Database Points',
marker=dict(color='blue', size=4),
opacity=0.3, hoverinfo='skip'
))
# Query path
fig.add_trace(go.Scatter(
x=query_coords_m[0], y=query_coords_m[1],
mode='markers', name='Query Path',
marker=dict(size=4, color='red'),
opacity=0.3, hoverinfo='skip'
))
# -------------------------------------------------------------
# Add prediction markers with TPR shapes + correctness colour
# -------------------------------------------------------------
hover_text = []
marker_symbols = []
marker_colors = []
custom_data = []
color_map = {
f'Correct (< {tolerance_format}{unit_suffix})': 'green',
f'Incorrect (≥ {tolerance_format}{unit_suffix})': 'red'
}
tpr_symbol_map = {
'TP': 'circle',
'FP': 'triangle-up',
'FN': 'x',
'TN': 'square',
'N/A': 'circle'
}
for _, row in plot_df.iterrows():
query_idx = row['query_index']
is_correct = row['distance_error'] < tolerance
status_symbol = "✓" if is_correct else "✗"
tpr_class = row['TPR_Classification']
hover = (
f"Query Index: {query_idx}
"
f"Method: {selected_method}
"
f"Predicted Index: {row['predicted_index']}
"
f"Error: {row['distance_error']:.1f}{unit_suffix}
"
f"Status: {status_symbol} {row['Accuracy_Category']}
"
f"TPR: {TPR_MAPPING.get(tpr_class, {}).get('legend', tpr_class)}"
)
hover_text.append(hover)
marker_symbols.append(tpr_symbol_map.get(tpr_class, 'circle'))
marker_colors.append(color_map[row['Accuracy_Category']])
custom_data.append([query_idx])
# Add main prediction trace
fig.add_trace(go.Scatter(
x=plot_df['Query_X_Meters'],
y=plot_df['Query_Y_Meters'],
mode='markers',
name='Predictions',
marker=dict(
size=10,
color=marker_colors,
symbol=marker_symbols,
line=dict(width=1, color='Black')
),
text=hover_text,
hoverinfo='text',
customdata=custom_data,
showlegend=False
))
# -------------------------------------------------------------
# Add correctness legend entries
# -------------------------------------------------------------
for category, color in color_map.items():
count = len(plot_df[plot_df['Accuracy_Category'] == category])
if count > 0:
fig.add_trace(go.Scatter(
x=[None], y=[None], mode='markers',
marker=dict(size=10, color=color, symbol='circle', line=dict(width=1, color='Black')),
name=f'{category} ({count})',
hoverinfo='none'
))
# Add TPR legend entries
for tpr_class, tpr_info in TPR_MAPPING.items():
tpr_count = len(plot_df[plot_df['TPR_Classification'] == tpr_class])
if tpr_count > 0:
fig.add_trace(go.Scatter(
x=[None], y=[None], mode='markers',
marker=dict(size=10, color=tpr_info['color'],
symbol=tpr_info['marker'],
line=dict(width=1, color='Black')),
name=f'{tpr_info["legend"]} ({tpr_count})',
hoverinfo='none'
))
# Layout styling
fig.update_layout(
title=f'GPS Accuracy Heatmap - {selected_method}
'
f'Tolerance: {tolerance}{unit_suffix}',
xaxis_title='East-West Distance (Meters)',
yaxis_title='North-South Distance (Meters)',
yaxis=dict(scaleanchor="x", scaleratio=1),
legend_title="Accuracy / TPR",
hoverlabel=dict(bgcolor="white", font_size=12),
height=700,
template="plotly_white",
hovermode='closest'
)
print(f"[DEBUG] Plot created with {len(predictions_dict)} clickable points")
return fig, predictions_dict
# -----------------------------------------
# CLICK HANDLER
# -----------------------------------------
# Experimental: tries to detect clicked plot points.
# Currently uses fallback logic to display any valid prediction.
def handle_plot_click(evt: gr.SelectData, predictions_store):
print(f"[DEBUG] Plot clicked! Event data: {evt}")
if predictions_store is None or not predictions_store:
return None, None, None, "No data loaded. Please select a method first.", gr.update(visible=False), 0
# Fallback strategy due to limited click metadata from Plotly
available_indices = sorted([int(k) for k in predictions_store.keys()])
if not available_indices:
return None, None, None, "No predictions available", gr.update(visible=False), 0
query_idx = available_indices[0]
print(f"[DEBUG] Showing query index: {query_idx}")
pred_data = predictions_store[str(query_idx)]
query_img = load_image_safe(pred_data['query_image'])
gt_img = load_image_safe(pred_data['gt_image'])
pred_img = load_image_safe(pred_data['pred_image'])
status = "✓ CORRECT" if pred_data['is_correct'] else "✗ INCORRECT"
status_color = "green" if pred_data['is_correct'] else "red"
info_html = f"""
Status: {status}
Error: {pred_data['distance_error']:.1f}m
Note: Click detection is experimental.
👈 Select query index
") with gr.Column(visible=False) as image_row: query_image = gr.Image(label="🔍 Query", type="pil", height=200) gt_image = gr.Image(label="✅ Ground Truth", type="pil", height=200) pred_image = gr.Image(label="🎯 Prediction", type="pil", height=200) # Helpful instructions block gr.Markdown(""" --- ### 💡 How to Use 1. Select a method to load its predictions 2. Hover markers to read metadata 3. Enter index or use ⬅️ ➡️ 4. Click View Images to compare """) # ------------------------------------------------------------- # CALLBACK BINDINGS # ------------------------------------------------------------- def refresh_methods(): methods = get_available_methods() return gr.Dropdown(choices=methods, value=methods[0]) def update_plot_and_state(method): return create_gps_plot(method, {}) method_dropdown.change( fn=update_plot_and_state, inputs=[method_dropdown], outputs=[plot, predictions_state] ) refresh_btn.click( fn=refresh_methods, inputs=[], outputs=[method_dropdown] ) view_btn.click( fn=lambda idx, store: handle_manual_selection(int(idx), store), inputs=[query_idx_input, predictions_state], outputs=[query_image, gt_image, pred_image, image_info, image_row] ) def navigate_and_view(current_idx, direction, store): new_idx = navigate_query(current_idx, direction, store) results = handle_manual_selection(new_idx, store) return (new_idx,) + results prev_btn.click( fn=lambda idx, store: navigate_and_view(idx, "prev", store), inputs=[query_idx_input, predictions_state], outputs=[query_idx_input, query_image, gt_image, pred_image, image_info, image_row] ) next_btn.click( fn=lambda idx, store: navigate_and_view(idx, "next", store), inputs=[query_idx_input, predictions_state], outputs=[query_idx_input, query_image, gt_image, pred_image, image_info, image_row] ) demo.load( fn=update_plot_and_state, inputs=[method_dropdown], outputs=[plot, predictions_state] ) # ----------------------------------------- # RUN APP # ----------------------------------------- if __name__ == "__main__": print("🚀 Starting VPR GPS Analysis App...") print(f"📁 Working directory: {os.getcwd()}") print(f"📈 Results directory: {os.path.abspath(RESULTS_DIR)}") demo.launch(server_name="0.0.0.0", server_port=7860)