""" Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration. Features: - Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored). - Upload .ply: Standard PLY format point clouds. - Upload .obj: OBJ format with vertices and faces (triangles). - Interactive 3D view: orbit, pan, zoom with mouse. - Optional: downsample for speed, normalize to unit cube, toggle axes, set point size. - Dual view: Input point cloud and model predictions side-by-side. - PI3DETR model integration for curve detection. - Immediate point cloud rendering on upload. """ import io import os from typing import List, Dict, Optional import gradio as gr import numpy as np import plotly.graph_objects as go from plyfile import PlyData import pandas import torch from torch_geometric.data import Data import fpsample import trimesh # NEW: for robust mesh loading & surface sampling # Import PI3DETR modules from pi3detr import ( build_model, build_model_config, load_args, load_weights, ) from pi3detr.dataset import normalize_and_scale # Global model cache PI3DETR_MODEL = None MODEL_STATUS = {"loaded": False, "message": "Model not loaded"} HOVER_FONT_SIZE = 16 # enlarged hover text size FIG_TEMPLATE = "plotly_white" # global figure template PLOT_HEIGHT = 800 # NEW: desired plot height (adjust as needed) # NEW: demo point cloud file paths (fill these with real .xyz/.ply paths) DEMO_POINTCLOUDS = { "Demo 1": "demo_inputs/demo1.xyz", "Demo 2": "demo_inputs/demo2.xyz", "Demo 3": "demo_inputs/demo3.xyz", "Demo 4": "demo_inputs/demo4.xyz", "Demo 5": "demo_inputs/demo5.xyz", } def initialize_model(checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"): """Initialize the model at startup and store it in the global cache.""" global PI3DETR_MODEL, MODEL_STATUS try: args = load_args(config_path) if config_path else {} model_config = build_model_config(args) model = build_model(model_config) load_weights(model, checkpoint_path) model.eval() PI3DETR_MODEL = model MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"} print("PI3DETR model initialized successfully") return True except Exception as e: MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"} print(f"Error initializing PI3DETR model: {e}") return False def read_xyz(file_obj: io.BytesIO) -> np.ndarray: """ Parse a .xyz text file from bytes and return Nx3 float32 array. Lines with fewer than 3 numeric values are skipped. Only the first three numeric columns are used. """ if file_obj is None: return np.zeros((0, 3), dtype=np.float32) # Read bytes → text raw = file_obj.read() try: text = raw.decode("utf-8", errors="ignore") except Exception: text = raw.decode("latin-1", errors="ignore") pts = [] for line in text.splitlines(): line = line.strip() if not line or line.startswith("#"): continue parts = line.replace(",", " ").split() nums = [] for p in parts: try: nums.append(float(p)) except ValueError: # skip non-numeric tokens pass if len(nums) == 3: break if len(nums) >= 3: pts.append(nums[:3]) if not pts: return np.zeros((0, 3), dtype=np.float32) return np.asarray(pts, dtype=np.float32) def read_ply(file_obj: io.BytesIO) -> np.ndarray: """ Parse a .ply file from bytes and return Nx3 float32 array of points. """ if file_obj is None: return np.zeros((0, 3), dtype=np.float32) try: ply_data = PlyData.read(file_obj) vertex = ply_data["vertex"] x = np.asarray(vertex["x"]) y = np.asarray(vertex["y"]) z = np.asarray(vertex["z"]) points = np.column_stack([x, y, z]).astype(np.float32) return points except Exception as e: print(f"Error reading PLY file: {e}") return np.zeros((0, 3), dtype=np.float32) def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int): """Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface.""" raw = file_obj.read() # Rewind not strictly needed after read since we don't reuse file_obj try: mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh") except Exception as e: print(f"trimesh load error: {e}") return ( np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32), "OBJ load failure", ) # Handle scenes by merging if isinstance(mesh, trimesh.Scene): mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values())) if mesh.is_empty or mesh.vertices.shape[0] == 0: return ( np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32), "OBJ: empty mesh", ) sample_n = min(display_max_points, max(1, display_max_points)) try: sampled = mesh.sample(sample_n) except Exception as e: print(f"Sampling error: {e}") sampled = mesh.vertices if sampled.shape[0] > sample_n: sampled = sampled[:sample_n] sampled = np.asarray(sampled, dtype=np.float32) info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts" model_pts = sampled.copy() return model_pts, sampled, info def downsample(pts: np.ndarray, max_points: int) -> np.ndarray: if pts.shape[0] <= max_points: return pts rng = np.random.default_rng(42) idx = rng.choice(pts.shape[0], size=max_points, replace=False) return pts[idx] def make_figure( pts: np.ndarray, point_size: int = 2, show_axes: bool = True, title: str = "", polylines: Optional[List[Dict]] = None, ) -> go.Figure: """ Build a Plotly 3D scatter figure with equal aspect ratio. Optionally includes polylines from model predictions. """ if pts.size == 0 and (polylines is None or len(polylines) == 0): fig = go.Figure() fig.update_layout( title="No data to display", template=FIG_TEMPLATE, scene=dict( xaxis_visible=False, yaxis_visible=False, zaxis_visible=False, ), margin=dict(l=0, r=0, t=40, b=0), ) return fig fig = go.Figure() # Add point cloud if available if pts.size > 0: x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] fig.add_trace( go.Scatter3d( x=x, y=y, z=z, mode="markers", marker=dict( size=max(1, int(point_size)), color="darkgray", opacity=0.2 ), hoverinfo="skip", name="Curves", showlegend=False, # legend hidden ) ) # Define colors for each curve type curve_colors = { "Line": "blue", "Circle": "green", "Arc": "red", "BSpline": "purple", } # Add polylines from model predictions if available if polylines: for curve in polylines: points = np.array(curve["points"]) if len(points) < 2: continue curve_type = curve["type"] curve_id = curve["id"] score = curve["score"] # NEW: allow override color if provided (e.g., threshold filtered) color = curve.get("display_color") or curve_colors.get(curve_type, "orange") # NEW: support hidden-by-default via legendonly fig.add_trace( go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode="lines", line=dict(color=color, width=8), # CHANGED: increased from 5 to 8 name=f"{curve_type} #{curve_id} ({score:.2f})", visible=curve.get("visible_state", True), hoverinfo="text", text=f"{curve_type} #{curve_id} ({score:.4f})", showlegend=False, # hide individual curve legend entries ) ) # Equal aspect ratio using data ranges if pts.size > 0: mins = pts.min(axis=0) maxs = pts.max(axis=0) elif polylines and len(polylines) > 0: # If we only have polylines, calculate range from them all_points = np.vstack([np.array(curve["points"]) for curve in polylines]) mins = all_points.min(axis=0) maxs = all_points.max(axis=0) else: mins = np.array([-1, -1, -1]) maxs = np.array([1, 1, 1]) centers = (mins + maxs) / 2.0 span = (maxs - mins).max() if span <= 0: span = 1.0 half = span / 2.0 xrange = [centers[0] - half, centers[0] + half] yrange = [centers[1] - half, centers[1] + half] zrange = [centers[2] - half, centers[2] + half] scene_axes = dict( xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""), yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""), zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""), aspectmode="cube", ) fig.update_layout( title=title, template=FIG_TEMPLATE, showlegend=False, scene=scene_axes, margin=dict(l=0, r=0, t=40, b=0), hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)), height=PLOT_HEIGHT, # NEW ) return fig def process_model_predictions(data: Data) -> list: """ Process model outputs into a format suitable for visualization. """ class_names = ["None", "BSpline", "Line", "Circle", "Arc"] polylines = data.polylines.cpu().numpy() curves = [] # Process detected polylines for i, polyline in enumerate(polylines): cls = data.polyline_class[i].item() score = data.polyline_score[i].item() cls_name = class_names[cls] # Skip low-confidence or "None" class predictions if cls == 0: continue # Add curve data to results with unique ID curve_data = { "type": cls_name, "id": i + 1, # 1-based ID for better user experience "index": i, "score": score, "points": polyline, } curves.append(curve_data) return curves def process_data_for_model( points: np.ndarray, sample: int = 32768, sample_mode: str = "fps", ) -> Data: # CHANGED: removed reduction param """ Process and subsample point cloud data using the same approach as predict_pi3detr.py. Args: points: Input point cloud as numpy array sample: Number of points to sample sample_mode: Sampling method ("fps", "random", "uniform", "all") Returns: Data object ready for model inference """ # Convert to torch tensor pos = torch.tensor(points, dtype=torch.float32) # Apply sampling strategy if sample_mode == "random": if pos.size(0) > sample: indices = torch.randperm(pos.size(0))[:sample] pos = pos[indices] elif sample_mode == "fps": if pos.size(0) > sample: indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6) pos = pos[indices] elif sample_mode == "uniform": if pos.size(0) > sample: step = max(1, pos.size(0) // sample) pos = pos[::step][:sample] elif sample_mode == "all": pass # Keep all points # Create Data object data = Data(pos=pos) # Add batch information for single point cloud BEFORE normalization data.batch = torch.zeros(data.pos.size(0), dtype=torch.long) data.batch_size = 1 # Normalize and scale using PI3DETR's method data = normalize_and_scale(data) # Ensure scale and center are proper batch tensors if hasattr(data, "scale") and data.scale.dim() == 0: data.scale = data.scale.unsqueeze(0) if hasattr(data, "center") and data.center.dim() == 1: data.center = data.center.unsqueeze(0) return data @torch.no_grad() def run_model_inference( model, points: np.ndarray, max_points: int = 32768, sample_mode: str = "fps", num_queries: int = 256, ) -> list: """Run model inference on the given point cloud.""" global PI3DETR_MODEL if model is None: model = PI3DETR_MODEL if model is None: return [] try: data = process_data_for_model( points, sample=max_points, sample_mode=sample_mode ) device = next(model.parameters()).device data = data.to(device) if model.num_preds != num_queries: model.set_num_preds(num_queries) output = model.predict_step( data, reverse_norm=True, thresholds=None, ) result = output[0] curves = process_model_predictions(result) return curves except Exception as e: print(f"Error in model inference: {e}") return [] def load_and_process_pointcloud( file: gr.File, max_points: int, point_size: int, show_axes: bool, ): """ Load and process a point cloud from .xyz or .ply file """ if file is None: empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, None, None, os.path.basename(file.name) if file else "" # Determine file type and read accordingly file_ext = os.path.splitext(file.name)[1].lower() # Read file based on extension with open(file.name, "rb") as f: if file_ext == ".xyz": pts = read_xyz(f) mode = "XYZ" elif file_ext == ".ply": pts = read_ply(f) mode = "PLY" elif file_ext == ".obj": model_pts, display_pts, _ = read_obj_and_sample(f, max_points) fig = make_figure( display_pts, point_size=point_size, show_axes=show_axes, title=f"{os.path.basename(file.name)}", ) return fig, model_pts, display_pts, os.path.basename(file.name) else: empty_fig = make_figure(np.zeros((0, 3))) return ( empty_fig, None, None, "Unsupported file type. Please use .xyz, .ply or .obj.", "", ) original_n = pts.shape[0] # Keep original points for model if normalizing for display model_pts = pts.copy() pts = downsample(pts, max_points=max_points) displayed_n = pts.shape[0] fig = make_figure( pts, point_size=point_size, show_axes=show_axes, title=f"{os.path.basename(file.name)}", ) info = f"Loaded ({mode}): {original_n} points" # | Displayed: {displayed_n} points" # RETURN single figure + model/full points + displayed subset return fig, model_pts, pts, os.path.basename(file.name) # ADDED filename def run_model_prediction( model_pts: np.ndarray, point_size: int, show_axes: bool, model_max_points: int, sample_mode: str, th_bspline: float, th_line: float, th_circle: float, th_arc: float, num_queries: int = 256, ): # NOTE: display points now handled outside; keep signature (called before adding display pts state) # (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version) return run_model_prediction_unified( # type: ignore model_pts, None, point_size, show_axes, model_max_points, sample_mode, th_bspline, th_line, th_circle, th_arc, "", num_queries, ) def run_model_prediction_unified( model_pts: np.ndarray, display_pts: Optional[np.ndarray], point_size: int, show_axes: bool, model_max_points: int, sample_mode: str, th_bspline: float, th_line: float, th_circle: float, th_arc: float, file_name: str = "", num_queries: int = 256, ): """ Run model inference and apply initial threshold-based coloring. """ global PI3DETR_MODEL, MODEL_STATUS if model_pts is None: empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, [] # Run model inference using cached model curves = [] try: if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]: # Try to initialize model if not already loaded initialize_model() if PI3DETR_MODEL is not None: # Run inference with the same settings as predict_pi3detr.py curves = run_model_inference( PI3DETR_MODEL, model_pts, max_points=model_max_points, sample_mode=sample_mode, num_queries=num_queries, ) except Exception: pass # NEW: apply thresholds for display (store raw curves separately) thresholds = { "BSpline": th_bspline, "Line": th_line, "Circle": th_circle, "Arc": th_arc, } colored_curves = [] for c in curves: c_disp = dict(c) if c["score"] < thresholds.get(c["type"], 0.7): c_disp["visible_state"] = "legendonly" colored_curves.append(c_disp) # Use existing displayed subset if provided; else derive lightweight subset if display_pts is None: display_pts = downsample(model_pts, max_points=100000) title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)" fig = make_figure( display_pts, point_size=point_size, show_axes=show_axes, title=title, polylines=colored_curves, ) return fig, curves def apply_pointcloud_display_settings( model_pts: np.ndarray, curves: List[Dict], max_points: int, point_size: int, show_axes: bool, th_bspline: float, th_line: float, th_circle: float, th_arc: float, file_name: str, ): """ Apply point cloud display settings without re-running inference. Keeps existing detections and re-applies thresholds. """ if model_pts is None: empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, None display_pts = downsample(model_pts, max_points=max_points) if not curves: fig = make_figure( display_pts, point_size=point_size, show_axes=show_axes, title=file_name or "Point Cloud", ) return fig, display_pts thresholds = { "BSpline": th_bspline, "Line": th_line, "Circle": th_circle, "Arc": th_arc, } colored_curves = [] for c in curves: c_disp = dict(c) if c["score"] < thresholds.get(c["type"], 0.7): c_disp["visible_state"] = "legendonly" colored_curves.append(c_disp) fig = make_figure( display_pts, point_size=point_size, show_axes=show_axes, title=(file_name or "Point Cloud") + " (curves)", polylines=colored_curves, ) return fig, display_pts def clear_curves( curves: List[Dict], display_pts: Optional[np.ndarray], model_pts: Optional[np.ndarray], point_size: int, show_axes: bool, file_name: str, ): """ Recolor already inferred curves based on updated thresholds (no re-inference). """ if curves is None or model_pts is None or len(curves) == 0: empty_fig = make_figure( display_pts if display_pts is not None else np.zeros((0, 3)) ) return empty_fig, None fig = make_figure( display_pts if display_pts is not None else np.zeros((0, 3)), point_size=point_size, show_axes=show_axes, title=file_name or "Point Cloud", polylines=None, ) return fig, None def load_demo_pointcloud( label: str, max_points: int, point_size: int, show_axes: bool, ): """ Load one of the predefined demo point clouds. Clears existing detected curves (curves_state -> None). Also returns a value for the file upload component so the filename shows up. """ path = DEMO_POINTCLOUDS.get(label, "") if not path or not os.path.isfile(path): empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, None, None, "", None, None ext = os.path.splitext(path)[1].lower() try: with open(path, "rb") as f: if ext == ".xyz": pts = read_xyz(f) elif ext == ".ply": pts = read_ply(f) elif ext == ".obj": model_pts, display_pts, _ = read_obj_and_sample( f, min(20000, max_points) ) fig = make_figure( display_pts, point_size=1, show_axes=show_axes, title=f"{os.path.basename(path)} (demo)", ) return fig, model_pts, display_pts, os.path.basename(path), None, path else: empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, None, None, "", None, None except Exception: empty_fig = make_figure(np.zeros((0, 3))) return empty_fig, None, None, "", None, None model_pts = pts.copy() pts = downsample(pts, max_points=max_points) fig = make_figure( pts, point_size=1, show_axes=show_axes, title=f"{os.path.basename(path)} (demo)", ) return fig, model_pts, pts, os.path.basename(path), None, path # Convenience wrappers for each demo (avoid lambdas for clarity) def load_demo1(max_points, point_size, show_axes): return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes) def load_demo2(max_points, point_size, show_axes): return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes) def load_demo3(max_points, point_size, show_axes): return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes) def load_demo4(max_points, point_size, show_axes): # NEW return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes) def load_demo5(max_points, point_size, show_axes): # NEW return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes) def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure: """Create a small preview figure for a demo point cloud (no curves).""" path = DEMO_POINTCLOUDS.get(label, "") if not path or not os.path.isfile(path): return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)") try: ext = os.path.splitext(path)[1].lower() with open(path, "rb") as f: if ext == ".xyz": pts = read_xyz(f) elif ext == ".ply": pts = read_ply(f) elif ext == ".obj": # UPDATED _, pts, _ = read_obj_and_sample(f, max_pts) else: return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)") pts = downsample(pts, max_pts) return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview") except Exception as e: return make_figure(np.zeros((0, 3)), title=f"{label}: error") def run_model_with_display( model_pts: np.ndarray, max_points: int, point_size: int, show_axes: bool, model_max_points: int, sample_mode: str, th_bspline: float, th_line: float, th_circle: float, th_arc: float, file_name: str = "", num_queries: int = 256, ): """ Run inference (if model_pts present) then immediately apply current display (max_points/point_size/show_axes) and thresholds. Returns: figure, info_text, curves(list), display_pts """ if model_pts is None: empty = make_figure(np.zeros((0, 3))) return empty, None, None # Inference first (no display subset passed so it builds from model_pts) fig_infer, curves = run_model_prediction_unified( model_pts, None, point_size, show_axes, model_max_points, sample_mode, th_bspline, th_line, th_circle, th_arc, file_name, num_queries, ) # Now apply current display settings & thresholds without re-inference fig_final, display_pts = apply_pointcloud_display_settings( model_pts, curves, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name, ) return fig_final, curves, display_pts with gr.Blocks(title="PI3DETR") as demo: gr.Markdown( """ # 🥧 PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW] A novel end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.
""" ) with gr.Row(): with gr.Column(): gr.Markdown( "### 🧩 Supported Inputs\n" "- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n" "- `Mesh` is surface-sampled using **Max Points (display)** slider." ) with gr.Column(): gr.Markdown( "### ⚙️ Point Cloud Settings\n" "- Adjust **Max Points**, **point size**, and **axes visibility**.\n" "- Controls visualization of point cloud." ) with gr.Column(): gr.Markdown( "### 🎯 Confidence Thresholds\n" "- Hover to inspect scores.\n" "- Filter curves by **class confidence** interactively" ) with gr.Row(): with gr.Column(): gr.Markdown( "### 🧠 Model Settings\n" "- **Sampling Mode:** Choose downsampling strategy.\n" "- **Model Input Size:** Number of model input points.\n" "- **Queries:** Transformer decoder queries (max. output curves)." ) with gr.Column(): gr.Markdown( "### ⚡ Performance Notes\n" "- Trained on **human-made objects**.\n" "- Optimized for **GPU**; this demo runs on **CPU**.\n" "- For **full qualitative performance**: \n" "[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)" ) with gr.Column(): gr.Markdown( "### ▶️ Run Inference\n" "- Click on demo point clouds (from test set) below.\n" "- Press **Run PI3DETR** to execute inference and visualize results." ) model_pts_state = gr.State(None) display_pts_state = gr.State(None) curves_state = gr.State(None) file_name_state = gr.State("demo_inputs/demo2.xyz") with gr.Row(): file_in = gr.File( label="Upload Point Cloud (auto-renders)", file_types=[".xyz", ".ply", ".obj"], type="filepath", ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Point Cloud Settings") max_points = gr.Slider( 0, 500_000, value=200_000, step=1_000, label="Max points (display)", ) point_size = gr.Slider(1, 8, value=1, step=1, label="Point size") show_axes = gr.Checkbox(value=False, label="Show axes") gr.Markdown("### Model Settings") sample_mode = gr.Radio( ["fps", "random", "all"], value="fps", label="Main Sampling Method", ) model_max_points = gr.Slider( 1_000, 100_000, value=32768, step=500, label="Downsample to Model Input Size", ) num_queries = gr.Slider( # NEW 32, 512, value=128, step=1, label="Number of Queries", ) # Threshold sliders (no auto-change triggers) gr.Markdown("#### Confidence Thresholds (per class)") th_bspline = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="BSpline ≥") th_line = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Line ≥") th_circle = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Circle ≥") th_arc = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Arc ≥") with gr.Column(scale=1): main_plot = gr.Plot( label="Point Cloud & Curves" ) # height from fig.update_layout(PLOT_HEIGHT) run_model_button = gr.Button("Run PI3DETR", variant="primary") clear_curves_button = gr.Button("Clear Curves", variant="secondary") # Auto-render point cloud when file is uploaded file_in.change( load_and_process_pointcloud, inputs=[file_in, max_points, point_size, show_axes], outputs=[ main_plot, model_pts_state, display_pts_state, file_name_state, ], ) run_model_button.click( run_model_with_display, inputs=[ model_pts_state, max_points, point_size, show_axes, model_max_points, sample_mode, th_bspline, th_line, th_circle, th_arc, file_name_state, num_queries, ], outputs=[main_plot, curves_state, display_pts_state], ) # NEW: auto-apply display & thresholds on interaction (no inference) def _apply_display_wrapper( model_pts, curves, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name, display_pts_state_value, ): fig, display_pts = apply_pointcloud_display_settings( model_pts, curves, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name, ) return fig, display_pts # Point cloud sliders (release) & checkbox (change) for slider in [max_points, point_size]: slider.release( _apply_display_wrapper, inputs=[ model_pts_state, curves_state, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name_state, display_pts_state, ], outputs=[main_plot, display_pts_state], ) show_axes.change( _apply_display_wrapper, inputs=[ model_pts_state, curves_state, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name_state, display_pts_state, ], outputs=[main_plot, display_pts_state], ) # Threshold sliders (apply on release) for th in [th_bspline, th_line, th_circle, th_arc]: th.release( _apply_display_wrapper, inputs=[ model_pts_state, curves_state, max_points, point_size, show_axes, th_bspline, th_line, th_circle, th_arc, file_name_state, display_pts_state, ], outputs=[main_plot, display_pts_state], ) clear_curves_button.click( clear_curves, inputs=[ curves_state, display_pts_state, model_pts_state, point_size, show_axes, file_name_state, ], outputs=[main_plot, curves_state], ) # REPLACED demo preview plots + buttons WITH clickable images with gr.Row(): gr.Markdown("### Demo Point Clouds (click an image to load)") with gr.Row(): # CLEANUP: generate images dynamically for all demos demo_image_components = {} for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: # UPDATED png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png" demo_image_components[label] = gr.Image( value=png_path if os.path.isfile(png_path) else None, label=label, interactive=False, ) # CLEANUP: map labels to loader functions & attach select handlers _demo_loaders = { "Demo 1": load_demo1, "Demo 2": load_demo2, "Demo 3": load_demo3, "Demo 4": load_demo4, "Demo 5": load_demo5, # NEW } for label, comp in demo_image_components.items(): comp.select( _demo_loaders[label], inputs=[max_points, point_size, show_axes], outputs=[ main_plot, model_pts_state, display_pts_state, file_name_state, curves_state, file_in, ], ) # NEW: auto-load Demo 2 on app start demo.load( load_demo2, inputs=[max_points, point_size, show_axes], outputs=[ main_plot, model_pts_state, display_pts_state, file_name_state, curves_state, file_in, ], ) if __name__ == "__main__": # Initialize model at startup initialize_model() demo.launch()