import gradio as gr import os import torch import faiss import pandas as pd import numpy as np from PIL import Image from huggingface_hub import hf_hub_download import torchvision.transforms as tfm import torchvision.transforms.v2 as v2 import requests from io import BytesIO import urllib.parse from sklearn.decomposition import PCA import base64 # --- Import your model definitions --- from dinov2 import DINOv2FeatureExtractor from dinov3 import DINOv3FeatureExtractor # --- Constants & Configuration --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_USERNAME = "pawlo2013" DEFAULT_DATASET = "Cars196" DEFAULT_VERSION = "3" DEFAULT_SIZE = "b" class GlobalState: model = None index = None mapping_df = None transform = None current_config = {} current_results_text = "" pca_model = None # Cache for the PCA transformation state = GlobalState() # ========================================== # 1. HELPER FUNCTIONS # ========================================== def extract_class_name(url): try: decoded_url = urllib.parse.unquote(url) parts = decoded_url.split('/') if len(parts) >= 2: class_folder = parts[-2] return class_folder.replace('_', ' ') return "Unknown" except Exception: return "N/A" def get_transforms(dino_version): width, height = (518, 518) if dino_version == "2" else (512, 512) return tfm.Compose([ v2.RGB(), tfm.Resize(size=(width, height), antialias=True), tfm.ToTensor(), tfm.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def construct_image_url(file_path, dataset_name): image_repo_id = f"{HF_USERNAME}/{dataset_name}" clean_path = file_path.replace("\\", "/") prefix = f"data/{dataset_name}/" if clean_path.startswith(prefix): clean_path = clean_path.replace(prefix, "", 1) elif clean_path.startswith("data/"): clean_path = clean_path.replace("data/", "", 1) if dataset_name == "StanfordOnlineProducts": if not clean_path.startswith("Stanford_Online_Products"): clean_path = f"Stanford_Online_Products/{clean_path}" return f"https://huggingface.co/datasets/{image_repo_id}/resolve/main/{clean_path}" def construct_repo_id(dataset, version, size, finetune): run_name = ( f"{dataset}_dino{version}" f"{'_finetune_' if finetune else ''}" f"{size}" ) index_repo_id = f"{HF_USERNAME}/{run_name}" model_repo_id = f"{HF_USERNAME}/{run_name}-model" return index_repo_id, model_repo_id def load_resources(dataset, dino_version, dino_size, is_finetuned): config_key = f"{dataset}_{dino_version}_{dino_size}_{is_finetuned}" if state.current_config.get("key") == config_key: return (f"Resources already loaded for {config_key}!", state.current_results_text) index_repo, model_repo = construct_repo_id(dataset, dino_version, dino_size, is_finetuned) results_display = "No results available." try: try: results_path = hf_hub_download(repo_id=index_repo, filename="results.txt", repo_type="dataset") with open(results_path, 'r', encoding='utf-8') as f: raw_text = f.read() results_display = f"```text\n{raw_text}\n```" except Exception: results_display = "⚠️ `results.txt` not found." index_path = hf_hub_download(repo_id=index_repo, filename="faiss_index.bin", repo_type="dataset") csv_path = hf_hub_download(repo_id=index_repo, filename="faiss_index_mapping.csv", repo_type="dataset") state.index = faiss.read_index(index_path) state.mapping_df = pd.read_csv(csv_path) state.mapping_df['image_url'] = state.mapping_df['file_path'].apply( lambda x: construct_image_url(x, dataset) ) if dino_version == "3": model_name_map = { "s": "facebook/dinov3-vits16-pretrain-lvd1689m", "b": "facebook/dinov3-vitb16-pretrain-lvd1689m", "l": "facebook/dinov3-vitl16-pretrain-lvd1689m" } state.model = DINOv3FeatureExtractor(model_type=model_name_map[dino_size]) if is_finetuned: weights_path = hf_hub_download(repo_id=model_repo, filename="best_model.pth", repo_type="model") state.model.load_state_dict(torch.load(weights_path, map_location=DEVICE, weights_only=True)) state.model.to(DEVICE) state.model.eval() state.transform = get_transforms(dino_version) state.current_config = {"key": config_key} state.current_results_text = results_display return f"✅ Successfully loaded {dataset}", results_display except Exception as e: return f"❌ Error: {str(e)}", "Error loading stats." def pil_to_base64(pil_img): """Converts a PIL Image to a base64 data URI string.""" img_buffer = BytesIO() pil_img = pil_img.convert("RGB") pil_img.save(img_buffer, format="JPEG") byte_data = img_buffer.getvalue() base64_str = base64.b64encode(byte_data).decode("utf-8") return f"data:image/jpeg;base64,{base64_str}" def fetch_image_from_url(url): try: if url.startswith("data:image"): header, encoded = url.split(",", 1) data = base64.b64decode(encoded) return Image.open(BytesIO(data)).convert("RGB") headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, timeout=5) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") except Exception: return Image.new("RGB", (224, 224), color="red") def get_example_images(num_examples=10): if state.mapping_df is None: return [] test_df = state.mapping_df[state.mapping_df['split'] == 'test'] if test_df.empty: test_df = state.mapping_df sample = test_df.sample(n=min(len(test_df), num_examples)) return [(fetch_image_from_url(row['image_url']), row['image_url']) for _, row in sample.iterrows()] def process_image(image_input, k_neighbors): if state.model is None or state.index is None: return [], "⚠️ Please wait for model to load..." try: k = int(k_neighbors) if isinstance(image_input, str): query_img = fetch_image_from_url(image_input) else: query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): embedding = state.model(img_tensor).cpu().numpy().astype(np.float32) faiss.normalize_L2(embedding) distances, indices = state.index.search(embedding, k) results = [] for dist, idx in zip(distances[0], indices[0]): if idx < 0 or idx >= len(state.mapping_df): continue row = state.mapping_df.iloc[idx] url = row['image_url'] class_name = extract_class_name(url) caption = f"Class: {class_name}\nSim: {dist:.3f}" res_img = fetch_image_from_url(url) results.append((res_img, caption)) return results, f"✅ Found {k} matches." except Exception as e: return [], f"❌ Search failed: {str(e)}" # ========================================== # 2. HEADLESS API FUNCTIONS # ========================================== def get_faiss_samples(index_path, dataset_name, num_samples): """ API endpoint function. Takes a FAISS index path, dataset name, and number of samples. Returns file path, class name, image URL (standard string), and 3D PCA coordinates. """ try: if index_path.endswith('.bin'): csv_path = index_path.replace('.bin', '_mapping.csv') elif os.path.isdir(index_path): csv_path = os.path.join(index_path, 'faiss_index_mapping.csv') else: csv_path = index_path if not os.path.exists(csv_path): if state.mapping_df is not None: df = state.mapping_df else: return {"error": f"Mapping file not found at {csv_path} and no active memory state."} else: df = pd.read_csv(csv_path) if state.index is not None: faiss_idx = state.index else: if not os.path.exists(index_path): return {"error": f"FAISS index not found at {index_path} and not in memory."} faiss_idx = faiss.read_index(index_path) try: faiss_idx.reconstruct(0) except RuntimeError: try: faiss_idx.make_direct_map() except AttributeError: pass n = int(num_samples) sample_df = df.sample(n=min(n, len(df))) vectors = [] valid_indices = [] for orig_idx, row in sample_df.iterrows(): try: vec = faiss_idx.reconstruct(int(orig_idx)) vectors.append(vec) valid_indices.append(orig_idx) except Exception as e: continue vectors = np.array(vectors) if len(vectors) >= 3: pca = PCA(n_components=3) pca_coords = pca.fit_transform(vectors) state.pca_model = pca # <-- Cache the fitted PCA model else: pca_coords = np.zeros((len(vectors), 3)) state.pca_model = None results = [] for i, orig_idx in enumerate(valid_indices): row = sample_df.loc[orig_idx] file_path = str(row.get('file_path', '')) class_name = extract_class_name(file_path) if 'image_url' in row and pd.notna(row['image_url']): img_url = row['image_url'] else: img_url = construct_image_url(file_path, dataset_name) clean_path = file_path.replace('\\', '/') results.append({ "file_path": clean_path, "class_name": class_name, "image_url": img_url, # <-- Standard URL (No Base64 overhead) "pca_3d": pca_coords[i].tolist() }) return {"samples": results} except Exception as e: return {"error": str(e)} def embed_image_api(image_input, index_path, dataset_name, skip_pca=False): """ API endpoint function. Embeds the Image using the model. If skip_pca is False, projects it into 3D using the cached PCA (or calculates it via index fallback). Returns it with the raw_vector and Base64 image. """ if state.model is None: return {"error": "Model not loaded. Please trigger 'Re-Load Resources' via UI or API first."} try: if isinstance(image_input, str): query_img = fetch_image_from_url(image_input) else: query_img = Image.fromarray(image_input) if isinstance(image_input, np.ndarray) else image_input img_tensor = state.transform(query_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): embedding = state.model(img_tensor).cpu().numpy().astype(np.float32) faiss.normalize_L2(embedding) if skip_pca: pca_3d = [0.0, 0.0, 0.0] else: # Ensure PCA is cached; if not, rebuild it dynamically from the FAISS index if state.pca_model is None and index_path: faiss_idx = None if state.index is not None: faiss_idx = state.index elif os.path.exists(index_path): faiss_idx = faiss.read_index(index_path) if faiss_idx is not None: try: total_vectors = faiss_idx.ntotal sample_size = min(250, total_vectors) np.random.seed(42) sample_ids = np.random.choice(total_vectors, sample_size, replace=False) fallback_vectors = [] for orig_idx in sample_ids: try: vec = faiss_idx.reconstruct(int(orig_idx)) fallback_vectors.append(vec) except Exception: continue fallback_vectors = np.array(fallback_vectors) if len(fallback_vectors) >= 3: pca = PCA(n_components=3) pca.fit(fallback_vectors) state.pca_model = pca except Exception: pass # Transform using the PCA model if state.pca_model is not None: pca_3d = state.pca_model.transform(embedding)[0].tolist() else: pca_3d = [0.0, 0.0, 0.0] b64_img = pil_to_base64(query_img) results = [{ "file_path": "uploaded_query_image", "class_name": "Query", "image_url": b64_img, "pca_3d": pca_3d, "raw_vector": embedding[0].tolist() }] return {"samples": results} except Exception as e: return {"error": str(e)} # ========================================== # 3. UI WRAPPERS & GRADIO UI # ========================================== def refresh_examples_wrapper(): return get_example_images(10) def on_select_example(evt: gr.SelectData, gallery_data, k): if not gallery_data: return url = gallery_data[evt.index][1] return process_image(url, k) with gr.Blocks(title="DINO Image Retrieval") as demo: gr.Markdown("# 🦖 DINOv3 Image Retrieval System") with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### ⚙️ Configuration") inp_dataset = gr.Dropdown(label="Dataset", choices=["Cars196", "CUB", "StanfordOnlineProducts"], value=DEFAULT_DATASET) with gr.Row(): inp_ver = gr.Dropdown(label="Version", choices=["3"], value=DEFAULT_VERSION) inp_size = gr.Dropdown(label="Size", choices=["s", "b"], value=DEFAULT_SIZE) inp_finetune = gr.Checkbox(label="Finetuned?", value=False) inp_k = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Top-K Matches") btn_load = gr.Button("Re-Load Resources", variant="secondary") out_status = gr.Textbox(label="Status", value="Initializing...", interactive=False) gr.Markdown("### 📊 Performance Stats") out_results = gr.Markdown(value="Stats will appear here.") with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("Select Example"): btn_refresh_ex = gr.Button("🔄 Refresh Examples") ex_gallery = gr.Gallery(label="Examples", columns=5, height="auto") with gr.TabItem("Upload Image"): inp_img_upload = gr.Image(type="pil", label="Upload Query") btn_search_upload = gr.Button("🔍 Search", variant="primary") gr.Markdown("### Matches") out_gallery = gr.Gallery(label="Results", columns=5, height="auto") # --- Hidden API Endpoint Routing --- # 1. /api/get_samples api_index_path = gr.Textbox(visible=False) api_dataset_name = gr.Textbox(visible=False) api_num_samples = gr.Number(visible=False) api_samples_output = gr.JSON(visible=False) api_samples_btn = gr.Button(visible=False) api_samples_btn.click( fn=get_faiss_samples, inputs=[api_index_path, api_dataset_name, api_num_samples], outputs=[api_samples_output], api_name="get_samples" ) # 2. /api/embed api_embed_img_input = gr.Image(visible=False) api_skip_pca_input = gr.Checkbox(value=False, visible=False) # <-- New hidden input api_embed_output = gr.JSON(visible=False) api_embed_btn = gr.Button(visible=False) api_embed_btn.click( fn=embed_image_api, inputs=[api_embed_img_input, api_index_path, api_dataset_name, api_skip_pca_input], outputs=[api_embed_output], api_name="embed" ) # --- Standard UI Events --- btn_load.click(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results]).then(refresh_examples_wrapper, outputs=[ex_gallery]) btn_search_upload.click(process_image, [inp_img_upload, inp_k], [out_gallery, out_status]) btn_refresh_ex.click(refresh_examples_wrapper, outputs=[ex_gallery]) ex_gallery.select(on_select_example, [ex_gallery, inp_k], [out_gallery, out_status]) demo.load(load_resources, [inp_dataset, inp_ver, inp_size, inp_finetune], [out_status, out_results], queue=False).then(refresh_examples_wrapper, outputs=[ex_gallery]) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())