fafraob commited on
Commit
36effdc
·
0 Parent(s):

initial test commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ lightning_logs*
3
+ logs_*
4
+ .shapenet
5
+ .vscode
6
+ *.ipynb
7
+ *.ini
8
+ scans
9
+ *.venv
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV PIP_ROOT_USER_ACTION=ignore
5
+
6
+ RUN apt-get update -qq && \
7
+ apt-get install -y zip git git-lfs vim libgtk2.0-dev ffmpeg libsm6 libxext6 && \
8
+ rm -rf /var/cache/apk/*
9
+
10
+ COPY requirements.txt /workspace
11
+
12
+ # Activate conda environment and install packages
13
+ RUN conda init bash && \
14
+ echo "conda activate base" >> ~/.bashrc
15
+
16
+ SHELL ["conda", "run", "-n", "base", "/bin/bash", "-c"]
17
+
18
+ RUN pip --no-cache-dir install -r /workspace/requirements.txt
19
+
20
+ ARG USERNAME=user
21
+ ARG USER_UID=1000
22
+ ARG USER_GID=$USER_UID
23
+
24
+ # Create the user
25
+ RUN groupadd --gid $USER_GID $USERNAME \
26
+ && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME -s /bin/bash \
27
+ #
28
+ # [Optional] Add sudo support. Omit if you don't need to install software after connecting.
29
+ && apt-get update \
30
+ && apt-get install -y sudo \
31
+ && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
32
+ && chmod 0440 /etc/sudoers.d/$USERNAME
33
+
34
+ WORKDIR /workspaces/pi3detr
LICENSE ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PolyForm Noncommercial License 1.0.0
2
+
3
+ <https://polyformproject.org/licenses/noncommercial/1.0.0>
4
+
5
+ ## Acceptance
6
+
7
+ In order to get any license under these terms, you must agree
8
+ to them as both strict obligations and conditions to all
9
+ your licenses.
10
+
11
+ ## Copyright License
12
+
13
+ The licensor grants you a copyright license for the
14
+ software to do everything you might do with the software
15
+ that would otherwise infringe the licensor's copyright
16
+ in it for any permitted purpose. However, you may
17
+ only distribute the software according to [Distribution
18
+ License](#distribution-license) and make changes or new works
19
+ based on the software according to [Changes and New Works
20
+ License](#changes-and-new-works-license).
21
+
22
+ ## Distribution License
23
+
24
+ The licensor grants you an additional copyright license
25
+ to distribute copies of the software. Your license
26
+ to distribute covers distributing the software with
27
+ changes and new works permitted by [Changes and New Works
28
+ License](#changes-and-new-works-license).
29
+
30
+ ## Notices
31
+
32
+ You must ensure that anyone who gets a copy of any part of
33
+ the software from you also gets a copy of these terms or the
34
+ URL for them above, as well as copies of any plain-text lines
35
+ beginning with `Required Notice:` that the licensor provided
36
+ with the software. For example:
37
+
38
+ > Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
39
+
40
+ ## Changes and New Works License
41
+
42
+ The licensor grants you an additional copyright license to
43
+ make changes and new works based on the software for any
44
+ permitted purpose.
45
+
46
+ ## Patent License
47
+
48
+ The licensor grants you a patent license for the software that
49
+ covers patent claims the licensor can license, or becomes able
50
+ to license, that you would infringe by using the software.
51
+
52
+ ## Noncommercial Purposes
53
+
54
+ Any noncommercial purpose is a permitted purpose.
55
+
56
+ ## Personal Uses
57
+
58
+ Personal use for research, experiment, and testing for
59
+ the benefit of public knowledge, personal study, private
60
+ entertainment, hobby projects, amateur pursuits, or religious
61
+ observance, without any anticipated commercial application,
62
+ is use for a permitted purpose.
63
+
64
+ ## Noncommercial Organizations
65
+
66
+ Use by any charitable organization, educational institution,
67
+ public research organization, public safety or health
68
+ organization, environmental protection organization,
69
+ or government institution is use for a permitted purpose
70
+ regardless of the source of funding or obligations resulting
71
+ from the funding.
72
+
73
+ ## Fair Use
74
+
75
+ You may have "fair use" rights for the software under the
76
+ law. These terms do not limit them.
77
+
78
+ ## No Other Rights
79
+
80
+ These terms do not allow you to sublicense or transfer any of
81
+ your licenses to anyone else, or prevent the licensor from
82
+ granting licenses to anyone else. These terms do not imply
83
+ any other licenses.
84
+
85
+ ## Patent Defense
86
+
87
+ If you make any written claim that the software infringes or
88
+ contributes to infringement of any patent, your patent license
89
+ for the software granted under these terms ends immediately. If
90
+ your company makes such a claim, your patent license ends
91
+ immediately for work on behalf of your company.
92
+
93
+ ## Violations
94
+
95
+ The first time you are notified in writing that you have
96
+ violated any of these terms, or done anything with the software
97
+ not covered by your licenses, your licenses can nonetheless
98
+ continue if you come into full compliance with these terms,
99
+ and take practical steps to correct past violations, within
100
+ 32 days of receiving notice. Otherwise, all your licenses
101
+ end immediately.
102
+
103
+ ## No Liability
104
+
105
+ ***As far as the law allows, the software comes as is, without
106
+ any warranty or condition, and the licensor will not be liable
107
+ to you for any damages arising out of these terms or the use
108
+ or nature of the software, under any kind of legal claim.***
109
+
110
+ ## Definitions
111
+
112
+ The **licensor** is the individual or entity offering these
113
+ terms, and the **software** is the software the licensor makes
114
+ available under these terms.
115
+
116
+ **You** refers to the individual or entity agreeing to these
117
+ terms.
118
+
119
+ **Your company** is any legal entity, sole proprietorship,
120
+ or other kind of organization that you work for, plus all
121
+ organizations that have control over, are under the control of,
122
+ or are under common control with that organization. **Control**
123
+ means ownership of substantially all the assets of an entity,
124
+ or the power to direct its management and policies by vote,
125
+ contract, or otherwise. Control can be direct or indirect.
126
+
127
+ **Your licenses** are all the licenses granted to you for the
128
+ software under these terms.
129
+
130
+ **Use** means anything you do with the software requiring one
131
+ of your licenses.
app.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration.
3
+
4
+ Features:
5
+ - Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored).
6
+ - Upload .ply: Standard PLY format point clouds.
7
+ - Upload .obj: OBJ format with vertices and faces (triangles).
8
+ - Interactive 3D view: orbit, pan, zoom with mouse.
9
+ - Optional: downsample for speed, normalize to unit cube, toggle axes, set point size.
10
+ - Dual view: Input point cloud and model predictions side-by-side.
11
+ - PI3DETR model integration for curve detection.
12
+ - Immediate point cloud rendering on upload.
13
+ """
14
+
15
+ import io
16
+ import os
17
+ from typing import List, Dict, Optional
18
+
19
+ import gradio as gr
20
+ import numpy as np
21
+ import plotly.graph_objects as go
22
+ from plyfile import PlyData
23
+ import pandas
24
+ import torch
25
+ from torch_geometric.data import Data
26
+ import fpsample
27
+ import trimesh # NEW: for robust mesh loading & surface sampling
28
+
29
+ # Import PI3DETR modules
30
+ from pi3detr import (
31
+ build_model,
32
+ build_model_config,
33
+ load_args,
34
+ load_weights,
35
+ )
36
+ from pi3detr.dataset import normalize_and_scale
37
+
38
+ # Global model cache
39
+ PI3DETR_MODEL = None
40
+ MODEL_STATUS = {"loaded": False, "message": "Model not loaded"}
41
+
42
+ HOVER_FONT_SIZE = 16 # enlarged hover text size
43
+ FIG_TEMPLATE = "plotly_white" # global figure template
44
+ PLOT_HEIGHT = 800 # NEW: desired plot height (adjust as needed)
45
+
46
+ # NEW: demo point cloud file paths (fill these with real .xyz/.ply paths)
47
+ DEMO_POINTCLOUDS = {
48
+ "Demo 1": "demo_inputs/demo1.xyz",
49
+ "Demo 2": "demo_inputs/demo2.xyz",
50
+ "Demo 3": "demo_inputs/demo3.xyz",
51
+ "Demo 4": "demo_inputs/demo4.xyz",
52
+ "Demo 5": "demo_inputs/demo5.xyz",
53
+ }
54
+
55
+
56
+ def initialize_model(
57
+ checkpoint_path="checkpoint.ckpt", config_path="configs/pi3detr.yaml"
58
+ ):
59
+ """Initialize the model at startup and store it in the global cache."""
60
+ global PI3DETR_MODEL, MODEL_STATUS
61
+ try:
62
+ args = load_args(config_path) if config_path else {}
63
+ model_config = build_model_config(args)
64
+ model = build_model(model_config)
65
+ load_weights(model, checkpoint_path)
66
+ model.eval()
67
+
68
+ PI3DETR_MODEL = model
69
+ MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"}
70
+ print("PI3DETR model initialized successfully")
71
+ return True
72
+ except Exception as e:
73
+ MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"}
74
+ print(f"Error initializing PI3DETR model: {e}")
75
+ return False
76
+
77
+
78
+ def read_xyz(file_obj: io.BytesIO) -> np.ndarray:
79
+ """
80
+ Parse a .xyz text file from bytes and return Nx3 float32 array.
81
+ Lines with fewer than 3 numeric values are skipped.
82
+ Only the first three numeric columns are used.
83
+ """
84
+ if file_obj is None:
85
+ return np.zeros((0, 3), dtype=np.float32)
86
+
87
+ # Read bytes → text
88
+ raw = file_obj.read()
89
+ try:
90
+ text = raw.decode("utf-8", errors="ignore")
91
+ except Exception:
92
+ text = raw.decode("latin-1", errors="ignore")
93
+
94
+ pts = []
95
+ for line in text.splitlines():
96
+ line = line.strip()
97
+ if not line or line.startswith("#"):
98
+ continue
99
+ parts = line.replace(",", " ").split()
100
+ nums = []
101
+ for p in parts:
102
+ try:
103
+ nums.append(float(p))
104
+ except ValueError:
105
+ # skip non-numeric tokens
106
+ pass
107
+ if len(nums) == 3:
108
+ break
109
+ if len(nums) >= 3:
110
+ pts.append(nums[:3])
111
+
112
+ if not pts:
113
+ return np.zeros((0, 3), dtype=np.float32)
114
+
115
+ return np.asarray(pts, dtype=np.float32)
116
+
117
+
118
+ def read_ply(file_obj: io.BytesIO) -> np.ndarray:
119
+ """
120
+ Parse a .ply file from bytes and return Nx3 float32 array of points.
121
+ """
122
+ if file_obj is None:
123
+ return np.zeros((0, 3), dtype=np.float32)
124
+
125
+ try:
126
+ ply_data = PlyData.read(file_obj)
127
+ vertex = ply_data["vertex"]
128
+
129
+ x = np.asarray(vertex["x"])
130
+ y = np.asarray(vertex["y"])
131
+ z = np.asarray(vertex["z"])
132
+
133
+ points = np.column_stack([x, y, z]).astype(np.float32)
134
+ return points
135
+ except Exception as e:
136
+ print(f"Error reading PLY file: {e}")
137
+ return np.zeros((0, 3), dtype=np.float32)
138
+
139
+
140
+ def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int):
141
+ """Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface."""
142
+ raw = file_obj.read()
143
+ # Rewind not strictly needed after read since we don't reuse file_obj
144
+ try:
145
+ mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh")
146
+ except Exception as e:
147
+ print(f"trimesh load error: {e}")
148
+ return (
149
+ np.zeros((0, 3), dtype=np.float32),
150
+ np.zeros((0, 3), dtype=np.float32),
151
+ "OBJ load failure",
152
+ )
153
+ # Handle scenes by merging
154
+ if isinstance(mesh, trimesh.Scene):
155
+ mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values()))
156
+ if mesh.is_empty or mesh.vertices.shape[0] == 0:
157
+ return (
158
+ np.zeros((0, 3), dtype=np.float32),
159
+ np.zeros((0, 3), dtype=np.float32),
160
+ "OBJ: empty mesh",
161
+ )
162
+ sample_n = min(display_max_points, max(1, display_max_points))
163
+ try:
164
+ sampled = mesh.sample(sample_n)
165
+ except Exception as e:
166
+ print(f"Sampling error: {e}")
167
+ sampled = mesh.vertices
168
+ if sampled.shape[0] > sample_n:
169
+ sampled = sampled[:sample_n]
170
+ sampled = np.asarray(sampled, dtype=np.float32)
171
+ info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts"
172
+ model_pts = sampled.copy()
173
+ return model_pts, sampled, info
174
+
175
+
176
+ def downsample(pts: np.ndarray, max_points: int) -> np.ndarray:
177
+ if pts.shape[0] <= max_points:
178
+ return pts
179
+ rng = np.random.default_rng(42)
180
+ idx = rng.choice(pts.shape[0], size=max_points, replace=False)
181
+ return pts[idx]
182
+
183
+
184
+ def make_figure(
185
+ pts: np.ndarray,
186
+ point_size: int = 2,
187
+ show_axes: bool = True,
188
+ title: str = "",
189
+ polylines: Optional[List[Dict]] = None,
190
+ ) -> go.Figure:
191
+ """
192
+ Build a Plotly 3D scatter figure with equal aspect ratio.
193
+ Optionally includes polylines from model predictions.
194
+ """
195
+ if pts.size == 0 and (polylines is None or len(polylines) == 0):
196
+ fig = go.Figure()
197
+ fig.update_layout(
198
+ title="No data to display",
199
+ template=FIG_TEMPLATE,
200
+ scene=dict(
201
+ xaxis_visible=False,
202
+ yaxis_visible=False,
203
+ zaxis_visible=False,
204
+ ),
205
+ margin=dict(l=0, r=0, t=40, b=0),
206
+ )
207
+ return fig
208
+
209
+ fig = go.Figure()
210
+
211
+ # Add point cloud if available
212
+ if pts.size > 0:
213
+ x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]
214
+ fig.add_trace(
215
+ go.Scatter3d(
216
+ x=x,
217
+ y=y,
218
+ z=z,
219
+ mode="markers",
220
+ marker=dict(
221
+ size=max(1, int(point_size)), color="darkgray", opacity=0.2
222
+ ),
223
+ hoverinfo="skip",
224
+ name="Curves",
225
+ showlegend=False, # legend hidden
226
+ )
227
+ )
228
+
229
+ # Define colors for each curve type
230
+ curve_colors = {
231
+ "Line": "blue",
232
+ "Circle": "green",
233
+ "Arc": "red",
234
+ "BSpline": "purple",
235
+ }
236
+
237
+ # Add polylines from model predictions if available
238
+ if polylines:
239
+ for curve in polylines:
240
+ points = np.array(curve["points"])
241
+ if len(points) < 2:
242
+ continue
243
+
244
+ curve_type = curve["type"]
245
+ curve_id = curve["id"]
246
+ score = curve["score"]
247
+
248
+ # NEW: allow override color if provided (e.g., threshold filtered)
249
+ color = curve.get("display_color") or curve_colors.get(curve_type, "orange")
250
+
251
+ # NEW: support hidden-by-default via legendonly
252
+ fig.add_trace(
253
+ go.Scatter3d(
254
+ x=points[:, 0],
255
+ y=points[:, 1],
256
+ z=points[:, 2],
257
+ mode="lines",
258
+ line=dict(color=color, width=5),
259
+ name=f"{curve_type} #{curve_id} ({score:.2f})",
260
+ visible=curve.get("visible_state", True),
261
+ hoverinfo="text",
262
+ text=f"{curve_type} #{curve_id} ({score:.4f})",
263
+ showlegend=False, # hide individual curve legend entries
264
+ )
265
+ )
266
+
267
+ # Equal aspect ratio using data ranges
268
+ if pts.size > 0:
269
+ mins = pts.min(axis=0)
270
+ maxs = pts.max(axis=0)
271
+ elif polylines and len(polylines) > 0:
272
+ # If we only have polylines, calculate range from them
273
+ all_points = np.vstack([np.array(curve["points"]) for curve in polylines])
274
+ mins = all_points.min(axis=0)
275
+ maxs = all_points.max(axis=0)
276
+ else:
277
+ mins = np.array([-1, -1, -1])
278
+ maxs = np.array([1, 1, 1])
279
+
280
+ centers = (mins + maxs) / 2.0
281
+ span = (maxs - mins).max()
282
+ if span <= 0:
283
+ span = 1.0
284
+ half = span / 2.0
285
+ xrange = [centers[0] - half, centers[0] + half]
286
+ yrange = [centers[1] - half, centers[1] + half]
287
+ zrange = [centers[2] - half, centers[2] + half]
288
+
289
+ scene_axes = dict(
290
+ xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""),
291
+ yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""),
292
+ zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""),
293
+ aspectmode="cube",
294
+ )
295
+
296
+ fig.update_layout(
297
+ title=title,
298
+ template=FIG_TEMPLATE,
299
+ showlegend=False,
300
+ scene=scene_axes,
301
+ margin=dict(l=0, r=0, t=40, b=0),
302
+ hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)),
303
+ height=PLOT_HEIGHT, # NEW
304
+ )
305
+ return fig
306
+
307
+
308
+ def process_model_predictions(data: Data) -> list:
309
+ """
310
+ Process model outputs into a format suitable for visualization.
311
+ """
312
+ class_names = ["None", "BSpline", "Line", "Circle", "Arc"]
313
+ polylines = data.polylines.cpu().numpy()
314
+ curves = []
315
+
316
+ # Process detected polylines
317
+ for i, polyline in enumerate(polylines):
318
+ cls = data.polyline_class[i].item()
319
+ score = data.polyline_score[i].item()
320
+ cls_name = class_names[cls]
321
+
322
+ # Skip low-confidence or "None" class predictions
323
+ if cls == 0:
324
+ continue
325
+
326
+ # Add curve data to results with unique ID
327
+ curve_data = {
328
+ "type": cls_name,
329
+ "id": i + 1, # 1-based ID for better user experience
330
+ "index": i,
331
+ "score": score,
332
+ "points": polyline,
333
+ }
334
+ curves.append(curve_data)
335
+
336
+ return curves
337
+
338
+
339
+ def process_data_for_model(
340
+ points: np.ndarray,
341
+ sample: int = 32768,
342
+ sample_mode: str = "fps",
343
+ ) -> Data: # CHANGED: removed reduction param
344
+ """
345
+ Process and subsample point cloud data using the same approach as predict_pi3detr.py.
346
+
347
+ Args:
348
+ points: Input point cloud as numpy array
349
+ sample: Number of points to sample
350
+ sample_mode: Sampling method ("fps", "random", "uniform", "all")
351
+
352
+ Returns:
353
+ Data object ready for model inference
354
+ """
355
+ # Convert to torch tensor
356
+ pos = torch.tensor(points, dtype=torch.float32)
357
+
358
+ # Apply sampling strategy
359
+ if sample_mode == "random":
360
+ if pos.size(0) > sample:
361
+ indices = torch.randperm(pos.size(0))[:sample]
362
+ pos = pos[indices]
363
+
364
+ elif sample_mode == "fps":
365
+ if pos.size(0) > sample:
366
+ indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6)
367
+ pos = pos[indices]
368
+
369
+ elif sample_mode == "uniform":
370
+ if pos.size(0) > sample:
371
+ step = max(1, pos.size(0) // sample)
372
+ pos = pos[::step][:sample]
373
+
374
+ elif sample_mode == "all":
375
+ pass # Keep all points
376
+
377
+ # Create Data object
378
+ data = Data(pos=pos)
379
+
380
+ # Add batch information for single point cloud BEFORE normalization
381
+ data.batch = torch.zeros(data.pos.size(0), dtype=torch.long)
382
+ data.batch_size = 1
383
+
384
+ # Normalize and scale using PI3DETR's method
385
+ data = normalize_and_scale(data)
386
+
387
+ # Ensure scale and center are proper batch tensors
388
+ if hasattr(data, "scale") and data.scale.dim() == 0:
389
+ data.scale = data.scale.unsqueeze(0)
390
+ if hasattr(data, "center") and data.center.dim() == 1:
391
+ data.center = data.center.unsqueeze(0)
392
+
393
+ return data
394
+
395
+
396
+ @torch.no_grad()
397
+ def run_model_inference(
398
+ model,
399
+ points: np.ndarray,
400
+ max_points: int = 32768,
401
+ sample_mode: str = "fps",
402
+ num_queries: int = 256, # NEW
403
+ snap_and_fit: bool = False, # NEW
404
+ iou_filter: bool = False, # NEW
405
+ ) -> list:
406
+ """Run model inference on the given point cloud (extended with num_queries, snap_and_fit, iou_filter)."""
407
+ global PI3DETR_MODEL
408
+ if model is None:
409
+ model = PI3DETR_MODEL
410
+ if model is None:
411
+ return []
412
+ try:
413
+ data = process_data_for_model(
414
+ points, sample=max_points, sample_mode=sample_mode
415
+ )
416
+ device = next(model.parameters()).device
417
+ data = data.to(device)
418
+
419
+ if model.num_preds != num_queries:
420
+ model.set_num_preds(num_queries)
421
+
422
+ output = model.predict_step(
423
+ data,
424
+ reverse_norm=True,
425
+ thresholds=None,
426
+ snap_and_fit=snap_and_fit, # CHANGED
427
+ iou_filter=iou_filter, # CHANGED
428
+ )
429
+ result = output[0]
430
+ curves = process_model_predictions(result)
431
+ return curves
432
+ except Exception as e:
433
+ print(f"Error in model inference: {e}")
434
+ return []
435
+
436
+
437
+ def load_and_process_pointcloud(
438
+ file: gr.File,
439
+ max_points: int,
440
+ point_size: int,
441
+ show_axes: bool,
442
+ ):
443
+ """
444
+ Load and process a point cloud from .xyz or .ply file
445
+ """
446
+ if file is None:
447
+ empty_fig = make_figure(np.zeros((0, 3)))
448
+ return empty_fig, None, None, os.path.basename(file.name) if file else ""
449
+
450
+ # Determine file type and read accordingly
451
+ file_ext = os.path.splitext(file.name)[1].lower()
452
+
453
+ # Read file based on extension
454
+ with open(file.name, "rb") as f:
455
+ if file_ext == ".xyz":
456
+ pts = read_xyz(f)
457
+ mode = "XYZ"
458
+ elif file_ext == ".ply":
459
+ pts = read_ply(f)
460
+ mode = "PLY"
461
+ elif file_ext == ".obj":
462
+ model_pts, display_pts, _ = read_obj_and_sample(f, max_points)
463
+ fig = make_figure(
464
+ display_pts,
465
+ point_size=point_size,
466
+ show_axes=show_axes,
467
+ title=f"{os.path.basename(file.name)}",
468
+ )
469
+ return fig, model_pts, display_pts, os.path.basename(file.name)
470
+ else:
471
+ empty_fig = make_figure(np.zeros((0, 3)))
472
+ return (
473
+ empty_fig,
474
+ None,
475
+ None,
476
+ "Unsupported file type. Please use .xyz, .ply or .obj.",
477
+ "",
478
+ )
479
+
480
+ original_n = pts.shape[0]
481
+
482
+ # Keep original points for model if normalizing for display
483
+ model_pts = pts.copy()
484
+
485
+ pts = downsample(pts, max_points=max_points)
486
+ displayed_n = pts.shape[0]
487
+
488
+ fig = make_figure(
489
+ pts,
490
+ point_size=point_size,
491
+ show_axes=show_axes,
492
+ title=f"{os.path.basename(file.name)}",
493
+ )
494
+
495
+ info = f"Loaded ({mode}): {original_n} points" # | Displayed: {displayed_n} points"
496
+
497
+ # RETURN single figure + model/full points + displayed subset
498
+ return fig, model_pts, pts, os.path.basename(file.name) # ADDED filename
499
+
500
+
501
+ def run_model_prediction(
502
+ model_pts: np.ndarray,
503
+ point_size: int,
504
+ show_axes: bool,
505
+ model_max_points: int,
506
+ sample_mode: str,
507
+ th_bspline: float,
508
+ th_line: float,
509
+ th_circle: float,
510
+ th_arc: float,
511
+ num_queries: int = 256,
512
+ snap_and_fit: bool = False,
513
+ iou_filter: bool = False,
514
+ ): # CHANGED: removed reduction
515
+ # NOTE: display points now handled outside; keep signature (called before adding display pts state)
516
+ # (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version)
517
+ return run_model_prediction_unified( # type: ignore
518
+ model_pts,
519
+ None,
520
+ point_size,
521
+ show_axes,
522
+ model_max_points,
523
+ sample_mode,
524
+ th_bspline,
525
+ th_line,
526
+ th_circle,
527
+ th_arc,
528
+ "",
529
+ num_queries,
530
+ snap_and_fit,
531
+ iou_filter,
532
+ )
533
+
534
+
535
+ def run_model_prediction_unified(
536
+ model_pts: np.ndarray,
537
+ display_pts: Optional[np.ndarray],
538
+ point_size: int,
539
+ show_axes: bool,
540
+ model_max_points: int,
541
+ sample_mode: str,
542
+ th_bspline: float,
543
+ th_line: float,
544
+ th_circle: float,
545
+ th_arc: float,
546
+ file_name: str = "",
547
+ num_queries: int = 256,
548
+ snap_and_fit: bool = False,
549
+ iou_filter: bool = False,
550
+ ):
551
+ """
552
+ Run model inference and apply initial threshold-based coloring.
553
+ """
554
+ global PI3DETR_MODEL, MODEL_STATUS
555
+ if model_pts is None:
556
+ empty_fig = make_figure(np.zeros((0, 3)))
557
+ return empty_fig, []
558
+
559
+ # Run model inference using cached model
560
+ curves = []
561
+ try:
562
+ if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]:
563
+ # Try to initialize model if not already loaded
564
+ initialize_model()
565
+
566
+ if PI3DETR_MODEL is not None:
567
+ # Run inference with the same settings as predict_pi3detr.py
568
+ curves = run_model_inference(
569
+ PI3DETR_MODEL,
570
+ model_pts,
571
+ max_points=model_max_points,
572
+ sample_mode=sample_mode,
573
+ num_queries=num_queries, # NEW
574
+ snap_and_fit=snap_and_fit, # NEW
575
+ iou_filter=iou_filter, # NEW
576
+ )
577
+ except Exception:
578
+ pass
579
+
580
+ # NEW: apply thresholds for display (store raw curves separately)
581
+ thresholds = {
582
+ "BSpline": th_bspline,
583
+ "Line": th_line,
584
+ "Circle": th_circle,
585
+ "Arc": th_arc,
586
+ }
587
+ colored_curves = []
588
+ for c in curves:
589
+ c_disp = dict(c)
590
+ if c["score"] < thresholds.get(c["type"], 0.7):
591
+ c_disp["visible_state"] = "legendonly"
592
+ colored_curves.append(c_disp)
593
+
594
+ # Use existing displayed subset if provided; else derive lightweight subset
595
+ if display_pts is None:
596
+ display_pts = downsample(model_pts, max_points=100000)
597
+ title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)"
598
+ fig = make_figure(
599
+ display_pts,
600
+ point_size=point_size,
601
+ show_axes=show_axes,
602
+ title=title,
603
+ polylines=colored_curves,
604
+ )
605
+ return fig, curves
606
+
607
+
608
+ def apply_pointcloud_display_settings(
609
+ model_pts: np.ndarray,
610
+ curves: List[Dict],
611
+ max_points: int,
612
+ point_size: int,
613
+ show_axes: bool,
614
+ th_bspline: float,
615
+ th_line: float,
616
+ th_circle: float,
617
+ th_arc: float,
618
+ file_name: str,
619
+ ):
620
+ """
621
+ Apply point cloud display settings without re-running inference.
622
+ Keeps existing detections and re-applies thresholds.
623
+ """
624
+ if model_pts is None:
625
+ empty_fig = make_figure(np.zeros((0, 3)))
626
+ return empty_fig, None
627
+ display_pts = downsample(model_pts, max_points=max_points)
628
+ if not curves:
629
+ fig = make_figure(
630
+ display_pts,
631
+ point_size=point_size,
632
+ show_axes=show_axes,
633
+ title=file_name or "Point Cloud",
634
+ )
635
+ return fig, display_pts
636
+ thresholds = {
637
+ "BSpline": th_bspline,
638
+ "Line": th_line,
639
+ "Circle": th_circle,
640
+ "Arc": th_arc,
641
+ }
642
+ colored_curves = []
643
+ for c in curves:
644
+ c_disp = dict(c)
645
+ if c["score"] < thresholds.get(c["type"], 0.7):
646
+ c_disp["visible_state"] = "legendonly"
647
+ colored_curves.append(c_disp)
648
+ fig = make_figure(
649
+ display_pts,
650
+ point_size=point_size,
651
+ show_axes=show_axes,
652
+ title=(file_name or "Point Cloud") + " (curves)",
653
+ polylines=colored_curves,
654
+ )
655
+ return fig, display_pts
656
+
657
+
658
+ def clear_curves(
659
+ curves: List[Dict],
660
+ display_pts: Optional[np.ndarray],
661
+ model_pts: Optional[np.ndarray],
662
+ point_size: int,
663
+ show_axes: bool,
664
+ file_name: str,
665
+ ):
666
+ """
667
+ Recolor already inferred curves based on updated thresholds (no re-inference).
668
+ """
669
+ if curves is None or model_pts is None or len(curves) == 0:
670
+ empty_fig = make_figure(
671
+ display_pts if display_pts is not None else np.zeros((0, 3))
672
+ )
673
+ return empty_fig, None
674
+
675
+ fig = make_figure(
676
+ display_pts if display_pts is not None else np.zeros((0, 3)),
677
+ point_size=point_size,
678
+ show_axes=show_axes,
679
+ title=file_name or "Point Cloud",
680
+ polylines=None,
681
+ )
682
+ return fig, None
683
+
684
+
685
+ def load_demo_pointcloud(
686
+ label: str,
687
+ max_points: int,
688
+ point_size: int,
689
+ show_axes: bool,
690
+ ):
691
+ """
692
+ Load one of the predefined demo point clouds.
693
+ Clears existing detected curves (curves_state -> None).
694
+ Also returns a value for the file upload component so the filename shows up.
695
+ """
696
+ path = DEMO_POINTCLOUDS.get(label, "")
697
+ if not path or not os.path.isfile(path):
698
+ empty_fig = make_figure(np.zeros((0, 3)))
699
+ return empty_fig, None, None, "", None, None
700
+ ext = os.path.splitext(path)[1].lower()
701
+ try:
702
+ with open(path, "rb") as f:
703
+ if ext == ".xyz":
704
+ pts = read_xyz(f)
705
+ elif ext == ".ply":
706
+ pts = read_ply(f)
707
+ elif ext == ".obj":
708
+ model_pts, display_pts, _ = read_obj_and_sample(
709
+ f, min(20000, max_points)
710
+ )
711
+ fig = make_figure(
712
+ display_pts,
713
+ point_size=1,
714
+ show_axes=show_axes,
715
+ title=f"{os.path.basename(path)} (demo)",
716
+ )
717
+ return fig, model_pts, display_pts, os.path.basename(path), None, path
718
+ else:
719
+ empty_fig = make_figure(np.zeros((0, 3)))
720
+ return empty_fig, None, None, "", None, None
721
+ except Exception:
722
+ empty_fig = make_figure(np.zeros((0, 3)))
723
+ return empty_fig, None, None, "", None, None
724
+ model_pts = pts.copy()
725
+ pts = downsample(pts, max_points=max_points)
726
+ fig = make_figure(
727
+ pts,
728
+ point_size=1,
729
+ show_axes=show_axes,
730
+ title=f"{os.path.basename(path)} (demo)",
731
+ )
732
+ return fig, model_pts, pts, os.path.basename(path), None, path
733
+
734
+
735
+ # Convenience wrappers for each demo (avoid lambdas for clarity)
736
+ def load_demo1(max_points, point_size, show_axes):
737
+ return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes)
738
+
739
+
740
+ def load_demo2(max_points, point_size, show_axes):
741
+ return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes)
742
+
743
+
744
+ def load_demo3(max_points, point_size, show_axes):
745
+ return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes)
746
+
747
+
748
+ def load_demo4(max_points, point_size, show_axes): # NEW
749
+ return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes)
750
+
751
+
752
+ def load_demo5(max_points, point_size, show_axes): # NEW
753
+ return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes)
754
+
755
+
756
+ def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure:
757
+ """Create a small preview figure for a demo point cloud (no curves)."""
758
+ path = DEMO_POINTCLOUDS.get(label, "")
759
+ if not path or not os.path.isfile(path):
760
+ return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)")
761
+ try:
762
+ ext = os.path.splitext(path)[1].lower()
763
+ with open(path, "rb") as f:
764
+ if ext == ".xyz":
765
+ pts = read_xyz(f)
766
+ elif ext == ".ply":
767
+ pts = read_ply(f)
768
+ elif ext == ".obj": # UPDATED
769
+ _, pts, _ = read_obj_and_sample(f, max_pts)
770
+ else:
771
+ return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)")
772
+ pts = downsample(pts, max_pts)
773
+ return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview")
774
+ except Exception as e:
775
+ return make_figure(np.zeros((0, 3)), title=f"{label}: error")
776
+
777
+
778
+ def run_model_with_display(
779
+ model_pts: np.ndarray,
780
+ max_points: int,
781
+ point_size: int,
782
+ show_axes: bool,
783
+ model_max_points: int,
784
+ sample_mode: str,
785
+ th_bspline: float,
786
+ th_line: float,
787
+ th_circle: float,
788
+ th_arc: float,
789
+ file_name: str = "",
790
+ num_queries: int = 256,
791
+ snap_and_fit: bool = False,
792
+ iou_filter: bool = False,
793
+ ): # CHANGED: removed reduction
794
+ """
795
+ Run inference (if model_pts present) then immediately apply current display
796
+ (max_points/point_size/show_axes) and thresholds. Returns:
797
+ figure, info_text, curves(list), display_pts
798
+ """
799
+ if model_pts is None:
800
+ empty = make_figure(np.zeros((0, 3)))
801
+ return empty, None, None
802
+
803
+ # Inference first (no display subset passed so it builds from model_pts)
804
+ fig_infer, curves = run_model_prediction_unified(
805
+ model_pts,
806
+ None,
807
+ point_size,
808
+ show_axes,
809
+ model_max_points,
810
+ sample_mode,
811
+ th_bspline,
812
+ th_line,
813
+ th_circle,
814
+ th_arc,
815
+ file_name,
816
+ num_queries, # NEW
817
+ snap_and_fit, # NEW
818
+ iou_filter, # NEW
819
+ ) # CHANGED: removed reduction
820
+
821
+ # Now apply current display settings & thresholds without re-inference
822
+ fig_final, display_pts = apply_pointcloud_display_settings(
823
+ model_pts,
824
+ curves,
825
+ max_points,
826
+ point_size,
827
+ show_axes,
828
+ th_bspline,
829
+ th_line,
830
+ th_circle,
831
+ th_arc,
832
+ file_name,
833
+ )
834
+ return fig_final, curves, display_pts
835
+
836
+
837
+ with gr.Blocks(title="PI3DETR") as demo:
838
+ gr.Markdown(
839
+ "# 🥧 PI3DETR: 3D Parametric Curve Inference [CPU-PREVIEW]\n"
840
+ "An end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.\n"
841
+ "Upload a `.xyz`, `.ply`, or `.obj` file to explore curve detection."
842
+ )
843
+
844
+ with gr.Row():
845
+ with gr.Column():
846
+ gr.Markdown(
847
+ "### 🧩 Supported Inputs\n"
848
+ "- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n"
849
+ "- `Mesh` is surface-sampled using **Max Points (display)** slider."
850
+ )
851
+ with gr.Column():
852
+ gr.Markdown(
853
+ "### ⚙️ Point Cloud Settings\n"
854
+ "- Adjust **Max Points**, **point size**, and **axes visibility**.\n"
855
+ "- Controls visualization of point cloud."
856
+ )
857
+ with gr.Column():
858
+ gr.Markdown(
859
+ "### 🎯 Confidence Thresholds\n"
860
+ "- Hover to inspect scores\n."
861
+ "- Filter curves by **class confidence** interactively"
862
+ )
863
+ with gr.Row():
864
+ with gr.Column():
865
+ gr.Markdown(
866
+ "### 🧠 Model Settings\n"
867
+ "- **Sampling Mode:** Choose downsampling strategy.\n"
868
+ "- **Model Input Size:** Number of model input points.\n"
869
+ "- **Queries:** Transformer decoder queries (max. output curves).\n"
870
+ "- Optional: *Snap&Fit* / *IOU-Filter* post-processing."
871
+ )
872
+ with gr.Column():
873
+ gr.Markdown(
874
+ "### ⚡ Performance Notes\n"
875
+ "- Trained on **human-made objects**.\n"
876
+ "- Optimized for **GPU**; this demo runs on **CPU**.\n"
877
+ "- For full qualitative performance: \n"
878
+ "[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)"
879
+ )
880
+ with gr.Column():
881
+ gr.Markdown(
882
+ "### ▶️ Run Inference\n"
883
+ "- Click on demo point clouds (from test set) below.\n"
884
+ "- Press **Run PI3DETR** to execute inference and visualize results."
885
+ )
886
+
887
+ model_pts_state = gr.State(None)
888
+ display_pts_state = gr.State(None)
889
+ curves_state = gr.State(None)
890
+ file_name_state = gr.State("demo_inputs/demo2.xyz")
891
+ with gr.Row():
892
+ file_in = gr.File(
893
+ label="Upload Point Cloud (auto-renders)",
894
+ file_types=[".xyz", ".ply", ".obj"],
895
+ type="filepath",
896
+ )
897
+ with gr.Row():
898
+ with gr.Column(scale=1):
899
+ gr.Markdown("### Point Cloud Settings")
900
+ max_points = gr.Slider(
901
+ 0,
902
+ 500_000,
903
+ value=200_000,
904
+ step=1_000,
905
+ label="Max points (display)",
906
+ )
907
+ point_size = gr.Slider(1, 8, value=1, step=1, label="Point size")
908
+ show_axes = gr.Checkbox(value=False, label="Show axes")
909
+
910
+ gr.Markdown("### Model Settings")
911
+ sample_mode = gr.Radio(
912
+ ["fps", "random", "all"],
913
+ value="fps",
914
+ label="Main Sampling Method",
915
+ )
916
+ model_max_points = gr.Slider(
917
+ 1_000,
918
+ 100_000,
919
+ value=32768,
920
+ step=500,
921
+ label="Downsample to Model Input Size",
922
+ )
923
+ num_queries = gr.Slider( # NEW
924
+ 32,
925
+ 512,
926
+ value=128,
927
+ step=1,
928
+ label="Number of Queries",
929
+ )
930
+ with gr.Row():
931
+ snap_and_fit_chk = gr.Checkbox(value=True, label="Snap&Fit")
932
+ iou_filter_chk = gr.Checkbox(value=False, label="IOU-Filter")
933
+
934
+ # Threshold sliders (no auto-change triggers)
935
+ gr.Markdown("#### Confidence Thresholds (per class)")
936
+ th_bspline = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="BSpline ≥")
937
+ th_line = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Line ≥")
938
+ th_circle = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Circle ≥")
939
+ th_arc = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Arc ≥")
940
+
941
+ with gr.Column(scale=1):
942
+ main_plot = gr.Plot(
943
+ label="Point Cloud & Curves"
944
+ ) # height from fig.update_layout(PLOT_HEIGHT)
945
+
946
+ run_model_button = gr.Button("Run PI3DETR", variant="primary")
947
+ clear_curves_button = gr.Button("Clear Curves", variant="secondary")
948
+
949
+ # Auto-render point cloud when file is uploaded
950
+ file_in.change(
951
+ load_and_process_pointcloud,
952
+ inputs=[file_in, max_points, point_size, show_axes],
953
+ outputs=[
954
+ main_plot,
955
+ model_pts_state,
956
+ display_pts_state,
957
+ file_name_state,
958
+ ],
959
+ )
960
+
961
+ run_model_button.click(
962
+ run_model_with_display,
963
+ inputs=[
964
+ model_pts_state,
965
+ max_points,
966
+ point_size,
967
+ show_axes,
968
+ model_max_points,
969
+ sample_mode,
970
+ th_bspline,
971
+ th_line,
972
+ th_circle,
973
+ th_arc,
974
+ file_name_state,
975
+ num_queries,
976
+ snap_and_fit_chk,
977
+ iou_filter_chk,
978
+ ],
979
+ outputs=[main_plot, curves_state, display_pts_state],
980
+ )
981
+
982
+ # NEW: auto-apply display & thresholds on interaction (no inference)
983
+ def _apply_display_wrapper(
984
+ model_pts,
985
+ curves,
986
+ max_points,
987
+ point_size,
988
+ show_axes,
989
+ th_bspline,
990
+ th_line,
991
+ th_circle,
992
+ th_arc,
993
+ file_name,
994
+ display_pts_state_value,
995
+ ):
996
+ fig, display_pts = apply_pointcloud_display_settings(
997
+ model_pts,
998
+ curves,
999
+ max_points,
1000
+ point_size,
1001
+ show_axes,
1002
+ th_bspline,
1003
+ th_line,
1004
+ th_circle,
1005
+ th_arc,
1006
+ file_name,
1007
+ )
1008
+ return fig, display_pts
1009
+
1010
+ # Point cloud sliders (release) & checkbox (change)
1011
+ for slider in [max_points, point_size]:
1012
+ slider.release(
1013
+ _apply_display_wrapper,
1014
+ inputs=[
1015
+ model_pts_state,
1016
+ curves_state,
1017
+ max_points,
1018
+ point_size,
1019
+ show_axes,
1020
+ th_bspline,
1021
+ th_line,
1022
+ th_circle,
1023
+ th_arc,
1024
+ file_name_state,
1025
+ display_pts_state,
1026
+ ],
1027
+ outputs=[main_plot, display_pts_state],
1028
+ )
1029
+
1030
+ show_axes.change(
1031
+ _apply_display_wrapper,
1032
+ inputs=[
1033
+ model_pts_state,
1034
+ curves_state,
1035
+ max_points,
1036
+ point_size,
1037
+ show_axes,
1038
+ th_bspline,
1039
+ th_line,
1040
+ th_circle,
1041
+ th_arc,
1042
+ file_name_state,
1043
+ display_pts_state,
1044
+ ],
1045
+ outputs=[main_plot, display_pts_state],
1046
+ )
1047
+
1048
+ # Threshold sliders (apply on release)
1049
+ for th in [th_bspline, th_line, th_circle, th_arc]:
1050
+ th.release(
1051
+ _apply_display_wrapper,
1052
+ inputs=[
1053
+ model_pts_state,
1054
+ curves_state,
1055
+ max_points,
1056
+ point_size,
1057
+ show_axes,
1058
+ th_bspline,
1059
+ th_line,
1060
+ th_circle,
1061
+ th_arc,
1062
+ file_name_state,
1063
+ display_pts_state,
1064
+ ],
1065
+ outputs=[main_plot, display_pts_state],
1066
+ )
1067
+
1068
+ clear_curves_button.click(
1069
+ clear_curves,
1070
+ inputs=[
1071
+ curves_state,
1072
+ display_pts_state,
1073
+ model_pts_state,
1074
+ point_size,
1075
+ show_axes,
1076
+ file_name_state,
1077
+ ],
1078
+ outputs=[main_plot, curves_state],
1079
+ )
1080
+
1081
+ # REPLACED demo preview plots + buttons WITH clickable images
1082
+ with gr.Row():
1083
+ gr.Markdown("### Demo Point Clouds (click an image to load)")
1084
+ with gr.Row():
1085
+ # CLEANUP: generate images dynamically for all demos
1086
+ demo_image_components = {}
1087
+ for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: # UPDATED
1088
+ png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png"
1089
+ demo_image_components[label] = gr.Image(
1090
+ value=png_path if os.path.isfile(png_path) else None,
1091
+ label=label,
1092
+ interactive=False,
1093
+ )
1094
+
1095
+ # CLEANUP: map labels to loader functions & attach select handlers
1096
+ _demo_loaders = {
1097
+ "Demo 1": load_demo1,
1098
+ "Demo 2": load_demo2,
1099
+ "Demo 3": load_demo3,
1100
+ "Demo 4": load_demo4,
1101
+ "Demo 5": load_demo5, # NEW
1102
+ }
1103
+ for label, comp in demo_image_components.items():
1104
+ comp.select(
1105
+ _demo_loaders[label],
1106
+ inputs=[max_points, point_size, show_axes],
1107
+ outputs=[
1108
+ main_plot,
1109
+ model_pts_state,
1110
+ display_pts_state,
1111
+ file_name_state,
1112
+ curves_state,
1113
+ file_in,
1114
+ ],
1115
+ )
1116
+
1117
+ # NEW: auto-load Demo 2 on app start
1118
+ demo.load(
1119
+ load_demo2,
1120
+ inputs=[max_points, point_size, show_axes],
1121
+ outputs=[
1122
+ main_plot,
1123
+ model_pts_state,
1124
+ display_pts_state,
1125
+ file_name_state,
1126
+ curves_state,
1127
+ file_in,
1128
+ ],
1129
+ )
1130
+
1131
+
1132
+ if __name__ == "__main__":
1133
+ # Initialize model at startup
1134
+ initialize_model()
1135
+ demo.launch()
configs/pi3detr.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Training parameters
2
+ epochs: 1715
3
+ lr_step: 1250
4
+ lr_warmup_epochs: 15
5
+ lr_warmup_start_factor: 1.0e-6
6
+ lr: 1.0e-4
7
+ batch_size: 8
8
+ batch_size_val: 8
9
+ accumulate_grad_batches: 16
10
+ grad_clip_val: 0.2 # max gradient norm
11
+ to_monitor: "val_seg_iou"
12
+ monitor_mode: "max"
13
+ val_interval: 1
14
+
15
+ ### Model parameters
16
+ model: "pi3detr"
17
+ preencoder_type: "samodule"
18
+ num_features: 0
19
+ weights: ""
20
+ preencoder_lr: 1.0e-4
21
+ freeze_backbone: false
22
+ encoder_dim: 768
23
+ decoder_dim: 768
24
+ num_encoder_layers: 3
25
+ num_decoder_layers: 9
26
+ encoder_dropout: 0.1 # dropout in encoder
27
+ decoder_dropout: 0.1 # dropout in decoder
28
+ num_attn_heads: 8 # number of attention heads
29
+ enc_dim_feedforward: 2048 # dimension of feedforward in encoder
30
+ dec_dim_feedforward: 2048 # dimension of feedforward in decoder
31
+ mlp_dropout: 0.0 # dropout in MLP heads
32
+ num_preds: 128 # num outputs of transformer
33
+ num_classes: 5
34
+ auxiliary_loss: true
35
+ max_points_in_param: 4
36
+ num_transformer_points: 2048 # number of transformer points (needed for some preencoders)
37
+ query_type: "point_fps"
38
+ pos_embed_type: "sine" # Options: "fourier", "sine"
39
+ class_loss_type: "cross_entropy"
40
+ class_loss_weights: [0.04834912, 0.40329467, 0.09588135, 0.23071379, 0.22176106]
41
+
42
+ ### Curve and validation parameters
43
+ num_curve_points: 64 # must be same as points_per_curve
44
+ num_curve_points_val: 256 # validation curve points
45
+
46
+ ### Loss weights
47
+ loss_weights:
48
+ loss_class: 1
49
+ loss_bspline: 1
50
+ loss_bspline_chamfer: 1
51
+ loss_line_position: 1
52
+ loss_line_length: 1
53
+ loss_line_chamfer: 1
54
+ loss_circle_position: 1
55
+ loss_circle_radius: 1
56
+ loss_circle_chamfer: 1
57
+ loss_arc: 1
58
+ loss_arc_chamfer: 1
59
+ loss_seg: 1
60
+
61
+ ### Cost weights
62
+ cost_weights:
63
+ cost_class: 1
64
+ cost_curve: 1
65
+
66
+ ### Dataset parameters
67
+ dataset: "abc_dataset"
68
+ num_workers: 8
69
+ data_root: "/dataset/train"
70
+ data_val_root: "/dataset/val"
71
+ data_test_root: "/dataset/test"
72
+ augment: true
73
+ random_rotate_prob: 1.0
74
+ random_sample_prob: 0.85
75
+ random_sample_bounds: [1.0, 0.2] # [max, min] fraction of points to keep
76
+ noise_prob: 0.0
77
+ noise_scale: 0.0
configs/pi3detr_k256.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Training parameters
2
+ epochs: 1715
3
+ lr_step: 1250
4
+ lr_warmup_epochs: 15
5
+ lr_warmup_start_factor: 1.0e-6
6
+ lr: 1.0e-4
7
+ batch_size: 8
8
+ batch_size_val: 8
9
+ accumulate_grad_batches: 16
10
+ grad_clip_val: 0.2 # max gradient norm
11
+ to_monitor: "val_seg_iou"
12
+ monitor_mode: "max"
13
+ val_interval: 1
14
+
15
+ ### Model parameters
16
+ model: "pi3detr"
17
+ preencoder_type: "samodule"
18
+ num_features: 0
19
+ weights: ""
20
+ preencoder_lr: 1.0e-4
21
+ freeze_backbone: false
22
+ encoder_dim: 768
23
+ decoder_dim: 768
24
+ num_encoder_layers: 3
25
+ num_decoder_layers: 9
26
+ encoder_dropout: 0.1 # dropout in encoder
27
+ decoder_dropout: 0.1 # dropout in decoder
28
+ num_attn_heads: 8 # number of attention heads
29
+ enc_dim_feedforward: 2048 # dimension of feedforward in encoder
30
+ dec_dim_feedforward: 2048 # dimension of feedforward in decoder
31
+ mlp_dropout: 0.0 # dropout in MLP heads
32
+ num_preds: 256 # num outputs of transformer
33
+ num_classes: 5
34
+ auxiliary_loss: true
35
+ max_points_in_param: 4
36
+ num_transformer_points: 2048 # number of transformer points (needed for some preencoders)
37
+ query_type: "point_fps"
38
+ pos_embed_type: "sine" # Options: "fourier", "sine"
39
+ class_loss_type: "cross_entropy"
40
+ class_loss_weights: [0.04834912, 0.40329467, 0.09588135, 0.23071379, 0.22176106]
41
+
42
+ ### Curve and validation parameters
43
+ num_curve_points: 64 # must be same as points_per_curve
44
+ num_curve_points_val: 256 # validation curve points
45
+
46
+ ### Loss weights
47
+ loss_weights:
48
+ loss_class: 1
49
+ loss_bspline: 1
50
+ loss_bspline_chamfer: 1
51
+ loss_line_position: 1
52
+ loss_line_length: 1
53
+ loss_line_chamfer: 1
54
+ loss_circle_position: 1
55
+ loss_circle_radius: 1
56
+ loss_circle_chamfer: 1
57
+ loss_arc: 1
58
+ loss_arc_chamfer: 1
59
+ loss_seg: 1
60
+
61
+ ### Cost weights
62
+ cost_weights:
63
+ cost_class: 1
64
+ cost_curve: 1
65
+
66
+ ### Dataset parameters
67
+ dataset: "abc_dataset"
68
+ num_workers: 8
69
+ data_root: "/dataset/train"
70
+ data_val_root: "/dataset/val"
71
+ data_test_root: "/dataset/test"
72
+ augment: true
73
+ random_rotate_prob: 1.0
74
+ random_sample_prob: 0.85
75
+ random_sample_bounds: [1.0, 0.2] # [max, min] fraction of points to keep
76
+ noise_prob: 0.0
77
+ noise_scale: 0.0
demo_inputs/demo1.png ADDED
demo_inputs/demo1.xyz ADDED
The diff for this file is too large to render. See raw diff
 
demo_inputs/demo2.png ADDED
demo_inputs/demo2.xyz ADDED
The diff for this file is too large to render. See raw diff
 
demo_inputs/demo3.png ADDED
demo_inputs/demo3.xyz ADDED
The diff for this file is too large to render. See raw diff
 
demo_inputs/demo4.png ADDED
demo_inputs/demo4.xyz ADDED
The diff for this file is too large to render. See raw diff
 
demo_inputs/demo5.png ADDED
demo_inputs/demo5.xyz ADDED
The diff for this file is too large to render. See raw diff
 
pi3detr/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Union, Optional
3
+ import torch.nn as nn
4
+ from torch_geometric.data import Dataset
5
+ import inspect
6
+ from .models import ModelConfig
7
+ from .utils import load_args, load_weights
8
+ from .models import PI3DETR
9
+ from .dataset import DatasetConfig, ABCDataset
10
+
11
+
12
+ def build_model_config(args: Union[argparse.Namespace, str]) -> ModelConfig:
13
+ if isinstance(args, str):
14
+ args = load_args(args)
15
+
16
+ # Get required parameters from ModelConfig constructor
17
+ model_config_signature = inspect.signature(ModelConfig.__init__)
18
+ required_params = [
19
+ param for param in model_config_signature.parameters.keys() if param != "self"
20
+ ]
21
+
22
+ for param in required_params:
23
+ if not hasattr(args, param):
24
+ print(f"ERROR: Parameter '{param}' has to be specified in the arguments")
25
+ raise ValueError(f"Missing required parameter: {param}")
26
+
27
+ # Create model config with all parameters from args
28
+ model_config_args = {param: getattr(args, param) for param in required_params}
29
+ model_config = ModelConfig(**model_config_args)
30
+
31
+ print(model_config)
32
+ return model_config
33
+
34
+
35
+ def build_dataset_config(
36
+ args: Union[argparse.Namespace, str], data_root: str, augment: bool
37
+ ) -> DatasetConfig:
38
+ if isinstance(args, str):
39
+ args = load_args(args)
40
+
41
+ # Get required parameters from DatasetConfig constructor (excluding root and augment)
42
+ dataset_config_signature = inspect.signature(DatasetConfig.__init__)
43
+ required_params = [
44
+ param
45
+ for param in dataset_config_signature.parameters.keys()
46
+ if param not in ["self", "root", "augment"]
47
+ ]
48
+
49
+ for param in required_params:
50
+ if not hasattr(args, param):
51
+ print(f"ERROR: Parameter '{param}' has to be specified in the arguments")
52
+ raise ValueError(f"Missing required parameter: {param}")
53
+
54
+ # Create dataset config with parameters from args plus root and augment
55
+ dataset_config_args = {param: getattr(args, param) for param in required_params}
56
+ dataset_config = DatasetConfig(
57
+ root=data_root, augment=augment, **dataset_config_args
58
+ )
59
+
60
+ print(dataset_config)
61
+ return dataset_config
62
+
63
+
64
+ def build_dataset(config: DatasetConfig) -> Dataset:
65
+ if config.dataset == "abc_dataset":
66
+ return ABCDataset(config)
67
+ else:
68
+ raise ValueError(f"Unknown dataset {config.dataset}")
69
+
70
+
71
+ def build_model(config: ModelConfig) -> nn.Module:
72
+ print(f"Model: {config.model}")
73
+ if config.model == "pi3detr":
74
+ return PI3DETR(config)
75
+ else:
76
+ raise ValueError(f"Unknown model {config.model}")
pi3detr/dataset/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .abc_dataset import ABCDataset, DatasetConfig
2
+ from .point_cloud_transforms import (
3
+ normalize_and_scale,
4
+ normalize_and_scale_with_params,
5
+ reverse_normalize_and_scale,
6
+ reverse_normalize_and_scale_with_params,
7
+ )
pi3detr/dataset/abc_dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ from typing import Union
5
+ from pathlib import Path
6
+ from torch_geometric.data import Dataset
7
+ from torch_geometric.data.data import BaseData
8
+ from dataclasses import dataclass
9
+ from typing import Callable
10
+ import torch_geometric.transforms as T
11
+
12
+ from .point_cloud_transforms import (
13
+ random_rotate,
14
+ normalize_and_scale,
15
+ add_noise,
16
+ subsample,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class DatasetConfig:
22
+ dataset: str
23
+ root: str
24
+ augment: bool = False
25
+ random_rotate_prob: float = 1
26
+ random_sample_prob: float = 0.5
27
+ random_sample_bounds: tuple[float, float] = (1, 0.5)
28
+ noise_prob: float = 0
29
+ noise_scale: float = 0
30
+
31
+
32
+ class ABCDataset(Dataset):
33
+ def __init__(
34
+ self,
35
+ config: DatasetConfig,
36
+ ) -> None:
37
+ self.file_names = self._read_file_names(config.root)
38
+ self.config = config
39
+ super().__init__(
40
+ config.root,
41
+ None,
42
+ None,
43
+ None,
44
+ )
45
+
46
+ @property
47
+ def raw_dir(self) -> str:
48
+ return self.root
49
+
50
+ @property
51
+ def raw_file_names(self) -> Union[str, list[str], tuple]:
52
+ return self.file_names
53
+
54
+ @property
55
+ def processed_file_names(self) -> Union[str, list[str], tuple]:
56
+ return [f"{file_name}.pt" for file_name in self.file_names]
57
+
58
+ def process(self) -> None:
59
+ print("Should already be processed.")
60
+
61
+ def get(self, idx: int) -> BaseData:
62
+
63
+ data = torch.load(
64
+ Path(self.processed_dir) / f"{self.raw_file_names[idx]}.pt",
65
+ weights_only=False,
66
+ )
67
+ data["pos"] = data["pos"].to(torch.float32)
68
+
69
+ augment = self.config.augment
70
+ if augment and random.random() < self.config.noise_prob:
71
+ sigma = (
72
+ np.max(
73
+ np.max(data.pos.cpu().numpy(), axis=0)
74
+ - np.min(data.pos.cpu().numpy(), axis=0)
75
+ )
76
+ / self.config.noise_scale
77
+ )
78
+ noise = torch.tensor(
79
+ np.random.normal(loc=0, scale=sigma, size=data.pos.shape),
80
+ dtype=data.pos.dtype,
81
+ device=data.pos.device,
82
+ )
83
+ data.pos += noise
84
+
85
+ if not hasattr(data, "real_scale") or not hasattr(data, "real_center"):
86
+ data.real_center = torch.zeros(3)
87
+ data.real_scale = torch.tensor(1.0)
88
+
89
+ if augment and random.random() < self.config.random_sample_prob:
90
+ data = subsample(
91
+ data,
92
+ *self.config.random_sample_bounds,
93
+ max_points=None,
94
+ extra_fields=["y_seg", "y_seg_cls"],
95
+ )
96
+
97
+ extra_fields = [
98
+ "y_curve_64",
99
+ "bspline_params",
100
+ "line_params",
101
+ "circle_params",
102
+ "arc_params",
103
+ ]
104
+ if augment and random.random() < self.config.random_rotate_prob:
105
+ data = random_rotate(data, 180, axis=0, extra_fields=extra_fields)
106
+ data = random_rotate(data, 180, axis=1, extra_fields=extra_fields)
107
+ data = random_rotate(data, 180, axis=2, extra_fields=extra_fields)
108
+
109
+ line_direction = data.line_params[:, 1]
110
+ circle_normal = data.circle_params[:, 1]
111
+
112
+ data = normalize_and_scale(
113
+ data,
114
+ extra_fields=extra_fields,
115
+ )
116
+ # normal vecotrs shouldn't change
117
+ data.line_params[:, 1] = line_direction
118
+ data.circle_params[:, 1] = circle_normal
119
+ # manually adjust length and radius
120
+ data.line_length = data.line_length * data.scale
121
+ data.circle_radius = data.circle_radius * data.scale
122
+
123
+ data.y_params = torch.zeros(data.num_curves, 12, dtype=torch.float32)
124
+ for i in range(data.num_curves):
125
+ if data.y_cls[i] == 1:
126
+ # B-spline
127
+ # P0, P1, P2, P3
128
+ data.y_params[i][:12] = data.bspline_params[i].reshape(-1)
129
+ elif data.y_cls[i] == 2:
130
+ # Line
131
+ # midpoint, normal, length
132
+ data.y_params[i][:3] = data.line_params[i][0].reshape(-1)
133
+ data.y_params[i][3:6] = line_direction[i].reshape(-1)
134
+ data.y_params[i][6] = data.line_length[i] # already adjusted above
135
+ elif data.y_cls[i] == 3:
136
+ # Circle
137
+ # center, normal, radius
138
+ data.y_params[i][:3] = data.circle_params[i][0].reshape(-1)
139
+ data.y_params[i][3:6] = circle_normal[i].reshape(-1)
140
+ data.y_params[i][6] = data.circle_radius[i] # already adjusted above
141
+ elif data.y_cls[i] == 4:
142
+ # Arc
143
+ # midpoint, start, end
144
+ data.y_params[i][:9] = data.arc_params[i].reshape(-1)
145
+ data.filename = self.raw_file_names[idx]
146
+
147
+ return data
148
+
149
+ def len(self) -> int:
150
+ return len(self.processed_file_names)
151
+
152
+ def _read_file_names(self, root: str) -> list[Path]:
153
+ return sorted(
154
+ [
155
+ fp.stem
156
+ for fp in Path(root).joinpath("processed").glob(f"*.pt")
157
+ if "pre_" not in fp.stem
158
+ ]
159
+ )
pi3detr/dataset/point_cloud_transforms.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from typing import Optional
6
+ import torch_geometric.transforms as T
7
+ from torch_geometric.data import Data
8
+
9
+
10
+ def subsample(
11
+ data: Data,
12
+ upper_bound: float = 1.0,
13
+ lower_bound: float = 0.5,
14
+ max_points: Optional[int] = None,
15
+ extra_fields: list[str] = [],
16
+ ) -> Data:
17
+ r"""Subsamples the point cloud to a random number of points within the
18
+ range :obj:`[lower_bound, upper_bound]` (functional name: :obj:`subsample`).
19
+ """
20
+ if data.pos.size(0) == 0:
21
+ return data
22
+ num_points = int(random.uniform(lower_bound, upper_bound) * data.pos.size(0))
23
+ if max_points is not None:
24
+ num_points = min(num_points, max_points)
25
+ idx = torch.randperm(data.pos.size(0))[:num_points]
26
+ data.pos = data.pos[idx]
27
+ for field in extra_fields:
28
+ if hasattr(data, field):
29
+ setattr(data, field, getattr(data, field)[idx])
30
+ return data
31
+
32
+
33
+ def numpy_normalize_and_scale(xyz: np.ndarray) -> tuple[np.ndarray, float, float]:
34
+ r"""Normalizes the point cloud in such a way that the points are centered
35
+ around the origin and are within the interval :math:`[-1, 1]` (functional
36
+ name: :obj:`normalize`).
37
+ """
38
+ center = xyz.mean(0)
39
+ scale = (1 / np.max(np.abs(xyz - center))) * 0.999999
40
+ xyz = numpy_normalize_and_scale_with_params(xyz, center, scale)
41
+ return xyz, center, scale
42
+
43
+
44
+ def numpy_normalize_and_scale_with_params(
45
+ xyz: np.ndarray, center: np.ndarray, scale: float
46
+ ) -> np.ndarray:
47
+ r"""Normalizes the point cloud in such a way that the points are centered
48
+ around the origin and are within the interval :math:`[-1, 1]` (functional
49
+ name: :obj:`normalize`).
50
+ """
51
+ if xyz.size == 0:
52
+ return xyz
53
+ shape = xyz.shape
54
+ return ((xyz.reshape(-1, shape[-1]) - center) * scale).reshape(shape)
55
+
56
+
57
+ def normalize_and_scale(data: Data, extra_fields: list[str] = []) -> Data:
58
+ r"""Centers and normalizes the given fields to the interval :math:`[-1, 1]`
59
+ (functional name: :obj:`normalize_scale`).
60
+ """
61
+ if data.pos.size(0) == 0:
62
+ data.center = torch.empty(0)
63
+ data.scale = torch.empty(0)
64
+ return data
65
+ # center the pos points
66
+ center = data.pos.mean(dim=-2, keepdim=True)
67
+ # scale the pos points
68
+ scale = (1 / (data.pos - center).abs().max()) * 0.999999
69
+
70
+ return normalize_and_scale_with_params(data, center, scale, extra_fields)
71
+
72
+
73
+ def reverse_normalize_and_scale(data: Data, extra_fields: list[str] = []) -> Data:
74
+ r"""Reverses the centering and normalization of the given fields
75
+ (functional name: :obj:`reverse_normalize_scale`).
76
+ """
77
+ assert hasattr(data, "center") and hasattr(
78
+ data, "scale"
79
+ ), "Data object does not contain the center and scale attributes."
80
+ return reverse_normalize_and_scale_with_params(
81
+ data, data.center, data.scale, extra_fields
82
+ )
83
+
84
+
85
+ def normalize_and_scale_with_params(
86
+ data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
87
+ ) -> Data:
88
+ if data.pos.size(0) == 0:
89
+ data.center = torch.empty(0)
90
+ data.scale = torch.empty(0)
91
+ return data
92
+ data.pos = (data.pos - center) * scale
93
+ for field in extra_fields:
94
+ if hasattr(data, field):
95
+ shape = getattr(data, field).size()
96
+ setattr(
97
+ data,
98
+ field,
99
+ (getattr(data, field).reshape(-1, shape[-1]) - center) * scale,
100
+ )
101
+ setattr(data, field, getattr(data, field).reshape(shape))
102
+ data.center = center
103
+ data.scale = scale
104
+ return data
105
+
106
+
107
+ def reverse_normalize_and_scale_with_params(
108
+ data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
109
+ ) -> Data:
110
+ r"""Reverses the centering and normalization of the given fields
111
+ (functional name: :obj:`reverse_normalize_scale`).
112
+ """
113
+ # Reverse the scaling and centering of the pos points
114
+ data.pos = data.pos / scale + center
115
+
116
+ for field in extra_fields:
117
+ if hasattr(data, field):
118
+ shape = getattr(data, field).size()
119
+ setattr(
120
+ data,
121
+ field,
122
+ (getattr(data, field).reshape(-1, shape[-1]) / scale) + center,
123
+ )
124
+ setattr(data, field, getattr(data, field).reshape(shape))
125
+ data.center = torch.empty(0)
126
+ data.scale = torch.empty(0)
127
+ return data
128
+
129
+
130
+ def reverse_normalize_and_scale_with_params(
131
+ data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
132
+ ) -> Data:
133
+ r"""Reverses the centering and normalization of the given fields
134
+ (functional name: :obj:`reverse_normalize_scale`).
135
+ """
136
+ # Reverse the scaling and centering of the pos points
137
+ data.pos = data.pos / scale + center
138
+
139
+ for field in extra_fields:
140
+ if hasattr(data, field):
141
+ shape = getattr(data, field).size()
142
+ setattr(
143
+ data,
144
+ field,
145
+ (getattr(data, field).reshape(-1, shape[-1]) / scale) + center,
146
+ )
147
+ setattr(data, field, getattr(data, field).reshape(shape))
148
+ data.center = torch.empty(0)
149
+ data.scale = torch.empty(0)
150
+ return data
151
+
152
+
153
+ def random_rotate(
154
+ data: Data, degrees: float, axis: int, extra_fields: list[str] = []
155
+ ) -> Data:
156
+ r"""Rotates the object around the origin by a random angle within the
157
+ range :obj:`[-degrees, degrees]` (functional name: :obj:`random_rotate
158
+ `).
159
+ """
160
+ if data.pos.size(0) == 0:
161
+ return data
162
+ return rotate_with_params(
163
+ data, random.uniform(-degrees, degrees), axis, extra_fields
164
+ )
165
+
166
+
167
+ def rotate_with_params(
168
+ data: Data, degrees: float, axis: int = 0, extra_fields: list[str] = []
169
+ ) -> Data:
170
+ r"""Rotates the object around the origin by a given angle
171
+ (functional name: :obj:`rotate`).
172
+ """
173
+ angle = math.pi * degrees / 180.0
174
+ if data.pos.size(0) == 0:
175
+ return data
176
+ sin, cos = math.sin(angle), math.cos(angle)
177
+ if data.pos.size(-1) == 2:
178
+ matrix = torch.tensor([[cos, sin], [-sin, cos]])
179
+ else:
180
+ if axis == 0:
181
+ matrix = torch.tensor([[1, 0, 0], [0, cos, sin], [0, -sin, cos]])
182
+ elif axis == 1:
183
+ matrix = torch.tensor([[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]])
184
+ else:
185
+ matrix = torch.tensor([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
186
+
187
+ matrix_dtype = matrix.to(data.pos.dtype)
188
+ matrix = matrix.to(matrix_dtype)
189
+
190
+ data.pos = data.pos @ matrix.t()
191
+ for field in extra_fields:
192
+ if hasattr(data, field):
193
+ shape = getattr(data, field).size()
194
+ # get dtype of field
195
+ dtype = getattr(data, field).dtype
196
+
197
+ matrix_dtype = matrix.to(dtype)
198
+ setattr(data, field, getattr(data, field) @ matrix_dtype.t())
199
+ setattr(data, field, getattr(data, field).reshape(shape))
200
+ setattr(data, f"rotated_{axis}", degrees)
201
+ return data
202
+
203
+
204
+ def reverse_rotate(data: Data, axis: int = 0, extra_fields: list[str] = []) -> Data:
205
+ r"""Reverses the rotation of the object around the origin
206
+ (functional name: :obj:`reverse_rotate`).
207
+ """
208
+ if not hasattr(data, f"rotated_{axis}"):
209
+ return data
210
+ return rotate_with_params(
211
+ data, -getattr(data, f"rotated_{axis}"), axis, extra_fields
212
+ )
213
+
214
+
215
+ def add_noise(data: Data, std: float) -> Data:
216
+ r"""Adds Gaussian noise to the node features (functional name:
217
+ :obj:`add_noise`).
218
+ """
219
+ if data.pos.size(0) == 0:
220
+ return data
221
+ noise = torch.randn_like(data.pos) * std
222
+ data.pos = data.pos + noise
223
+ data.noise = noise
224
+ return data
225
+
226
+
227
+ def remove_noise(data: Data) -> Data:
228
+ r"""Removes the noise from the node features (functional name:
229
+ :obj:`remove_noise`).
230
+ """
231
+ assert hasattr(data, "noise"), "Data object does not contain the noise attribute."
232
+ data.pos = data.pos - data.noise
233
+ del data.noise
234
+ return data
235
+
236
+
237
+ def custom_normalize_and_scale(
238
+ data: Data, p1: torch.Tensor, p2: torch.Tensor, extra_fields: list[str] = []
239
+ ) -> Data:
240
+ r"""Normalizes the point cloud in such a way that after the transformation
241
+ p1 is at (0,0,0) and p2 at (1,1,1)` (functional name:
242
+ :obj:`normalize`).
243
+ """
244
+ assert p1.size() == p2.size() == (3,), "Invalid interval."
245
+ if data.pos.size(0) == 0:
246
+ return data
247
+ data.pos = (data.pos - p1) / (p2 - p1)
248
+ for field in extra_fields:
249
+ if hasattr(data, field):
250
+ shape = getattr(data, field).size()
251
+ setattr(
252
+ data,
253
+ field,
254
+ (getattr(data, field).reshape(-1, shape[-1]) - p1) / (p2 - p1),
255
+ )
256
+ setattr(data, field, getattr(data, field).reshape(shape))
257
+ data.p1 = p1
258
+ data.p2 = p2
259
+ return data
pi3detr/dataset/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import open3d as o3d
4
+
5
+
6
+ def read_xyz_file(file_path: str, column_idxs: list[int] = [0, 1, 2]) -> np.ndarray:
7
+ """Reads a point cloud from a .xyz file.
8
+
9
+ Args:
10
+ file_path (str): Path to the .xyz file.
11
+ column_idxs (list[int], optional): Indices of the columns to read. Defaults to [0,1,2].
12
+
13
+ Returns:
14
+ np.ndarray: Point cloud as a numpy array.
15
+ """
16
+ return np.loadtxt(file_path, usecols=column_idxs)
17
+
18
+
19
+ def read_curve_file(file_path: str) -> tuple[np.ndarray]:
20
+ with open(file_path, "r") as f:
21
+ data = json.load(f)
22
+ return np.array(data["linear"]), np.array(data["bezier"])
23
+
24
+
25
+ def read_polyline_file(file_path: str, sep: str = ",") -> np.ndarray:
26
+ polylines = []
27
+ with open(file_path, "r") as f:
28
+ polyline = []
29
+ for line in f:
30
+ if line == "\n":
31
+ polylines.append(polyline)
32
+ polyline = []
33
+ else:
34
+ point = [float(x) for x in line.split(sep)]
35
+ polyline.append(point)
36
+ if polyline:
37
+ polylines.append(polyline)
38
+ return np.array(polylines)
39
+
40
+
41
+ def voxel_down_sample(xyz: np.ndarray, voxel_size: float) -> np.ndarray:
42
+ """Downsamples a point cloud using voxel grid downsampling.
43
+
44
+ Args:
45
+ xyz (np.ndarray): Point cloud.
46
+ voxel_size (float): Voxel size.
47
+
48
+ Returns:
49
+ np.ndarray: Downsampled point cloud.
50
+ """
51
+ pcd = o3d.geometry.PointCloud()
52
+ pcd.points = o3d.utility.Vector3dVector(xyz)
53
+ downpcd = pcd.voxel_down_sample(voxel_size=voxel_size)
54
+ return np.asarray(downpcd.points)
55
+
56
+
57
+ def filter_normals(
58
+ xyz: np.ndarray, radius: float, max_nn: float, threshold: float
59
+ ) -> np.ndarray:
60
+ """Filters normals of a point cloud.
61
+
62
+ Args:
63
+ xyz (np.ndarray): Point cloud.
64
+ threshold (float, optional): Threshold for filtering normals.
65
+
66
+ Returns:
67
+ np.ndarray: Filtered point cloud.
68
+ """
69
+ pcd = o3d.geometry.PointCloud()
70
+ pcd.points = o3d.utility.Vector3dVector(xyz)
71
+ pcd.estimate_normals(
72
+ search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=radius, max_nn=max_nn)
73
+ )
74
+ new_pts = np.asarray(pcd.points)[np.abs(np.asarray(pcd.normals)[:, 2]) < threshold]
75
+ return new_pts
pi3detr/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .abc_metrics import ChamferMAP, ChamferIntervalMetric
pi3detr/evaluation/abc_metrics.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchmetrics import Metric
4
+ from torchmetrics.functional import average_precision
5
+ from scipy.spatial import KDTree
6
+
7
+
8
+ def calc_chamfer_distance(pred_points, gt_points):
9
+ """Calculate chamfer distance and bidirectional hausdorff distance."""
10
+ if len(pred_points) == 0 or len(gt_points) == 0:
11
+ return float("inf"), float("inf")
12
+
13
+ tree_pred = KDTree(pred_points)
14
+ tree_gt = KDTree(gt_points)
15
+
16
+ dist_pred2gt, _ = tree_gt.query(pred_points)
17
+ dist_gt2pred, _ = tree_pred.query(gt_points)
18
+
19
+ chamfer_dist = np.mean(dist_pred2gt**2) + np.mean(dist_gt2pred**2)
20
+ bhaussdorf_dist = (dist_pred2gt.max() + dist_gt2pred.max()) / 2
21
+
22
+ return chamfer_dist, bhaussdorf_dist
23
+
24
+
25
+ class ChamferMAP(Metric):
26
+ def __init__(self, chamfer_thresh=0.05, dist_sync_on_step=False):
27
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
28
+ self.chamfer_thresh = chamfer_thresh
29
+ self.class_names = {
30
+ 1: "mAP_bspline",
31
+ 2: "mAP_line",
32
+ 3: "mAP_circle",
33
+ 4: "mAP_arc",
34
+ }
35
+
36
+ self.add_state("all_scores", default=[], dist_reduce_fx="cat")
37
+ self.add_state("all_matches", default=[], dist_reduce_fx="cat")
38
+ self.add_state("all_classes", default=[], dist_reduce_fx="cat")
39
+
40
+ def pairwise_chamfer_distance_batch(self, pred, gt):
41
+ """
42
+ pred: [P, 64, 3]
43
+ gt: [G, 64, 3]
44
+ returns: [P, G] chamfer distances
45
+ """
46
+ P, G = pred.size(0), gt.size(0)
47
+
48
+ # Reshape for pairwise comparison
49
+ pred_exp = pred.unsqueeze(1) # [P, 1, 64, 3]
50
+ gt_exp = gt.unsqueeze(0) # [1, G, 64, 3]
51
+
52
+ # Compute pairwise distances between points
53
+ dists = torch.cdist(pred_exp, gt_exp, p=2) # [P, G, 64, 64]
54
+
55
+ a2b = dists.min(dim=3).values.mean(dim=2) # [P, G]
56
+ b2a = dists.min(dim=2).values.mean(dim=2) # [P, G]
57
+
58
+ return a2b + b2a # [P, G]
59
+
60
+ def update(self, outputs, batch):
61
+ B = outputs["pred_class"].shape[0]
62
+ y_curves = batch.y_curve_64 # [total_gt, 64, 3]
63
+ y_cls = batch.y_cls # [total_gt]
64
+ num_curves_per_batch = batch.num_curves.tolist()
65
+
66
+ gt_splits = torch.split(y_curves, num_curves_per_batch, dim=0)
67
+ cls_splits = torch.split(y_cls, num_curves_per_batch, dim=0)
68
+
69
+ pred_classes_all = outputs["pred_class"].softmax(dim=-1) # [B, N, C]
70
+ for b in range(B):
71
+ pred_classes = pred_classes_all[b] # [N, C]
72
+
73
+ preds_all = {
74
+ 1: outputs["pred_bspline_points"][b], # [N, 64, 3]
75
+ 2: outputs["pred_line_points"][b],
76
+ 3: outputs["pred_circle_points"][b],
77
+ 4: outputs["pred_arc_points"][b],
78
+ }
79
+
80
+ for cls in self.class_names.keys():
81
+ pred_points = preds_all[cls] # [P, 64, 3]
82
+ scores = pred_classes[:, cls] # [P]
83
+
84
+ gt_points = gt_splits[b][cls_splits[b] == cls] # [G, 64, 3]
85
+ if gt_points.size(0) == 0:
86
+ self.all_scores.append(scores)
87
+ self.all_matches.append(torch.zeros_like(scores))
88
+ self.all_classes.append(torch.full_like(scores, cls))
89
+ continue
90
+
91
+ chamfer = self.pairwise_chamfer_distance_batch(pred_points, gt_points)
92
+ used_gt = torch.zeros(
93
+ gt_points.size(0), dtype=torch.bool, device=pred_points.device
94
+ )
95
+ matches = torch.zeros(pred_points.size(0), device=pred_points.device)
96
+
97
+ for i in range(pred_points.size(0)):
98
+ dists = chamfer[i]
99
+ min_dist, min_idx = dists.min(0)
100
+ if min_dist < self.chamfer_thresh and not used_gt[min_idx]:
101
+ matches[i] = 1.0
102
+ used_gt[min_idx] = True
103
+
104
+ self.all_scores.append(scores)
105
+ self.all_matches.append(matches)
106
+ self.all_classes.append(torch.full_like(matches, cls))
107
+
108
+ def compute(self):
109
+ if not self.all_scores:
110
+ return {cls_name: 0.0 for cls_name in self.class_names.values()}
111
+
112
+ scores = torch.cat(self.all_scores)
113
+ matches = torch.cat(self.all_matches)
114
+ classes = torch.cat(self.all_classes)
115
+
116
+ result = {}
117
+ ap_values = []
118
+
119
+ for cls in self.class_names.keys():
120
+ mask = classes == cls
121
+ if mask.sum() == 0 or torch.sum(matches[mask]) == 0:
122
+ ap = torch.tensor(0.0, device=self.device)
123
+ else:
124
+ ap = average_precision(
125
+ scores[mask], matches[mask].to(torch.int32), task="binary"
126
+ )
127
+ result[self.class_names[cls]] = ap.item()
128
+ ap_values.append(ap)
129
+
130
+ result["mAP"] = torch.stack(ap_values).mean().item()
131
+ return result
132
+
133
+
134
+ class ChamferIntervalMetric(Metric):
135
+
136
+ def __init__(self, interval=0.01, map_cd_thresh=0.005, dist_sync_on_step=False):
137
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
138
+ self.interval = interval
139
+ self.add_state("total_cd", default=torch.tensor(0.0), dist_reduce_fx="sum")
140
+ self.add_state("total_cd_sq", default=torch.tensor(0.0), dist_reduce_fx="sum")
141
+ self.add_state("total_bhd", default=torch.tensor(0.0), dist_reduce_fx="sum")
142
+ self.add_state("total_bhd_sq", default=torch.tensor(0.0), dist_reduce_fx="sum")
143
+ self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
144
+ self.add_state("valid_count", default=torch.tensor(0), dist_reduce_fx="sum")
145
+
146
+ self.map_cd_thresh = map_cd_thresh
147
+ self.map_cls_names = {
148
+ 1: "mAP_bspline",
149
+ 2: "mAP_line",
150
+ 3: "mAP_circle",
151
+ 4: "mAP_arc",
152
+ }
153
+
154
+ self.add_state("map_all_scores", default=[], dist_reduce_fx="cat")
155
+ self.add_state("map_all_matches", default=[], dist_reduce_fx="cat")
156
+ self.add_state("map_all_classes", default=[], dist_reduce_fx="cat")
157
+
158
+ def sample_curve_by_interval(self, points, interval, force_last=False):
159
+ """Sample points along the curve at fixed length `interval`."""
160
+ if len(points) < 2:
161
+ return points
162
+
163
+ edges = np.array([[j, j + 1] for j in range(len(points) - 1)])
164
+ edge_lengths = np.linalg.norm(points[edges[:, 1]] - points[edges[:, 0]], axis=1)
165
+
166
+ samples = [points[0]]
167
+ distance_accum = 0.0
168
+ next_sample_dist = interval
169
+ edge_index = 0
170
+
171
+ while edge_index < len(edges):
172
+ p0 = points[edges[edge_index, 0]]
173
+ p1 = points[edges[edge_index, 1]]
174
+ edge_vec = p1 - p0
175
+ edge_len = np.linalg.norm(edge_vec)
176
+ if edge_len == 0:
177
+ edge_index += 1
178
+ continue
179
+
180
+ while distance_accum + edge_len >= next_sample_dist:
181
+ t = (next_sample_dist - distance_accum) / edge_len
182
+ sample = p0 + t * edge_vec
183
+ samples.append(sample)
184
+ next_sample_dist += interval
185
+
186
+ distance_accum += edge_len
187
+ edge_index += 1
188
+
189
+ if force_last and not np.allclose(samples[-1], points[-1]):
190
+ samples.append(points[-1])
191
+
192
+ return np.array(samples)
193
+
194
+ def update(self, data, batch):
195
+
196
+ # Get ground truth curves
197
+ y_curves = batch.y_curve_64.cpu().numpy() # [total_gt, 64, 3]
198
+ num_curves_per_batch = batch.num_curves.tolist()
199
+
200
+ # Since batch size is 1 in your case
201
+ B = 1
202
+
203
+ # Sample ground truth curves
204
+ gt_points_list = []
205
+ gt_cls_list = []
206
+ for i, gt_curve in enumerate(y_curves):
207
+ if len(gt_curve) < 2 or np.any(np.isnan(gt_curve)):
208
+ continue
209
+ sampled_gt = self.sample_curve_by_interval(
210
+ gt_curve, self.interval, force_last=True
211
+ )
212
+ if len(sampled_gt) > 0 and np.all(np.isfinite(sampled_gt)):
213
+ gt_points_list.append(sampled_gt)
214
+ gt_cls_list.append(batch.y_cls[i].cpu().item())
215
+
216
+ # Sample predicted curves from post-processed data
217
+ pred_points_list = []
218
+ pred_cls_list = []
219
+ pred_score_list = []
220
+ for polyline, cls in zip(
221
+ data.polylines.cpu().numpy(), data.polyline_class.cpu().numpy()
222
+ ):
223
+ if cls == 0: # Skip background class
224
+ continue
225
+
226
+ if len(polyline) < 2 or np.any(np.isnan(polyline)):
227
+ continue
228
+
229
+ sampled_pred = self.sample_curve_by_interval(
230
+ polyline, self.interval, force_last=True
231
+ )
232
+ if len(sampled_pred) > 0 and np.all(np.isfinite(sampled_pred)):
233
+ pred_points_list.append(sampled_pred)
234
+ pred_cls_list.append(int(cls))
235
+ pred_score_list.append(data.polyline_score[i].cpu().item())
236
+
237
+ if len(gt_points_list) == 0 and len(pred_points_list) == 0:
238
+ # No ground truth and no predictions, no penalty
239
+ self.count += 1
240
+ return
241
+ elif len(gt_points_list) == 0:
242
+ # Penalize no ground truth
243
+ self.count += 1
244
+ scores = torch.tensor(pred_score_list)
245
+ self.map_all_scores.append(scores)
246
+ self.map_all_matches.append(torch.zeros_like(scores))
247
+ self.map_all_classes.append(torch.tensor(pred_cls_list))
248
+ return
249
+ elif len(pred_points_list) == 0:
250
+ # Penalize no predictions
251
+ self.count += 1
252
+ cls_list = torch.tensor(gt_cls_list, dtype=torch.float32)
253
+ self.map_all_scores.append(torch.zeros_like(cls_list))
254
+ self.map_all_matches.append(torch.ones_like(cls_list))
255
+ self.map_all_classes.append(torch.tensor(cls_list))
256
+ return
257
+
258
+ # calculate mAP
259
+ for cls in self.map_cls_names.keys():
260
+ mask = torch.tensor(pred_cls_list) == cls
261
+ pred_curves = [curve for i, curve in enumerate(pred_points_list) if mask[i]]
262
+ pred_scores = torch.tensor(pred_score_list)[mask]
263
+ gt_curves = [
264
+ curve for i, curve in enumerate(gt_points_list) if cls == gt_cls_list[i]
265
+ ]
266
+ if len(pred_curves) == 0 and len(gt_curves) != 0:
267
+ scores = torch.zeros(len(gt_curves))
268
+ self.map_all_scores.append(scores)
269
+ self.map_all_matches.append(torch.zeros_like(scores))
270
+ self.map_all_classes.append(torch.full_like(scores, cls))
271
+ continue
272
+ if len(gt_curves) == 0:
273
+ self.map_all_scores.append(pred_scores)
274
+ self.map_all_matches.append(torch.zeros_like(pred_scores))
275
+ self.map_all_classes.append(torch.full_like(pred_scores, cls))
276
+ continue
277
+
278
+ # get [P, G] matrix of chamfer distances
279
+ cd_matrix = torch.ones((len(pred_curves), len(gt_curves))) * float("inf")
280
+ for i, pred_curve in enumerate(pred_curves):
281
+ for j, gt_curve in enumerate(gt_curves):
282
+ cd_matrix[i, j] = calc_chamfer_distance(pred_curve, gt_curve)[0]
283
+
284
+ used_gt = set()
285
+ matches = torch.zeros(len(pred_curves))
286
+ for i in range(len(pred_curves)):
287
+ dists = cd_matrix[i]
288
+ min_dist, min_idx = dists.min(0)
289
+ if min_dist < self.map_cd_thresh and min_idx not in used_gt:
290
+ matches[i] = 1.0
291
+ used_gt.add(min_idx)
292
+
293
+ self.map_all_scores.append(pred_scores)
294
+ self.map_all_matches.append(matches)
295
+ self.map_all_classes.append(torch.full_like(pred_scores, cls))
296
+
297
+ pred_points = np.concatenate(pred_points_list, axis=0)
298
+ gt_points = np.concatenate(gt_points_list, axis=0)
299
+ # Calculate distances
300
+ cd, bhd = calc_chamfer_distance(pred_points, gt_points)
301
+
302
+ self.total_cd += torch.tensor(cd)
303
+ self.total_cd_sq += torch.tensor(cd**2)
304
+ self.total_bhd += torch.tensor(bhd)
305
+ self.total_bhd_sq += torch.tensor(bhd**2)
306
+ self.count += 1
307
+ self.valid_count += 1 if len(pred_points) > 0 else 0
308
+
309
+ def compute(self):
310
+ if not self.map_all_scores:
311
+ return {cls_name: 0.0 for cls_name in self.map_cls_names.values()}
312
+ scores = torch.cat(self.map_all_scores)
313
+ matches = torch.cat(self.map_all_matches)
314
+ classes = torch.cat(self.map_all_classes)
315
+ map_result = {}
316
+ ap_values = []
317
+
318
+ for cls in self.map_cls_names.keys():
319
+ mask = classes == cls
320
+ if mask.sum() == 0 or torch.sum(matches[mask]) == 0:
321
+ ap = torch.tensor(0.0, device=self.device)
322
+ else:
323
+ ap = average_precision(
324
+ scores[mask], matches[mask].to(torch.int32), task="binary"
325
+ )
326
+ map_result[self.map_cls_names[cls]] = ap.item()
327
+ ap_values.append(ap)
328
+
329
+ map_result["mAP"] = torch.stack(ap_values).mean().item()
330
+
331
+ if self.count == 0:
332
+ return {"chamfer_distance": 0.0, "bidirectional_hausdorff": 0.0}
333
+
334
+ mean_cd = (self.total_cd / self.valid_count).item()
335
+ mean_bhd = (self.total_bhd / self.valid_count).item()
336
+ results = {
337
+ "chamfer_distance": mean_cd,
338
+ "chamfer_distance_std": (self.total_cd_sq / self.valid_count - (mean_cd**2))
339
+ .sqrt()
340
+ .item(),
341
+ "bidirectional_hausdorff": mean_bhd,
342
+ "bidirectional_hausdorff_std": (
343
+ self.total_bhd_sq / self.valid_count - (mean_bhd**2)
344
+ )
345
+ .sqrt()
346
+ .item(),
347
+ }
348
+ results.update(map_result)
349
+ return results
pi3detr/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_config import ModelConfig
2
+ from .pi3detr import PI3DETR
pi3detr/models/losses/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .losses import (
2
+ chamfer_distance_batch,
3
+ LossParams,
4
+ ParametricLoss,
5
+ )
pi3detr/models/losses/losses.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torch
3
+ from torch import nn
4
+ from dataclasses import dataclass, field
5
+ import torch.nn.functional as F
6
+ from torch_geometric.data.data import Data
7
+ from kornia.losses import focal_loss
8
+ from .matcher import *
9
+
10
+
11
+ def chamfer_distance_batch(pts1: torch.Tensor, pts2: torch.Tensor) -> torch.Tensor:
12
+ assert len(pts1.shape) == 3 and len(pts2.shape) == 3
13
+ if pts1.nelement() == 0 or pts2.nelement() == 0:
14
+ return torch.tensor(0.0, device=pts1.device, requires_grad=True)
15
+ dist_matrix = torch.cdist(
16
+ pts1, pts2, p=2
17
+ ) # shape: (batch_size, num_points, num_points)
18
+ dist1 = dist_matrix.min(dim=2).values.mean(dim=1) # min over pts2, mean over pts1
19
+ dist2 = dist_matrix.min(dim=1).values.mean(dim=1) # min over pts1, mean over pts2
20
+ return (dist1 + dist2) / 2
21
+
22
+
23
+ @torch.no_grad()
24
+ def accuracy(output, target, topk=(1,)):
25
+ """Computes the precision@k for the specified values of k"""
26
+ if target.numel() == 0:
27
+ return [torch.zeros([], device=output.device)]
28
+ maxk = max(topk)
29
+ batch_size = target.size(0)
30
+
31
+ _, pred = output.topk(maxk, 1, True, True)
32
+ pred = pred.t()
33
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
34
+
35
+ res = []
36
+ for k in topk:
37
+ correct_k = correct[:k].view(-1).float().sum(0)
38
+ res.append(correct_k.mul_(100.0 / batch_size))
39
+ return res
40
+
41
+
42
+ @torch.no_grad()
43
+ def f1_score(output, target, threshold=0.5):
44
+ output = output > threshold
45
+ target = target > threshold
46
+ tp = (output & target).sum()
47
+ tn = (~output & ~target).sum()
48
+ fp = (output & ~target).sum()
49
+ fn = (~output & target).sum()
50
+ precision = tp / (tp + fp + 1e-8)
51
+ recall = tp / (tp + fn + 1e-8)
52
+ return 2 * (precision * recall) / (precision + recall + 1e-8), precision, recall
53
+
54
+
55
+ @dataclass
56
+ class LossParams:
57
+ num_classes: int
58
+ cost_class: int = 1
59
+ cost_curve: int = 1
60
+ class_loss_type: str = "cross_entropy" # or "focal"
61
+ class_loss_weights: list[float] = field(
62
+ default_factory=lambda: [
63
+ 0.04834912,
64
+ 0.40329467,
65
+ 0.09588135,
66
+ 0.23071379,
67
+ 0.22176106,
68
+ ]
69
+ )
70
+ # NOTE: Weights calculated based on the dataset
71
+ # bezier, line, circle, arc, empty
72
+ # counts = np.array([11347, 200751, 34672, 37528])
73
+ # counts = np.append(counts, total_pred - counts.sum())
74
+ # weights = 1 / np.sqrt(counts)
75
+ # weights = weights / weights.sum()
76
+
77
+
78
+ class Loss(nn.Module):
79
+
80
+ def __init__(self, params: LossParams) -> None:
81
+ super().__init__()
82
+ self.matcher = ParametricMatcher(params.cost_class, params.cost_curve)
83
+ self.num_classes = params.num_classes
84
+ self.class_loss_type = params.class_loss_type
85
+ class_weights = torch.tensor(
86
+ params.class_loss_weights,
87
+ )
88
+
89
+ self.register_buffer("class_weights", class_weights)
90
+
91
+ def forward(
92
+ self, outputs: dict[str, torch.Tensor], data: Data
93
+ ) -> dict[str, torch.Tensor]:
94
+ indices = self.matcher(outputs, data)
95
+ losses = {}
96
+ losses.update(self._loss_class(outputs, data, indices))
97
+ losses.update(self._loss_polyline(outputs, data, indices))
98
+ # In case of auxiliary losses, we repeat this process with the output
99
+ # of each intermediate layer.
100
+ if "aux_outputs" in outputs:
101
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
102
+ indices = self.matcher(aux_outputs, data)
103
+ l_dict = self._loss_class(aux_outputs, data, indices, False)
104
+ l_dict.update(self._loss_polyline(aux_outputs, data, indices))
105
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
106
+ losses.update(l_dict)
107
+ return losses
108
+
109
+ @abstractmethod
110
+ def _loss_polyline(
111
+ self,
112
+ outputs: dict[str, torch.Tensor],
113
+ data: Data,
114
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
115
+ ) -> dict[str, torch.Tensor]:
116
+ """Compute the polyline loss."""
117
+ pass
118
+
119
+ def _loss_class(
120
+ self,
121
+ outputs: dict[str, torch.Tensor],
122
+ data: Data,
123
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
124
+ log: bool = True,
125
+ ) -> torch.Tensor:
126
+ num_targets = (
127
+ data.num_polylines.tolist()
128
+ if hasattr(data, "num_polylines")
129
+ else data.num_curves.tolist()
130
+ )
131
+ src_logits = outputs["pred_class"]
132
+ idx = self._get_src_permutation_idx(indices)
133
+ target_classes_o = torch.cat(
134
+ [
135
+ target[J]
136
+ for target, (_, J) in zip(
137
+ data.y_cls.split_with_sizes(num_targets), indices
138
+ )
139
+ ]
140
+ )
141
+ target_classes = torch.full(
142
+ src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device
143
+ ) # 0: empty class
144
+ target_classes[idx] = target_classes_o
145
+ losses = {}
146
+
147
+ if self.class_loss_type == "cross_entropy":
148
+ loss_class = F.cross_entropy(
149
+ src_logits.transpose(1, 2),
150
+ target_classes,
151
+ weight=self.class_weights.to(src_logits.device),
152
+ reduction="mean",
153
+ )
154
+ else:
155
+ loss_class = focal_loss(
156
+ src_logits.transpose(1, 2),
157
+ target_classes,
158
+ alpha=0.25,
159
+ gamma=2.0,
160
+ weight=self.class_weights.to(src_logits.device),
161
+ reduction="mean",
162
+ )
163
+ losses["loss_class"] = loss_class
164
+ if log:
165
+ losses["class_error"] = (
166
+ 100
167
+ - accuracy(
168
+ src_logits.reshape(-1, src_logits.size(-1)),
169
+ target_classes.flatten(),
170
+ )[0]
171
+ )
172
+ f1, _, _ = f1_score(
173
+ src_logits.reshape(-1, src_logits.size(-1)).softmax(-1).argmax(-1),
174
+ target_classes.flatten(),
175
+ threshold=0.5,
176
+ )
177
+ losses["class_f1_score"] = f1
178
+
179
+ return losses
180
+
181
+ def _get_src_permutation_idx(
182
+ self, indices: list[tuple[torch.Tensor, torch.Tensor]]
183
+ ) -> tuple[torch.Tensor, torch.Tensor]:
184
+ # permute predictions following indices
185
+ batch_idx = torch.cat(
186
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
187
+ )
188
+ src_idx = torch.cat([src for (src, _) in indices])
189
+ return batch_idx, src_idx
190
+
191
+ def _get_tgt_permutation_idx(
192
+ self, indices: list[tuple[torch.Tensor, torch.Tensor]]
193
+ ) -> tuple[torch.Tensor, torch.Tensor]:
194
+ # permute targets following indices
195
+ batch_idx = torch.cat(
196
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
197
+ )
198
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
199
+ return batch_idx, tgt_idx
200
+
201
+
202
+ class ParametricLoss(Loss):
203
+ def __init__(self, params: LossParams) -> None:
204
+ super().__init__(params)
205
+
206
+ def _loss_polyline(
207
+ self,
208
+ outputs: dict[str, torch.Tensor],
209
+ data: Data,
210
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
211
+ ) -> dict[str, torch.Tensor]:
212
+ idx = self._get_src_permutation_idx(indices)
213
+ src_bspline_params = outputs["pred_bspline_params"][idx]
214
+ src_bspline_points = outputs["pred_bspline_points"][idx]
215
+ src_line_params = outputs["pred_line_params"][idx]
216
+ src_line_length = outputs["pred_line_length"][idx]
217
+ src_line_points = outputs["pred_line_points"][idx]
218
+ src_circle_params = outputs["pred_circle_params"][idx]
219
+ src_circle_radius = outputs["pred_circle_radius"][idx]
220
+ src_circle_points = outputs["pred_circle_points"][idx]
221
+ src_arc_params = outputs["pred_arc_params"][idx]
222
+ src_arc_points = outputs["pred_arc_points"][idx]
223
+ target_params = torch.cat(
224
+ [
225
+ target[J]
226
+ for target, (_, J) in zip(
227
+ data.y_params.split_with_sizes(data.num_curves.tolist()), indices
228
+ )
229
+ ]
230
+ )
231
+ target_classes = torch.cat(
232
+ [
233
+ target[J]
234
+ for target, (_, J) in zip(
235
+ data.y_cls.split_with_sizes(data.num_curves.tolist()), indices
236
+ )
237
+ ]
238
+ )
239
+ target_curves = torch.cat(
240
+ [
241
+ target[J]
242
+ for target, (_, J) in zip(
243
+ data.y_curve_64.split_with_sizes(data.num_curves.tolist()), indices
244
+ )
245
+ ]
246
+ )
247
+
248
+ losses = {}
249
+
250
+ # Filter indices for each class
251
+ bspline_mask = target_classes == 1 # B-spline
252
+ line_mask = target_classes == 2 # Line
253
+ circle_mask = target_classes == 3 # Circle
254
+ arc_mask = target_classes == 4 # Arc
255
+
256
+ # Compute loss for B-splines
257
+ if bspline_mask.any():
258
+ bspline_order_l1 = torch.min(
259
+ F.l1_loss(
260
+ src_bspline_params[bspline_mask].flatten(-2, -1),
261
+ target_params[bspline_mask],
262
+ reduction="none",
263
+ ).mean(-1),
264
+ F.l1_loss(
265
+ src_bspline_params[bspline_mask].flip([1]).flatten(-2, -1),
266
+ target_params[bspline_mask],
267
+ reduction="none",
268
+ ).mean(-1),
269
+ ).mean()
270
+ losses["loss_bspline"] = bspline_order_l1
271
+ bspline_chamfer = chamfer_distance_batch(
272
+ src_bspline_points[bspline_mask], target_curves[bspline_mask]
273
+ )
274
+ losses["loss_bspline_chamfer"] = bspline_chamfer.mean()
275
+ else:
276
+ losses["loss_bspline"] = torch.tensor(0.0, device=src_bspline_params.device)
277
+ losses["loss_bspline_chamfer"] = torch.tensor(
278
+ 0.0, device=src_bspline_points.device
279
+ )
280
+
281
+ # Compute loss for Lines
282
+ if line_mask.any():
283
+ line_position_l1 = torch.min(
284
+ F.l1_loss(
285
+ src_line_params[line_mask].flatten(-2, -1),
286
+ target_params[line_mask, :6],
287
+ reduction="none",
288
+ ).mean(-1),
289
+ # also consider the negative direction
290
+ F.l1_loss(
291
+ (
292
+ src_line_params[line_mask]
293
+ * torch.tensor([1.0, -1.0])
294
+ .view(1, 2, 1)
295
+ .to(src_line_params.device)
296
+ ).flatten(-2, -1),
297
+ target_params[line_mask, :6],
298
+ reduction="none",
299
+ ).mean(-1),
300
+ ).mean()
301
+ line_length_loss = F.l1_loss(
302
+ src_line_length[line_mask],
303
+ target_params[line_mask, 6].unsqueeze(-1),
304
+ )
305
+ losses["loss_line_position"] = line_position_l1
306
+ losses["loss_line_length"] = line_length_loss
307
+ line_chamfer = chamfer_distance_batch(
308
+ src_line_points[line_mask], target_curves[line_mask]
309
+ )
310
+ losses["loss_line_chamfer"] = line_chamfer.mean()
311
+ else:
312
+ losses["loss_line_position"] = torch.tensor(
313
+ 0.0, device=src_line_params.device
314
+ )
315
+ losses["loss_line_length"] = torch.tensor(
316
+ 0.0, device=src_line_length.device
317
+ )
318
+ losses["loss_line_chamfer"] = torch.tensor(
319
+ 0.0, device=src_line_points.device
320
+ )
321
+
322
+ # Compute loss for Circles
323
+ if circle_mask.any():
324
+ circle_position_l1 = torch.min(
325
+ F.l1_loss(
326
+ src_circle_params[circle_mask].flatten(-2, -1),
327
+ target_params[circle_mask, :6],
328
+ reduction="none",
329
+ ).mean(-1),
330
+ # also consider the negative direction
331
+ F.l1_loss(
332
+ (
333
+ src_circle_params[circle_mask]
334
+ * torch.tensor([1.0, -1.0])
335
+ .view(1, 2, 1)
336
+ .to(src_circle_params.device)
337
+ ).flatten(-2, -1),
338
+ target_params[circle_mask, :6],
339
+ reduction="none",
340
+ ).mean(-1),
341
+ ).mean()
342
+ radius_loss = F.l1_loss(
343
+ src_circle_radius[circle_mask],
344
+ target_params[circle_mask, 6].unsqueeze(-1),
345
+ )
346
+ losses["loss_circle_position"] = circle_position_l1
347
+ losses["loss_circle_radius"] = radius_loss
348
+ circle_chamfer = chamfer_distance_batch(
349
+ src_circle_points[circle_mask], target_curves[circle_mask]
350
+ )
351
+ losses["loss_circle_chamfer"] = circle_chamfer.mean()
352
+ else:
353
+ losses["loss_circle_position"] = torch.tensor(
354
+ 0.0, device=src_circle_params.device
355
+ )
356
+ losses["loss_circle_radius"] = torch.tensor(
357
+ 0.0, device=src_circle_radius.device
358
+ )
359
+ losses["loss_circle_chamfer"] = torch.tensor(
360
+ 0.0, device=src_circle_points.device
361
+ )
362
+
363
+ # Compute loss for Arcs
364
+ if arc_mask.any():
365
+ arc_order_l1 = torch.min(
366
+ F.l1_loss(
367
+ src_arc_params[arc_mask].flatten(-2, -1),
368
+ target_params[arc_mask, :9],
369
+ reduction="none",
370
+ ).mean(-1),
371
+ F.l1_loss(
372
+ src_arc_params[arc_mask][:, [0, 2, 1]].flatten(-2, -1),
373
+ target_params[arc_mask, :9],
374
+ reduction="none",
375
+ ).mean(-1),
376
+ ).mean()
377
+ losses["loss_arc"] = arc_order_l1
378
+ arc_chamfer = chamfer_distance_batch(
379
+ src_arc_points[arc_mask], target_curves[arc_mask]
380
+ )
381
+ losses["loss_arc_chamfer"] = arc_chamfer.mean()
382
+ else:
383
+ losses["loss_arc"] = torch.tensor(0.0, device=src_arc_params.device)
384
+ losses["loss_arc_chamfer"] = torch.tensor(0.0, device=src_arc_points.device)
385
+
386
+ losses["total_curve"] = (
387
+ losses["loss_bspline"]
388
+ + losses["loss_line_position"]
389
+ + losses["loss_line_length"]
390
+ + losses["loss_circle_position"]
391
+ + losses["loss_circle_radius"]
392
+ + losses["loss_line_chamfer"]
393
+ + losses["loss_circle_chamfer"]
394
+ + losses["loss_bspline_chamfer"]
395
+ + losses["loss_arc"]
396
+ + losses["loss_arc_chamfer"]
397
+ )
398
+
399
+ return losses
pi3detr/models/losses/matcher.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from scipy.optimize import linear_sum_assignment
4
+ from torch_geometric.data.data import Data
5
+
6
+
7
+ class ParametricMatcher(nn.Module):
8
+ def __init__(self, cost_class: int = 1, cost_curve: int = 1) -> None:
9
+ super().__init__()
10
+ self.cost_class = cost_class
11
+ self.cost_curve = cost_curve
12
+
13
+ @torch.no_grad()
14
+ def forward(
15
+ self, outputs: dict[str, torch.Tensor], data: Data
16
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
17
+ """
18
+ Compute the matching indices based on class costs and Chamfer distance.
19
+ """
20
+ bs, num_queries = outputs["pred_class"].shape[:2]
21
+
22
+ # Compute the classification cost
23
+ out_prob = (
24
+ outputs["pred_class"].flatten(0, 1).softmax(-1)
25
+ ) # [batch_size * num_queries, num_classes]
26
+ cost_class = -out_prob[:, data.y_cls]
27
+
28
+ pred_bspline_params = outputs["pred_bspline_params"].flatten(0, 1)
29
+ pred_line_params = outputs["pred_line_params"].flatten(0, 1)
30
+ pred_line_length = outputs["pred_line_length"].flatten(0, 1)
31
+ pred_circle_params = outputs["pred_circle_params"].flatten(0, 1)
32
+ pred_circle_radius = outputs["pred_circle_radius"].flatten(0, 1)
33
+ pred_arc_params = outputs["pred_arc_params"].flatten(0, 1)
34
+
35
+ # classes -> 1: bspline, 2: line, 3: circle, 4: arc
36
+ # NOTE: scaling done assuming points are in [-1, 1] range
37
+ bspline_costs = torch.min(
38
+ torch.cdist(
39
+ pred_bspline_params.flatten(-2, -1),
40
+ data.bspline_params.flatten(-2, -1),
41
+ p=1,
42
+ ),
43
+ torch.cdist(
44
+ pred_bspline_params.flip([1]).flatten(-2, -1),
45
+ data.bspline_params.flatten(-2, -1),
46
+ p=1,
47
+ ),
48
+ ) # [batch_size * num_queries, num_curves]
49
+ line_costs = torch.min(
50
+ torch.cdist(
51
+ pred_line_params.flatten(-2, -1),
52
+ data.line_params.flatten(-2, -1),
53
+ p=1,
54
+ ),
55
+ torch.cdist(
56
+ (
57
+ pred_line_params
58
+ * torch.tensor([1.0, -1.0])
59
+ .view(1, 2, 1)
60
+ .to(pred_line_params.device)
61
+ ).flatten(-2, -1),
62
+ data.line_params.flatten(-2, -1),
63
+ p=1,
64
+ ),
65
+ ) + torch.cdist(
66
+ pred_line_length,
67
+ data.line_length.unsqueeze(-1),
68
+ p=1,
69
+ ) # [batch_size * num_queries, num_curves]
70
+ circle_costs = torch.min(
71
+ torch.cdist(
72
+ pred_circle_params.flatten(-2, -1),
73
+ data.circle_params.flatten(-2, -1),
74
+ p=1,
75
+ ),
76
+ torch.cdist(
77
+ (
78
+ pred_circle_params
79
+ * torch.tensor([1.0, -1.0])
80
+ .view(1, 2, 1)
81
+ .to(pred_circle_params.device)
82
+ ).flatten(-2, -1),
83
+ data.circle_params.flatten(-2, -1),
84
+ p=1,
85
+ ),
86
+ ) + torch.cdist(
87
+ pred_circle_radius,
88
+ data.circle_radius.unsqueeze(-1),
89
+ p=1,
90
+ ) # [batch_size * num_queries, num_curves]
91
+ arc_costs = torch.min(
92
+ torch.cdist(
93
+ pred_arc_params.flatten(-2, -1),
94
+ data.arc_params.flatten(-2, -1),
95
+ p=1,
96
+ ),
97
+ # mid, start, end | start and end can be swapped
98
+ torch.cdist(
99
+ pred_arc_params[:, [0, 2, 1], :].flatten(-2, -1),
100
+ data.arc_params.flatten(-2, -1),
101
+ ),
102
+ )
103
+
104
+ cost_params = torch.stack(
105
+ [
106
+ torch.zeros_like(line_costs),
107
+ bspline_costs,
108
+ line_costs,
109
+ circle_costs,
110
+ arc_costs,
111
+ ],
112
+ dim=-1,
113
+ )
114
+ cost_params = cost_params[
115
+ torch.arange(cost_params.size(0))[:, None],
116
+ torch.arange(cost_params.size(1)),
117
+ data.y_cls,
118
+ ] # [num_queries, num_curves]
119
+
120
+ # Combine costs
121
+ C = self.cost_class * cost_class + self.cost_curve * cost_params
122
+ C = C.view(bs, num_queries, -1).cpu()
123
+
124
+ # Perform Hungarian matching
125
+ indices = [
126
+ linear_sum_assignment(c[i])
127
+ for i, c in enumerate(C.split(data.num_curves.cpu().tolist(), -1))
128
+ ]
129
+ return [
130
+ (
131
+ torch.as_tensor(i, dtype=torch.int64),
132
+ torch.as_tensor(j, dtype=torch.int64),
133
+ )
134
+ for i, j in indices
135
+ ]
pi3detr/models/model_config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+ from torch_geometric.nn import (
4
+ MLP,
5
+ )
6
+ from .pointnetpp import SAModule
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ model: str
12
+ num_features: int
13
+ epochs: int = 1700
14
+ lr: float = 1e-4
15
+ lr_warmup_epochs: int = 15
16
+ lr_warmup_start_factor: float = 1e-6
17
+ lr_step: int = 1230
18
+ batch_size: int = 8
19
+ batch_size_val: int = 8
20
+ loss_weights: Optional[dict[str, float]] = None
21
+ num_curve_points: Optional[int] = 64
22
+ num_curve_points_val: Optional[int] = 256
23
+ preencoder_type: Optional[str] = "samodule"
24
+ preencoder_lr: Optional[float] = 1e-4
25
+ freeze_backbone: bool = False
26
+ encoder_dim: Optional[int] = 768
27
+ decoder_dim: Optional[int] = 768
28
+ num_encoder_layers: Optional[int] = 3
29
+ num_decoder_layers: Optional[int] = 9
30
+ encoder_dropout: float = 0.1
31
+ decoder_dropout: float = 0.1
32
+ num_attn_heads: Optional[int] = 8
33
+ enc_dim_feedforward: Optional[int] = 2048
34
+ dec_dim_feedforward: Optional[int] = 2048
35
+ mlp_dropout: float = 0.0
36
+ num_preds: Optional[int] = 128
37
+ num_classes: Optional[int] = 5
38
+ cost_weights: Optional[dict[str, float]] = None
39
+ auxiliary_loss: bool = True
40
+ max_points_in_param: Optional[int] = 4
41
+ num_transformer_points: Optional[int] = 2048
42
+ query_type: str = "point_fps"
43
+ pos_embed_type: str = "sine"
44
+ class_loss_type: str = "cross_entropy" # or "focal"
45
+ class_loss_weights: list[float] = field(
46
+ default_factory=lambda: [
47
+ 0.04834912,
48
+ 0.40329467,
49
+ 0.09588135,
50
+ 0.23071379,
51
+ 0.22176106,
52
+ ]
53
+ )
54
+
55
+ def get_preencoder(self):
56
+ preencoder_type = self.preencoder_type
57
+ preencoder = None
58
+ if preencoder_type == "samodule":
59
+ preencoder = SAModule(
60
+ MLP([self.num_features + 3, 64, 128, self.encoder_dim]),
61
+ num_out_points=self.num_transformer_points,
62
+ )
63
+ preencoder.out_channels = self.encoder_dim
64
+ else:
65
+ raise ValueError(f"Unknown preencoder type: {self.preencoder_type}.")
66
+ return preencoder
pi3detr/models/pi3detr.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler
5
+ from torch_geometric.data.data import Data
6
+ from torch_geometric.nn import MLP
7
+ from torch_geometric.utils import to_dense_batch
8
+ from .model_config import ModelConfig
9
+ from .losses import LossParams, ParametricLoss
10
+ from .transformer import Transformer
11
+ from .positional_embedding import PositionEmbeddingCoordsSine
12
+ from .query_engine import build_query_engine
13
+ from pi3detr.dataset import reverse_normalize_and_scale
14
+ from ..utils.curve_fitter import (
15
+ torch_bezier_curve,
16
+ torch_line_points,
17
+ generate_points_on_circle_torch,
18
+ torch_arc_points,
19
+ )
20
+ from ..utils.postprocessing import (
21
+ snap_and_fit_curves,
22
+ filter_predictions,
23
+ iou_filter_point_based,
24
+ iou_filter_predictions,
25
+ )
26
+
27
+ from pi3detr.evaluation.abc_metrics import (
28
+ ChamferMAP,
29
+ )
30
+ from torchmetrics.classification import (
31
+ BinaryJaccardIndex,
32
+ BinaryPrecision,
33
+ BinaryRecall,
34
+ )
35
+
36
+
37
+ class PI3DETR(pl.LightningModule):
38
+ def __init__(self, config: ModelConfig):
39
+ super().__init__()
40
+ self.config = config
41
+ self.pc_preencoder = config.get_preencoder()
42
+ self.enc_dim = config.encoder_dim
43
+ self.dec_dim = config.decoder_dim
44
+ self.num_preds = config.num_preds
45
+ self.num_curve_points = config.num_curve_points
46
+ self.num_curve_points_val = config.num_curve_points_val
47
+ self.num_classes = config.num_classes
48
+ self.max_points_in_param = config.max_points_in_param
49
+ self.preenc_to_enc_proj = MLP(
50
+ [self.pc_preencoder.out_channels, self.enc_dim, self.enc_dim],
51
+ act="relu",
52
+ norm="layer_norm",
53
+ )
54
+ num_decoder_layers = config.num_decoder_layers
55
+ self.transformer = Transformer(
56
+ self.enc_dim,
57
+ self.dec_dim,
58
+ nhead=config.num_attn_heads,
59
+ num_encoder_layers=config.num_encoder_layers,
60
+ num_decoder_layers=num_decoder_layers,
61
+ enc_dim_feedforward=config.enc_dim_feedforward,
62
+ dec_dim_feedforward=config.dec_dim_feedforward,
63
+ enc_dropout=config.encoder_dropout,
64
+ dec_dropout=config.decoder_dropout,
65
+ return_intermediate_dec=True,
66
+ )
67
+ self.positional_embedding = PositionEmbeddingCoordsSine(
68
+ d_pos=self.dec_dim, pos_type=self.config.pos_embed_type
69
+ )
70
+ self.pos_embed_proj = MLP(
71
+ [self.dec_dim, self.dec_dim, self.dec_dim], act="relu", norm="layer_norm"
72
+ )
73
+ self.query_type = config.query_type
74
+ self.query_engine = build_query_engine(
75
+ self.query_type,
76
+ self.positional_embedding,
77
+ self.dec_dim,
78
+ self.max_points_in_param,
79
+ self.num_preds,
80
+ )
81
+
82
+ def make_mlp(out_dim, layers=4, base_dim=None, bias_last=True):
83
+ base_dim = base_dim or self.dec_dim
84
+ n_layers = layers - 1
85
+ return MLP(
86
+ channel_list=[base_dim] * n_layers + [out_dim],
87
+ bias=[False] * (n_layers - 1) + [bias_last],
88
+ dropout=self.config.mlp_dropout,
89
+ act="relu",
90
+ norm="layer_norm",
91
+ )
92
+
93
+ self.class_head = make_mlp(self.num_classes)
94
+ self.bspline_param_head = make_mlp(4 * 3)
95
+ self.line_param_head = make_mlp(2 * 3)
96
+ self.line_length_head = make_mlp(1, layers=3)
97
+ self.circle_param_head = make_mlp(2 * 3)
98
+ self.circle_radius_head = make_mlp(1, layers=3)
99
+ self.arc_param_head = make_mlp(3 * 3)
100
+
101
+ self.loss = ParametricLoss(
102
+ LossParams(
103
+ num_classes=self.num_classes - 1, # -1 for the EOS token
104
+ cost_class=config.cost_weights["cost_class"],
105
+ cost_curve=config.cost_weights["cost_curve"],
106
+ class_loss_type=config.class_loss_type,
107
+ class_loss_weights=config.class_loss_weights,
108
+ )
109
+ )
110
+ self.auxiliary_loss = self.config.auxiliary_loss
111
+ self.weight_dict = {
112
+ "loss_class": config.loss_weights["loss_class"],
113
+ "loss_bspline": config.loss_weights["loss_bspline"],
114
+ "loss_bspline_chamfer": config.loss_weights["loss_bspline_chamfer"],
115
+ "loss_line_position": config.loss_weights["loss_line_position"],
116
+ "loss_line_length": config.loss_weights["loss_line_length"],
117
+ "loss_line_chamfer": config.loss_weights["loss_line_chamfer"],
118
+ "loss_circle_position": config.loss_weights["loss_circle_position"],
119
+ "loss_circle_radius": config.loss_weights["loss_circle_radius"],
120
+ "loss_circle_chamfer": config.loss_weights["loss_circle_chamfer"],
121
+ "loss_arc": config.loss_weights["loss_arc"],
122
+ "loss_arc_chamfer": config.loss_weights["loss_arc_chamfer"],
123
+ }
124
+ # TODO this is a hack
125
+ self.aux_weight_dict = {}
126
+ if self.auxiliary_loss:
127
+ for i in range(num_decoder_layers - 1):
128
+ self.aux_weight_dict.update(
129
+ {k + f"_{i}": v for k, v in self.weight_dict.items()}
130
+ )
131
+ self.weight_dict.update(self.aux_weight_dict)
132
+
133
+ self.chamfer_map = ChamferMAP(chamfer_thresh=0.05)
134
+
135
+ # Torchmetrics for segmentation
136
+ self.seg_iou = BinaryJaccardIndex()
137
+ self.seg_precision = BinaryPrecision()
138
+ self.seg_recall = BinaryRecall()
139
+
140
+ def forward(self, data: Data) -> dict[str, Tensor]:
141
+ x, pos, batch = self.pc_preencoder(data)[-1]
142
+ x = self.preenc_to_enc_proj(x)
143
+ x, mask = to_dense_batch(x, batch)
144
+ pos_dense_batch, _ = to_dense_batch(pos, batch)
145
+ pos_embed = self.pos_embed_proj(
146
+ self.positional_embedding(
147
+ pos_dense_batch, num_channels=self.dec_dim
148
+ ).permute(0, 2, 1)
149
+ ).permute(0, 2, 1)
150
+ query_xyz, query_embed = self.query_engine(Data(pos=pos, batch=batch))
151
+ x = self.transformer(
152
+ x, # [batch_size, num_points, enc_dim]
153
+ # transformer expects 1s to be masked
154
+ ~mask if not torch.all(mask) else None, # [batch_size, num_points]
155
+ query_embed, # [batch_size, dec_dim, num_queries]
156
+ pos_embed, # [batch_size, dec_dim, num_points]
157
+ )
158
+ output_class = self.class_head(x)
159
+ output_bspline_params = self.bspline_param_head(x)
160
+ output_line_params = self.line_param_head(x)
161
+ output_line_length = self.line_length_head(x)
162
+ output_circle_params = self.circle_param_head(x)
163
+ output_circle_radius = self.circle_radius_head(x)
164
+ output_arc_params = self.arc_param_head(x)
165
+
166
+ pred_bspline_params = (
167
+ output_bspline_params[-1].reshape(data.batch_size, self.num_preds, 4, 3)
168
+ + query_xyz
169
+ )
170
+ pred_line_params = output_line_params[-1].reshape(
171
+ data.batch_size, self.num_preds, 2, 3
172
+ )
173
+ pred_line_params[:, :, 0, :] = (
174
+ pred_line_params[:, :, 0, :] + query_xyz[:, :, 0, :]
175
+ )
176
+
177
+ pred_circle_params = output_circle_params[-1].reshape(
178
+ data.batch_size, self.num_preds, 2, 3
179
+ )
180
+ pred_circle_params[:, :, 0, :] = (
181
+ pred_circle_params[:, :, 0, :] + query_xyz[:, :, 0, :]
182
+ )
183
+
184
+ pred_arc_params = (
185
+ output_arc_params[-1].reshape(data.batch_size, self.num_preds, 3, 3)
186
+ + query_xyz[:, :, :3, :]
187
+ )
188
+
189
+ out = {
190
+ "pred_class": output_class[-1],
191
+ "pred_bspline_params": pred_bspline_params,
192
+ "pred_line_params": pred_line_params,
193
+ "pred_line_length": output_line_length[-1],
194
+ "pred_circle_params": pred_circle_params,
195
+ "pred_circle_radius": output_circle_radius[-1],
196
+ "pred_arc_params": pred_arc_params,
197
+ "query_xyz": query_xyz,
198
+ }
199
+ if self.auxiliary_loss and self.training:
200
+ out["aux_outputs"] = self._set_aux_loss(
201
+ output_bspline_params,
202
+ output_line_params,
203
+ output_line_length,
204
+ output_circle_params,
205
+ output_circle_radius,
206
+ output_arc_params,
207
+ query_xyz,
208
+ output_class,
209
+ )
210
+ return out
211
+
212
+ @torch.jit.unused
213
+ def _set_aux_loss(
214
+ self,
215
+ output_bspline_params: torch.Tensor,
216
+ output_line_params: torch.Tensor,
217
+ output_line_length: torch.Tensor,
218
+ output_circle_params: torch.Tensor,
219
+ output_circle_radius: torch.Tensor,
220
+ output_arc_params: torch.Tensor,
221
+ query_xyz: torch.Tensor,
222
+ output_class: torch.Tensor,
223
+ ) -> list[dict[str, torch.Tensor]]:
224
+ # this is a workaround to make torchscript happy, as torchscript
225
+ # doesn't support dictionary with non-homogeneous values, such
226
+ # as a dict having both a Tensor and a list.
227
+ out_aux = []
228
+ for b, l, ll, c, cr, a, cl in zip(
229
+ output_bspline_params[:-1],
230
+ output_line_params[:-1],
231
+ output_line_length[:-1],
232
+ output_circle_params[:-1],
233
+ output_circle_radius[:-1],
234
+ output_arc_params[:-1],
235
+ output_class[:-1],
236
+ ):
237
+ pred_bspline_params = b.reshape(*b.shape[:2], 4, 3) + query_xyz
238
+ # second point is the direction vector
239
+ pred_line_params = l.reshape(*l.shape[:2], 2, 3)
240
+ pred_line_params_adjusted = pred_line_params.clone()
241
+ pred_line_params_adjusted[:, :, 0, :] = (
242
+ pred_line_params[:, :, 0, :] + query_xyz[:, :, 0, :]
243
+ )
244
+ pred_circle_params = c.reshape(*c.shape[:2], 2, 3)
245
+ pred_circle_params_adjusted = pred_circle_params.clone()
246
+ pred_circle_params_adjusted[:, :, 0, :] = (
247
+ pred_circle_params[:, :, 0, :] + query_xyz[:, :, 0, :]
248
+ )
249
+ pred_arc_params = a.reshape(*a.shape[:2], 3, 3) + query_xyz[:, :, :3, :]
250
+
251
+ layer_out = {
252
+ "pred_bspline_params": pred_bspline_params,
253
+ "pred_line_params": pred_line_params_adjusted,
254
+ "pred_line_length": ll,
255
+ "pred_circle_params": pred_circle_params_adjusted,
256
+ "pred_circle_radius": cr,
257
+ "pred_arc_params": pred_arc_params,
258
+ "pred_class": cl,
259
+ }
260
+ layer_out.update(
261
+ self._sample_curve_points(layer_out, self.num_curve_points)
262
+ )
263
+ out_aux.append(layer_out)
264
+
265
+ return out_aux
266
+
267
+ def _sample_curve_points(
268
+ self, out: dict[str, Tensor], num_points: int
269
+ ) -> dict[str, Tensor]:
270
+ batch_size, num_preds = out["pred_bspline_params"].shape[:2]
271
+ pred_line_params = out["pred_line_params"]
272
+ pred_line_length = out["pred_line_length"]
273
+ pred_line_start = (
274
+ pred_line_params[:, :, 0, :]
275
+ - pred_line_params[:, :, 1, :] * pred_line_length / 2.0
276
+ )
277
+ pred_line_end = (
278
+ pred_line_params[:, :, 0, :]
279
+ + pred_line_params[:, :, 1, :] * pred_line_length / 2.0
280
+ )
281
+ curves = {}
282
+ curves["pred_bspline_points"] = torch_bezier_curve(
283
+ out["pred_bspline_params"].reshape(-1, 4, 3), num_points
284
+ ).reshape(batch_size, num_preds, -1, 3)
285
+ curves["pred_line_points"] = torch_line_points(
286
+ pred_line_start.reshape(-1, 3),
287
+ pred_line_end.reshape(-1, 3),
288
+ num_points,
289
+ ).reshape(batch_size, num_preds, -1, 3)
290
+ curves["pred_circle_points"] = generate_points_on_circle_torch(
291
+ out["pred_circle_params"].reshape(-1, 2, 3)[:, 0],
292
+ out["pred_circle_params"].reshape(-1, 2, 3)[:, 1],
293
+ out["pred_circle_radius"].reshape(-1),
294
+ num_points,
295
+ ).reshape(batch_size, num_preds, -1, 3)
296
+ curves["pred_arc_points"] = torch_arc_points(
297
+ out["pred_arc_params"][:, :, 1, :].reshape(-1, 3),
298
+ out["pred_arc_params"][:, :, 0, :].reshape(-1, 3),
299
+ out["pred_arc_params"][:, :, 2, :].reshape(-1, 3),
300
+ num_points,
301
+ ).reshape(batch_size, num_preds, -1, 3)
302
+ return curves
303
+
304
+ def predict_step(
305
+ self,
306
+ batch: Data,
307
+ reverse_norm: bool = True,
308
+ thresholds: list[float] = None,
309
+ snap_and_fit: bool = True,
310
+ iou_filter: bool = False,
311
+ ) -> list[Data]:
312
+ preds = self(batch)
313
+ preds.update(self._sample_curve_points(preds, self.num_curve_points_val))
314
+
315
+ outputs = self.decode_predictions(batch, preds, reverse_norm)
316
+
317
+ if thresholds:
318
+ outputs = [filter_predictions(data, thresholds) for data in outputs]
319
+
320
+ if snap_and_fit:
321
+ outputs = [snap_and_fit_curves(data.clone()) for data in outputs]
322
+
323
+ if iou_filter:
324
+ outputs = [iou_filter_predictions(data) for data in outputs]
325
+
326
+ return outputs
327
+
328
+ def training_step(self, batch: Data, batch_idx: int) -> Tensor:
329
+ outputs = self(batch)
330
+ outputs.update(self._sample_curve_points(outputs, self.num_curve_points))
331
+ loss_dict = self.loss(outputs, batch)
332
+ for k, v in loss_dict.items():
333
+ weight = self.weight_dict[k] if "loss" in k else 1
334
+ self._default_log(k, v * weight)
335
+ # weigh losses and sum them for backpropagation
336
+ weighted_loss_dict = {
337
+ k: loss_dict[k] * self.weight_dict[k] for k in self.weight_dict.keys()
338
+ }
339
+ total_loss = sum(weighted_loss_dict.values())
340
+ self._default_log("loss_train", total_loss)
341
+ return total_loss
342
+
343
+ @torch.no_grad()
344
+ def validation_step(self, batch: Data, batch_idx: int) -> None:
345
+ outputs = self(batch)
346
+ # sample the training curve points for loss computation
347
+ outputs.update(self._sample_curve_points(outputs, self.num_curve_points))
348
+ loss_dict = self.loss(outputs, batch)
349
+ for k, v in loss_dict.items():
350
+ weight = self.weight_dict[k] if "loss" in k else 1
351
+ self._default_log(f"val_{k}", v * weight)
352
+ without_aux = {
353
+ k for k in self.weight_dict.keys() if k not in self.aux_weight_dict
354
+ }
355
+ self._default_log(
356
+ "loss_val", sum(loss_dict[k] * self.weight_dict[k] for k in without_aux)
357
+ )
358
+ # Sample curve points for validation
359
+ outputs.update(self._sample_curve_points(outputs, self.num_curve_points_val))
360
+ self._compute_metrics(batch, outputs)
361
+
362
+ @torch.no_grad()
363
+ def on_validation_epoch_end(self) -> None:
364
+ metrics = self.chamfer_map.compute()
365
+ for key, value in metrics.items():
366
+ self._default_log(f"val_{key}", value)
367
+ self.chamfer_map.reset()
368
+
369
+ # Log segmentation metrics at epoch end
370
+ self._default_log("val_seg_iou", self.seg_iou.compute())
371
+ self._default_log(
372
+ "val_seg_precision",
373
+ self.seg_precision.compute(),
374
+ )
375
+ self._default_log("val_seg_recall", self.seg_recall.compute())
376
+ self.seg_iou.reset()
377
+ self.seg_precision.reset()
378
+ self.seg_recall.reset()
379
+
380
+ def test_step(self, batch: Data, batch_idx: int) -> None:
381
+ outputs = self(batch)
382
+ outputs.update(self._sample_curve_points(outputs, self.num_curve_points_val))
383
+ self._compute_metrics(batch, outputs)
384
+
385
+ def on_test_epoch_end(self) -> None:
386
+ metrics = self.chamfer_map.compute()
387
+ self.chamfer_map.reset()
388
+ for key, value in metrics.items():
389
+ self.log(f"test_{key}", value, prog_bar=False)
390
+
391
+ # Log segmentation metrics at epoch end
392
+ self.log("test_seg_iou", self.seg_iou.compute(), prog_bar=False)
393
+ self.log("test_seg_precision", self.seg_precision.compute(), prog_bar=False)
394
+ self.log("test_seg_recall", self.seg_recall.compute(), prog_bar=False)
395
+ self.seg_iou.reset()
396
+ self.seg_precision.reset()
397
+ self.seg_recall.reset()
398
+
399
+ def _compute_metrics(self, batch: Data, preds: dict):
400
+ # segmentation metrics
401
+ outputs = self.decode_predictions(batch, preds, reverse_norm=True)
402
+ for i, output in enumerate(outputs):
403
+ self.seg_iou.update(output.segmentation, output.y_seg)
404
+ self.seg_precision.update(output.segmentation, output.y_seg)
405
+ self.seg_recall.update(output.segmentation, output.y_seg)
406
+ # chamfer metrics
407
+ self.chamfer_map.update(preds, batch)
408
+
409
+ def set_num_preds(self, num_preds: int) -> None:
410
+ if num_preds == self.num_preds:
411
+ return
412
+ self.num_preds = num_preds
413
+ old_state = (
414
+ self.query_engine.state_dict()
415
+ if isinstance(self.query_engine, nn.Module)
416
+ else None
417
+ )
418
+ new_engine = build_query_engine(
419
+ self.query_type,
420
+ self.positional_embedding,
421
+ self.dec_dim,
422
+ self.max_points_in_param,
423
+ self.num_preds,
424
+ )
425
+ if old_state is not None:
426
+ new_state = new_engine.state_dict()
427
+ for k, v in old_state.items():
428
+ assert k in new_state, f"Missing parameter in new query engine: {k}"
429
+ nv = new_state[k]
430
+ assert (
431
+ v.shape == nv.shape
432
+ ), f"Shape mismatch for {k}: {v.shape} != {nv.shape}"
433
+ nv.copy_(v.to(nv.device))
434
+ new_engine.load_state_dict(new_state, strict=True)
435
+ self.query_engine = new_engine.to(self.device)
436
+
437
+ @torch.no_grad()
438
+ def decode_predictions(
439
+ self, batch: Data, preds: Data, reverse_norm: bool = True
440
+ ) -> list[Data]:
441
+ outputs = []
442
+
443
+ # Vectorized class prediction and score
444
+ preds_class = preds["pred_class"].softmax(-1)
445
+ polyline_class = preds_class.argmax(-1) # (batch_size, num_preds)
446
+ polyline_score = preds_class.max(-1).values # (batch_size, num_preds)
447
+
448
+ # Prepare all possible polylines: (batch_size, num_preds, num_polypoints, 3)
449
+ bspline_points = preds["pred_bspline_points"]
450
+ line_points = preds["pred_line_points"]
451
+ circle_points = preds["pred_circle_points"]
452
+ arc_points = preds["pred_arc_points"]
453
+ zeros_points = torch.zeros_like(bspline_points) # EOS/empty
454
+
455
+ # Stack all types: (batch_size, num_preds, 4, num_polypoints, 3)
456
+ all_polylines = torch.stack(
457
+ [zeros_points, bspline_points, line_points, circle_points, arc_points],
458
+ dim=2,
459
+ )
460
+
461
+ # Gather correct polyline for each prediction
462
+ # polyline_class: (batch_size, num_preds)
463
+ # Need to expand to match all_polylines shape for gather
464
+ idx = (
465
+ polyline_class.unsqueeze(-1)
466
+ .unsqueeze(-1)
467
+ .unsqueeze(2) # shape: (batch_size, num_preds, 1, num_polypoints, 3)
468
+ .expand(-1, -1, 1, self.num_curve_points_val, 3)
469
+ )
470
+ polylines = torch.gather(all_polylines, 2, idx).squeeze(
471
+ 2
472
+ ) # (batch_size, num_preds, num_polypoints, 3)
473
+
474
+ batch_size = batch.batch_size
475
+ device = batch.pos.device
476
+ segmentations = []
477
+ for i in range(batch_size):
478
+ # If all predicted classes are zero (EOS), segmentation should be all zeros
479
+ if torch.all(polyline_class[i] == 0):
480
+ pc_pts = batch.pos[batch.batch == i]
481
+ segmentation = torch.zeros(
482
+ pc_pts.shape[0], dtype=torch.long, device=device
483
+ )
484
+ segmentations.append(segmentation)
485
+ continue
486
+
487
+ poly_pts = polylines[i, polyline_class[i] != 0].reshape(-1, 3)
488
+ pc_pts = batch.pos[batch.batch == i] # (num_points_in_cloud, 3)
489
+ dists = torch.cdist(poly_pts, pc_pts)
490
+ closest_idx = dists.argmin(dim=1)
491
+ segmentation = torch.zeros(pc_pts.shape[0], dtype=torch.long, device=device)
492
+ segmentation[closest_idx.unique()] = 1
493
+ segmentations.append(segmentation)
494
+
495
+ for i in range(batch.batch_size):
496
+ output = Data(
497
+ pos=batch.pos[batch.batch == i].clone(), # point cloud
498
+ bspline_points=bspline_points[i], # prediction of B-spline head
499
+ line_points=line_points[i], # prediction of line heads
500
+ circle_points=circle_points[i], # prediction of circle heads
501
+ arc_points=arc_points[i], # prediction of arc head
502
+ polyline_class=polyline_class[i], # class of each polyline
503
+ polyline_score=polyline_score[i], # score of polyline class
504
+ polylines=polylines[i], # polyline that matches polyline_class
505
+ segmentation=segmentations[
506
+ i
507
+ ], # curve segmentation for whole point cloud
508
+ query_xyz=preds["query_xyz"][i], # query points for the transformer
509
+ )
510
+ if hasattr(batch, "y_seg"):
511
+ output.y_seg = batch.y_seg[batch.batch == i]
512
+
513
+ if reverse_norm:
514
+ output.center = batch.center[i]
515
+ output.scale = batch.scale[i]
516
+
517
+ output = reverse_normalize_and_scale(
518
+ output,
519
+ extra_fields=[
520
+ "polylines",
521
+ "bspline_points",
522
+ "line_points",
523
+ "circle_points",
524
+ "arc_points",
525
+ "query_xyz",
526
+ ],
527
+ )
528
+ outputs.append(output)
529
+
530
+ return outputs
531
+
532
+ def _default_log(self, name: str, value: Tensor) -> None:
533
+ batch_size = (
534
+ self.config.batch_size if self.training else self.config.batch_size_val
535
+ )
536
+ self.log(
537
+ name,
538
+ value,
539
+ prog_bar=True,
540
+ on_epoch=True,
541
+ on_step=False,
542
+ sync_dist=True,
543
+ batch_size=batch_size,
544
+ )
545
+
546
+ def configure_optimizers(self) -> OptimizerLRScheduler:
547
+ param_dict = None
548
+ config = self.config
549
+ if config.freeze_backbone:
550
+ for param in self.pc_preencoder.parameters():
551
+ param.requires_grad = False
552
+ param_dict = self.parameters()
553
+ elif config.lr != config.preencoder_lr:
554
+ param_dict = [
555
+ {
556
+ "params": [
557
+ p for n, p in self.named_parameters() if "pc_encoder" not in n
558
+ ]
559
+ },
560
+ {
561
+ "params": [
562
+ p for n, p in self.named_parameters() if "pc_encoder" in n
563
+ ],
564
+ "lr": self.config.preencoder_lr,
565
+ },
566
+ ]
567
+ else:
568
+ param_dict = self.parameters()
569
+ # ----- OPTIMIZER -----
570
+ optimizer = torch.optim.AdamW(param_dict, lr=self.config.lr)
571
+ # ----- WARMUP SCHEDULER -----
572
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
573
+ optimizer,
574
+ start_factor=config.lr_warmup_start_factor, # start near zero
575
+ end_factor=1.0, # ramp up to base LR
576
+ total_iters=config.lr_warmup_epochs,
577
+ )
578
+ # ----- STEP SCHEDULER -----
579
+ # Drop LR by factor after (step_epoch - warmup_epochs) epochs
580
+ step_scheduler = torch.optim.lr_scheduler.StepLR(
581
+ optimizer,
582
+ step_size=config.lr_step - config.lr_warmup_epochs,
583
+ gamma=0.1, # drop LR to 10%
584
+ last_epoch=config.epochs,
585
+ )
586
+ # ----- COMBINE -----
587
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
588
+ optimizer,
589
+ schedulers=[warmup_scheduler, step_scheduler],
590
+ milestones=[config.lr_warmup_epochs],
591
+ )
592
+
593
+ return [optimizer], {"scheduler": scheduler, "interval": "epoch"}
pi3detr/models/pointnetpp.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from typing import Optional, Tuple
4
+ import numpy as np
5
+ from scipy.spatial import cKDTree
6
+ from torch_geometric.nn import (
7
+ MLP,
8
+ PointNetConv,
9
+ fps,
10
+ knn_interpolate,
11
+ radius,
12
+ global_max_pool,
13
+ )
14
+ from torch_geometric.data.data import BaseData
15
+
16
+ TensorTriple = Tuple[Tensor, Tensor, Tensor]
17
+
18
+
19
+ def radius_cpu(
20
+ x: torch.Tensor,
21
+ y: torch.Tensor,
22
+ r: float,
23
+ batch_x: Optional[torch.Tensor] = None,
24
+ batch_y: Optional[torch.Tensor] = None,
25
+ max_num_neighbors: Optional[int] = None,
26
+ loop: bool = False,
27
+ sort_by_distance: bool = True,
28
+ ) -> Tuple[torch.LongTensor, torch.LongTensor]:
29
+ """
30
+ CPU replacement for torch_cluster.radius / torch_geometric.radius.
31
+
32
+ Semantics (matching torch_geometric.radius):
33
+ Returns (row, col) where
34
+ row: indices into `y` (centers) in the range [0, y.size(0))
35
+ col: indices into `x` (neighbors) in the range [0, x.size(0))
36
+
37
+ Thus, for y = x[idx]:
38
+ edge_index = torch.stack([col, row], dim=0)
39
+ edge_index[0] indexes the full set (source/neighbor),
40
+ edge_index[1] indexes the sampled centers.
41
+ """
42
+ # Basic checks
43
+ if x.device.type != "cpu" or y.device.type != "cpu":
44
+ raise ValueError("radius_cpu expects x and y to be on CPU.")
45
+ if x.ndim != 2 or y.ndim != 2:
46
+ raise ValueError("x and y must be 2D (N, D).")
47
+ if x.shape[1] != y.shape[1]:
48
+ raise ValueError("x and y must have same dimensionality D.")
49
+
50
+ N_x = x.shape[0]
51
+ N_y = y.shape[0]
52
+ if N_x == 0 or N_y == 0:
53
+ return torch.empty((0,), dtype=torch.long), torch.empty((0,), dtype=torch.long)
54
+
55
+ x_np = np.asarray(x)
56
+ y_np = np.asarray(y)
57
+
58
+ if batch_x is None:
59
+ batch_x = torch.zeros(N_x, dtype=torch.long)
60
+ else:
61
+ if batch_x.device.type != "cpu":
62
+ batch_x = batch_x.cpu()
63
+ batch_x = batch_x.long()
64
+
65
+ if batch_y is None:
66
+ batch_y = torch.zeros(N_y, dtype=torch.long)
67
+ else:
68
+ if batch_y.device.type != "cpu":
69
+ batch_y = batch_y.cpu()
70
+ batch_y = batch_y.long()
71
+
72
+ rows = []
73
+ cols = []
74
+
75
+ unique_batches = torch.unique(torch.cat([batch_x, batch_y])).tolist()
76
+ # iterate only over batches actually present in y to avoid unnecessary work
77
+ unique_batches = sorted(set(batch_y.tolist()))
78
+
79
+ for b in unique_batches:
80
+ # mask and maps from local->global indices
81
+ mask_x = (batch_x == b).numpy()
82
+ mask_y = (batch_y == b).numpy()
83
+ idxs_x = np.nonzero(mask_x)[0] # global indices in x
84
+ idxs_y = np.nonzero(mask_y)[0] # global indices in y
85
+
86
+ if idxs_y.size == 0 or idxs_x.size == 0:
87
+ continue
88
+
89
+ pts_x = x_np[mask_x]
90
+ pts_y = y_np[mask_y]
91
+
92
+ # build tree on source points (x) and query for each center in y
93
+ tree = cKDTree(pts_x)
94
+ # neighbors_list: for each center (local), a list of local indices into pts_x
95
+ neighbors_list = tree.query_ball_point(pts_y, r)
96
+
97
+ for local_center, neigh_locals in enumerate(neighbors_list):
98
+ if len(neigh_locals) == 0:
99
+ continue
100
+ neigh_locals = np.array(neigh_locals, dtype=int)
101
+
102
+ # remove self if requested AND x and y are the same set at same global indices
103
+ if not loop:
104
+ # If x and y refer to the same global indices and same coords, remove self-match
105
+ # we detect self by checking whether global index equals center global index
106
+ center_global = idxs_y[local_center]
107
+ # compute global neighbor indices
108
+ neigh_globals = idxs_x[neigh_locals]
109
+ # boolean mask for neighbors that are not self
110
+ not_self_mask = neigh_globals != center_global
111
+ neigh_locals = neigh_locals[not_self_mask]
112
+
113
+ if neigh_locals.size == 0:
114
+ continue
115
+
116
+ # apply max_num_neighbors: keep closest ones by distance if requested
117
+ if max_num_neighbors is not None and neigh_locals.size > max_num_neighbors:
118
+ if sort_by_distance:
119
+ dists = np.linalg.norm(
120
+ pts_x[neigh_locals] - pts_y[local_center], axis=1
121
+ )
122
+ order = np.argsort(dists)[:max_num_neighbors]
123
+ neigh_locals = neigh_locals[order]
124
+ else:
125
+ neigh_locals = np.sort(neigh_locals)[:max_num_neighbors]
126
+
127
+ # optionally sort by distance
128
+ if sort_by_distance and neigh_locals.size > 0:
129
+ dists = np.linalg.norm(
130
+ pts_x[neigh_locals] - pts_y[local_center], axis=1
131
+ )
132
+ order = np.argsort(dists)
133
+ neigh_locals = neigh_locals[order]
134
+
135
+ # convert to global indices and append
136
+ neigh_globals = idxs_x[neigh_locals].tolist()
137
+ center_global = int(idxs_y[local_center])
138
+ rows.extend(neigh_globals) # neighbor indices into x (row)
139
+ cols.extend(
140
+ [center_global] * len(neigh_globals)
141
+ ) # center indices into y (col)
142
+
143
+ if len(rows) == 0:
144
+ return torch.empty((0,), dtype=torch.long), torch.empty((0,), dtype=torch.long)
145
+
146
+ row_t = torch.tensor(rows, dtype=torch.long) # currently neighbors (x)
147
+ col_t = torch.tensor(cols, dtype=torch.long) # currently centers (y)
148
+
149
+ # Swap to enforce (row=center_indices_in_y, col=neighbor_indices_in_x)
150
+ return col_t, row_t
151
+
152
+
153
+ class SAModuleRatio(torch.nn.Module):
154
+ def __init__(
155
+ self, ratio: float, r: float, nn: nn.Module, max_num_neighbors: int = 64
156
+ ):
157
+ super().__init__()
158
+ self.ratio = ratio
159
+ self.r = r
160
+ self.conv = PointNetConv(nn, add_self_loops=False)
161
+ self.max_num_neighbors = max_num_neighbors
162
+
163
+ def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor):
164
+ idx = fps(pos, batch, ratio=self.ratio)
165
+ row, col = radius(
166
+ pos,
167
+ pos[idx],
168
+ self.r,
169
+ batch,
170
+ batch[idx],
171
+ max_num_neighbors=self.max_num_neighbors,
172
+ )
173
+ edge_index = torch.stack([col, row], dim=0)
174
+ x_dst = None if x is None else x[idx]
175
+ x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
176
+ pos, batch = pos[idx], batch[idx]
177
+ return x, pos, batch
178
+
179
+
180
+ class SAModule(torch.nn.Module):
181
+ def __init__(
182
+ self,
183
+ nn: nn.Module,
184
+ num_out_points: float = 2048,
185
+ r: float = 0.2,
186
+ max_num_neighbors: int = 64,
187
+ ):
188
+ super().__init__()
189
+ self.num_out_points = num_out_points
190
+ self.r = r
191
+ self.conv = PointNetConv(nn, add_self_loops=False)
192
+ self.max_num_neighbors = max_num_neighbors
193
+
194
+ def forward(self, data: BaseData) -> list[tuple[TensorTriple]]:
195
+ x, pos, batch = data.x, data.pos, data.batch
196
+ num_points_per_batch = torch.bincount(batch)
197
+ max_ratio = self.num_out_points / num_points_per_batch.min().item()
198
+ fps_idx = fps(pos, batch, ratio=max_ratio)
199
+ fps_batch = batch[fps_idx]
200
+ idx = torch.cat(
201
+ [
202
+ fps_idx[fps_batch == i][: self.num_out_points]
203
+ for i in range(batch.max().item() + 1)
204
+ ]
205
+ )
206
+ if pos.device == torch.device("cpu"):
207
+ row, col = radius_cpu(
208
+ pos,
209
+ pos[idx],
210
+ self.r,
211
+ batch,
212
+ batch[idx],
213
+ max_num_neighbors=self.max_num_neighbors,
214
+ sort_by_distance=False,
215
+ )
216
+ else: # GPU
217
+ row, col = radius(
218
+ pos,
219
+ pos[idx],
220
+ self.r,
221
+ batch,
222
+ batch[idx],
223
+ max_num_neighbors=self.max_num_neighbors,
224
+ )
225
+ edge_index = torch.stack([col, row], dim=0)
226
+ x_dst = None if x is None else x[idx]
227
+ x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
228
+ pos, batch = pos[idx], batch[idx]
229
+ return [(x, pos, batch)]
230
+
231
+
232
+ class GlobalSAModule(torch.nn.Module):
233
+ def __init__(self, nn):
234
+ super().__init__()
235
+ self.nn = nn
236
+
237
+ def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor):
238
+ x = self.nn(torch.cat([x, pos], dim=1))
239
+ x = global_max_pool(x, batch)
240
+ pos = pos.new_zeros((x.size(0), 3))
241
+ batch = torch.arange(x.size(0), device=batch.device)
242
+ return x, pos, batch
243
+
244
+
245
+ class FPModule(nn.Module):
246
+ def __init__(self, k, nn):
247
+ super().__init__()
248
+ self.k = k
249
+ self.nn = nn
250
+
251
+ def forward(
252
+ self,
253
+ x: torch.Tensor,
254
+ pos: torch.Tensor,
255
+ batch: torch.Tensor,
256
+ x_skip: torch.Tensor,
257
+ pos_skip: torch.Tensor,
258
+ batch_skip: torch.Tensor,
259
+ ):
260
+ x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
261
+ if x_skip is not None:
262
+ x = torch.cat([x, x_skip], dim=1)
263
+ x = self.nn(x)
264
+ return x, pos_skip, batch_skip
265
+
266
+
267
+ class PointNetPPEncoder(nn.Module):
268
+ def __init__(self, num_features: int = 3, out_channels: int = 512):
269
+ super().__init__()
270
+ self.out_channels = out_channels
271
+ # Input channels account for both `pos` and node features.
272
+ self.sa1_module = SAModuleRatio(
273
+ 0.5, 0.05, MLP([num_features + 3, 32, 32, 64]), 32
274
+ )
275
+ self.sa2_module = SAModuleRatio(0.5, 0.1, MLP([64 + 3, 64, 64, 128]), 32)
276
+ self.sa3_module = SAModuleRatio(0.5, 0.2, MLP([128 + 3, 128, 128, 256]), 32)
277
+ self.sa4_module = SAModuleRatio(
278
+ 0.5, 0.4, MLP([256 + 3, 256, 256, self.out_channels]), 32
279
+ )
280
+
281
+ def forward(self, data: BaseData) -> list[Tensor]:
282
+ sa0_out = (data.x, data.pos, data.batch)
283
+ sa1_out = self.sa1_module(*sa0_out)
284
+ sa2_out = self.sa2_module(*sa1_out)
285
+ sa3_out = self.sa3_module(*sa2_out)
286
+ sa4_out = self.sa4_module(*sa3_out)
287
+ return [sa0_out, sa1_out, sa2_out, sa3_out, sa4_out]
288
+
289
+
290
+ class PointNetPPDecoder(nn.Module):
291
+ def __init__(self, num_features: int = 3, out_channels: int = 256):
292
+ super().__init__()
293
+ self.out_channels = out_channels
294
+ self.fp4_module = FPModule(1, MLP([512 + 256, 256, 256]))
295
+ self.fp3_module = FPModule(3, MLP([256 + 128, 256, 256]))
296
+ self.fp2_module = FPModule(3, MLP([256 + 64, 256, 128]))
297
+ self.fp1_module = FPModule(3, MLP([128 + num_features, 128, self.out_channels]))
298
+
299
+ def forward(
300
+ self,
301
+ sa0_out: TensorTriple,
302
+ sa1_out: TensorTriple,
303
+ sa2_out: TensorTriple,
304
+ sa3_out: TensorTriple,
305
+ sa4_out: TensorTriple,
306
+ ) -> TensorTriple:
307
+ fp4_out = self.fp4_module(*sa4_out, *sa3_out)
308
+ fp3_out = self.fp3_module(*fp4_out, *sa2_out)
309
+ fp2_out = self.fp2_module(*fp3_out, *sa1_out)
310
+ x, pos, batch = self.fp1_module(*fp2_out, *sa0_out)
311
+ return [x, pos, batch]
312
+
313
+
314
+ class PointNetPP(nn.Module):
315
+ def __init__(self, num_features: int, dec_out_channels: int = 256):
316
+ super().__init__()
317
+ self.encoder = PointNetPPEncoder(num_features)
318
+ self.decoder = PointNetPPDecoder(num_features, dec_out_channels)
319
+ self.out_channels = self.decoder.out_channels
320
+
321
+ def forward(self, data: BaseData) -> TensorTriple:
322
+ x = self.encoder(data)
323
+ x, pos, batch = self.decoder(*x)
324
+ return [(x, pos, batch)]
pi3detr/models/positional_embedding.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+
4
+
5
+
6
+ """
7
+ Various positional encodings for the transformer.
8
+ """
9
+ import math
10
+ import torch
11
+ from torch import nn
12
+ import numpy as np
13
+
14
+ class PositionEmbeddingCoordsSine(nn.Module):
15
+ def __init__(
16
+ self,
17
+ temperature=10000,
18
+ normalize=False,
19
+ scale=None,
20
+ pos_type="fourier",
21
+ d_pos=None,
22
+ d_in=3,
23
+ gauss_scale=1.0,
24
+ ):
25
+ super().__init__()
26
+ self.temperature = temperature
27
+ self.normalize = normalize
28
+ if scale is not None and normalize is False:
29
+ raise ValueError("normalize should be True if scale is passed")
30
+ if scale is None:
31
+ scale = 2 * math.pi
32
+ assert pos_type in ["sine", "fourier"]
33
+ self.pos_type = pos_type
34
+ self.scale = scale
35
+ if pos_type == "fourier":
36
+ assert d_pos is not None
37
+ assert d_pos % 2 == 0
38
+ # define a gaussian matrix input_ch -> output_ch
39
+ B = torch.empty((d_in, d_pos // 2)).normal_()
40
+ B *= gauss_scale
41
+ self.register_buffer("gauss_B", B)
42
+ self.d_pos = d_pos
43
+
44
+ def get_sine_embeddings(self, xyz, num_channels):
45
+ # clone coords so that shift/scale operations do not affect original tensor
46
+ orig_xyz = xyz
47
+ xyz = orig_xyz.clone()
48
+
49
+ ndim = num_channels // xyz.shape[2]
50
+ if ndim % 2 != 0:
51
+ ndim -= 1
52
+ # automatically handle remainder by assiging it to the first dim
53
+ rems = num_channels - (ndim * xyz.shape[2])
54
+
55
+ assert (
56
+ ndim % 2 == 0
57
+ ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}"
58
+
59
+ final_embeds = []
60
+ prev_dim = 0
61
+
62
+ for d in range(xyz.shape[2]):
63
+ cdim = ndim
64
+ if rems > 0:
65
+ # add remainder in increments of two to maintain even size
66
+ cdim += 2
67
+ rems -= 2
68
+
69
+ if cdim != prev_dim:
70
+ dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device)
71
+ dim_t = self.temperature ** (2 * (dim_t // 2) / cdim)
72
+
73
+ # create batch x cdim x nccords embedding
74
+ raw_pos = xyz[:, :, d]
75
+ if self.scale:
76
+ raw_pos *= self.scale
77
+ pos = raw_pos[:, :, None] / dim_t
78
+ pos = torch.stack(
79
+ (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3
80
+ ).flatten(2)
81
+ final_embeds.append(pos)
82
+ prev_dim = cdim
83
+
84
+ final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
85
+ return final_embeds
86
+
87
+ def get_fourier_embeddings(self, xyz, num_channels=None):
88
+ # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
89
+
90
+ if num_channels is None:
91
+ num_channels = self.gauss_B.shape[1] * 2
92
+
93
+ bsize, npoints = xyz.shape[0], xyz.shape[1]
94
+ assert num_channels > 0 and num_channels % 2 == 0
95
+ d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1]
96
+ d_out = num_channels // 2
97
+ assert d_out <= max_d_out
98
+ assert d_in == xyz.shape[-1]
99
+
100
+ # clone coords so that shift/scale operations do not affect original tensor
101
+ orig_xyz = xyz
102
+ xyz = orig_xyz.clone()
103
+
104
+ xyz *= 2 * np.pi
105
+ xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view(
106
+ bsize, npoints, d_out
107
+ )
108
+ final_embeds = [xyz_proj.sin(), xyz_proj.cos()]
109
+
110
+ # return batch x d_pos x npoints embedding
111
+ final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
112
+ return final_embeds
113
+
114
+ def forward(self, xyz, num_channels=None):
115
+ assert isinstance(xyz, torch.Tensor)
116
+ assert xyz.ndim == 3
117
+ # xyz is batch x npoints x 3
118
+ if self.pos_type == "sine":
119
+ with torch.no_grad():
120
+ return self.get_sine_embeddings(xyz, num_channels)
121
+ elif self.pos_type == "fourier":
122
+ with torch.no_grad():
123
+ return self.get_fourier_embeddings(xyz, num_channels)
124
+ else:
125
+ raise ValueError(f"Unknown {self.pos_type}")
126
+
127
+ def extra_repr(self):
128
+ st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}"
129
+ if hasattr(self, "gauss_B"):
130
+ st += (
131
+ f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}"
132
+ )
133
+ return st
pi3detr/models/query_engine.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from abc import ABC, abstractmethod
4
+ from torch_geometric.nn import MLP, fps, knn
5
+ from torch_geometric.data.data import Data
6
+ from typing import Optional
7
+
8
+
9
+ class QueryEngine(nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ pos_embedder: Optional[nn.Module],
13
+ feat_dim: int,
14
+ max_points_in_param: int,
15
+ num_queries: int,
16
+ ):
17
+ super().__init__()
18
+ self.pos_embedder = pos_embedder
19
+ self.feat_dim = feat_dim
20
+ self.max_points_in_param = max_points_in_param
21
+ self.num_queries = num_queries
22
+
23
+ @abstractmethod
24
+ def forward(self, data: Data) -> tuple[Tensor]:
25
+ pass
26
+
27
+
28
+ class PointFPSQueryEngine(QueryEngine):
29
+ def __init__(
30
+ self,
31
+ pos_embedder: nn.Module,
32
+ feat_dim: int,
33
+ max_points_in_param: int,
34
+ num_queries: int,
35
+ ):
36
+ super().__init__(pos_embedder, feat_dim, max_points_in_param, num_queries)
37
+ self.num_queries = num_queries
38
+ self.query_proj = MLP(
39
+ [self.feat_dim, self.feat_dim, self.feat_dim],
40
+ bias=False,
41
+ act="relu",
42
+ norm="layer_norm",
43
+ )
44
+
45
+ def forward(self, data: Data) -> tuple[Tensor]:
46
+ num_points_per_batch = torch.bincount(data.batch)
47
+ max_ratio = self.num_queries / num_points_per_batch.min().item()
48
+ fps_idx = fps(data.pos, data.batch, ratio=max_ratio)
49
+ fps_batch = data.batch[fps_idx]
50
+ query_xyz = torch.stack(
51
+ [
52
+ data.pos[fps_idx[fps_batch == i][: self.num_queries]]
53
+ for i in range(data.batch.max().item() + 1)
54
+ ]
55
+ )
56
+ query_pos = self.pos_embedder(query_xyz, num_channels=self.feat_dim)
57
+ query_embed = self.query_proj(query_pos.permute(0, 2, 1))[
58
+ :, : self.num_queries, :
59
+ ].permute(0, 2, 1)
60
+ return (
61
+ query_xyz.unsqueeze(2).expand(-1, -1, self.max_points_in_param, -1),
62
+ query_embed,
63
+ )
64
+
65
+
66
+ class LearnedQueryEngine(QueryEngine):
67
+
68
+ def __init__(
69
+ self,
70
+ pos_embedder: Optional[nn.Module],
71
+ feat_dim: int,
72
+ max_points_in_param: int,
73
+ num_queries: int,
74
+ ):
75
+ super().__init__(None, feat_dim, max_points_in_param, num_queries)
76
+ self.query_embed = nn.Embedding(self.num_queries, feat_dim)
77
+
78
+ def forward(self, data: Data) -> tuple[Tensor]:
79
+ return (
80
+ torch.zeros(
81
+ data.batch_size,
82
+ self.num_queries,
83
+ self.max_points_in_param,
84
+ 3,
85
+ device=data.pos.device,
86
+ requires_grad=False,
87
+ ),
88
+ self.query_embed.weight.unsqueeze(0)
89
+ .expand(data.batch_size, -1, -1)
90
+ .permute(0, 2, 1),
91
+ )
92
+
93
+
94
+ def build_query_engine(
95
+ query_type: str,
96
+ pos_embedder: Optional[nn.Module],
97
+ feat_dim: int,
98
+ max_points_in_param: int,
99
+ num_queries: int,
100
+ ) -> QueryEngine:
101
+ if query_type == "point_fps":
102
+ return PointFPSQueryEngine(
103
+ pos_embedder, feat_dim, max_points_in_param, num_queries
104
+ )
105
+ elif query_type == "learned":
106
+ return LearnedQueryEngine(None, feat_dim, max_points_in_param, num_queries)
107
+ else:
108
+ raise ValueError(f"Unknown query type {query_type}")
pi3detr/models/transformer.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted code from Meta's DETR.
3
+ """
4
+
5
+ import copy
6
+ from typing import Optional, List
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn, Tensor
10
+
11
+
12
+ class Transformer(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ enc_dim: int = 256,
17
+ dec_dim: int = 256,
18
+ nhead: int = 8,
19
+ num_encoder_layers: int = 6,
20
+ num_decoder_layers: int = 6,
21
+ enc_dim_feedforward: int = 2048,
22
+ dec_dim_feedforward: int = 2048,
23
+ enc_dropout: float = 0.1,
24
+ dec_dropout: float = 0.1,
25
+ activation: str = "relu",
26
+ normalize_before: bool = False,
27
+ return_intermediate_dec: bool = False,
28
+ ):
29
+ super().__init__()
30
+ encoder_layer = TransformerEncoderLayer(
31
+ enc_dim,
32
+ nhead,
33
+ enc_dim_feedforward,
34
+ enc_dropout,
35
+ activation,
36
+ normalize_before,
37
+ )
38
+ encoder_norm = nn.LayerNorm(enc_dim) if normalize_before else None
39
+ self.encoder = TransformerEncoder(
40
+ encoder_layer, num_encoder_layers, encoder_norm
41
+ )
42
+ if enc_dim != dec_dim:
43
+ self.enc_to_dec_proj = nn.Linear(enc_dim, dec_dim)
44
+ else:
45
+ self.enc_to_dec_proj = nn.Identity()
46
+ decoder_layer = TransformerDecoderLayer(
47
+ dec_dim,
48
+ nhead,
49
+ dec_dim_feedforward,
50
+ dec_dropout,
51
+ activation,
52
+ normalize_before,
53
+ )
54
+ decoder_norm = nn.LayerNorm(dec_dim)
55
+ self.decoder = TransformerDecoder(
56
+ decoder_layer,
57
+ num_decoder_layers,
58
+ decoder_norm,
59
+ return_intermediate=return_intermediate_dec,
60
+ )
61
+ self._reset_parameters()
62
+ self.d_model = dec_dim
63
+ self.nhead = nhead
64
+
65
+ def _reset_parameters(self):
66
+ for p in self.parameters():
67
+ if p.dim() > 1:
68
+ nn.init.xavier_uniform_(p)
69
+
70
+ def forward(
71
+ self,
72
+ src: Tensor,
73
+ mask: Optional[Tensor],
74
+ query_embed: Tensor,
75
+ pos_embed: Tensor = None,
76
+ ) -> Tensor:
77
+ bs, _, _ = src.shape
78
+
79
+ src = src.permute(1, 0, 2) # (bs, seq, feat) -> (seq, bs, feat)
80
+ if pos_embed is not None:
81
+ pos_embed = pos_embed.permute(2, 0, 1)
82
+
83
+ memory = self.encoder(src, mask=None, src_key_padding_mask=mask, pos=None)
84
+ memory = self.enc_to_dec_proj(memory)
85
+
86
+ query_embed = query_embed.permute(2, 0, 1)
87
+ tgt = torch.zeros_like(query_embed)
88
+ hs = self.decoder(
89
+ tgt,
90
+ memory,
91
+ tgt_mask=None,
92
+ memory_mask=None,
93
+ tgt_key_padding_mask=None,
94
+ memory_key_padding_mask=mask,
95
+ pos=pos_embed,
96
+ query_pos=query_embed,
97
+ )
98
+ return hs.transpose(1, 2)
99
+
100
+
101
+ class TransformerEncoder(nn.Module):
102
+
103
+ def __init__(self, encoder_layer, num_layers, norm=None):
104
+ super().__init__()
105
+ self.layers = _get_clones(encoder_layer, num_layers)
106
+ self.num_layers = num_layers
107
+ self.norm = norm
108
+
109
+ def forward(
110
+ self,
111
+ src,
112
+ mask: Optional[Tensor] = None,
113
+ src_key_padding_mask: Optional[Tensor] = None,
114
+ pos: Optional[Tensor] = None,
115
+ ):
116
+ output = src
117
+
118
+ for layer in self.layers:
119
+ output = layer(
120
+ output,
121
+ src_mask=mask,
122
+ src_key_padding_mask=src_key_padding_mask,
123
+ pos=pos,
124
+ )
125
+
126
+ if self.norm is not None:
127
+ output = self.norm(output)
128
+
129
+ return output
130
+
131
+
132
+ class TransformerEncoderLayer(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ d_model,
137
+ nhead,
138
+ dim_feedforward=2048,
139
+ dropout=0.1,
140
+ activation="relu",
141
+ normalize_before=False,
142
+ ):
143
+ super().__init__()
144
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
145
+ # Implementation of Feedforward model
146
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
147
+ self.dropout = nn.Dropout(dropout)
148
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
149
+
150
+ self.norm1 = nn.LayerNorm(d_model)
151
+ self.norm2 = nn.LayerNorm(d_model)
152
+ self.dropout1 = nn.Dropout(dropout)
153
+ self.dropout2 = nn.Dropout(dropout)
154
+
155
+ self.activation = _get_activation_fn(activation)
156
+ self.normalize_before = normalize_before
157
+
158
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
159
+ return tensor if pos is None else tensor + pos
160
+
161
+ def forward_post(
162
+ self,
163
+ src,
164
+ src_mask: Optional[Tensor] = None,
165
+ src_key_padding_mask: Optional[Tensor] = None,
166
+ pos: Optional[Tensor] = None,
167
+ ):
168
+ q = k = self.with_pos_embed(src, pos)
169
+ src2 = self.self_attn(
170
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
171
+ )[0]
172
+ src = src + self.dropout1(src2)
173
+ src = self.norm1(src)
174
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
175
+ src = src + self.dropout2(src2)
176
+ src = self.norm2(src)
177
+ return src
178
+
179
+ def forward_pre(
180
+ self,
181
+ src,
182
+ src_mask: Optional[Tensor] = None,
183
+ src_key_padding_mask: Optional[Tensor] = None,
184
+ pos: Optional[Tensor] = None,
185
+ ):
186
+ src2 = self.norm1(src)
187
+ q = k = self.with_pos_embed(src2, pos)
188
+ src2 = self.self_attn(
189
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
190
+ )[0]
191
+ src = src + self.dropout1(src2)
192
+ src2 = self.norm2(src)
193
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
194
+ src = src + self.dropout2(src2)
195
+ return src
196
+
197
+ def forward(
198
+ self,
199
+ src,
200
+ src_mask: Optional[Tensor] = None,
201
+ src_key_padding_mask: Optional[Tensor] = None,
202
+ pos: Optional[Tensor] = None,
203
+ ):
204
+ if self.normalize_before:
205
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
206
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
207
+
208
+
209
+ class TransformerDecoder(nn.Module):
210
+
211
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
212
+ super().__init__()
213
+ self.layers = _get_clones(decoder_layer, num_layers)
214
+ self.num_layers = num_layers
215
+ self.norm = norm
216
+ self.return_intermediate = return_intermediate
217
+
218
+ def forward(
219
+ self,
220
+ tgt,
221
+ memory,
222
+ tgt_mask: Optional[Tensor] = None,
223
+ memory_mask: Optional[Tensor] = None,
224
+ tgt_key_padding_mask: Optional[Tensor] = None,
225
+ memory_key_padding_mask: Optional[Tensor] = None,
226
+ pos: Optional[Tensor] = None,
227
+ query_pos: Optional[Tensor] = None,
228
+ ):
229
+ output = tgt
230
+
231
+ intermediate = []
232
+
233
+ for layer in self.layers:
234
+ output = layer(
235
+ output,
236
+ memory,
237
+ tgt_mask=tgt_mask,
238
+ memory_mask=memory_mask,
239
+ tgt_key_padding_mask=tgt_key_padding_mask,
240
+ memory_key_padding_mask=memory_key_padding_mask,
241
+ pos=pos,
242
+ query_pos=query_pos,
243
+ )
244
+ if self.return_intermediate:
245
+ intermediate.append(self.norm(output))
246
+
247
+ if self.norm is not None:
248
+ output = self.norm(output)
249
+ if self.return_intermediate:
250
+ intermediate.pop()
251
+ intermediate.append(output)
252
+
253
+ if self.return_intermediate:
254
+ return torch.stack(intermediate)
255
+
256
+ return output.unsqueeze(0)
257
+
258
+
259
+ class TransformerDecoderLayer(nn.Module):
260
+
261
+ def __init__(
262
+ self,
263
+ d_model,
264
+ nhead,
265
+ dim_feedforward=2048,
266
+ dropout=0.1,
267
+ activation="relu",
268
+ normalize_before=False,
269
+ ):
270
+ super().__init__()
271
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
272
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
273
+ # Implementation of Feedforward model
274
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
275
+ self.dropout = nn.Dropout(dropout)
276
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
277
+
278
+ self.norm1 = nn.LayerNorm(d_model)
279
+ self.norm2 = nn.LayerNorm(d_model)
280
+ self.norm3 = nn.LayerNorm(d_model)
281
+ self.dropout1 = nn.Dropout(dropout)
282
+ self.dropout2 = nn.Dropout(dropout)
283
+ self.dropout3 = nn.Dropout(dropout)
284
+
285
+ self.activation = _get_activation_fn(activation)
286
+ self.normalize_before = normalize_before
287
+
288
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
289
+ return tensor if pos is None else tensor + pos
290
+
291
+ def forward_post(
292
+ self,
293
+ tgt,
294
+ memory,
295
+ tgt_mask: Optional[Tensor] = None,
296
+ memory_mask: Optional[Tensor] = None,
297
+ tgt_key_padding_mask: Optional[Tensor] = None,
298
+ memory_key_padding_mask: Optional[Tensor] = None,
299
+ pos: Optional[Tensor] = None,
300
+ query_pos: Optional[Tensor] = None,
301
+ ):
302
+ q = k = self.with_pos_embed(tgt, query_pos)
303
+ tgt2 = self.self_attn(
304
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
305
+ )[0]
306
+ tgt = tgt + self.dropout1(tgt2)
307
+ tgt = self.norm1(tgt)
308
+ tgt2 = self.multihead_attn(
309
+ query=self.with_pos_embed(tgt, query_pos),
310
+ key=self.with_pos_embed(memory, pos),
311
+ value=memory,
312
+ attn_mask=memory_mask,
313
+ key_padding_mask=memory_key_padding_mask,
314
+ )[0]
315
+ tgt = tgt + self.dropout2(tgt2)
316
+ tgt = self.norm2(tgt)
317
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
318
+ tgt = tgt + self.dropout3(tgt2)
319
+ tgt = self.norm3(tgt)
320
+ return tgt
321
+
322
+ def forward_pre(
323
+ self,
324
+ tgt,
325
+ memory,
326
+ tgt_mask: Optional[Tensor] = None,
327
+ memory_mask: Optional[Tensor] = None,
328
+ tgt_key_padding_mask: Optional[Tensor] = None,
329
+ memory_key_padding_mask: Optional[Tensor] = None,
330
+ pos: Optional[Tensor] = None,
331
+ query_pos: Optional[Tensor] = None,
332
+ ):
333
+ tgt2 = self.norm1(tgt)
334
+ q = k = self.with_pos_embed(tgt2, query_pos)
335
+ tgt2 = self.self_attn(
336
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
337
+ )[0]
338
+ tgt = tgt + self.dropout1(tgt2)
339
+ tgt2 = self.norm2(tgt)
340
+ tgt2 = self.multihead_attn(
341
+ query=self.with_pos_embed(tgt2, query_pos),
342
+ key=self.with_pos_embed(memory, pos),
343
+ value=memory,
344
+ attn_mask=memory_mask,
345
+ key_padding_mask=memory_key_padding_mask,
346
+ )[0]
347
+ tgt = tgt + self.dropout2(tgt2)
348
+ tgt2 = self.norm3(tgt)
349
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
350
+ tgt = tgt + self.dropout3(tgt2)
351
+ return tgt
352
+
353
+ def forward(
354
+ self,
355
+ tgt,
356
+ memory,
357
+ tgt_mask: Optional[Tensor] = None,
358
+ memory_mask: Optional[Tensor] = None,
359
+ tgt_key_padding_mask: Optional[Tensor] = None,
360
+ memory_key_padding_mask: Optional[Tensor] = None,
361
+ pos: Optional[Tensor] = None,
362
+ query_pos: Optional[Tensor] = None,
363
+ ):
364
+ if self.normalize_before:
365
+ return self.forward_pre(
366
+ tgt,
367
+ memory,
368
+ tgt_mask,
369
+ memory_mask,
370
+ tgt_key_padding_mask,
371
+ memory_key_padding_mask,
372
+ pos,
373
+ query_pos,
374
+ )
375
+ return self.forward_post(
376
+ tgt,
377
+ memory,
378
+ tgt_mask,
379
+ memory_mask,
380
+ tgt_key_padding_mask,
381
+ memory_key_padding_mask,
382
+ pos,
383
+ query_pos,
384
+ )
385
+
386
+
387
+ def _get_clones(module, N):
388
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
389
+
390
+
391
+ def _get_activation_fn(activation):
392
+ """Return an activation function given a string"""
393
+ if activation == "relu":
394
+ return F.relu
395
+ if activation == "gelu":
396
+ return F.gelu
397
+ if activation == "glu":
398
+ return F.glu
399
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
pi3detr/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config_reader import load_args
2
+ from .layer_utils import no_grad, load_weights
3
+ from .curve_fitter import torch_bezier_curve
pi3detr/utils/config_reader.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from argparse import Namespace
3
+ from types import SimpleNamespace
4
+ import yaml
5
+
6
+
7
+ def load_yaml(file_path: str) -> yaml.YAMLObject:
8
+ with open(file_path, "r") as file:
9
+ try:
10
+ return yaml.safe_load(file)
11
+ except yaml.YAMLError as exc:
12
+ print(exc)
13
+ yaml.YAMLError("error reading yaml file")
14
+
15
+
16
+ def load_args(config: str, parsed_args: Optional[Namespace] = None) -> SimpleNamespace:
17
+ args = load_yaml(config)
18
+ parsed_args = vars(parsed_args) if parsed_args else {}
19
+ args = args | parsed_args
20
+ args = SimpleNamespace(**args)
21
+ return args
pi3detr/utils/curve_fitter.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+
5
+
6
+ def torch_arc_points(start, mid, end, num_points=100):
7
+ """
8
+ Sample points along a circular arc defined by 3 points in 3D, batched.
9
+ Inputs:
10
+ start, mid, end: tensors of shape [B, 3]
11
+ num_points: number of points sampled along the arc
12
+ Returns:
13
+ arc_points: tensor of shape [B, num_points, 3]
14
+ """
15
+ B = start.shape[0]
16
+
17
+ # 1. Compute circle center and normal vector for each batch
18
+ v1 = mid - start # [B,3]
19
+ v2 = end - start # [B,3]
20
+
21
+ normal = torch.cross(v1, v2, dim=1) # [B,3]
22
+ normal_norm = normal.norm(dim=1, keepdim=True).clamp(min=1e-8)
23
+ normal = normal / normal_norm # normalize
24
+
25
+ mid1 = (start + mid) / 2 # [B,3]
26
+ mid2 = (start + end) / 2 # [B,3]
27
+
28
+ # perpendicular directions in the plane
29
+ perp1 = torch.cross(normal, v1, dim=1) # [B,3]
30
+ perp2 = torch.cross(normal, v2, dim=1) # [B,3]
31
+
32
+ # Solve line intersection for each batch:
33
+ # Line 1: point mid1, direction perp1
34
+ # Line 2: point mid2, direction perp2
35
+ # Solve for t in mid1 + t * perp1 = mid2 + s * perp2
36
+
37
+ # Construct matrix A and vector b for least squares
38
+ A = torch.stack([perp1, -perp2], dim=2) # [B,3,2]
39
+ b = (mid2 - mid1).unsqueeze(2) # [B,3,1]
40
+
41
+ # Use torch.linalg.lstsq if available, fallback to pinv:
42
+ try:
43
+ t_s = torch.linalg.lstsq(A, b).solution # [B,2,1]
44
+ except:
45
+ # fallback
46
+ At = A.transpose(1, 2) # [B,2,3]
47
+ pinv = torch.linalg.pinv(A) # [B,2,3]
48
+ t_s = torch.bmm(pinv, b) # [B,2,1]
49
+
50
+ t = t_s[:, 0, 0] # [B]
51
+
52
+ center = mid1 + (perp1 * t.unsqueeze(1)) # [B,3]
53
+
54
+ radius = (start - center).norm(dim=1, keepdim=True) # [B,1]
55
+
56
+ # 2. Define local basis in the arc plane
57
+ x_axis = (start - center) / radius # [B,3]
58
+ y_axis = torch.cross(normal, x_axis, dim=1) # [B,3]
59
+
60
+ # 3. Compute angles function
61
+ def angle_from_vector(v):
62
+ x = (v * x_axis).sum(dim=1) # [B]
63
+ y = (v * y_axis).sum(dim=1) # [B]
64
+ angles = torch.atan2(y, x) # [-pi, pi]
65
+ angles = angles % (2 * math.pi)
66
+ return angles
67
+
68
+ theta_start = torch.zeros(B, device=start.device) # [B], 0 since x_axis is ref
69
+ theta_end = angle_from_vector(end - center) # [B]
70
+ theta_mid = angle_from_vector(mid - center) # [B]
71
+
72
+ # 4. Ensure arc goes the correct way (shortest arc through mid)
73
+ # Helper function vectorized:
74
+ def between(a, b, c):
75
+ # returns bool tensor if b is between a and c going CCW mod 2pi
76
+ return ((a < b) & (b < c)) | ((c < a) & ((a < b) | (b < c)))
77
+
78
+ cond = between(theta_start, theta_mid, theta_end)
79
+
80
+ # If not cond, swap start/end angles by adding 2pi to one side
81
+ # We'll add 2pi to whichever angle is smaller to preserve direction
82
+ theta_start_new = torch.where(
83
+ cond,
84
+ theta_start,
85
+ torch.where(theta_start < theta_end, theta_start, theta_start + 2 * math.pi),
86
+ )
87
+ theta_end_new = torch.where(
88
+ cond,
89
+ theta_end,
90
+ torch.where(theta_end < theta_start, theta_end + 2 * math.pi, theta_end),
91
+ )
92
+
93
+ # 5. Sample angles
94
+ t_lin = (
95
+ torch.linspace(0, 1, steps=num_points, device=start.device)
96
+ .unsqueeze(0)
97
+ .repeat(B, 1)
98
+ ) # [B, num_points]
99
+
100
+ angles = theta_start_new.unsqueeze(1) + t_lin * (
101
+ theta_end_new - theta_start_new
102
+ ).unsqueeze(
103
+ 1
104
+ ) # [B, num_points]
105
+ angles = angles % (2 * math.pi)
106
+
107
+ # 6. Map back to 3D
108
+ cos_a = torch.cos(angles).unsqueeze(2) # [B, num_points, 1]
109
+ sin_a = torch.sin(angles).unsqueeze(2) # [B, num_points, 1]
110
+
111
+ points = center.unsqueeze(1) + radius.unsqueeze(1) * (
112
+ cos_a * x_axis.unsqueeze(1) + sin_a * y_axis.unsqueeze(1)
113
+ ) # [B, num_points, 3]
114
+
115
+ return points
116
+
117
+
118
+ def torch_circle_fitter(
119
+ points: torch.Tensor,
120
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
121
+ """
122
+ Fits a circle to an arbitrary number of 3D points using least squares.
123
+
124
+ Args:
125
+ points: Tensor of shape (B, N, 3), where B = batch size, N = number of points per batch.
126
+
127
+ Returns:
128
+ center_3d: (B, 3) tensor of circle centers in 3D
129
+ normal: (B, 3) tensor of normal vectors to the circle's plane
130
+ radius: (B,) tensor of circle radii
131
+ """
132
+ B, N, _ = points.shape
133
+ mean = points.mean(dim=1, keepdim=True)
134
+ centered = points - mean
135
+
136
+ # PCA via SVD
137
+ U, S, Vh = torch.linalg.svd(centered)
138
+ normal = Vh[
139
+ :, -1, :
140
+ ] # last singular vector corresponds to the smallest variance (plane normal)
141
+
142
+ # Project to plane
143
+ x_axis = Vh[:, 0, :]
144
+ y_axis = Vh[:, 1, :]
145
+ X = torch.einsum("bij,bj->bi", centered, x_axis) # (B, N)
146
+ Y = torch.einsum("bij,bj->bi", centered, y_axis) # (B, N)
147
+
148
+ # Fit circle in 2D: (x - xc)^2 + (y - yc)^2 = r^2
149
+ A = torch.stack([2 * X, 2 * Y, torch.ones_like(X)], dim=-1) # (B, N, 3)
150
+ b = (X**2 + Y**2).unsqueeze(-1) # (B, N, 1)
151
+
152
+ # Solve the least squares system: A @ [xc, yc, c] = b
153
+ AtA = A.transpose(1, 2) @ A
154
+ Atb = A.transpose(1, 2) @ b
155
+ sol = torch.linalg.solve(AtA, Atb).squeeze(-1) # (B, 3)
156
+
157
+ xc, yc, c = sol[:, 0], sol[:, 1], sol[:, 2]
158
+ radius = torch.sqrt(xc**2 + yc**2 + c)
159
+
160
+ # Reconstruct center in 3D
161
+ center_3d = mean.squeeze(1) + xc.unsqueeze(1) * x_axis + yc.unsqueeze(1) * y_axis
162
+
163
+ return center_3d, normal, radius
164
+
165
+
166
+ def generate_points_on_circle(center, normal, radius, num_points=100):
167
+ # Normalize the normal vector
168
+ normal = normal / np.linalg.norm(normal)
169
+
170
+ # Find two orthogonal vectors in the plane of the circle
171
+ if np.allclose(normal, [0, 0, 1]):
172
+ u = np.array([1, 0, 0])
173
+ else:
174
+ u = np.cross(normal, [0, 0, 1])
175
+ u = u / np.linalg.norm(u)
176
+ v = np.cross(normal, u)
177
+
178
+ # Generate points on the circle in the plane
179
+ theta = np.linspace(0, 2 * np.pi, num_points)
180
+ circle_points = (
181
+ center
182
+ + radius * np.outer(np.cos(theta), u)
183
+ + radius * np.outer(np.sin(theta), v)
184
+ )
185
+
186
+ return circle_points
187
+
188
+
189
+ def generate_points_on_circle_torch(
190
+ center, normal, radius, num_points=100
191
+ ) -> torch.Tensor:
192
+ """
193
+ Generate points on a circle in 3D space using PyTorch, supporting batching.
194
+
195
+ Args:
196
+ center: Tensor of shape (B, 3), circle centers.
197
+ normal: Tensor of shape (B, 3), normal vectors to the circle's plane.
198
+ radius: Tensor of shape (B,), radii of the circles.
199
+ num_points: Number of points to generate per circle.
200
+
201
+ Returns:
202
+ Tensor of shape (B, num_points, 3), points on the circles.
203
+ """
204
+ B = center.shape[0]
205
+ normal = normal / torch.norm(normal, dim=1, keepdim=True) # Normalize normals
206
+
207
+ # Find two orthogonal vectors in the plane of the circle
208
+ u = torch.linalg.cross(
209
+ normal,
210
+ torch.tensor([0, 0, 1], dtype=normal.dtype, device=normal.device).expand_as(
211
+ normal
212
+ ),
213
+ )
214
+ u = torch.where(
215
+ torch.norm(u, dim=1, keepdim=True) > 1e-6,
216
+ u,
217
+ torch.tensor([1, 0, 0], dtype=normal.dtype, device=normal.device).expand_as(
218
+ normal
219
+ ),
220
+ )
221
+ u = u / torch.norm(u, dim=1, keepdim=True)
222
+ v = torch.linalg.cross(normal, u)
223
+
224
+ # Generate points on the circle in the plane
225
+ theta = (
226
+ torch.linspace(0, 2 * torch.pi, num_points, device=center.device)
227
+ .unsqueeze(0)
228
+ .repeat(B, 1)
229
+ )
230
+ circle_points = (
231
+ center.unsqueeze(1)
232
+ + radius.unsqueeze(1).unsqueeze(2)
233
+ * torch.cos(theta).unsqueeze(2)
234
+ * u.unsqueeze(1)
235
+ + radius.unsqueeze(1).unsqueeze(2)
236
+ * torch.sin(theta).unsqueeze(2)
237
+ * v.unsqueeze(1)
238
+ )
239
+
240
+ return circle_points
241
+
242
+
243
+ def torch_bezier_curve(
244
+ control_points: torch.Tensor, num_points: int = 100
245
+ ) -> torch.Tensor:
246
+ control_points = control_points.float()
247
+ t = (torch.linspace(0, 1, num_points).unsqueeze(-1).unsqueeze(-1)).to(
248
+ control_points.device
249
+ ) # shape [1, num_points, 1]
250
+ B = (
251
+ (1 - t) ** 3 * control_points[:, 0]
252
+ + 3 * (1 - t) ** 2 * t * control_points[:, 1]
253
+ + 3 * (1 - t) * t**2 * control_points[:, 2]
254
+ + t**3 * control_points[:, 3]
255
+ )
256
+ # Transpose the first two dimensions to get the shape (batch_size, num_points, 3)
257
+ B = B.transpose(0, 1)
258
+
259
+ return B
260
+
261
+
262
+ def torch_line_points(
263
+ start_points: torch.Tensor, end_points: torch.Tensor, num_points: int = 100
264
+ ) -> torch.Tensor:
265
+ weights = (
266
+ torch.linspace(0, 1, num_points)
267
+ .unsqueeze(0)
268
+ .unsqueeze(-1)
269
+ .to(start_points.device)
270
+ )
271
+ line_points = (1 - weights) * start_points.unsqueeze(
272
+ 1
273
+ ) + weights * end_points.unsqueeze(1)
274
+ return line_points
275
+
276
+
277
+ def fit_line(points: torch.Tensor, K: int = 100) -> torch.Tensor:
278
+ """
279
+ Fit a line to 3D points and sample K points along it.
280
+ """
281
+ assert points.ndim == 2 and points.shape[1] == 3, "Input must be [N, 3]"
282
+
283
+ # Step 1: Center the points
284
+ mean = points.mean(dim=0, keepdim=True)
285
+ centered = points - mean
286
+
287
+ # Step 2: SVD
288
+ U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
289
+ direction = Vh[0] # First principal component
290
+
291
+ # Step 3: Project points onto the line to get min/max
292
+ projections = torch.matmul(centered, direction)
293
+ t_min, t_max = projections.min(), projections.max()
294
+
295
+ # Step 4: Sample along the line
296
+ t_vals = torch.linspace(t_min, t_max, K).to(points.device)
297
+ fitted_points = mean + t_vals[:, None] * direction
298
+
299
+ return fitted_points
300
+
301
+
302
+ def fit_cubic_bezier(points_3d: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Fit a cubic Bézier curve to 3D points while fixing the start and end points.
305
+
306
+ Args:
307
+ points_3d: (N, 3) Tensor of 3D arc points.
308
+
309
+ Returns:
310
+ bezier_pts: Tensor of 4 control points (P0, P1, P2, P3), shape (4, 3)
311
+ """
312
+ if not isinstance(points_3d, torch.Tensor):
313
+ points_3d = torch.tensor(points_3d, dtype=torch.float32)
314
+
315
+ n = len(points_3d)
316
+
317
+ if n < 4:
318
+ raise ValueError("At least 4 points are required to fit a cubic Bézier curve.")
319
+
320
+ device = points_3d.device
321
+
322
+ # Fixed start and end points
323
+ P0 = points_3d[0]
324
+ P3 = points_3d[-1]
325
+
326
+ # Normalize parameter t
327
+ t = torch.linspace(0, 1, n, device=device)
328
+
329
+ # Bernstein basis functions for cubic Bézier
330
+ def bernstein(t):
331
+ b0 = (1 - t) ** 3
332
+ b1 = 3 * (1 - t) ** 2 * t
333
+ b2 = 3 * (1 - t) * t**2
334
+ b3 = t**3
335
+ return torch.stack([b0, b1, b2, b3], dim=1) # (n, 4)
336
+
337
+ B = bernstein(t)
338
+
339
+ # Initial guess for P1 and P2 (based on tangents)
340
+ P1_init = P0 + (points_3d[1] - P0) * 1.5
341
+ P2_init = P3 + (points_3d[-2] - P3) * 1.5
342
+
343
+ # Optimization parameters - make them require gradients
344
+ P1 = P1_init.clone().detach().requires_grad_(True)
345
+ P2 = P2_init.clone().detach().requires_grad_(True)
346
+
347
+ # Optimizer
348
+ optimizer = torch.optim.LBFGS([P1, P2], max_iter=100, line_search_fn="strong_wolfe")
349
+
350
+ def closure():
351
+ optimizer.zero_grad()
352
+
353
+ # Compute Bézier curve
354
+ curve = (
355
+ B[:, 0].unsqueeze(1) * P0
356
+ + B[:, 1].unsqueeze(1) * P1
357
+ + B[:, 2].unsqueeze(1) * P2
358
+ + B[:, 3].unsqueeze(1) * P3
359
+ )
360
+
361
+ # Compute loss (mean squared error)
362
+ loss = torch.mean((curve - points_3d) ** 2)
363
+ loss.backward()
364
+ return loss
365
+
366
+ # Optimize
367
+ optimizer.step(closure)
368
+
369
+ # Return control points
370
+ with torch.no_grad():
371
+ return torch.stack([P0, P1, P2, P3])
pi3detr/utils/layer_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ def load_weights(
7
+ model: nn.Module, ckpt_path: str, dont_load: Optional[list[str]] = []
8
+ ) -> None:
9
+ ckpt = torch.load(ckpt_path, weights_only=False)
10
+ state_dict = {}
11
+ for k, v in ckpt["state_dict"].items():
12
+ if not any([dl in k for dl in dont_load]):
13
+ state_dict[k] = v
14
+ else:
15
+ print(f"Didn't load {k}")
16
+ model.load_state_dict(state_dict, strict=False)
17
+ print(f"Loaded checkpoint: {ckpt_path}")
18
+
19
+
20
+ def no_grad(model: nn.Module) -> None:
21
+ for param in model.parameters():
22
+ param.requires_grad = False
pi3detr/utils/postprocessing.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import Tensor
4
+ from torch_geometric.data import Data
5
+ from scipy.spatial import cKDTree # faster than KDTree
6
+
7
+ from .curve_fitter import (
8
+ fit_cubic_bezier,
9
+ torch_bezier_curve,
10
+ torch_circle_fitter,
11
+ generate_points_on_circle_torch,
12
+ torch_arc_points,
13
+ fit_line,
14
+ )
15
+
16
+
17
+ def snap_and_fit_curves(
18
+ data: Data,
19
+ ) -> Data:
20
+ """
21
+ Snap polylines to nearest point cloud points and fit geometric curves based on predicted classes.
22
+
23
+ This function performs two main operations:
24
+ 1. Snaps each polyline vertex to its nearest neighbor in the point cloud
25
+ 2. Fits the appropriate geometric curve (line, circle, arc, or B-spline) based on the predicted class
26
+
27
+ Class mapping:
28
+ 0: Background (no processing, kept as-is)
29
+ 1: B-spline (cubic Bezier curve fitting)
30
+ 2: Line (linear regression fitting)
31
+ 3: Circle (3D circle fitting)
32
+ 4: Arc (circular arc through 3 points)
33
+
34
+ Args:
35
+ data (Data): PyTorch Geometric Data object containing:
36
+ - pos (Tensor): Point cloud coordinates [P, 3]
37
+ - polylines (Tensor): Raw polyline predictions [M, K, 3]
38
+ - polyline_class (Tensor): Predicted classes for each polyline [M]
39
+
40
+ Returns:
41
+ Data: Cloned Data object with fitted polylines replacing the original polylines.
42
+ All other attributes remain unchanged.
43
+
44
+ Note:
45
+ - Robust error handling: falls back to original polyline if fitting fails
46
+ - Validates output shapes and numerical stability (NaN/Inf checks)
47
+ - Requires minimum points per curve type (e.g., 3 for circles, 4 for B-splines)
48
+ """
49
+ point_cloud = data.pos
50
+ polylines = data.polylines
51
+ polyline_classes = data.polyline_class
52
+ M, K, _ = polylines.shape
53
+ snapped_and_fitted = torch.zeros_like(polylines)
54
+
55
+ for i, cls in enumerate(polyline_classes):
56
+ if cls == 0:
57
+ snapped_and_fitted[i] = polylines[i] # Keep original for class 0
58
+ continue
59
+
60
+ try:
61
+ # Snap the polyline to the nearest point in the point cloud
62
+ distances = torch.cdist(polylines[i], point_cloud)
63
+ nearest_idx = distances.argmin(dim=1)
64
+ nn_points = point_cloud[nearest_idx]
65
+
66
+ # Safety check: ensure we have valid points
67
+ if (
68
+ len(nn_points) == 0
69
+ or torch.any(torch.isnan(nn_points))
70
+ or torch.any(torch.isinf(nn_points))
71
+ ):
72
+ snapped_and_fitted[i] = polylines[i]
73
+ continue
74
+
75
+ new_curve = None
76
+
77
+ if cls == 1: # BSpline
78
+ try:
79
+ if len(nn_points) < 4:
80
+ # Not enough points for cubic Bezier, fallback to original
81
+ new_curve = polylines[i]
82
+ else:
83
+ bezier_pts = fit_cubic_bezier(nn_points)
84
+ new_curve = torch_bezier_curve(
85
+ bezier_pts.unsqueeze(0), K
86
+ ).squeeze(0)
87
+
88
+ # Validate output shape and values
89
+ if (
90
+ new_curve.shape != (K, 3)
91
+ or torch.any(torch.isnan(new_curve))
92
+ or torch.any(torch.isinf(new_curve))
93
+ ):
94
+ new_curve = polylines[i]
95
+
96
+ except Exception:
97
+ new_curve = polylines[i]
98
+
99
+ elif cls == 2: # Line
100
+ try:
101
+ if len(nn_points) < 2:
102
+ new_curve = polylines[i]
103
+ else:
104
+ new_curve = fit_line(nn_points, K)
105
+
106
+ # Validate output shape and values
107
+ if (
108
+ new_curve.shape != (K, 3)
109
+ or torch.any(torch.isnan(new_curve))
110
+ or torch.any(torch.isinf(new_curve))
111
+ ):
112
+ new_curve = polylines[i]
113
+ except Exception:
114
+ new_curve = polylines[i]
115
+
116
+ elif cls == 3: # Circle
117
+ try:
118
+ # Check if we have enough unique points for circle fitting
119
+ unique_points = torch.unique(nn_points, dim=0)
120
+ if len(unique_points) < 3:
121
+ new_curve = polylines[i]
122
+ else:
123
+ center, normal, radius = torch_circle_fitter(
124
+ nn_points.unsqueeze(0)
125
+ )
126
+
127
+ # Validate circle parameters
128
+ if (
129
+ torch.any(torch.isnan(center))
130
+ or torch.any(torch.isnan(normal))
131
+ or torch.any(torch.isnan(radius))
132
+ or torch.any(torch.isinf(center))
133
+ or torch.any(torch.isinf(normal))
134
+ or torch.any(torch.isinf(radius))
135
+ or radius <= 0
136
+ ):
137
+ new_curve = polylines[i]
138
+ else:
139
+ new_curve = generate_points_on_circle_torch(
140
+ center, normal, radius, K
141
+ ).squeeze(0)
142
+
143
+ # Validate output shape and values
144
+ if (
145
+ new_curve.shape != (K, 3)
146
+ or torch.any(torch.isnan(new_curve))
147
+ or torch.any(torch.isinf(new_curve))
148
+ ):
149
+ new_curve = polylines[i]
150
+ except Exception:
151
+ new_curve = polylines[i]
152
+
153
+ elif cls == 4: # Arc
154
+ try:
155
+ if len(nn_points) < 3:
156
+ new_curve = polylines[i]
157
+ else:
158
+ start_pt = nn_points[0].unsqueeze(0)
159
+ mid_pt = nn_points[len(nn_points) // 2].unsqueeze(0)
160
+ end_pt = nn_points[-1].unsqueeze(0)
161
+
162
+ new_curve = torch_arc_points(
163
+ start_pt, mid_pt, end_pt, K
164
+ ).squeeze(0)
165
+
166
+ # Validate output shape and values
167
+ if (
168
+ new_curve.shape != (K, 3)
169
+ or torch.any(torch.isnan(new_curve))
170
+ or torch.any(torch.isinf(new_curve))
171
+ ):
172
+ new_curve = polylines[i]
173
+ except Exception:
174
+ new_curve = polylines[i]
175
+
176
+ else:
177
+ # Unknown class, keep original
178
+ new_curve = polylines[i]
179
+
180
+ # Final safety check
181
+ if new_curve is not None and new_curve.shape == (K, 3):
182
+ snapped_and_fitted[i] = new_curve
183
+ else:
184
+ snapped_and_fitted[i] = polylines[i]
185
+
186
+ except Exception:
187
+ # If anything goes wrong, fallback to original polyline
188
+ snapped_and_fitted[i] = polylines[i]
189
+
190
+ output = data.clone()
191
+ output.polylines = snapped_and_fitted
192
+ return output
193
+
194
+
195
+ def filter_predictions(pred_data: Data, thresholds: list[float]) -> Data:
196
+ """
197
+ Filter predictions based on class-specific confidence thresholds.
198
+
199
+ Removes polylines whose confidence scores fall below the specified threshold
200
+ for their predicted class. This is typically used as a post-processing step
201
+ to remove low-confidence predictions before further analysis.
202
+
203
+ Args:
204
+ pred_data (Data): PyTorch Geometric Data object containing:
205
+ - pos (Tensor): Point cloud coordinates [P, 3]
206
+ - polyline_class (Tensor): Predicted classes [N]
207
+ - polyline_score (Tensor): Confidence scores [N]
208
+ - polylines (Tensor): Polyline coordinates [N, K, 3]
209
+ - query_xyz (Tensor, optional): Query coordinates [N, 3]
210
+ thresholds (list[float]): Confidence thresholds for each class.
211
+ Length must match the number of classes.
212
+ thresholds[i] is the minimum confidence for class i.
213
+
214
+ Returns:
215
+ Data: Filtered Data object containing only polylines that meet their
216
+ class-specific confidence thresholds. Maintains the same structure
217
+ as input but with potentially fewer polylines.
218
+
219
+ Example:
220
+ # Keep only polylines with confidence > 0.5 for class 0, > 0.7 for class 1, etc.
221
+ filtered = filter_predictions(data, [0.5, 0.7, 0.6, 0.8])
222
+ """
223
+ mask = (
224
+ pred_data.polyline_score
225
+ >= torch.tensor(thresholds, device=pred_data.pos.device)[
226
+ pred_data.polyline_class
227
+ ]
228
+ )
229
+ filtered_data = Data(
230
+ pos=pred_data.pos,
231
+ polyline_class=pred_data.polyline_class[mask],
232
+ polyline_score=pred_data.polyline_score[mask],
233
+ polylines=pred_data.polylines[mask],
234
+ query_xyz=(
235
+ pred_data.query_xyz[mask] if hasattr(pred_data, "query_xyz") else None
236
+ ),
237
+ )
238
+
239
+ return filtered_data
240
+
241
+
242
+ def iou_filter_point_based(
243
+ pred_data,
244
+ iou_threshold: float = 0.6,
245
+ background_class: int = 0,
246
+ ):
247
+ """
248
+ Efficient per-class Non-Maximum Suppression using IoU computed on point cloud indices.
249
+
250
+ This optimized NMS implementation:
251
+ 1. Snaps all polyline vertices to nearest point cloud neighbors
252
+ 2. Computes IoU based on overlapping point cloud indices (not 3D distances)
253
+ 3. Applies greedy NMS within each class, keeping highest-scoring polylines
254
+ 4. Uses optimized data structures (cKDTree, sorted arrays) for speed
255
+
256
+ Algorithm details:
257
+ - Single batched nearest neighbor query for all valid vertices
258
+ - IoU = |intersection| / |union| of snapped point indices
259
+ - Polylines ordered by: score (desc) → #snapped_points (desc) → index (asc)
260
+ - Background class polylines are never removed
261
+ - Polylines with no valid snapped points are dropped
262
+
263
+ Args:
264
+ pred_data (Data): PyTorch Geometric Data object with polyline predictions
265
+ iou_threshold (float, optional): IoU threshold for suppression. Default: 0.6
266
+ background_class (int, optional): Class ID to exclude from NMS. Default: 0
267
+
268
+ Returns:
269
+ Data: Filtered Data object with overlapping polylines removed per class.
270
+ Maintains same structure with potentially fewer polylines.
271
+
272
+ Performance:
273
+ Significantly faster than distance-based methods due to:
274
+ - Batched spatial queries (cKDTree)
275
+ - Integer set operations (np.intersect1d)
276
+ - Minimal Python loops
277
+ """
278
+ data = pred_data.clone()
279
+
280
+ polylines: torch.Tensor = data.polylines # (N, M, 3)
281
+ classes: torch.Tensor = data.polyline_class # (N,)
282
+ pc: torch.Tensor = data.pos # (P, 3)
283
+ scores = getattr(data, "polyline_score", None)
284
+
285
+ device = polylines.device
286
+ N = polylines.shape[0]
287
+ if N == 0 or pc.shape[0] == 0:
288
+ return data
289
+
290
+ # ---- helpers ----
291
+ def valid_mask(pts_t: torch.Tensor) -> torch.Tensor:
292
+ finite = torch.isfinite(pts_t).all(dim=-1)
293
+ non_zero = pts_t.abs().sum(dim=-1) > 0
294
+ return finite & non_zero
295
+
296
+ # ---- gather all valid vertices once (batched) ----
297
+ # We'll collect (poly_idx, vertex_xyz) over all non-background curves.
298
+ poly_indices_list = []
299
+ all_vertices = []
300
+
301
+ bg = int(background_class)
302
+ for i in range(N):
303
+ if int(classes[i].item()) == bg:
304
+ continue
305
+ vm = valid_mask(polylines[i])
306
+ if vm.any():
307
+ pts = polylines[i][vm].detach().cpu().numpy()
308
+ if pts.size > 0:
309
+ all_vertices.append(pts)
310
+ poly_indices_list.append(np.full((pts.shape[0],), i, dtype=np.int32))
311
+
312
+ if len(all_vertices) == 0:
313
+ # nothing to snap; everything gets dropped
314
+ keep_mask = torch.zeros(N, dtype=torch.bool, device=device)
315
+ data.polylines = data.polylines[keep_mask]
316
+ data.polyline_class = data.polyline_class[keep_mask]
317
+ if hasattr(data, "polyline_score") and data.polyline_score is not None:
318
+ data.polyline_score = data.polyline_score[keep_mask]
319
+ if hasattr(data, "query_xyz") and data.query_xyz is not None:
320
+ data.query_xyz = data.query_xyz[keep_mask]
321
+ return data
322
+
323
+ all_vertices = np.concatenate(all_vertices, axis=0) # (T, 3)
324
+ owner_poly = np.concatenate(poly_indices_list, axis=0) # (T,)
325
+
326
+ # ---- one cKDTree query for all vertices ----
327
+ pc_np = pc.detach().cpu().numpy()
328
+ tree = cKDTree(pc_np)
329
+ # Use parallel workers if SciPy supports it (falls back silently otherwise)
330
+ nn_idx = tree.query(all_vertices, workers=-1)[1].astype(np.int64) # (T,)
331
+
332
+ # ---- split back to per-curve snapped unique index arrays (sorted) ----
333
+ snapped_arrays = [None] * N
334
+ set_sizes = torch.zeros(N, dtype=torch.long, device=device)
335
+
336
+ # group indices by polyline using numpy argsort
337
+ order = np.argsort(owner_poly, kind="mergesort")
338
+ owner_sorted = owner_poly[order]
339
+ nn_sorted = nn_idx[order]
340
+
341
+ # find segment starts for each unique polyline id
342
+ unique_ids, starts = np.unique(owner_sorted, return_index=True)
343
+ # append end sentinel
344
+ starts = np.append(starts, owner_sorted.shape[0])
345
+
346
+ for k in range(len(unique_ids)):
347
+ i = int(unique_ids[k])
348
+ seg = nn_sorted[starts[k] : starts[k + 1]]
349
+ if seg.size == 0:
350
+ snapped_arrays[i] = np.empty((0,), dtype=np.int64)
351
+ continue
352
+ uniq = np.unique(seg) # already sorted
353
+ snapped_arrays[i] = uniq
354
+ set_sizes[i] = uniq.size
355
+
356
+ # For background curves or curves with no valid vertices, ensure empty arrays
357
+ for i in range(N):
358
+ if snapped_arrays[i] is None:
359
+ snapped_arrays[i] = np.empty((0,), dtype=np.int64)
360
+
361
+ # fallback scores: prefer more snapped support
362
+ if scores is None:
363
+ scores = set_sizes.to(torch.float)
364
+
365
+ keep_mask = torch.ones(N, dtype=torch.bool, device=device)
366
+
367
+ # ---- per-class greedy NMS (IoU via fast array intersection) ----
368
+ target_classes = torch.unique(classes[classes != background_class]).tolist()
369
+ for cls in target_classes:
370
+ cls_inds = torch.where(classes == cls)[0].tolist()
371
+ if not cls_inds:
372
+ continue
373
+
374
+ # order by (score desc, size desc, index asc)
375
+ cls_order = sorted(
376
+ cls_inds,
377
+ key=lambda idx: (
378
+ -float(scores[idx].item()),
379
+ -int(set_sizes[idx].item()),
380
+ idx,
381
+ ),
382
+ )
383
+
384
+ suppressed = set()
385
+ for i_idx in cls_order:
386
+ if i_idx in suppressed:
387
+ continue
388
+
389
+ A = snapped_arrays[i_idx]
390
+ if A.size == 0:
391
+ suppressed.add(i_idx)
392
+ continue
393
+
394
+ lenA = A.size
395
+ for j_idx in cls_order:
396
+ if j_idx <= i_idx or j_idx in suppressed:
397
+ continue
398
+ B = snapped_arrays[j_idx]
399
+ if B.size == 0:
400
+ suppressed.add(j_idx)
401
+ continue
402
+
403
+ # fast intersection of two sorted int arrays
404
+ inter = np.intersect1d(A, B, assume_unique=True).size
405
+ union = lenA + B.size - inter
406
+ if union == 0:
407
+ continue
408
+ if (inter / union) > iou_threshold:
409
+ suppressed.add(j_idx)
410
+
411
+ if suppressed:
412
+ keep_mask[list(suppressed)] = False
413
+
414
+ # ---- filter aligned fields ----
415
+ data.polylines = data.polylines[keep_mask]
416
+ data.polyline_class = data.polyline_class[keep_mask]
417
+ if hasattr(data, "polyline_score") and data.polyline_score is not None:
418
+ data.polyline_score = data.polyline_score[keep_mask]
419
+ if hasattr(data, "query_xyz") and data.query_xyz is not None:
420
+ data.query_xyz = data.query_xyz[keep_mask]
421
+
422
+ return data
423
+
424
+
425
+ def iou_filter_predictions(
426
+ data: Data,
427
+ iou_threshold: float = 0.6,
428
+ tol: float = 1e-2,
429
+ ) -> Data:
430
+ """
431
+ Remove overlapping polylines within each class using point-to-point distance IoU.
432
+
433
+ Performs class-wise Non-Maximum Suppression to eliminate redundant predictions:
434
+ 1. Filters out invalid points (NaN, Inf, near-zero)
435
+ 2. Computes pairwise point distances between polylines of the same class
436
+ 3. Calculates IoU based on points within distance tolerance
437
+ 4. Removes lower-scoring polylines when IoU exceeds threshold
438
+ 5. Protects "lonely" polylines (minimal overlap) from removal
439
+
440
+ IoU Calculation:
441
+ - overlap_i = number of points in polyline_i within tolerance of polyline_j
442
+ - overlap_j = number of points in polyline_j within tolerance of polyline_i
443
+ - intersection = min(overlap_i, overlap_j)
444
+ - union = len(polyline_i) + len(polyline_j) - intersection
445
+ - IoU = intersection / union
446
+
447
+ Args:
448
+ data (Data): PyTorch Geometric Data object containing:
449
+ - polylines (Tensor): Polyline coordinates [N, P, 3]
450
+ - polyline_class (Tensor): Class predictions [N]
451
+ - polyline_score (Tensor): Confidence scores [N]
452
+ - query_xyz (Tensor, optional): Query coordinates [N, 3]
453
+ iou_threshold (float, optional): IoU threshold for duplicate removal. Default: 0.6
454
+ tol (float, optional): Distance tolerance for point overlap detection. Default: 1e-2
455
+
456
+ Returns:
457
+ Data: Filtered Data object with overlapping polylines removed.
458
+ Background class (0) polylines are never removed.
459
+
460
+ Note:
461
+ - Processes polylines in descending score order for deterministic results
462
+ - Requires significant overlap (≥2 points, ≥10% of smaller polyline) before considering removal
463
+ - More computationally expensive than index-based methods but handles arbitrary point clouds
464
+ """
465
+ polylines = data.polylines
466
+ polyline_class = data.polyline_class
467
+ scores = data.polyline_score
468
+
469
+ # Precompute valid points for all polylines
470
+ valid_pts = []
471
+ for poly in polylines:
472
+ mask = ~torch.isnan(poly).any(dim=1) & (
473
+ ~torch.isinf(poly).any(dim=1) & (torch.norm(poly, dim=1) > 1e-6)
474
+ )
475
+ valid_pts.append(poly[mask])
476
+
477
+ remove_set = set()
478
+
479
+ # Process each class independently
480
+ for cls in torch.unique(polyline_class):
481
+ if cls == 0: # Skip background
482
+ continue
483
+ # Get indices for this class
484
+ class_mask = polyline_class == cls
485
+ class_indices = torch.where(class_mask)[0]
486
+ if len(class_indices) < 2:
487
+ continue
488
+
489
+ # Sort by score descending, then index ascending for determinism
490
+ sorted_indices = sorted(
491
+ class_indices.tolist(), key=lambda idx: (-scores[idx].item(), idx)
492
+ )
493
+
494
+ # Compare pairs in sorted order
495
+ for i, idx_i in enumerate(sorted_indices):
496
+ if idx_i in remove_set:
497
+ continue
498
+ pts_i = valid_pts[idx_i]
499
+ if len(pts_i) == 0:
500
+ continue
501
+ for j in range(i + 1, len(sorted_indices)):
502
+ idx_j = sorted_indices[j]
503
+ if idx_j in remove_set:
504
+ continue
505
+ pts_j = valid_pts[idx_j]
506
+ if len(pts_j) == 0:
507
+ continue
508
+
509
+ # Compute point-wise distances
510
+ dists = torch.cdist(pts_i, pts_j)
511
+ # Calculate overlaps
512
+ overlap_i = (dists.min(dim=1).values < tol).sum().item()
513
+ overlap_j = (dists.min(dim=0).values < tol).sum().item()
514
+ min_points = min(len(pts_i), len(pts_j))
515
+ # Skip if not significant overlap
516
+ if (
517
+ overlap_i < 2
518
+ or overlap_j < 2
519
+ or min(overlap_i, overlap_j) < 0.1 * min_points
520
+ ):
521
+ continue
522
+
523
+ # Calculate IoU
524
+ intersection = min(overlap_i, overlap_j)
525
+ union = len(pts_i) + len(pts_j) - intersection
526
+ iou = intersection / union if union > 0 else 0.0
527
+ if iou > iou_threshold:
528
+ # Always remove lower-scoring polyline
529
+ remove_set.add(idx_j)
530
+
531
+ # Create keep mask (protects lonely lines)
532
+ keep_mask = torch.ones(len(polylines), dtype=torch.bool)
533
+ for idx in remove_set:
534
+ keep_mask[idx] = False
535
+
536
+ # Apply filtering
537
+ data.polylines = polylines[keep_mask]
538
+ data.polyline_class = polyline_class[keep_mask]
539
+ data.polyline_score = scores[keep_mask]
540
+ if hasattr(data, "query_xyz"):
541
+ data.query_xyz = data.query_xyz[keep_mask]
542
+
543
+ return data
pi3detr/utils/viz.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def hex_to_rgb(hex_color: str) -> tuple:
2
+ """
3
+ Convert hex color string to RGB tuple.
4
+
5
+ Args:
6
+ hex_color: Hex color string (e.g., "#FF5733")
7
+
8
+ Returns:
9
+ Tuple of RGB values (0-1 range)
10
+ """
11
+ hex_color = hex_color.lstrip("#")
12
+ return tuple(int(hex_color[i : i + 2], 16) / 255.0 for i in (0, 2, 4))
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==== Repos for CUDA 12.1 PyTorch + PyG wheels ====
2
+ --index-url https://download.pytorch.org/whl/cu121
3
+ --extra-index-url https://pypi.org/simple
4
+ --find-links https://data.pyg.org/whl/torch-2.5.1%2Bcu121.html
5
+
6
+ # ==== Core: PyTorch CUDA 12.1 ====
7
+ torch==2.5.1+cu121
8
+
9
+ # ==== PyTorch Geometric stack (built for torch 2.5 + cu121) ====
10
+ pyg-lib==0.4.0+pt25cu121
11
+ torch-scatter==2.1.2+pt25cu121
12
+ torch-sparse==0.6.18+pt25cu121
13
+ torch-cluster==1.6.3+pt25cu121
14
+ torch-spline-conv==1.2.2+pt25cu121
15
+ torch-geometric==2.6.1
16
+
17
+ # ==== Vision / geometry extras ====
18
+ kornia==0.8.0
19
+ opencv-python==4.11.0.86
20
+ open3d==0.19.0
21
+ polyscope==2.3.0
22
+ trimesh==4.6.8
23
+ timm==1.0.14
24
+ spconv-cu121==2.3.8
25
+ fpsample==0.3.3
26
+
27
+ # ==== SciPy stack ====
28
+ # Use NumPy >=2.0 for SciPy 1.15.x; leave upper bound open for compatibility with CUDA wheels.
29
+ numpy>=2.0
30
+ scipy==1.15.2
31
+ scikit-learn==1.6.1
32
+ matplotlib==3.10.0
33
+
34
+ # ==== Jupyter / developer tools ====
35
+ ipython>=8.20
36
+ ipykernel==6.29.5
37
+ ipywidgets==8.1.5
38
+ black==25.1.0
39
+ tqdm>=4.66
40
+ tensorboard==2.19.0
41
+ pytorch-lightning==2.5.0.post0
42
+
43
+ # ==== Hugging Face Space ====
44
+ gradio==5.49.1
45
+ plotly==6.3.1
46
+ plyfile==1.1.2