import gradio as gr import torch import numpy as np import trimesh import plotly.graph_objects as go from torch_geometric.utils import to_dense_batch from model import HierarchicalFPSCliffordNet # --- SAF PYTORCH C++ YEDEKLERİ --- def knn_pure(x, y, k=3): dist = torch.cdist(x, y) _, topk_idx = torch.topk(dist, k, dim=1, largest=False) return torch.stack([torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1), topk_idx.reshape(-1)], dim=0) def radius_graph_pure(pos, r, max_n=16): dist = torch.cdist(pos, pos) _, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False) source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0))) mask = dist[source, target] < r return torch.stack([source[mask], target[mask]], dim=0) def fps_pure(pos, ratio=0.5): k = max(1, int(pos.size(0) * ratio)) idx = torch.zeros(k, dtype=torch.long) dist = torch.full((pos.size(0),), 1e10) farthest = 0 for i in range(k): idx[i] = farthest dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze()) farthest = torch.argmax(dist).item() return idx # --- Ayarlar --- CATEGORIES = {'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15} SEG_CLASSES = {'Airplane': list(range(0,4)), 'Chair': list(range(12,16)), 'Guitar': list(range(19,22)), 'Laptop': list(range(28,30))} # Örnek kısıtlı liste device = torch.device("cpu") model = HierarchicalFPSCliffordNet(base_channels=12).to(device) try: checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True) model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False) model.eval() except: print("Ağırlıklar bulunamadı.") def predict(file, category_name): if not file: return None # DÜZELTME: Gradio 6.0'da gr.Model3D doğrudan string (dosya yolu) döndürür. # Güvenceye almak için tip kontrolü yapıyoruz. file_path = file if isinstance(file, str) else file.name mesh = trimesh.load(file_path, force='mesh') p = np.array(mesh.vertices) p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)] pos = torch.tensor(p, dtype=torch.float32) pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8) # İşleme f1 = fps_pure(pos, 0.5); p2 = pos[f1] f2 = fps_pure(p2, 0.25); p3 = p2[f2] e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30) a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos) _, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long)) with torch.no_grad(): out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m) res = out.argmax(-1).numpy() fig = go.Figure(data=[go.Scatter3d(x=p[:,0], y=p[:,1], z=p[:,2], mode='markers', marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8))]) fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))) return fig # DÜZELTME: Tema (theme) uyarısını kaldırmak için sade blok tanımı yapıldı. with gr.Blocks() as demo: gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)") with gr.Row(): with gr.Column(): inp = gr.Model3D(label="Model Yükle (.glb, .obj)") cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair") btn = gr.Button("Tahmin Et", variant="primary") out = gr.Plot(label="Sonuç") btn.click(predict, [inp, cat], out) demo.launch()