| ο»Ώ |
| import os |
| import sys |
| import glob |
| import numpy as np |
| import trimesh |
| import plotly.graph_objs as go |
| import gradio as gr |
| from plyfile import PlyData |
|
|
| |
| so_candidates = glob.glob("build/python/pyregister*.so") |
| if not so_candidates: |
| os.system("bash build.sh") |
| so_candidates = glob.glob("build/python/pyregister*.so") |
| if not so_candidates: |
| raise FileNotFoundError("pyregister build failed, no .so found in build/python") |
| sys.path.append("build/python") |
|
|
| import pyregister |
|
|
| |
| EXAMPLES = { |
| "Rigid": {"target": "./examples/data/fricp/target.ply", "source": "./examples/data/fricp/source.ply"}, |
| "NonRigid": {"target": "./examples/data/spare/target.obj", "source": "./examples/data/spare/source.obj"}, |
| } |
|
|
| FRICP_PARAS = { |
| "useinit": False, |
| "fileinit": "", |
| "maxiter": 100, |
| "stop": 1e-5 |
| } |
|
|
| SPARE_PARAS = { |
| "iters": 30, |
| "stopcoarse": 1e-3, |
| "stopfine": 1e-4, |
| "use_landmark": False, |
| "src": [], |
| "tar": [] |
| } |
|
|
| def load_example(example_type): |
| target_path = EXAMPLES[example_type]["target"] |
| source_path = EXAMPLES[example_type]["source"] |
| method = example_type |
| return source_path, target_path, method |
|
|
| def read_mesh(file_path): |
| ext = os.path.splitext(file_path)[1].lower() |
| if ext == ".ply": |
| plydata = PlyData.read(file_path) |
| vertex = plydata["vertex"] |
| vertices = np.stack([vertex["x"], vertex["y"], vertex["z"]], axis=-1) |
| faces = None |
| if "face" in plydata and len(plydata["face"].data) > 0: |
| face_list = [f[0] for f in plydata["face"].data] |
| if len(face_list) > 0: |
| faces = np.array(face_list) |
| if faces.ndim == 1: |
| faces = faces.reshape(-1, 3) |
| return vertices, faces |
| elif ext == ".obj": |
| mesh = trimesh.load(file_path, process=False) |
| return np.asarray(mesh.vertices), (np.asarray(mesh.faces) if mesh.faces.size > 0 else None) |
| else: |
| raise ValueError(f"Unsupported file type: {ext}") |
|
|
| def plot_source_target_mesh(target_file, source_file, alpha_target=1.0, alpha_source=1.0, scatter_mode=False): |
| target_v, target_f = read_mesh(target_file) |
| source_v, source_f = read_mesh(source_file) |
|
|
| fig = go.Figure() |
| lighting_opts = dict(ambient=0.5, diffuse=0.8, specular=0.6, roughness=0.25) |
| light_pos = dict(x=100, y=200, z=50) |
|
|
| target_color = "crimson" |
|
|
| if scatter_mode: |
|
|
| fig.add_trace(go.Scatter3d( |
| x=target_v[:, 0], y=target_v[:, 1], z=target_v[:, 2], |
| mode="markers", |
| |
| marker=dict(size=0.6, color=target_color, opacity=1.0), |
| name="Target Points", |
| hoverinfo="text", |
| text=[ |
| f"<b><span style='color:black;'>Target ID:</span></b> {i}" |
| f"<br><span style='font-size:12px;'>x={x:.3f}, y={y:.3f}, z={z:.3f}</span>" |
| for i, (x, y, z) in enumerate(target_v) |
| ], |
| hoverlabel=dict( |
| bgcolor="white", |
| bordercolor="black", |
| font=dict(color="black", size=13, family="Arial", weight="bold") |
| ) |
| )) |
|
|
| fig.add_trace(go.Scatter3d( |
| x=source_v[:, 0], y=source_v[:, 1], z=source_v[:, 2], |
| mode="markers", |
|
|
| marker=dict(size=0.6, color="limegreen", opacity=1.0), |
| name="Source Points", |
| hoverinfo="text", |
| text=[ |
| f"<b><span style='color:red;'>Source ID:</span></b> {i}" |
| f"<br><span style='font-size:12px;'>x={x:.3f}, y={y:.3f}, z={z:.3f}</span>" |
| for i, (x, y, z) in enumerate(source_v) |
| ], |
| hoverlabel=dict( |
| bgcolor="white", |
| bordercolor="limegreen", |
| font=dict(color="red", size=13, family="Arial", weight="bold") |
| ) |
| )) |
|
|
| else: |
|
|
| if target_f is not None: |
| fig.add_trace(go.Mesh3d( |
| x=target_v[:, 0], y=target_v[:, 1], z=target_v[:, 2], |
| i=target_f[:, 0], j=target_f[:, 1], k=target_f[:, 2], |
| color="khaki", opacity=alpha_target, name="Target Mesh", |
| lighting=lighting_opts, lightposition=light_pos, |
| hovertemplate="<b>Target Surface</b><br>x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}<extra></extra>", |
| )) |
| else: |
| fig.add_trace(go.Scatter3d( |
| x=target_v[:, 0], y=target_v[:, 1], z=target_v[:, 2], |
| mode="markers", |
| marker=dict(size=0.6, color=target_color, opacity=alpha_target), |
| name="Target Points" |
| )) |
|
|
| if source_f is not None: |
| fig.add_trace(go.Mesh3d( |
| x=source_v[:, 0], y=source_v[:, 1], z=source_v[:, 2], |
| i=source_f[:, 0], j=source_f[:, 1], k=source_f[:, 2], |
| color="darkseagreen", opacity=alpha_source, name="Source Mesh", |
| lighting=lighting_opts, lightposition=light_pos, |
| hovertemplate="<b>Source Surface</b><br>x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}<extra></extra>", |
| )) |
| else: |
| fig.add_trace(go.Scatter3d( |
| x=source_v[:, 0], y=source_v[:, 1], z=source_v[:, 2], |
| mode="markers", |
| marker=dict(size=0.6, color="limegreen", opacity=alpha_source), |
| name="Source Points" |
| )) |
|
|
| fig.update_layout( |
| height=600, width=600, margin=dict(l=0, r=0, t=40, b=0), |
| scene=dict(aspectmode="data"), showlegend=True |
| ) |
| return fig |
|
|
|
|
|
|
| def plot_result_target_mesh(target_file, result_file): |
| target_v, target_f = read_mesh(target_file) |
| result_v, result_f = read_mesh(result_file) |
|
|
| fig = go.Figure() |
| lighting_opts = dict(ambient=0.5, diffuse=0.8, specular=0.6, roughness=0.25) |
| light_pos = dict(x=100, y=200, z=50) |
|
|
| target_mesh_color = "khaki" |
| result_mesh_color = "lightblue" |
|
|
| if target_f is not None: |
| fig.add_trace(go.Mesh3d( |
| x=target_v[:, 0], y=target_v[:, 1], z=target_v[:, 2], |
| i=target_f[:, 0], j=target_f[:, 1], k=target_f[:, 2], |
| color=target_mesh_color, opacity=1.0, name="Target Mesh", |
| lighting=lighting_opts, lightposition=light_pos |
| )) |
| else: |
| fig.add_trace(go.Scatter3d( |
| x=target_v[:, 0], y=target_v[:, 1], z=target_v[:, 2], |
| mode="markers", |
| marker=dict(size=0.6, color="crimson", opacity=1.0), |
| name="Target Points (Red)" |
| )) |
|
|
| if result_f is not None: |
| fig.add_trace(go.Mesh3d( |
| x=result_v[:, 0], y=result_v[:, 1], z=result_v[:, 2], |
| i=result_f[:, 0], j=result_f[:, 1], k=result_f[:, 2], |
| color=result_mesh_color, opacity=1.0, name="Result Mesh", |
| lighting=lighting_opts, lightposition=light_pos |
| )) |
| else: |
| fig.add_trace(go.Scatter3d( |
| x=result_v[:, 0], y=result_v[:, 1], z=result_v[:, 2], |
| mode="markers", |
| marker=dict(size=0.6, color="royalblue", opacity=1.0), |
| name="Result Points (Blue)" |
| )) |
|
|
| fig.update_layout(height=600, width=600, margin=dict(l=0, r=0, t=40, b=0), |
| scene=dict(aspectmode="data"), showlegend=True) |
| return fig |
|
|
| |
| def register_and_visualize_with_zip(target_file, source_file, output_dir, method, |
| useinit=False, fileinit=None, maxiter=100, stop=1e-5, |
| iters=30, stopcoarse=1e-3, stopfine=1e-4): |
| if target_file is None or source_file is None: |
| raise gr.Error("Please upload both target and source point cloud files first!") |
|
|
| target_path = target_file.name |
| source_path = source_file.name |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| if method == "Rigid": |
| reg = pyregister.RigidFricpRegistration() |
| FRICP_PARAS["useinit"] = bool(useinit) |
| FRICP_PARAS["fileinit"] = fileinit.name if useinit and fileinit else "" |
| FRICP_PARAS["maxiter"] = int(maxiter) |
| FRICP_PARAS["stop"] = float(stop) |
|
|
| reg.Paras_init(FRICP_PARAS["useinit"], |
| FRICP_PARAS["fileinit"], |
| FRICP_PARAS["maxiter"], |
| FRICP_PARAS["stop"]) |
| output_file = "FRICP_res.ply" |
|
|
| |
| else: |
| reg = pyregister.NonrigidSpareRegistration() |
| SPARE_PARAS["iters"] = int(iters) |
| SPARE_PARAS["stopcoarse"] = float(stopcoarse) |
| SPARE_PARAS["stopfine"] = float(stopfine) |
|
|
| reg.Paras_init( |
| SPARE_PARAS["iters"], |
| SPARE_PARAS["stopcoarse"], |
| SPARE_PARAS["stopfine"], |
| SPARE_PARAS["use_landmark"], |
| SPARE_PARAS["src"], |
| SPARE_PARAS["tar"] |
| ) |
| output_file = "spare_res.ply" |
|
|
|
|
| reg.Reg(target_path, source_path, output_dir) |
|
|
| result_path = os.path.join(output_dir, output_file) |
| fig_result = plot_result_target_mesh(target_path, result_path) |
|
|
| import zipfile |
| zip_path = os.path.join(output_dir, "results.zip") |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: |
| for f in os.listdir(output_dir): |
| if f.endswith(".ply"): |
| zipf.write(os.path.join(output_dir, f), arcname=f) |
|
|
| return fig_result, zip_path |
|
|
|
|
| |
| def reuse_last_result_as_source(target_file, output_dir): |
| if target_file is None: |
| raise gr.Error("Please upload the target point cloud file first!") |
| if not os.path.exists(output_dir): |
| raise gr.Error("Output directory does not exist!") |
|
|
| result_candidates = [] |
| for fname in ["FRICP_res.ply", "spare_res.ply"]: |
| fpath = os.path.join(output_dir, fname) |
| if os.path.exists(fpath): |
| result_candidates.append(fpath) |
| if not result_candidates: |
| raise gr.Error("No previous registration result found. Please run registration first!") |
|
|
| result_candidates.sort(key=os.path.getmtime, reverse=True) |
| latest_result = result_candidates[0] |
|
|
| fig = plot_source_target_mesh(target_file.name, latest_result) |
| return latest_result, fig |
|
|
| |
| def reset_parameters(): |
| """ |
| Reset all registration parameters (Rigid + NonRigid + Landmarks) |
| and clear any manual inputs in the UI. |
| This also resets the 'Use Landmarks' checkbox and hides the landmark UI group. |
| """ |
|
|
| |
| FRICP_PARAS.update({ |
| "useinit": False, |
| "fileinit": "", |
| "maxiter": 100, |
| "stop": 1e-5 |
| }) |
|
|
| |
| SPARE_PARAS.update({ |
| "iters": 30, |
| "stopcoarse": 1e-3, |
| "stopfine": 1e-4, |
| "use_landmark": False, |
| "src": [], |
| "tar": [] |
| }) |
|
|
| print("[reset] All parameters reset.") |
| print("[reset] FRICP_PARAS:", FRICP_PARAS) |
| print("[reset] SPARE_PARAS:", SPARE_PARAS) |
|
|
|
|
| return ( |
| FRICP_PARAS["useinit"], |
| None, |
| FRICP_PARAS["maxiter"], |
| FRICP_PARAS["stop"], |
| SPARE_PARAS["iters"], |
| SPARE_PARAS["stopcoarse"], |
| SPARE_PARAS["stopfine"], |
| gr.update(value=""), |
| gr.update(value=""), |
| gr.update(value=False), |
| gr.update(visible=False) |
| ) |
|
|
|
|
| def clear_all(): |
| """ |
| Full reset for all UI elements: |
| - Clears uploaded files |
| - Resets plots |
| - Resets all parameters (Rigid + NonRigid + Landmark) |
| - Resets landmark checkbox & hides group |
| """ |
|
|
| FRICP_PARAS.update({ |
| "useinit": False, |
| "fileinit": "", |
| "maxiter": 100, |
| "stop": 1e-5 |
| }) |
| SPARE_PARAS.update({ |
| "iters": 30, |
| "stopcoarse": 1e-3, |
| "stopfine": 1e-4, |
| "use_landmark": False, |
| "src": [], |
| "tar": [] |
| }) |
|
|
| print("[clear] All files, parameters, and landmarks cleared.") |
| return ( |
| None, |
| None, |
| None, |
| None, |
| FRICP_PARAS["useinit"], |
| None, |
| FRICP_PARAS["maxiter"], |
| FRICP_PARAS["stop"], |
| SPARE_PARAS["iters"], |
| SPARE_PARAS["stopcoarse"], |
| SPARE_PARAS["stopfine"], |
| gr.update(value=""), |
| gr.update(value=""), |
| gr.update(value=False), |
| gr.update(visible=False) |
| ) |
|
|
|
|
| def visualize_and_store(target_file, source_file, scatter_mode=False): |
| if target_file is None or source_file is None: |
| return None, None, None |
| target_v, _ = read_mesh(target_file.name) |
| source_v, _ = read_mesh(source_file.name) |
| fig = plot_source_target_mesh(target_file.name, source_file.name, scatter_mode=scatter_mode) |
|
|
| return fig, source_v, target_v |
|
|
| def clear_landmarks(): |
| return [], [], "", "" |
| def highlight_landmarks_on_mesh(target_file, source_file, src_text, tar_text): |
| src_v, _ = read_mesh(source_file.name) |
| tar_v, _ = read_mesh(target_file.name) |
| src_ids = [int(i) for i in src_text.split(",") if i.strip().isdigit()] |
| tar_ids = [int(i) for i in tar_text.split(",") if i.strip().isdigit()] |
| |
|
|
| fig = plot_source_target_mesh(target_file.name, source_file.name) |
|
|
| |
| if len(src_ids) > 0: |
| pts = src_v[src_ids] |
| fig.add_trace(go.Scatter3d( |
| x=pts[:, 0], y=pts[:, 1], z=pts[:, 2], |
| mode="markers+text", |
| text=[str(i) for i in src_ids], |
| textposition="top center", |
| textfont=dict(color="black", size=10, family="Arial"), |
| marker=dict( |
| size=6, |
| color="limegreen", |
| opacity=0.9, |
| line=dict(width=2, color="white"), |
| symbol="circle" |
| ), |
| name="Source Landmarks", |
| hoverinfo="text", |
| hovertext=[ |
| f"<b>Source ID:</b> {i}<br>x={x:.3f}, y={y:.3f}, z={z:.3f}" |
| for i, (x, y, z) in zip(src_ids, pts) |
| ] |
| )) |
|
|
| |
| if len(tar_ids) > 0: |
| pts = tar_v[tar_ids] |
| fig.add_trace(go.Scatter3d( |
| x=pts[:, 0], y=pts[:, 1], z=pts[:, 2], |
| mode="markers+text", |
| text=[str(i) for i in tar_ids], |
| textposition="top center", |
| textfont=dict(color="black", size=10, family="Arial"), |
| marker=dict( |
| size=6, |
| color="crimson", |
| opacity=0.9, |
| line=dict(width=2, color="white"), |
| symbol="circle" |
| ), |
| name="Target Landmarks", |
| hoverinfo="text", |
| hovertext=[ |
| f"<b>Target ID:</b> {i}<br>x={x:.3f}, y={y:.3f}, z={z:.3f}" |
| for i, (x, y, z) in zip(tar_ids, pts) |
| ] |
| )) |
|
|
|
|
| fig.update_layout( |
| scene=dict( |
| xaxis=dict(visible=False), |
| yaxis=dict(visible=False), |
| zaxis=dict(visible=False) |
| ), |
| legend=dict(bgcolor="rgba(255,255,255,0.6)"), |
| margin=dict(l=0, r=0, t=40, b=0), |
| title="Landmark Highlight View" |
| ) |
|
|
| return fig |
|
|
| def start_landmark_selection(target_file, source_file): |
| if target_file is None or source_file is None: |
| raise gr.Error("Please upload both Source and Target first!") |
|
|
| fig = plot_source_target_mesh(target_file.name, source_file.name, scatter_mode=True) |
| return fig |
| def update_landmark_ids(src_text, tar_text): |
| src_ids = [int(i) for i in src_text.split(",") if i.strip().isdigit()] |
| tar_ids = [int(i) for i in tar_text.split(",") if i.strip().isdigit()] |
| return src_ids, tar_ids, src_text, tar_text |
|
|
|
|
| def set_landmarks_to_registration(src_ids, tar_ids): |
| if not src_ids or not tar_ids: |
| raise gr.Error("Please select both Source and Target points first!") |
| if len(src_ids) != len(tar_ids): |
| raise gr.Error(f"Landmark count mismatch: Source={len(src_ids)} vs Target={len(tar_ids)}") |
|
|
| SPARE_PARAS["use_landmark"] = True |
| SPARE_PARAS["src"] = [int(i) for i in src_ids] |
| SPARE_PARAS["tar"] = [int(i) for i in tar_ids] |
|
|
| return f"β
Landmarks staged: {len(src_ids)} Source β {len(tar_ids)} Target. Will be used in NonRigid registration." |
|
|
| |
| BLUE_VIOLET_THEME = gr.themes.Soft( |
| primary_hue="indigo", |
| secondary_hue="violet", |
| neutral_hue="neutral" |
| ) |
|
|
| SECONDARY_BUTTON_CSS = """ |
| .gr-button-secondary { |
| background: linear-gradient(135deg, #f8c6ff, #dca4ff) !important; |
| color: #4a2560 !important; |
| border: 1px solid rgba(206, 135, 255, 0.6) !important; |
| box-shadow: 0 8px 20px rgba(206, 135, 255, 0.35); |
| } |
| |
| .gr-button-secondary:hover { |
| background: linear-gradient(135deg, #f2a9ff, #c589ff) !important; |
| box-shadow: 0 12px 26px rgba(197, 137, 255, 0.45); |
| } |
| """ |
|
|
| description = """ |
| <p style='text-align: center;'> |
| <h1 style='text-align: center; font-size: 4em !important; font-weight: 900; margin: 20px 0;'>π³ Lite3DReg</h1> |
| <span style='font-size: 1.2em; color: #333;'>A lightweight registration library with modern algorithms and interactive visualization</span><br> |
| <br> |
| <a href='https://ustc3dv.github.io/Lite3DReg/' target='_blank'>π Project Page</a> | |
| <a href='https://github.com/USTC3DV/Lite3DReg' target='_blank'>π» GitHub</a> | |
| <a href='https://arxiv.org/abs/2007.07627' target='_blank'>π Rigid (Fast ICP)</a> | |
| <a href='https://arxiv.org/abs/2405.20188' target='_blank'>π Non-Rigid (SPARE)</a> |
| </p> |
| |
| <div style='background-color: #f8f9fa; padding: 18px; border-radius: 12px; border: 1px solid #e9ecef; margin: 15px 0;'> |
| <strong style='font-size: 1.1em;'>π Instructions:</strong> |
| <ul style='line-height: 1.8; margin-top: 10px;'> |
| <li><b>1. Input Data:</b> Upload your meshes to <b>π₯ Source File</b> (Moving) and <b>π― Target File</b> (Fixed). <br> |
| <small style='color: #666;'><i>Preview the initial alignment in the <b>π°οΈ Source vs Target Mesh</b> plot.</i></small></li> |
| <li><b>2. Configure:</b> Select <b>π§ Registration Method</b> (Rigid/Non-Rigid) and adjust parameters in the accordions.</li> |
| <li><b>3. Landmarks (Optional):</b> Enable <b>π Use Landmarks</b> for non-rigid tasks to provide point-to-point guidance.</li> |
| <li><b>4. Execute:</b> Click <b>π Run Registration</b>. The results will appear in <b>π Result vs Target Mesh</b>.</li> |
| </ul> |
| <p style='font-size: 0.9em; color: #555; border-top: 1px solid #ddd; padding-top: 10px; margin-top: 10px;'> |
| π‘ <b>Note:</b> Supported formats: <code>.obj</code>, <code>.ply</code>. You can download the full result directory as a <code>.zip</code> file after processing. |
| </p> |
| </div> |
| """ |
|
|
| with gr.Blocks(theme=BLUE_VIOLET_THEME, css=SECONDARY_BUTTON_CSS) as demo: |
| gr.Markdown(description) |
|
|
| src_ids_state = gr.State([]) |
| tar_ids_state = gr.State([]) |
| src_vertices_state = gr.State(None) |
| tar_vertices_state = gr.State(None) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
|
|
| source_input = gr.File(label="π₯ Source File", file_types=[".ply", ".obj"]) |
| target_input = gr.File(label="π― Target File", file_types=[".ply", ".obj"]) |
| method_dropdown = gr.Dropdown(label="π§ Registration Method", |
| choices=["Rigid", "NonRigid"], value="Rigid") |
|
|
| with gr.Accordion("π§± Rigid Parameters", open=False): |
| fricp_useinit = gr.Checkbox(label="β¨ Use initial transform?", value=False) |
| fricp_fileinit = gr.File(label="π Initial transform file", file_types=[".txt"]) |
| fricp_maxiter = gr.Number(label="π Max iterations", value=100) |
| fricp_stop = gr.Number(label="π― Stop threshold", value=1e-5) |
|
|
|
|
| with gr.Accordion("π NonRigid Parameters", open=False): |
| spare_iters = gr.Number(label="π Iterations", value=30) |
| spare_stopcoarse = gr.Number(label="π§ Stop Coarse", value=1e-3) |
| spare_stopfine = gr.Number(label="ποΈ Stop Fine", value=1e-4) |
|
|
|
|
| use_landmark_cb = gr.Checkbox(label="π Use Landmarks", value=False) |
|
|
|
|
| with gr.Group(visible=False) as landmark_group: |
| gr.Markdown("### πΈ Landmark Point Selection") |
| selected_src = gr.Textbox( |
| label="Source IDs π", |
| placeholder="e.g. 12,45,89", interactive=True) |
| selected_tar = gr.Textbox( |
| label="Target IDs π―", |
| placeholder="e.g. 7,42,105", interactive=True) |
|
|
| with gr.Row(): |
| start_landmark_btn = gr.Button("π― Start Landmark Selection", variant="primary",visible=False) |
| highlight_btn = gr.Button("β¨ Highlight Landmarks", variant="secondary") |
| clear_landmark_btn = gr.Button("π§Ό Clear Selections", variant="secondary") |
|
|
| select_button = gr.Button("β
Confirm Landmarks for Registration", variant="primary") |
|
|
|
|
| output_dir = gr.Textbox(label="πΎ Output Directory", value="./output/", placeholder="./output/") |
| with gr.Row(): |
| run_button = gr.Button("π Run Registration", variant="primary") |
| rerun_button = gr.Button("π Reregister (Use Last Result as Source)", variant="secondary") |
| clear_button = gr.Button("π§½ Clear Workspace", variant="primary") |
| reset_button = gr.Button("β»οΈ Reset Parameters", variant="secondary") |
|
|
| with gr.Row(): |
| example_dropdown = gr.Dropdown(label="β¨ Load Example Data", |
| choices=["Rigid", "NonRigid"], value="Rigid") |
| example_button = gr.Button("π² Load Example", variant="primary") |
|
|
| example_button.click(fn=load_example, |
| inputs=[example_dropdown], |
| outputs=[source_input, target_input, method_dropdown]) |
|
|
| with gr.Column(scale=2): |
| upload_plot = gr.Plot(label="π°οΈ Source vs Target Mesh") |
| result_plot = gr.Plot(label="π Result vs Target Mesh") |
| download_button = gr.File(label="β¬οΈ Download Result Directory", |
| file_types=[".zip"], interactive=False) |
|
|
| reset_button.click( |
| fn=reset_parameters, |
| inputs=None, |
| outputs=[ |
| fricp_useinit, fricp_fileinit, fricp_maxiter, fricp_stop, |
| spare_iters, spare_stopcoarse, spare_stopfine, |
| selected_src, selected_tar, |
| use_landmark_cb, landmark_group |
| ] |
| ) |
|
|
| def toggle_landmark_visibility(use_landmark): |
| SPARE_PARAS["use_landmark"] = bool(use_landmark) |
| print(f"[UI] use_landmark set to {SPARE_PARAS['use_landmark']}") |
| return gr.update(visible=use_landmark) |
|
|
| use_landmark_cb.change(fn=toggle_landmark_visibility, |
| inputs=[use_landmark_cb], |
| outputs=[landmark_group]) |
|
|
| clear_button.click( |
| fn=clear_all, |
| inputs=None, |
| outputs=[ |
| target_input, source_input, upload_plot, result_plot, |
| fricp_useinit, fricp_fileinit, fricp_maxiter, fricp_stop, |
| spare_iters, spare_stopcoarse, spare_stopfine, |
| selected_src, selected_tar, |
| use_landmark_cb, landmark_group |
| ] |
| ) |
|
|
|
|
| target_input.change(fn=visualize_and_store, |
| inputs=[target_input, source_input], |
| outputs=[upload_plot, src_vertices_state, tar_vertices_state]) |
| source_input.change(fn=visualize_and_store, |
| inputs=[target_input, source_input], |
| outputs=[upload_plot, src_vertices_state, tar_vertices_state]) |
|
|
| start_landmark_btn.click(fn=start_landmark_selection, |
| inputs=[target_input, source_input], |
| outputs=[upload_plot]) |
|
|
| highlight_btn.click(fn=highlight_landmarks_on_mesh, |
| inputs=[target_input, source_input, selected_src, selected_tar], |
| outputs=[upload_plot]) |
|
|
| clear_landmark_btn.click(fn=clear_landmarks, |
| inputs=None, |
| outputs=[src_ids_state, tar_ids_state, selected_src, selected_tar]) |
|
|
| select_button.click(fn=update_landmark_ids, |
| inputs=[selected_src, selected_tar], |
| outputs=[src_ids_state, tar_ids_state, selected_src, selected_tar] |
| ).then(fn=set_landmarks_to_registration, |
| inputs=[src_ids_state, tar_ids_state], |
| outputs=[gr.Textbox(label="Landmark Setup Info")]) |
|
|
| run_button.click(fn=register_and_visualize_with_zip, |
| inputs=[target_input, source_input, output_dir, method_dropdown, |
| fricp_useinit, fricp_fileinit, fricp_maxiter, fricp_stop, |
| spare_iters, spare_stopcoarse, spare_stopfine], |
| outputs=[result_plot, download_button]) |
|
|
| rerun_button.click(fn=reuse_last_result_as_source, |
| inputs=[target_input, output_dir], |
| outputs=[source_input, upload_plot]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|