Commit
·
36effdc
0
Parent(s):
initial test commit
Browse files- .gitattributes +2 -0
- .gitignore +9 -0
- Dockerfile +34 -0
- LICENSE +131 -0
- app.py +1135 -0
- configs/pi3detr.yaml +77 -0
- configs/pi3detr_k256.yaml +77 -0
- demo_inputs/demo1.png +0 -0
- demo_inputs/demo1.xyz +0 -0
- demo_inputs/demo2.png +0 -0
- demo_inputs/demo2.xyz +0 -0
- demo_inputs/demo3.png +0 -0
- demo_inputs/demo3.xyz +0 -0
- demo_inputs/demo4.png +0 -0
- demo_inputs/demo4.xyz +0 -0
- demo_inputs/demo5.png +0 -0
- demo_inputs/demo5.xyz +0 -0
- pi3detr/__init__.py +76 -0
- pi3detr/dataset/__init__.py +7 -0
- pi3detr/dataset/abc_dataset.py +159 -0
- pi3detr/dataset/point_cloud_transforms.py +259 -0
- pi3detr/dataset/utils.py +75 -0
- pi3detr/evaluation/__init__.py +1 -0
- pi3detr/evaluation/abc_metrics.py +349 -0
- pi3detr/models/__init__.py +2 -0
- pi3detr/models/losses/__init__.py +5 -0
- pi3detr/models/losses/losses.py +399 -0
- pi3detr/models/losses/matcher.py +135 -0
- pi3detr/models/model_config.py +66 -0
- pi3detr/models/pi3detr.py +593 -0
- pi3detr/models/pointnetpp.py +324 -0
- pi3detr/models/positional_embedding.py +133 -0
- pi3detr/models/query_engine.py +108 -0
- pi3detr/models/transformer.py +399 -0
- pi3detr/utils/__init__.py +3 -0
- pi3detr/utils/config_reader.py +21 -0
- pi3detr/utils/curve_fitter.py +371 -0
- pi3detr/utils/layer_utils.py +22 -0
- pi3detr/utils/postprocessing.py +543 -0
- pi3detr/utils/viz.py +12 -0
- requirements.txt +46 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
lightning_logs*
|
| 3 |
+
logs_*
|
| 4 |
+
.shapenet
|
| 5 |
+
.vscode
|
| 6 |
+
*.ipynb
|
| 7 |
+
*.ini
|
| 8 |
+
scans
|
| 9 |
+
*.venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
ENV PIP_ROOT_USER_ACTION=ignore
|
| 5 |
+
|
| 6 |
+
RUN apt-get update -qq && \
|
| 7 |
+
apt-get install -y zip git git-lfs vim libgtk2.0-dev ffmpeg libsm6 libxext6 && \
|
| 8 |
+
rm -rf /var/cache/apk/*
|
| 9 |
+
|
| 10 |
+
COPY requirements.txt /workspace
|
| 11 |
+
|
| 12 |
+
# Activate conda environment and install packages
|
| 13 |
+
RUN conda init bash && \
|
| 14 |
+
echo "conda activate base" >> ~/.bashrc
|
| 15 |
+
|
| 16 |
+
SHELL ["conda", "run", "-n", "base", "/bin/bash", "-c"]
|
| 17 |
+
|
| 18 |
+
RUN pip --no-cache-dir install -r /workspace/requirements.txt
|
| 19 |
+
|
| 20 |
+
ARG USERNAME=user
|
| 21 |
+
ARG USER_UID=1000
|
| 22 |
+
ARG USER_GID=$USER_UID
|
| 23 |
+
|
| 24 |
+
# Create the user
|
| 25 |
+
RUN groupadd --gid $USER_GID $USERNAME \
|
| 26 |
+
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME -s /bin/bash \
|
| 27 |
+
#
|
| 28 |
+
# [Optional] Add sudo support. Omit if you don't need to install software after connecting.
|
| 29 |
+
&& apt-get update \
|
| 30 |
+
&& apt-get install -y sudo \
|
| 31 |
+
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
|
| 32 |
+
&& chmod 0440 /etc/sudoers.d/$USERNAME
|
| 33 |
+
|
| 34 |
+
WORKDIR /workspaces/pi3detr
|
LICENSE
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PolyForm Noncommercial License 1.0.0
|
| 2 |
+
|
| 3 |
+
<https://polyformproject.org/licenses/noncommercial/1.0.0>
|
| 4 |
+
|
| 5 |
+
## Acceptance
|
| 6 |
+
|
| 7 |
+
In order to get any license under these terms, you must agree
|
| 8 |
+
to them as both strict obligations and conditions to all
|
| 9 |
+
your licenses.
|
| 10 |
+
|
| 11 |
+
## Copyright License
|
| 12 |
+
|
| 13 |
+
The licensor grants you a copyright license for the
|
| 14 |
+
software to do everything you might do with the software
|
| 15 |
+
that would otherwise infringe the licensor's copyright
|
| 16 |
+
in it for any permitted purpose. However, you may
|
| 17 |
+
only distribute the software according to [Distribution
|
| 18 |
+
License](#distribution-license) and make changes or new works
|
| 19 |
+
based on the software according to [Changes and New Works
|
| 20 |
+
License](#changes-and-new-works-license).
|
| 21 |
+
|
| 22 |
+
## Distribution License
|
| 23 |
+
|
| 24 |
+
The licensor grants you an additional copyright license
|
| 25 |
+
to distribute copies of the software. Your license
|
| 26 |
+
to distribute covers distributing the software with
|
| 27 |
+
changes and new works permitted by [Changes and New Works
|
| 28 |
+
License](#changes-and-new-works-license).
|
| 29 |
+
|
| 30 |
+
## Notices
|
| 31 |
+
|
| 32 |
+
You must ensure that anyone who gets a copy of any part of
|
| 33 |
+
the software from you also gets a copy of these terms or the
|
| 34 |
+
URL for them above, as well as copies of any plain-text lines
|
| 35 |
+
beginning with `Required Notice:` that the licensor provided
|
| 36 |
+
with the software. For example:
|
| 37 |
+
|
| 38 |
+
> Required Notice: Copyright Yoyodyne, Inc. (http://example.com)
|
| 39 |
+
|
| 40 |
+
## Changes and New Works License
|
| 41 |
+
|
| 42 |
+
The licensor grants you an additional copyright license to
|
| 43 |
+
make changes and new works based on the software for any
|
| 44 |
+
permitted purpose.
|
| 45 |
+
|
| 46 |
+
## Patent License
|
| 47 |
+
|
| 48 |
+
The licensor grants you a patent license for the software that
|
| 49 |
+
covers patent claims the licensor can license, or becomes able
|
| 50 |
+
to license, that you would infringe by using the software.
|
| 51 |
+
|
| 52 |
+
## Noncommercial Purposes
|
| 53 |
+
|
| 54 |
+
Any noncommercial purpose is a permitted purpose.
|
| 55 |
+
|
| 56 |
+
## Personal Uses
|
| 57 |
+
|
| 58 |
+
Personal use for research, experiment, and testing for
|
| 59 |
+
the benefit of public knowledge, personal study, private
|
| 60 |
+
entertainment, hobby projects, amateur pursuits, or religious
|
| 61 |
+
observance, without any anticipated commercial application,
|
| 62 |
+
is use for a permitted purpose.
|
| 63 |
+
|
| 64 |
+
## Noncommercial Organizations
|
| 65 |
+
|
| 66 |
+
Use by any charitable organization, educational institution,
|
| 67 |
+
public research organization, public safety or health
|
| 68 |
+
organization, environmental protection organization,
|
| 69 |
+
or government institution is use for a permitted purpose
|
| 70 |
+
regardless of the source of funding or obligations resulting
|
| 71 |
+
from the funding.
|
| 72 |
+
|
| 73 |
+
## Fair Use
|
| 74 |
+
|
| 75 |
+
You may have "fair use" rights for the software under the
|
| 76 |
+
law. These terms do not limit them.
|
| 77 |
+
|
| 78 |
+
## No Other Rights
|
| 79 |
+
|
| 80 |
+
These terms do not allow you to sublicense or transfer any of
|
| 81 |
+
your licenses to anyone else, or prevent the licensor from
|
| 82 |
+
granting licenses to anyone else. These terms do not imply
|
| 83 |
+
any other licenses.
|
| 84 |
+
|
| 85 |
+
## Patent Defense
|
| 86 |
+
|
| 87 |
+
If you make any written claim that the software infringes or
|
| 88 |
+
contributes to infringement of any patent, your patent license
|
| 89 |
+
for the software granted under these terms ends immediately. If
|
| 90 |
+
your company makes such a claim, your patent license ends
|
| 91 |
+
immediately for work on behalf of your company.
|
| 92 |
+
|
| 93 |
+
## Violations
|
| 94 |
+
|
| 95 |
+
The first time you are notified in writing that you have
|
| 96 |
+
violated any of these terms, or done anything with the software
|
| 97 |
+
not covered by your licenses, your licenses can nonetheless
|
| 98 |
+
continue if you come into full compliance with these terms,
|
| 99 |
+
and take practical steps to correct past violations, within
|
| 100 |
+
32 days of receiving notice. Otherwise, all your licenses
|
| 101 |
+
end immediately.
|
| 102 |
+
|
| 103 |
+
## No Liability
|
| 104 |
+
|
| 105 |
+
***As far as the law allows, the software comes as is, without
|
| 106 |
+
any warranty or condition, and the licensor will not be liable
|
| 107 |
+
to you for any damages arising out of these terms or the use
|
| 108 |
+
or nature of the software, under any kind of legal claim.***
|
| 109 |
+
|
| 110 |
+
## Definitions
|
| 111 |
+
|
| 112 |
+
The **licensor** is the individual or entity offering these
|
| 113 |
+
terms, and the **software** is the software the licensor makes
|
| 114 |
+
available under these terms.
|
| 115 |
+
|
| 116 |
+
**You** refers to the individual or entity agreeing to these
|
| 117 |
+
terms.
|
| 118 |
+
|
| 119 |
+
**Your company** is any legal entity, sole proprietorship,
|
| 120 |
+
or other kind of organization that you work for, plus all
|
| 121 |
+
organizations that have control over, are under the control of,
|
| 122 |
+
or are under common control with that organization. **Control**
|
| 123 |
+
means ownership of substantially all the assets of an entity,
|
| 124 |
+
or the power to direct its management and policies by vote,
|
| 125 |
+
contract, or otherwise. Control can be direct or indirect.
|
| 126 |
+
|
| 127 |
+
**Your licenses** are all the licenses granted to you for the
|
| 128 |
+
software under these terms.
|
| 129 |
+
|
| 130 |
+
**Use** means anything you do with the software requiring one
|
| 131 |
+
of your licenses.
|
app.py
ADDED
|
@@ -0,0 +1,1135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio + Plotly point cloud viewer for .xyz, .ply and .obj files with PI3DETR model integration.
|
| 3 |
+
|
| 4 |
+
Features:
|
| 5 |
+
- Upload .xyz (ASCII): one point per line: "x y z" (extra columns are ignored).
|
| 6 |
+
- Upload .ply: Standard PLY format point clouds.
|
| 7 |
+
- Upload .obj: OBJ format with vertices and faces (triangles).
|
| 8 |
+
- Interactive 3D view: orbit, pan, zoom with mouse.
|
| 9 |
+
- Optional: downsample for speed, normalize to unit cube, toggle axes, set point size.
|
| 10 |
+
- Dual view: Input point cloud and model predictions side-by-side.
|
| 11 |
+
- PI3DETR model integration for curve detection.
|
| 12 |
+
- Immediate point cloud rendering on upload.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import io
|
| 16 |
+
import os
|
| 17 |
+
from typing import List, Dict, Optional
|
| 18 |
+
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import numpy as np
|
| 21 |
+
import plotly.graph_objects as go
|
| 22 |
+
from plyfile import PlyData
|
| 23 |
+
import pandas
|
| 24 |
+
import torch
|
| 25 |
+
from torch_geometric.data import Data
|
| 26 |
+
import fpsample
|
| 27 |
+
import trimesh # NEW: for robust mesh loading & surface sampling
|
| 28 |
+
|
| 29 |
+
# Import PI3DETR modules
|
| 30 |
+
from pi3detr import (
|
| 31 |
+
build_model,
|
| 32 |
+
build_model_config,
|
| 33 |
+
load_args,
|
| 34 |
+
load_weights,
|
| 35 |
+
)
|
| 36 |
+
from pi3detr.dataset import normalize_and_scale
|
| 37 |
+
|
| 38 |
+
# Global model cache
|
| 39 |
+
PI3DETR_MODEL = None
|
| 40 |
+
MODEL_STATUS = {"loaded": False, "message": "Model not loaded"}
|
| 41 |
+
|
| 42 |
+
HOVER_FONT_SIZE = 16 # enlarged hover text size
|
| 43 |
+
FIG_TEMPLATE = "plotly_white" # global figure template
|
| 44 |
+
PLOT_HEIGHT = 800 # NEW: desired plot height (adjust as needed)
|
| 45 |
+
|
| 46 |
+
# NEW: demo point cloud file paths (fill these with real .xyz/.ply paths)
|
| 47 |
+
DEMO_POINTCLOUDS = {
|
| 48 |
+
"Demo 1": "demo_inputs/demo1.xyz",
|
| 49 |
+
"Demo 2": "demo_inputs/demo2.xyz",
|
| 50 |
+
"Demo 3": "demo_inputs/demo3.xyz",
|
| 51 |
+
"Demo 4": "demo_inputs/demo4.xyz",
|
| 52 |
+
"Demo 5": "demo_inputs/demo5.xyz",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def initialize_model(
|
| 57 |
+
checkpoint_path="checkpoint.ckpt", config_path="configs/pi3detr.yaml"
|
| 58 |
+
):
|
| 59 |
+
"""Initialize the model at startup and store it in the global cache."""
|
| 60 |
+
global PI3DETR_MODEL, MODEL_STATUS
|
| 61 |
+
try:
|
| 62 |
+
args = load_args(config_path) if config_path else {}
|
| 63 |
+
model_config = build_model_config(args)
|
| 64 |
+
model = build_model(model_config)
|
| 65 |
+
load_weights(model, checkpoint_path)
|
| 66 |
+
model.eval()
|
| 67 |
+
|
| 68 |
+
PI3DETR_MODEL = model
|
| 69 |
+
MODEL_STATUS = {"loaded": True, "message": "Model loaded successfully"}
|
| 70 |
+
print("PI3DETR model initialized successfully")
|
| 71 |
+
return True
|
| 72 |
+
except Exception as e:
|
| 73 |
+
MODEL_STATUS = {"loaded": False, "message": f"Error loading model: {str(e)}"}
|
| 74 |
+
print(f"Error initializing PI3DETR model: {e}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def read_xyz(file_obj: io.BytesIO) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Parse a .xyz text file from bytes and return Nx3 float32 array.
|
| 81 |
+
Lines with fewer than 3 numeric values are skipped.
|
| 82 |
+
Only the first three numeric columns are used.
|
| 83 |
+
"""
|
| 84 |
+
if file_obj is None:
|
| 85 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 86 |
+
|
| 87 |
+
# Read bytes → text
|
| 88 |
+
raw = file_obj.read()
|
| 89 |
+
try:
|
| 90 |
+
text = raw.decode("utf-8", errors="ignore")
|
| 91 |
+
except Exception:
|
| 92 |
+
text = raw.decode("latin-1", errors="ignore")
|
| 93 |
+
|
| 94 |
+
pts = []
|
| 95 |
+
for line in text.splitlines():
|
| 96 |
+
line = line.strip()
|
| 97 |
+
if not line or line.startswith("#"):
|
| 98 |
+
continue
|
| 99 |
+
parts = line.replace(",", " ").split()
|
| 100 |
+
nums = []
|
| 101 |
+
for p in parts:
|
| 102 |
+
try:
|
| 103 |
+
nums.append(float(p))
|
| 104 |
+
except ValueError:
|
| 105 |
+
# skip non-numeric tokens
|
| 106 |
+
pass
|
| 107 |
+
if len(nums) == 3:
|
| 108 |
+
break
|
| 109 |
+
if len(nums) >= 3:
|
| 110 |
+
pts.append(nums[:3])
|
| 111 |
+
|
| 112 |
+
if not pts:
|
| 113 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 114 |
+
|
| 115 |
+
return np.asarray(pts, dtype=np.float32)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def read_ply(file_obj: io.BytesIO) -> np.ndarray:
|
| 119 |
+
"""
|
| 120 |
+
Parse a .ply file from bytes and return Nx3 float32 array of points.
|
| 121 |
+
"""
|
| 122 |
+
if file_obj is None:
|
| 123 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
ply_data = PlyData.read(file_obj)
|
| 127 |
+
vertex = ply_data["vertex"]
|
| 128 |
+
|
| 129 |
+
x = np.asarray(vertex["x"])
|
| 130 |
+
y = np.asarray(vertex["y"])
|
| 131 |
+
z = np.asarray(vertex["z"])
|
| 132 |
+
|
| 133 |
+
points = np.column_stack([x, y, z]).astype(np.float32)
|
| 134 |
+
return points
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Error reading PLY file: {e}")
|
| 137 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def read_obj_and_sample(file_obj: io.BytesIO, display_max_points: int):
|
| 141 |
+
"""Parse OBJ via trimesh and sample up to display_max_points uniformly over the surface."""
|
| 142 |
+
raw = file_obj.read()
|
| 143 |
+
# Rewind not strictly needed after read since we don't reuse file_obj
|
| 144 |
+
try:
|
| 145 |
+
mesh = trimesh.load(io.BytesIO(raw), file_type="obj", force="mesh")
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"trimesh load error: {e}")
|
| 148 |
+
return (
|
| 149 |
+
np.zeros((0, 3), dtype=np.float32),
|
| 150 |
+
np.zeros((0, 3), dtype=np.float32),
|
| 151 |
+
"OBJ load failure",
|
| 152 |
+
)
|
| 153 |
+
# Handle scenes by merging
|
| 154 |
+
if isinstance(mesh, trimesh.Scene):
|
| 155 |
+
mesh = trimesh.util.concatenate(tuple(g for g in mesh.geometry.values()))
|
| 156 |
+
if mesh.is_empty or mesh.vertices.shape[0] == 0:
|
| 157 |
+
return (
|
| 158 |
+
np.zeros((0, 3), dtype=np.float32),
|
| 159 |
+
np.zeros((0, 3), dtype=np.float32),
|
| 160 |
+
"OBJ: empty mesh",
|
| 161 |
+
)
|
| 162 |
+
sample_n = min(display_max_points, max(1, display_max_points))
|
| 163 |
+
try:
|
| 164 |
+
sampled = mesh.sample(sample_n)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Sampling error: {e}")
|
| 167 |
+
sampled = mesh.vertices
|
| 168 |
+
if sampled.shape[0] > sample_n:
|
| 169 |
+
sampled = sampled[:sample_n]
|
| 170 |
+
sampled = np.asarray(sampled, dtype=np.float32)
|
| 171 |
+
info = f"OBJ: {mesh.vertices.shape[0]} verts, {len(mesh.faces) if mesh.faces is not None else 0} tris | Surface sampled: {sampled.shape[0]} pts"
|
| 172 |
+
model_pts = sampled.copy()
|
| 173 |
+
return model_pts, sampled, info
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def downsample(pts: np.ndarray, max_points: int) -> np.ndarray:
|
| 177 |
+
if pts.shape[0] <= max_points:
|
| 178 |
+
return pts
|
| 179 |
+
rng = np.random.default_rng(42)
|
| 180 |
+
idx = rng.choice(pts.shape[0], size=max_points, replace=False)
|
| 181 |
+
return pts[idx]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def make_figure(
|
| 185 |
+
pts: np.ndarray,
|
| 186 |
+
point_size: int = 2,
|
| 187 |
+
show_axes: bool = True,
|
| 188 |
+
title: str = "",
|
| 189 |
+
polylines: Optional[List[Dict]] = None,
|
| 190 |
+
) -> go.Figure:
|
| 191 |
+
"""
|
| 192 |
+
Build a Plotly 3D scatter figure with equal aspect ratio.
|
| 193 |
+
Optionally includes polylines from model predictions.
|
| 194 |
+
"""
|
| 195 |
+
if pts.size == 0 and (polylines is None or len(polylines) == 0):
|
| 196 |
+
fig = go.Figure()
|
| 197 |
+
fig.update_layout(
|
| 198 |
+
title="No data to display",
|
| 199 |
+
template=FIG_TEMPLATE,
|
| 200 |
+
scene=dict(
|
| 201 |
+
xaxis_visible=False,
|
| 202 |
+
yaxis_visible=False,
|
| 203 |
+
zaxis_visible=False,
|
| 204 |
+
),
|
| 205 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 206 |
+
)
|
| 207 |
+
return fig
|
| 208 |
+
|
| 209 |
+
fig = go.Figure()
|
| 210 |
+
|
| 211 |
+
# Add point cloud if available
|
| 212 |
+
if pts.size > 0:
|
| 213 |
+
x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]
|
| 214 |
+
fig.add_trace(
|
| 215 |
+
go.Scatter3d(
|
| 216 |
+
x=x,
|
| 217 |
+
y=y,
|
| 218 |
+
z=z,
|
| 219 |
+
mode="markers",
|
| 220 |
+
marker=dict(
|
| 221 |
+
size=max(1, int(point_size)), color="darkgray", opacity=0.2
|
| 222 |
+
),
|
| 223 |
+
hoverinfo="skip",
|
| 224 |
+
name="Curves",
|
| 225 |
+
showlegend=False, # legend hidden
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Define colors for each curve type
|
| 230 |
+
curve_colors = {
|
| 231 |
+
"Line": "blue",
|
| 232 |
+
"Circle": "green",
|
| 233 |
+
"Arc": "red",
|
| 234 |
+
"BSpline": "purple",
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# Add polylines from model predictions if available
|
| 238 |
+
if polylines:
|
| 239 |
+
for curve in polylines:
|
| 240 |
+
points = np.array(curve["points"])
|
| 241 |
+
if len(points) < 2:
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
curve_type = curve["type"]
|
| 245 |
+
curve_id = curve["id"]
|
| 246 |
+
score = curve["score"]
|
| 247 |
+
|
| 248 |
+
# NEW: allow override color if provided (e.g., threshold filtered)
|
| 249 |
+
color = curve.get("display_color") or curve_colors.get(curve_type, "orange")
|
| 250 |
+
|
| 251 |
+
# NEW: support hidden-by-default via legendonly
|
| 252 |
+
fig.add_trace(
|
| 253 |
+
go.Scatter3d(
|
| 254 |
+
x=points[:, 0],
|
| 255 |
+
y=points[:, 1],
|
| 256 |
+
z=points[:, 2],
|
| 257 |
+
mode="lines",
|
| 258 |
+
line=dict(color=color, width=5),
|
| 259 |
+
name=f"{curve_type} #{curve_id} ({score:.2f})",
|
| 260 |
+
visible=curve.get("visible_state", True),
|
| 261 |
+
hoverinfo="text",
|
| 262 |
+
text=f"{curve_type} #{curve_id} ({score:.4f})",
|
| 263 |
+
showlegend=False, # hide individual curve legend entries
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Equal aspect ratio using data ranges
|
| 268 |
+
if pts.size > 0:
|
| 269 |
+
mins = pts.min(axis=0)
|
| 270 |
+
maxs = pts.max(axis=0)
|
| 271 |
+
elif polylines and len(polylines) > 0:
|
| 272 |
+
# If we only have polylines, calculate range from them
|
| 273 |
+
all_points = np.vstack([np.array(curve["points"]) for curve in polylines])
|
| 274 |
+
mins = all_points.min(axis=0)
|
| 275 |
+
maxs = all_points.max(axis=0)
|
| 276 |
+
else:
|
| 277 |
+
mins = np.array([-1, -1, -1])
|
| 278 |
+
maxs = np.array([1, 1, 1])
|
| 279 |
+
|
| 280 |
+
centers = (mins + maxs) / 2.0
|
| 281 |
+
span = (maxs - mins).max()
|
| 282 |
+
if span <= 0:
|
| 283 |
+
span = 1.0
|
| 284 |
+
half = span / 2.0
|
| 285 |
+
xrange = [centers[0] - half, centers[0] + half]
|
| 286 |
+
yrange = [centers[1] - half, centers[1] + half]
|
| 287 |
+
zrange = [centers[2] - half, centers[2] + half]
|
| 288 |
+
|
| 289 |
+
scene_axes = dict(
|
| 290 |
+
xaxis=dict(range=xrange, visible=show_axes, title="x" if show_axes else ""),
|
| 291 |
+
yaxis=dict(range=yrange, visible=show_axes, title="y" if show_axes else ""),
|
| 292 |
+
zaxis=dict(range=zrange, visible=show_axes, title="z" if show_axes else ""),
|
| 293 |
+
aspectmode="cube",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
fig.update_layout(
|
| 297 |
+
title=title,
|
| 298 |
+
template=FIG_TEMPLATE,
|
| 299 |
+
showlegend=False,
|
| 300 |
+
scene=scene_axes,
|
| 301 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 302 |
+
hoverlabel=dict(font=dict(size=HOVER_FONT_SIZE)),
|
| 303 |
+
height=PLOT_HEIGHT, # NEW
|
| 304 |
+
)
|
| 305 |
+
return fig
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def process_model_predictions(data: Data) -> list:
|
| 309 |
+
"""
|
| 310 |
+
Process model outputs into a format suitable for visualization.
|
| 311 |
+
"""
|
| 312 |
+
class_names = ["None", "BSpline", "Line", "Circle", "Arc"]
|
| 313 |
+
polylines = data.polylines.cpu().numpy()
|
| 314 |
+
curves = []
|
| 315 |
+
|
| 316 |
+
# Process detected polylines
|
| 317 |
+
for i, polyline in enumerate(polylines):
|
| 318 |
+
cls = data.polyline_class[i].item()
|
| 319 |
+
score = data.polyline_score[i].item()
|
| 320 |
+
cls_name = class_names[cls]
|
| 321 |
+
|
| 322 |
+
# Skip low-confidence or "None" class predictions
|
| 323 |
+
if cls == 0:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
# Add curve data to results with unique ID
|
| 327 |
+
curve_data = {
|
| 328 |
+
"type": cls_name,
|
| 329 |
+
"id": i + 1, # 1-based ID for better user experience
|
| 330 |
+
"index": i,
|
| 331 |
+
"score": score,
|
| 332 |
+
"points": polyline,
|
| 333 |
+
}
|
| 334 |
+
curves.append(curve_data)
|
| 335 |
+
|
| 336 |
+
return curves
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def process_data_for_model(
|
| 340 |
+
points: np.ndarray,
|
| 341 |
+
sample: int = 32768,
|
| 342 |
+
sample_mode: str = "fps",
|
| 343 |
+
) -> Data: # CHANGED: removed reduction param
|
| 344 |
+
"""
|
| 345 |
+
Process and subsample point cloud data using the same approach as predict_pi3detr.py.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
points: Input point cloud as numpy array
|
| 349 |
+
sample: Number of points to sample
|
| 350 |
+
sample_mode: Sampling method ("fps", "random", "uniform", "all")
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Data object ready for model inference
|
| 354 |
+
"""
|
| 355 |
+
# Convert to torch tensor
|
| 356 |
+
pos = torch.tensor(points, dtype=torch.float32)
|
| 357 |
+
|
| 358 |
+
# Apply sampling strategy
|
| 359 |
+
if sample_mode == "random":
|
| 360 |
+
if pos.size(0) > sample:
|
| 361 |
+
indices = torch.randperm(pos.size(0))[:sample]
|
| 362 |
+
pos = pos[indices]
|
| 363 |
+
|
| 364 |
+
elif sample_mode == "fps":
|
| 365 |
+
if pos.size(0) > sample:
|
| 366 |
+
indices = fpsample.bucket_fps_kdline_sampling(pos, sample, h=6)
|
| 367 |
+
pos = pos[indices]
|
| 368 |
+
|
| 369 |
+
elif sample_mode == "uniform":
|
| 370 |
+
if pos.size(0) > sample:
|
| 371 |
+
step = max(1, pos.size(0) // sample)
|
| 372 |
+
pos = pos[::step][:sample]
|
| 373 |
+
|
| 374 |
+
elif sample_mode == "all":
|
| 375 |
+
pass # Keep all points
|
| 376 |
+
|
| 377 |
+
# Create Data object
|
| 378 |
+
data = Data(pos=pos)
|
| 379 |
+
|
| 380 |
+
# Add batch information for single point cloud BEFORE normalization
|
| 381 |
+
data.batch = torch.zeros(data.pos.size(0), dtype=torch.long)
|
| 382 |
+
data.batch_size = 1
|
| 383 |
+
|
| 384 |
+
# Normalize and scale using PI3DETR's method
|
| 385 |
+
data = normalize_and_scale(data)
|
| 386 |
+
|
| 387 |
+
# Ensure scale and center are proper batch tensors
|
| 388 |
+
if hasattr(data, "scale") and data.scale.dim() == 0:
|
| 389 |
+
data.scale = data.scale.unsqueeze(0)
|
| 390 |
+
if hasattr(data, "center") and data.center.dim() == 1:
|
| 391 |
+
data.center = data.center.unsqueeze(0)
|
| 392 |
+
|
| 393 |
+
return data
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@torch.no_grad()
|
| 397 |
+
def run_model_inference(
|
| 398 |
+
model,
|
| 399 |
+
points: np.ndarray,
|
| 400 |
+
max_points: int = 32768,
|
| 401 |
+
sample_mode: str = "fps",
|
| 402 |
+
num_queries: int = 256, # NEW
|
| 403 |
+
snap_and_fit: bool = False, # NEW
|
| 404 |
+
iou_filter: bool = False, # NEW
|
| 405 |
+
) -> list:
|
| 406 |
+
"""Run model inference on the given point cloud (extended with num_queries, snap_and_fit, iou_filter)."""
|
| 407 |
+
global PI3DETR_MODEL
|
| 408 |
+
if model is None:
|
| 409 |
+
model = PI3DETR_MODEL
|
| 410 |
+
if model is None:
|
| 411 |
+
return []
|
| 412 |
+
try:
|
| 413 |
+
data = process_data_for_model(
|
| 414 |
+
points, sample=max_points, sample_mode=sample_mode
|
| 415 |
+
)
|
| 416 |
+
device = next(model.parameters()).device
|
| 417 |
+
data = data.to(device)
|
| 418 |
+
|
| 419 |
+
if model.num_preds != num_queries:
|
| 420 |
+
model.set_num_preds(num_queries)
|
| 421 |
+
|
| 422 |
+
output = model.predict_step(
|
| 423 |
+
data,
|
| 424 |
+
reverse_norm=True,
|
| 425 |
+
thresholds=None,
|
| 426 |
+
snap_and_fit=snap_and_fit, # CHANGED
|
| 427 |
+
iou_filter=iou_filter, # CHANGED
|
| 428 |
+
)
|
| 429 |
+
result = output[0]
|
| 430 |
+
curves = process_model_predictions(result)
|
| 431 |
+
return curves
|
| 432 |
+
except Exception as e:
|
| 433 |
+
print(f"Error in model inference: {e}")
|
| 434 |
+
return []
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def load_and_process_pointcloud(
|
| 438 |
+
file: gr.File,
|
| 439 |
+
max_points: int,
|
| 440 |
+
point_size: int,
|
| 441 |
+
show_axes: bool,
|
| 442 |
+
):
|
| 443 |
+
"""
|
| 444 |
+
Load and process a point cloud from .xyz or .ply file
|
| 445 |
+
"""
|
| 446 |
+
if file is None:
|
| 447 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 448 |
+
return empty_fig, None, None, os.path.basename(file.name) if file else ""
|
| 449 |
+
|
| 450 |
+
# Determine file type and read accordingly
|
| 451 |
+
file_ext = os.path.splitext(file.name)[1].lower()
|
| 452 |
+
|
| 453 |
+
# Read file based on extension
|
| 454 |
+
with open(file.name, "rb") as f:
|
| 455 |
+
if file_ext == ".xyz":
|
| 456 |
+
pts = read_xyz(f)
|
| 457 |
+
mode = "XYZ"
|
| 458 |
+
elif file_ext == ".ply":
|
| 459 |
+
pts = read_ply(f)
|
| 460 |
+
mode = "PLY"
|
| 461 |
+
elif file_ext == ".obj":
|
| 462 |
+
model_pts, display_pts, _ = read_obj_and_sample(f, max_points)
|
| 463 |
+
fig = make_figure(
|
| 464 |
+
display_pts,
|
| 465 |
+
point_size=point_size,
|
| 466 |
+
show_axes=show_axes,
|
| 467 |
+
title=f"{os.path.basename(file.name)}",
|
| 468 |
+
)
|
| 469 |
+
return fig, model_pts, display_pts, os.path.basename(file.name)
|
| 470 |
+
else:
|
| 471 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 472 |
+
return (
|
| 473 |
+
empty_fig,
|
| 474 |
+
None,
|
| 475 |
+
None,
|
| 476 |
+
"Unsupported file type. Please use .xyz, .ply or .obj.",
|
| 477 |
+
"",
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
original_n = pts.shape[0]
|
| 481 |
+
|
| 482 |
+
# Keep original points for model if normalizing for display
|
| 483 |
+
model_pts = pts.copy()
|
| 484 |
+
|
| 485 |
+
pts = downsample(pts, max_points=max_points)
|
| 486 |
+
displayed_n = pts.shape[0]
|
| 487 |
+
|
| 488 |
+
fig = make_figure(
|
| 489 |
+
pts,
|
| 490 |
+
point_size=point_size,
|
| 491 |
+
show_axes=show_axes,
|
| 492 |
+
title=f"{os.path.basename(file.name)}",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
info = f"Loaded ({mode}): {original_n} points" # | Displayed: {displayed_n} points"
|
| 496 |
+
|
| 497 |
+
# RETURN single figure + model/full points + displayed subset
|
| 498 |
+
return fig, model_pts, pts, os.path.basename(file.name) # ADDED filename
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def run_model_prediction(
|
| 502 |
+
model_pts: np.ndarray,
|
| 503 |
+
point_size: int,
|
| 504 |
+
show_axes: bool,
|
| 505 |
+
model_max_points: int,
|
| 506 |
+
sample_mode: str,
|
| 507 |
+
th_bspline: float,
|
| 508 |
+
th_line: float,
|
| 509 |
+
th_circle: float,
|
| 510 |
+
th_arc: float,
|
| 511 |
+
num_queries: int = 256,
|
| 512 |
+
snap_and_fit: bool = False,
|
| 513 |
+
iou_filter: bool = False,
|
| 514 |
+
): # CHANGED: removed reduction
|
| 515 |
+
# NOTE: display points now handled outside; keep signature (called before adding display pts state)
|
| 516 |
+
# (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version)
|
| 517 |
+
return run_model_prediction_unified( # type: ignore
|
| 518 |
+
model_pts,
|
| 519 |
+
None,
|
| 520 |
+
point_size,
|
| 521 |
+
show_axes,
|
| 522 |
+
model_max_points,
|
| 523 |
+
sample_mode,
|
| 524 |
+
th_bspline,
|
| 525 |
+
th_line,
|
| 526 |
+
th_circle,
|
| 527 |
+
th_arc,
|
| 528 |
+
"",
|
| 529 |
+
num_queries,
|
| 530 |
+
snap_and_fit,
|
| 531 |
+
iou_filter,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def run_model_prediction_unified(
|
| 536 |
+
model_pts: np.ndarray,
|
| 537 |
+
display_pts: Optional[np.ndarray],
|
| 538 |
+
point_size: int,
|
| 539 |
+
show_axes: bool,
|
| 540 |
+
model_max_points: int,
|
| 541 |
+
sample_mode: str,
|
| 542 |
+
th_bspline: float,
|
| 543 |
+
th_line: float,
|
| 544 |
+
th_circle: float,
|
| 545 |
+
th_arc: float,
|
| 546 |
+
file_name: str = "",
|
| 547 |
+
num_queries: int = 256,
|
| 548 |
+
snap_and_fit: bool = False,
|
| 549 |
+
iou_filter: bool = False,
|
| 550 |
+
):
|
| 551 |
+
"""
|
| 552 |
+
Run model inference and apply initial threshold-based coloring.
|
| 553 |
+
"""
|
| 554 |
+
global PI3DETR_MODEL, MODEL_STATUS
|
| 555 |
+
if model_pts is None:
|
| 556 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 557 |
+
return empty_fig, []
|
| 558 |
+
|
| 559 |
+
# Run model inference using cached model
|
| 560 |
+
curves = []
|
| 561 |
+
try:
|
| 562 |
+
if PI3DETR_MODEL is None and not MODEL_STATUS["loaded"]:
|
| 563 |
+
# Try to initialize model if not already loaded
|
| 564 |
+
initialize_model()
|
| 565 |
+
|
| 566 |
+
if PI3DETR_MODEL is not None:
|
| 567 |
+
# Run inference with the same settings as predict_pi3detr.py
|
| 568 |
+
curves = run_model_inference(
|
| 569 |
+
PI3DETR_MODEL,
|
| 570 |
+
model_pts,
|
| 571 |
+
max_points=model_max_points,
|
| 572 |
+
sample_mode=sample_mode,
|
| 573 |
+
num_queries=num_queries, # NEW
|
| 574 |
+
snap_and_fit=snap_and_fit, # NEW
|
| 575 |
+
iou_filter=iou_filter, # NEW
|
| 576 |
+
)
|
| 577 |
+
except Exception:
|
| 578 |
+
pass
|
| 579 |
+
|
| 580 |
+
# NEW: apply thresholds for display (store raw curves separately)
|
| 581 |
+
thresholds = {
|
| 582 |
+
"BSpline": th_bspline,
|
| 583 |
+
"Line": th_line,
|
| 584 |
+
"Circle": th_circle,
|
| 585 |
+
"Arc": th_arc,
|
| 586 |
+
}
|
| 587 |
+
colored_curves = []
|
| 588 |
+
for c in curves:
|
| 589 |
+
c_disp = dict(c)
|
| 590 |
+
if c["score"] < thresholds.get(c["type"], 0.7):
|
| 591 |
+
c_disp["visible_state"] = "legendonly"
|
| 592 |
+
colored_curves.append(c_disp)
|
| 593 |
+
|
| 594 |
+
# Use existing displayed subset if provided; else derive lightweight subset
|
| 595 |
+
if display_pts is None:
|
| 596 |
+
display_pts = downsample(model_pts, max_points=100000)
|
| 597 |
+
title = f"{file_name} (curves)" if curves else f"{file_name} (no curves)"
|
| 598 |
+
fig = make_figure(
|
| 599 |
+
display_pts,
|
| 600 |
+
point_size=point_size,
|
| 601 |
+
show_axes=show_axes,
|
| 602 |
+
title=title,
|
| 603 |
+
polylines=colored_curves,
|
| 604 |
+
)
|
| 605 |
+
return fig, curves
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def apply_pointcloud_display_settings(
|
| 609 |
+
model_pts: np.ndarray,
|
| 610 |
+
curves: List[Dict],
|
| 611 |
+
max_points: int,
|
| 612 |
+
point_size: int,
|
| 613 |
+
show_axes: bool,
|
| 614 |
+
th_bspline: float,
|
| 615 |
+
th_line: float,
|
| 616 |
+
th_circle: float,
|
| 617 |
+
th_arc: float,
|
| 618 |
+
file_name: str,
|
| 619 |
+
):
|
| 620 |
+
"""
|
| 621 |
+
Apply point cloud display settings without re-running inference.
|
| 622 |
+
Keeps existing detections and re-applies thresholds.
|
| 623 |
+
"""
|
| 624 |
+
if model_pts is None:
|
| 625 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 626 |
+
return empty_fig, None
|
| 627 |
+
display_pts = downsample(model_pts, max_points=max_points)
|
| 628 |
+
if not curves:
|
| 629 |
+
fig = make_figure(
|
| 630 |
+
display_pts,
|
| 631 |
+
point_size=point_size,
|
| 632 |
+
show_axes=show_axes,
|
| 633 |
+
title=file_name or "Point Cloud",
|
| 634 |
+
)
|
| 635 |
+
return fig, display_pts
|
| 636 |
+
thresholds = {
|
| 637 |
+
"BSpline": th_bspline,
|
| 638 |
+
"Line": th_line,
|
| 639 |
+
"Circle": th_circle,
|
| 640 |
+
"Arc": th_arc,
|
| 641 |
+
}
|
| 642 |
+
colored_curves = []
|
| 643 |
+
for c in curves:
|
| 644 |
+
c_disp = dict(c)
|
| 645 |
+
if c["score"] < thresholds.get(c["type"], 0.7):
|
| 646 |
+
c_disp["visible_state"] = "legendonly"
|
| 647 |
+
colored_curves.append(c_disp)
|
| 648 |
+
fig = make_figure(
|
| 649 |
+
display_pts,
|
| 650 |
+
point_size=point_size,
|
| 651 |
+
show_axes=show_axes,
|
| 652 |
+
title=(file_name or "Point Cloud") + " (curves)",
|
| 653 |
+
polylines=colored_curves,
|
| 654 |
+
)
|
| 655 |
+
return fig, display_pts
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def clear_curves(
|
| 659 |
+
curves: List[Dict],
|
| 660 |
+
display_pts: Optional[np.ndarray],
|
| 661 |
+
model_pts: Optional[np.ndarray],
|
| 662 |
+
point_size: int,
|
| 663 |
+
show_axes: bool,
|
| 664 |
+
file_name: str,
|
| 665 |
+
):
|
| 666 |
+
"""
|
| 667 |
+
Recolor already inferred curves based on updated thresholds (no re-inference).
|
| 668 |
+
"""
|
| 669 |
+
if curves is None or model_pts is None or len(curves) == 0:
|
| 670 |
+
empty_fig = make_figure(
|
| 671 |
+
display_pts if display_pts is not None else np.zeros((0, 3))
|
| 672 |
+
)
|
| 673 |
+
return empty_fig, None
|
| 674 |
+
|
| 675 |
+
fig = make_figure(
|
| 676 |
+
display_pts if display_pts is not None else np.zeros((0, 3)),
|
| 677 |
+
point_size=point_size,
|
| 678 |
+
show_axes=show_axes,
|
| 679 |
+
title=file_name or "Point Cloud",
|
| 680 |
+
polylines=None,
|
| 681 |
+
)
|
| 682 |
+
return fig, None
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def load_demo_pointcloud(
|
| 686 |
+
label: str,
|
| 687 |
+
max_points: int,
|
| 688 |
+
point_size: int,
|
| 689 |
+
show_axes: bool,
|
| 690 |
+
):
|
| 691 |
+
"""
|
| 692 |
+
Load one of the predefined demo point clouds.
|
| 693 |
+
Clears existing detected curves (curves_state -> None).
|
| 694 |
+
Also returns a value for the file upload component so the filename shows up.
|
| 695 |
+
"""
|
| 696 |
+
path = DEMO_POINTCLOUDS.get(label, "")
|
| 697 |
+
if not path or not os.path.isfile(path):
|
| 698 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 699 |
+
return empty_fig, None, None, "", None, None
|
| 700 |
+
ext = os.path.splitext(path)[1].lower()
|
| 701 |
+
try:
|
| 702 |
+
with open(path, "rb") as f:
|
| 703 |
+
if ext == ".xyz":
|
| 704 |
+
pts = read_xyz(f)
|
| 705 |
+
elif ext == ".ply":
|
| 706 |
+
pts = read_ply(f)
|
| 707 |
+
elif ext == ".obj":
|
| 708 |
+
model_pts, display_pts, _ = read_obj_and_sample(
|
| 709 |
+
f, min(20000, max_points)
|
| 710 |
+
)
|
| 711 |
+
fig = make_figure(
|
| 712 |
+
display_pts,
|
| 713 |
+
point_size=1,
|
| 714 |
+
show_axes=show_axes,
|
| 715 |
+
title=f"{os.path.basename(path)} (demo)",
|
| 716 |
+
)
|
| 717 |
+
return fig, model_pts, display_pts, os.path.basename(path), None, path
|
| 718 |
+
else:
|
| 719 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 720 |
+
return empty_fig, None, None, "", None, None
|
| 721 |
+
except Exception:
|
| 722 |
+
empty_fig = make_figure(np.zeros((0, 3)))
|
| 723 |
+
return empty_fig, None, None, "", None, None
|
| 724 |
+
model_pts = pts.copy()
|
| 725 |
+
pts = downsample(pts, max_points=max_points)
|
| 726 |
+
fig = make_figure(
|
| 727 |
+
pts,
|
| 728 |
+
point_size=1,
|
| 729 |
+
show_axes=show_axes,
|
| 730 |
+
title=f"{os.path.basename(path)} (demo)",
|
| 731 |
+
)
|
| 732 |
+
return fig, model_pts, pts, os.path.basename(path), None, path
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
# Convenience wrappers for each demo (avoid lambdas for clarity)
|
| 736 |
+
def load_demo1(max_points, point_size, show_axes):
|
| 737 |
+
return load_demo_pointcloud("Demo 1", max_points, point_size, show_axes)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def load_demo2(max_points, point_size, show_axes):
|
| 741 |
+
return load_demo_pointcloud("Demo 2", max_points, point_size, show_axes)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def load_demo3(max_points, point_size, show_axes):
|
| 745 |
+
return load_demo_pointcloud("Demo 3", max_points, point_size, show_axes)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def load_demo4(max_points, point_size, show_axes): # NEW
|
| 749 |
+
return load_demo_pointcloud("Demo 4", max_points, point_size, show_axes)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def load_demo5(max_points, point_size, show_axes): # NEW
|
| 753 |
+
return load_demo_pointcloud("Demo 5", max_points, point_size, show_axes)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def build_demo_preview(label: str, max_pts: int = 20000) -> go.Figure:
|
| 757 |
+
"""Create a small preview figure for a demo point cloud (no curves)."""
|
| 758 |
+
path = DEMO_POINTCLOUDS.get(label, "")
|
| 759 |
+
if not path or not os.path.isfile(path):
|
| 760 |
+
return make_figure(np.zeros((0, 3)), title=f"{label}: (missing)")
|
| 761 |
+
try:
|
| 762 |
+
ext = os.path.splitext(path)[1].lower()
|
| 763 |
+
with open(path, "rb") as f:
|
| 764 |
+
if ext == ".xyz":
|
| 765 |
+
pts = read_xyz(f)
|
| 766 |
+
elif ext == ".ply":
|
| 767 |
+
pts = read_ply(f)
|
| 768 |
+
elif ext == ".obj": # UPDATED
|
| 769 |
+
_, pts, _ = read_obj_and_sample(f, max_pts)
|
| 770 |
+
else:
|
| 771 |
+
return make_figure(np.zeros((0, 3)), title=f"{label}: (unsupported)")
|
| 772 |
+
pts = downsample(pts, max_pts)
|
| 773 |
+
return make_figure(pts, point_size=1, show_axes=False, title=f"{label} preview")
|
| 774 |
+
except Exception as e:
|
| 775 |
+
return make_figure(np.zeros((0, 3)), title=f"{label}: error")
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def run_model_with_display(
|
| 779 |
+
model_pts: np.ndarray,
|
| 780 |
+
max_points: int,
|
| 781 |
+
point_size: int,
|
| 782 |
+
show_axes: bool,
|
| 783 |
+
model_max_points: int,
|
| 784 |
+
sample_mode: str,
|
| 785 |
+
th_bspline: float,
|
| 786 |
+
th_line: float,
|
| 787 |
+
th_circle: float,
|
| 788 |
+
th_arc: float,
|
| 789 |
+
file_name: str = "",
|
| 790 |
+
num_queries: int = 256,
|
| 791 |
+
snap_and_fit: bool = False,
|
| 792 |
+
iou_filter: bool = False,
|
| 793 |
+
): # CHANGED: removed reduction
|
| 794 |
+
"""
|
| 795 |
+
Run inference (if model_pts present) then immediately apply current display
|
| 796 |
+
(max_points/point_size/show_axes) and thresholds. Returns:
|
| 797 |
+
figure, info_text, curves(list), display_pts
|
| 798 |
+
"""
|
| 799 |
+
if model_pts is None:
|
| 800 |
+
empty = make_figure(np.zeros((0, 3)))
|
| 801 |
+
return empty, None, None
|
| 802 |
+
|
| 803 |
+
# Inference first (no display subset passed so it builds from model_pts)
|
| 804 |
+
fig_infer, curves = run_model_prediction_unified(
|
| 805 |
+
model_pts,
|
| 806 |
+
None,
|
| 807 |
+
point_size,
|
| 808 |
+
show_axes,
|
| 809 |
+
model_max_points,
|
| 810 |
+
sample_mode,
|
| 811 |
+
th_bspline,
|
| 812 |
+
th_line,
|
| 813 |
+
th_circle,
|
| 814 |
+
th_arc,
|
| 815 |
+
file_name,
|
| 816 |
+
num_queries, # NEW
|
| 817 |
+
snap_and_fit, # NEW
|
| 818 |
+
iou_filter, # NEW
|
| 819 |
+
) # CHANGED: removed reduction
|
| 820 |
+
|
| 821 |
+
# Now apply current display settings & thresholds without re-inference
|
| 822 |
+
fig_final, display_pts = apply_pointcloud_display_settings(
|
| 823 |
+
model_pts,
|
| 824 |
+
curves,
|
| 825 |
+
max_points,
|
| 826 |
+
point_size,
|
| 827 |
+
show_axes,
|
| 828 |
+
th_bspline,
|
| 829 |
+
th_line,
|
| 830 |
+
th_circle,
|
| 831 |
+
th_arc,
|
| 832 |
+
file_name,
|
| 833 |
+
)
|
| 834 |
+
return fig_final, curves, display_pts
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
with gr.Blocks(title="PI3DETR") as demo:
|
| 838 |
+
gr.Markdown(
|
| 839 |
+
"# 🥧 PI3DETR: 3D Parametric Curve Inference [CPU-PREVIEW]\n"
|
| 840 |
+
"An end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.\n"
|
| 841 |
+
"Upload a `.xyz`, `.ply`, or `.obj` file to explore curve detection."
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
with gr.Row():
|
| 845 |
+
with gr.Column():
|
| 846 |
+
gr.Markdown(
|
| 847 |
+
"### 🧩 Supported Inputs\n"
|
| 848 |
+
"- **Point Clouds:** `.xyz`, `.ply`; **Meshes:** `.obj`\n"
|
| 849 |
+
"- `Mesh` is surface-sampled using **Max Points (display)** slider."
|
| 850 |
+
)
|
| 851 |
+
with gr.Column():
|
| 852 |
+
gr.Markdown(
|
| 853 |
+
"### ⚙️ Point Cloud Settings\n"
|
| 854 |
+
"- Adjust **Max Points**, **point size**, and **axes visibility**.\n"
|
| 855 |
+
"- Controls visualization of point cloud."
|
| 856 |
+
)
|
| 857 |
+
with gr.Column():
|
| 858 |
+
gr.Markdown(
|
| 859 |
+
"### 🎯 Confidence Thresholds\n"
|
| 860 |
+
"- Hover to inspect scores\n."
|
| 861 |
+
"- Filter curves by **class confidence** interactively"
|
| 862 |
+
)
|
| 863 |
+
with gr.Row():
|
| 864 |
+
with gr.Column():
|
| 865 |
+
gr.Markdown(
|
| 866 |
+
"### 🧠 Model Settings\n"
|
| 867 |
+
"- **Sampling Mode:** Choose downsampling strategy.\n"
|
| 868 |
+
"- **Model Input Size:** Number of model input points.\n"
|
| 869 |
+
"- **Queries:** Transformer decoder queries (max. output curves).\n"
|
| 870 |
+
"- Optional: *Snap&Fit* / *IOU-Filter* post-processing."
|
| 871 |
+
)
|
| 872 |
+
with gr.Column():
|
| 873 |
+
gr.Markdown(
|
| 874 |
+
"### ⚡ Performance Notes\n"
|
| 875 |
+
"- Trained on **human-made objects**.\n"
|
| 876 |
+
"- Optimized for **GPU**; this demo runs on **CPU**.\n"
|
| 877 |
+
"- For full qualitative performance: \n"
|
| 878 |
+
"[GitHub → PI3DETR](https://github.com/fafraob/pi3detr)"
|
| 879 |
+
)
|
| 880 |
+
with gr.Column():
|
| 881 |
+
gr.Markdown(
|
| 882 |
+
"### ▶️ Run Inference\n"
|
| 883 |
+
"- Click on demo point clouds (from test set) below.\n"
|
| 884 |
+
"- Press **Run PI3DETR** to execute inference and visualize results."
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
model_pts_state = gr.State(None)
|
| 888 |
+
display_pts_state = gr.State(None)
|
| 889 |
+
curves_state = gr.State(None)
|
| 890 |
+
file_name_state = gr.State("demo_inputs/demo2.xyz")
|
| 891 |
+
with gr.Row():
|
| 892 |
+
file_in = gr.File(
|
| 893 |
+
label="Upload Point Cloud (auto-renders)",
|
| 894 |
+
file_types=[".xyz", ".ply", ".obj"],
|
| 895 |
+
type="filepath",
|
| 896 |
+
)
|
| 897 |
+
with gr.Row():
|
| 898 |
+
with gr.Column(scale=1):
|
| 899 |
+
gr.Markdown("### Point Cloud Settings")
|
| 900 |
+
max_points = gr.Slider(
|
| 901 |
+
0,
|
| 902 |
+
500_000,
|
| 903 |
+
value=200_000,
|
| 904 |
+
step=1_000,
|
| 905 |
+
label="Max points (display)",
|
| 906 |
+
)
|
| 907 |
+
point_size = gr.Slider(1, 8, value=1, step=1, label="Point size")
|
| 908 |
+
show_axes = gr.Checkbox(value=False, label="Show axes")
|
| 909 |
+
|
| 910 |
+
gr.Markdown("### Model Settings")
|
| 911 |
+
sample_mode = gr.Radio(
|
| 912 |
+
["fps", "random", "all"],
|
| 913 |
+
value="fps",
|
| 914 |
+
label="Main Sampling Method",
|
| 915 |
+
)
|
| 916 |
+
model_max_points = gr.Slider(
|
| 917 |
+
1_000,
|
| 918 |
+
100_000,
|
| 919 |
+
value=32768,
|
| 920 |
+
step=500,
|
| 921 |
+
label="Downsample to Model Input Size",
|
| 922 |
+
)
|
| 923 |
+
num_queries = gr.Slider( # NEW
|
| 924 |
+
32,
|
| 925 |
+
512,
|
| 926 |
+
value=128,
|
| 927 |
+
step=1,
|
| 928 |
+
label="Number of Queries",
|
| 929 |
+
)
|
| 930 |
+
with gr.Row():
|
| 931 |
+
snap_and_fit_chk = gr.Checkbox(value=True, label="Snap&Fit")
|
| 932 |
+
iou_filter_chk = gr.Checkbox(value=False, label="IOU-Filter")
|
| 933 |
+
|
| 934 |
+
# Threshold sliders (no auto-change triggers)
|
| 935 |
+
gr.Markdown("#### Confidence Thresholds (per class)")
|
| 936 |
+
th_bspline = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="BSpline ≥")
|
| 937 |
+
th_line = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Line ≥")
|
| 938 |
+
th_circle = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Circle ≥")
|
| 939 |
+
th_arc = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Arc ≥")
|
| 940 |
+
|
| 941 |
+
with gr.Column(scale=1):
|
| 942 |
+
main_plot = gr.Plot(
|
| 943 |
+
label="Point Cloud & Curves"
|
| 944 |
+
) # height from fig.update_layout(PLOT_HEIGHT)
|
| 945 |
+
|
| 946 |
+
run_model_button = gr.Button("Run PI3DETR", variant="primary")
|
| 947 |
+
clear_curves_button = gr.Button("Clear Curves", variant="secondary")
|
| 948 |
+
|
| 949 |
+
# Auto-render point cloud when file is uploaded
|
| 950 |
+
file_in.change(
|
| 951 |
+
load_and_process_pointcloud,
|
| 952 |
+
inputs=[file_in, max_points, point_size, show_axes],
|
| 953 |
+
outputs=[
|
| 954 |
+
main_plot,
|
| 955 |
+
model_pts_state,
|
| 956 |
+
display_pts_state,
|
| 957 |
+
file_name_state,
|
| 958 |
+
],
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
run_model_button.click(
|
| 962 |
+
run_model_with_display,
|
| 963 |
+
inputs=[
|
| 964 |
+
model_pts_state,
|
| 965 |
+
max_points,
|
| 966 |
+
point_size,
|
| 967 |
+
show_axes,
|
| 968 |
+
model_max_points,
|
| 969 |
+
sample_mode,
|
| 970 |
+
th_bspline,
|
| 971 |
+
th_line,
|
| 972 |
+
th_circle,
|
| 973 |
+
th_arc,
|
| 974 |
+
file_name_state,
|
| 975 |
+
num_queries,
|
| 976 |
+
snap_and_fit_chk,
|
| 977 |
+
iou_filter_chk,
|
| 978 |
+
],
|
| 979 |
+
outputs=[main_plot, curves_state, display_pts_state],
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
# NEW: auto-apply display & thresholds on interaction (no inference)
|
| 983 |
+
def _apply_display_wrapper(
|
| 984 |
+
model_pts,
|
| 985 |
+
curves,
|
| 986 |
+
max_points,
|
| 987 |
+
point_size,
|
| 988 |
+
show_axes,
|
| 989 |
+
th_bspline,
|
| 990 |
+
th_line,
|
| 991 |
+
th_circle,
|
| 992 |
+
th_arc,
|
| 993 |
+
file_name,
|
| 994 |
+
display_pts_state_value,
|
| 995 |
+
):
|
| 996 |
+
fig, display_pts = apply_pointcloud_display_settings(
|
| 997 |
+
model_pts,
|
| 998 |
+
curves,
|
| 999 |
+
max_points,
|
| 1000 |
+
point_size,
|
| 1001 |
+
show_axes,
|
| 1002 |
+
th_bspline,
|
| 1003 |
+
th_line,
|
| 1004 |
+
th_circle,
|
| 1005 |
+
th_arc,
|
| 1006 |
+
file_name,
|
| 1007 |
+
)
|
| 1008 |
+
return fig, display_pts
|
| 1009 |
+
|
| 1010 |
+
# Point cloud sliders (release) & checkbox (change)
|
| 1011 |
+
for slider in [max_points, point_size]:
|
| 1012 |
+
slider.release(
|
| 1013 |
+
_apply_display_wrapper,
|
| 1014 |
+
inputs=[
|
| 1015 |
+
model_pts_state,
|
| 1016 |
+
curves_state,
|
| 1017 |
+
max_points,
|
| 1018 |
+
point_size,
|
| 1019 |
+
show_axes,
|
| 1020 |
+
th_bspline,
|
| 1021 |
+
th_line,
|
| 1022 |
+
th_circle,
|
| 1023 |
+
th_arc,
|
| 1024 |
+
file_name_state,
|
| 1025 |
+
display_pts_state,
|
| 1026 |
+
],
|
| 1027 |
+
outputs=[main_plot, display_pts_state],
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
show_axes.change(
|
| 1031 |
+
_apply_display_wrapper,
|
| 1032 |
+
inputs=[
|
| 1033 |
+
model_pts_state,
|
| 1034 |
+
curves_state,
|
| 1035 |
+
max_points,
|
| 1036 |
+
point_size,
|
| 1037 |
+
show_axes,
|
| 1038 |
+
th_bspline,
|
| 1039 |
+
th_line,
|
| 1040 |
+
th_circle,
|
| 1041 |
+
th_arc,
|
| 1042 |
+
file_name_state,
|
| 1043 |
+
display_pts_state,
|
| 1044 |
+
],
|
| 1045 |
+
outputs=[main_plot, display_pts_state],
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
# Threshold sliders (apply on release)
|
| 1049 |
+
for th in [th_bspline, th_line, th_circle, th_arc]:
|
| 1050 |
+
th.release(
|
| 1051 |
+
_apply_display_wrapper,
|
| 1052 |
+
inputs=[
|
| 1053 |
+
model_pts_state,
|
| 1054 |
+
curves_state,
|
| 1055 |
+
max_points,
|
| 1056 |
+
point_size,
|
| 1057 |
+
show_axes,
|
| 1058 |
+
th_bspline,
|
| 1059 |
+
th_line,
|
| 1060 |
+
th_circle,
|
| 1061 |
+
th_arc,
|
| 1062 |
+
file_name_state,
|
| 1063 |
+
display_pts_state,
|
| 1064 |
+
],
|
| 1065 |
+
outputs=[main_plot, display_pts_state],
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
clear_curves_button.click(
|
| 1069 |
+
clear_curves,
|
| 1070 |
+
inputs=[
|
| 1071 |
+
curves_state,
|
| 1072 |
+
display_pts_state,
|
| 1073 |
+
model_pts_state,
|
| 1074 |
+
point_size,
|
| 1075 |
+
show_axes,
|
| 1076 |
+
file_name_state,
|
| 1077 |
+
],
|
| 1078 |
+
outputs=[main_plot, curves_state],
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
# REPLACED demo preview plots + buttons WITH clickable images
|
| 1082 |
+
with gr.Row():
|
| 1083 |
+
gr.Markdown("### Demo Point Clouds (click an image to load)")
|
| 1084 |
+
with gr.Row():
|
| 1085 |
+
# CLEANUP: generate images dynamically for all demos
|
| 1086 |
+
demo_image_components = {}
|
| 1087 |
+
for label in ["Demo 1", "Demo 2", "Demo 3", "Demo 4", "Demo 5"]: # UPDATED
|
| 1088 |
+
png_path = f"demo_inputs/{label.lower().replace(' ', '')}.png"
|
| 1089 |
+
demo_image_components[label] = gr.Image(
|
| 1090 |
+
value=png_path if os.path.isfile(png_path) else None,
|
| 1091 |
+
label=label,
|
| 1092 |
+
interactive=False,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
# CLEANUP: map labels to loader functions & attach select handlers
|
| 1096 |
+
_demo_loaders = {
|
| 1097 |
+
"Demo 1": load_demo1,
|
| 1098 |
+
"Demo 2": load_demo2,
|
| 1099 |
+
"Demo 3": load_demo3,
|
| 1100 |
+
"Demo 4": load_demo4,
|
| 1101 |
+
"Demo 5": load_demo5, # NEW
|
| 1102 |
+
}
|
| 1103 |
+
for label, comp in demo_image_components.items():
|
| 1104 |
+
comp.select(
|
| 1105 |
+
_demo_loaders[label],
|
| 1106 |
+
inputs=[max_points, point_size, show_axes],
|
| 1107 |
+
outputs=[
|
| 1108 |
+
main_plot,
|
| 1109 |
+
model_pts_state,
|
| 1110 |
+
display_pts_state,
|
| 1111 |
+
file_name_state,
|
| 1112 |
+
curves_state,
|
| 1113 |
+
file_in,
|
| 1114 |
+
],
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
# NEW: auto-load Demo 2 on app start
|
| 1118 |
+
demo.load(
|
| 1119 |
+
load_demo2,
|
| 1120 |
+
inputs=[max_points, point_size, show_axes],
|
| 1121 |
+
outputs=[
|
| 1122 |
+
main_plot,
|
| 1123 |
+
model_pts_state,
|
| 1124 |
+
display_pts_state,
|
| 1125 |
+
file_name_state,
|
| 1126 |
+
curves_state,
|
| 1127 |
+
file_in,
|
| 1128 |
+
],
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
if __name__ == "__main__":
|
| 1133 |
+
# Initialize model at startup
|
| 1134 |
+
initialize_model()
|
| 1135 |
+
demo.launch()
|
configs/pi3detr.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Training parameters
|
| 2 |
+
epochs: 1715
|
| 3 |
+
lr_step: 1250
|
| 4 |
+
lr_warmup_epochs: 15
|
| 5 |
+
lr_warmup_start_factor: 1.0e-6
|
| 6 |
+
lr: 1.0e-4
|
| 7 |
+
batch_size: 8
|
| 8 |
+
batch_size_val: 8
|
| 9 |
+
accumulate_grad_batches: 16
|
| 10 |
+
grad_clip_val: 0.2 # max gradient norm
|
| 11 |
+
to_monitor: "val_seg_iou"
|
| 12 |
+
monitor_mode: "max"
|
| 13 |
+
val_interval: 1
|
| 14 |
+
|
| 15 |
+
### Model parameters
|
| 16 |
+
model: "pi3detr"
|
| 17 |
+
preencoder_type: "samodule"
|
| 18 |
+
num_features: 0
|
| 19 |
+
weights: ""
|
| 20 |
+
preencoder_lr: 1.0e-4
|
| 21 |
+
freeze_backbone: false
|
| 22 |
+
encoder_dim: 768
|
| 23 |
+
decoder_dim: 768
|
| 24 |
+
num_encoder_layers: 3
|
| 25 |
+
num_decoder_layers: 9
|
| 26 |
+
encoder_dropout: 0.1 # dropout in encoder
|
| 27 |
+
decoder_dropout: 0.1 # dropout in decoder
|
| 28 |
+
num_attn_heads: 8 # number of attention heads
|
| 29 |
+
enc_dim_feedforward: 2048 # dimension of feedforward in encoder
|
| 30 |
+
dec_dim_feedforward: 2048 # dimension of feedforward in decoder
|
| 31 |
+
mlp_dropout: 0.0 # dropout in MLP heads
|
| 32 |
+
num_preds: 128 # num outputs of transformer
|
| 33 |
+
num_classes: 5
|
| 34 |
+
auxiliary_loss: true
|
| 35 |
+
max_points_in_param: 4
|
| 36 |
+
num_transformer_points: 2048 # number of transformer points (needed for some preencoders)
|
| 37 |
+
query_type: "point_fps"
|
| 38 |
+
pos_embed_type: "sine" # Options: "fourier", "sine"
|
| 39 |
+
class_loss_type: "cross_entropy"
|
| 40 |
+
class_loss_weights: [0.04834912, 0.40329467, 0.09588135, 0.23071379, 0.22176106]
|
| 41 |
+
|
| 42 |
+
### Curve and validation parameters
|
| 43 |
+
num_curve_points: 64 # must be same as points_per_curve
|
| 44 |
+
num_curve_points_val: 256 # validation curve points
|
| 45 |
+
|
| 46 |
+
### Loss weights
|
| 47 |
+
loss_weights:
|
| 48 |
+
loss_class: 1
|
| 49 |
+
loss_bspline: 1
|
| 50 |
+
loss_bspline_chamfer: 1
|
| 51 |
+
loss_line_position: 1
|
| 52 |
+
loss_line_length: 1
|
| 53 |
+
loss_line_chamfer: 1
|
| 54 |
+
loss_circle_position: 1
|
| 55 |
+
loss_circle_radius: 1
|
| 56 |
+
loss_circle_chamfer: 1
|
| 57 |
+
loss_arc: 1
|
| 58 |
+
loss_arc_chamfer: 1
|
| 59 |
+
loss_seg: 1
|
| 60 |
+
|
| 61 |
+
### Cost weights
|
| 62 |
+
cost_weights:
|
| 63 |
+
cost_class: 1
|
| 64 |
+
cost_curve: 1
|
| 65 |
+
|
| 66 |
+
### Dataset parameters
|
| 67 |
+
dataset: "abc_dataset"
|
| 68 |
+
num_workers: 8
|
| 69 |
+
data_root: "/dataset/train"
|
| 70 |
+
data_val_root: "/dataset/val"
|
| 71 |
+
data_test_root: "/dataset/test"
|
| 72 |
+
augment: true
|
| 73 |
+
random_rotate_prob: 1.0
|
| 74 |
+
random_sample_prob: 0.85
|
| 75 |
+
random_sample_bounds: [1.0, 0.2] # [max, min] fraction of points to keep
|
| 76 |
+
noise_prob: 0.0
|
| 77 |
+
noise_scale: 0.0
|
configs/pi3detr_k256.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Training parameters
|
| 2 |
+
epochs: 1715
|
| 3 |
+
lr_step: 1250
|
| 4 |
+
lr_warmup_epochs: 15
|
| 5 |
+
lr_warmup_start_factor: 1.0e-6
|
| 6 |
+
lr: 1.0e-4
|
| 7 |
+
batch_size: 8
|
| 8 |
+
batch_size_val: 8
|
| 9 |
+
accumulate_grad_batches: 16
|
| 10 |
+
grad_clip_val: 0.2 # max gradient norm
|
| 11 |
+
to_monitor: "val_seg_iou"
|
| 12 |
+
monitor_mode: "max"
|
| 13 |
+
val_interval: 1
|
| 14 |
+
|
| 15 |
+
### Model parameters
|
| 16 |
+
model: "pi3detr"
|
| 17 |
+
preencoder_type: "samodule"
|
| 18 |
+
num_features: 0
|
| 19 |
+
weights: ""
|
| 20 |
+
preencoder_lr: 1.0e-4
|
| 21 |
+
freeze_backbone: false
|
| 22 |
+
encoder_dim: 768
|
| 23 |
+
decoder_dim: 768
|
| 24 |
+
num_encoder_layers: 3
|
| 25 |
+
num_decoder_layers: 9
|
| 26 |
+
encoder_dropout: 0.1 # dropout in encoder
|
| 27 |
+
decoder_dropout: 0.1 # dropout in decoder
|
| 28 |
+
num_attn_heads: 8 # number of attention heads
|
| 29 |
+
enc_dim_feedforward: 2048 # dimension of feedforward in encoder
|
| 30 |
+
dec_dim_feedforward: 2048 # dimension of feedforward in decoder
|
| 31 |
+
mlp_dropout: 0.0 # dropout in MLP heads
|
| 32 |
+
num_preds: 256 # num outputs of transformer
|
| 33 |
+
num_classes: 5
|
| 34 |
+
auxiliary_loss: true
|
| 35 |
+
max_points_in_param: 4
|
| 36 |
+
num_transformer_points: 2048 # number of transformer points (needed for some preencoders)
|
| 37 |
+
query_type: "point_fps"
|
| 38 |
+
pos_embed_type: "sine" # Options: "fourier", "sine"
|
| 39 |
+
class_loss_type: "cross_entropy"
|
| 40 |
+
class_loss_weights: [0.04834912, 0.40329467, 0.09588135, 0.23071379, 0.22176106]
|
| 41 |
+
|
| 42 |
+
### Curve and validation parameters
|
| 43 |
+
num_curve_points: 64 # must be same as points_per_curve
|
| 44 |
+
num_curve_points_val: 256 # validation curve points
|
| 45 |
+
|
| 46 |
+
### Loss weights
|
| 47 |
+
loss_weights:
|
| 48 |
+
loss_class: 1
|
| 49 |
+
loss_bspline: 1
|
| 50 |
+
loss_bspline_chamfer: 1
|
| 51 |
+
loss_line_position: 1
|
| 52 |
+
loss_line_length: 1
|
| 53 |
+
loss_line_chamfer: 1
|
| 54 |
+
loss_circle_position: 1
|
| 55 |
+
loss_circle_radius: 1
|
| 56 |
+
loss_circle_chamfer: 1
|
| 57 |
+
loss_arc: 1
|
| 58 |
+
loss_arc_chamfer: 1
|
| 59 |
+
loss_seg: 1
|
| 60 |
+
|
| 61 |
+
### Cost weights
|
| 62 |
+
cost_weights:
|
| 63 |
+
cost_class: 1
|
| 64 |
+
cost_curve: 1
|
| 65 |
+
|
| 66 |
+
### Dataset parameters
|
| 67 |
+
dataset: "abc_dataset"
|
| 68 |
+
num_workers: 8
|
| 69 |
+
data_root: "/dataset/train"
|
| 70 |
+
data_val_root: "/dataset/val"
|
| 71 |
+
data_test_root: "/dataset/test"
|
| 72 |
+
augment: true
|
| 73 |
+
random_rotate_prob: 1.0
|
| 74 |
+
random_sample_prob: 0.85
|
| 75 |
+
random_sample_bounds: [1.0, 0.2] # [max, min] fraction of points to keep
|
| 76 |
+
noise_prob: 0.0
|
| 77 |
+
noise_scale: 0.0
|
demo_inputs/demo1.png
ADDED
|
demo_inputs/demo1.xyz
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo_inputs/demo2.png
ADDED
|
demo_inputs/demo2.xyz
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo_inputs/demo3.png
ADDED
|
demo_inputs/demo3.xyz
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo_inputs/demo4.png
ADDED
|
demo_inputs/demo4.xyz
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
demo_inputs/demo5.png
ADDED
|
demo_inputs/demo5.xyz
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pi3detr/__init__.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Union, Optional
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch_geometric.data import Dataset
|
| 5 |
+
import inspect
|
| 6 |
+
from .models import ModelConfig
|
| 7 |
+
from .utils import load_args, load_weights
|
| 8 |
+
from .models import PI3DETR
|
| 9 |
+
from .dataset import DatasetConfig, ABCDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_model_config(args: Union[argparse.Namespace, str]) -> ModelConfig:
|
| 13 |
+
if isinstance(args, str):
|
| 14 |
+
args = load_args(args)
|
| 15 |
+
|
| 16 |
+
# Get required parameters from ModelConfig constructor
|
| 17 |
+
model_config_signature = inspect.signature(ModelConfig.__init__)
|
| 18 |
+
required_params = [
|
| 19 |
+
param for param in model_config_signature.parameters.keys() if param != "self"
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
for param in required_params:
|
| 23 |
+
if not hasattr(args, param):
|
| 24 |
+
print(f"ERROR: Parameter '{param}' has to be specified in the arguments")
|
| 25 |
+
raise ValueError(f"Missing required parameter: {param}")
|
| 26 |
+
|
| 27 |
+
# Create model config with all parameters from args
|
| 28 |
+
model_config_args = {param: getattr(args, param) for param in required_params}
|
| 29 |
+
model_config = ModelConfig(**model_config_args)
|
| 30 |
+
|
| 31 |
+
print(model_config)
|
| 32 |
+
return model_config
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_dataset_config(
|
| 36 |
+
args: Union[argparse.Namespace, str], data_root: str, augment: bool
|
| 37 |
+
) -> DatasetConfig:
|
| 38 |
+
if isinstance(args, str):
|
| 39 |
+
args = load_args(args)
|
| 40 |
+
|
| 41 |
+
# Get required parameters from DatasetConfig constructor (excluding root and augment)
|
| 42 |
+
dataset_config_signature = inspect.signature(DatasetConfig.__init__)
|
| 43 |
+
required_params = [
|
| 44 |
+
param
|
| 45 |
+
for param in dataset_config_signature.parameters.keys()
|
| 46 |
+
if param not in ["self", "root", "augment"]
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
for param in required_params:
|
| 50 |
+
if not hasattr(args, param):
|
| 51 |
+
print(f"ERROR: Parameter '{param}' has to be specified in the arguments")
|
| 52 |
+
raise ValueError(f"Missing required parameter: {param}")
|
| 53 |
+
|
| 54 |
+
# Create dataset config with parameters from args plus root and augment
|
| 55 |
+
dataset_config_args = {param: getattr(args, param) for param in required_params}
|
| 56 |
+
dataset_config = DatasetConfig(
|
| 57 |
+
root=data_root, augment=augment, **dataset_config_args
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
print(dataset_config)
|
| 61 |
+
return dataset_config
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def build_dataset(config: DatasetConfig) -> Dataset:
|
| 65 |
+
if config.dataset == "abc_dataset":
|
| 66 |
+
return ABCDataset(config)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unknown dataset {config.dataset}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_model(config: ModelConfig) -> nn.Module:
|
| 72 |
+
print(f"Model: {config.model}")
|
| 73 |
+
if config.model == "pi3detr":
|
| 74 |
+
return PI3DETR(config)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown model {config.model}")
|
pi3detr/dataset/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .abc_dataset import ABCDataset, DatasetConfig
|
| 2 |
+
from .point_cloud_transforms import (
|
| 3 |
+
normalize_and_scale,
|
| 4 |
+
normalize_and_scale_with_params,
|
| 5 |
+
reverse_normalize_and_scale,
|
| 6 |
+
reverse_normalize_and_scale_with_params,
|
| 7 |
+
)
|
pi3detr/dataset/abc_dataset.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Union
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from torch_geometric.data import Dataset
|
| 7 |
+
from torch_geometric.data.data import BaseData
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Callable
|
| 10 |
+
import torch_geometric.transforms as T
|
| 11 |
+
|
| 12 |
+
from .point_cloud_transforms import (
|
| 13 |
+
random_rotate,
|
| 14 |
+
normalize_and_scale,
|
| 15 |
+
add_noise,
|
| 16 |
+
subsample,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class DatasetConfig:
|
| 22 |
+
dataset: str
|
| 23 |
+
root: str
|
| 24 |
+
augment: bool = False
|
| 25 |
+
random_rotate_prob: float = 1
|
| 26 |
+
random_sample_prob: float = 0.5
|
| 27 |
+
random_sample_bounds: tuple[float, float] = (1, 0.5)
|
| 28 |
+
noise_prob: float = 0
|
| 29 |
+
noise_scale: float = 0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ABCDataset(Dataset):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
config: DatasetConfig,
|
| 36 |
+
) -> None:
|
| 37 |
+
self.file_names = self._read_file_names(config.root)
|
| 38 |
+
self.config = config
|
| 39 |
+
super().__init__(
|
| 40 |
+
config.root,
|
| 41 |
+
None,
|
| 42 |
+
None,
|
| 43 |
+
None,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def raw_dir(self) -> str:
|
| 48 |
+
return self.root
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def raw_file_names(self) -> Union[str, list[str], tuple]:
|
| 52 |
+
return self.file_names
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def processed_file_names(self) -> Union[str, list[str], tuple]:
|
| 56 |
+
return [f"{file_name}.pt" for file_name in self.file_names]
|
| 57 |
+
|
| 58 |
+
def process(self) -> None:
|
| 59 |
+
print("Should already be processed.")
|
| 60 |
+
|
| 61 |
+
def get(self, idx: int) -> BaseData:
|
| 62 |
+
|
| 63 |
+
data = torch.load(
|
| 64 |
+
Path(self.processed_dir) / f"{self.raw_file_names[idx]}.pt",
|
| 65 |
+
weights_only=False,
|
| 66 |
+
)
|
| 67 |
+
data["pos"] = data["pos"].to(torch.float32)
|
| 68 |
+
|
| 69 |
+
augment = self.config.augment
|
| 70 |
+
if augment and random.random() < self.config.noise_prob:
|
| 71 |
+
sigma = (
|
| 72 |
+
np.max(
|
| 73 |
+
np.max(data.pos.cpu().numpy(), axis=0)
|
| 74 |
+
- np.min(data.pos.cpu().numpy(), axis=0)
|
| 75 |
+
)
|
| 76 |
+
/ self.config.noise_scale
|
| 77 |
+
)
|
| 78 |
+
noise = torch.tensor(
|
| 79 |
+
np.random.normal(loc=0, scale=sigma, size=data.pos.shape),
|
| 80 |
+
dtype=data.pos.dtype,
|
| 81 |
+
device=data.pos.device,
|
| 82 |
+
)
|
| 83 |
+
data.pos += noise
|
| 84 |
+
|
| 85 |
+
if not hasattr(data, "real_scale") or not hasattr(data, "real_center"):
|
| 86 |
+
data.real_center = torch.zeros(3)
|
| 87 |
+
data.real_scale = torch.tensor(1.0)
|
| 88 |
+
|
| 89 |
+
if augment and random.random() < self.config.random_sample_prob:
|
| 90 |
+
data = subsample(
|
| 91 |
+
data,
|
| 92 |
+
*self.config.random_sample_bounds,
|
| 93 |
+
max_points=None,
|
| 94 |
+
extra_fields=["y_seg", "y_seg_cls"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
extra_fields = [
|
| 98 |
+
"y_curve_64",
|
| 99 |
+
"bspline_params",
|
| 100 |
+
"line_params",
|
| 101 |
+
"circle_params",
|
| 102 |
+
"arc_params",
|
| 103 |
+
]
|
| 104 |
+
if augment and random.random() < self.config.random_rotate_prob:
|
| 105 |
+
data = random_rotate(data, 180, axis=0, extra_fields=extra_fields)
|
| 106 |
+
data = random_rotate(data, 180, axis=1, extra_fields=extra_fields)
|
| 107 |
+
data = random_rotate(data, 180, axis=2, extra_fields=extra_fields)
|
| 108 |
+
|
| 109 |
+
line_direction = data.line_params[:, 1]
|
| 110 |
+
circle_normal = data.circle_params[:, 1]
|
| 111 |
+
|
| 112 |
+
data = normalize_and_scale(
|
| 113 |
+
data,
|
| 114 |
+
extra_fields=extra_fields,
|
| 115 |
+
)
|
| 116 |
+
# normal vecotrs shouldn't change
|
| 117 |
+
data.line_params[:, 1] = line_direction
|
| 118 |
+
data.circle_params[:, 1] = circle_normal
|
| 119 |
+
# manually adjust length and radius
|
| 120 |
+
data.line_length = data.line_length * data.scale
|
| 121 |
+
data.circle_radius = data.circle_radius * data.scale
|
| 122 |
+
|
| 123 |
+
data.y_params = torch.zeros(data.num_curves, 12, dtype=torch.float32)
|
| 124 |
+
for i in range(data.num_curves):
|
| 125 |
+
if data.y_cls[i] == 1:
|
| 126 |
+
# B-spline
|
| 127 |
+
# P0, P1, P2, P3
|
| 128 |
+
data.y_params[i][:12] = data.bspline_params[i].reshape(-1)
|
| 129 |
+
elif data.y_cls[i] == 2:
|
| 130 |
+
# Line
|
| 131 |
+
# midpoint, normal, length
|
| 132 |
+
data.y_params[i][:3] = data.line_params[i][0].reshape(-1)
|
| 133 |
+
data.y_params[i][3:6] = line_direction[i].reshape(-1)
|
| 134 |
+
data.y_params[i][6] = data.line_length[i] # already adjusted above
|
| 135 |
+
elif data.y_cls[i] == 3:
|
| 136 |
+
# Circle
|
| 137 |
+
# center, normal, radius
|
| 138 |
+
data.y_params[i][:3] = data.circle_params[i][0].reshape(-1)
|
| 139 |
+
data.y_params[i][3:6] = circle_normal[i].reshape(-1)
|
| 140 |
+
data.y_params[i][6] = data.circle_radius[i] # already adjusted above
|
| 141 |
+
elif data.y_cls[i] == 4:
|
| 142 |
+
# Arc
|
| 143 |
+
# midpoint, start, end
|
| 144 |
+
data.y_params[i][:9] = data.arc_params[i].reshape(-1)
|
| 145 |
+
data.filename = self.raw_file_names[idx]
|
| 146 |
+
|
| 147 |
+
return data
|
| 148 |
+
|
| 149 |
+
def len(self) -> int:
|
| 150 |
+
return len(self.processed_file_names)
|
| 151 |
+
|
| 152 |
+
def _read_file_names(self, root: str) -> list[Path]:
|
| 153 |
+
return sorted(
|
| 154 |
+
[
|
| 155 |
+
fp.stem
|
| 156 |
+
for fp in Path(root).joinpath("processed").glob(f"*.pt")
|
| 157 |
+
if "pre_" not in fp.stem
|
| 158 |
+
]
|
| 159 |
+
)
|
pi3detr/dataset/point_cloud_transforms.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import torch_geometric.transforms as T
|
| 7 |
+
from torch_geometric.data import Data
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def subsample(
|
| 11 |
+
data: Data,
|
| 12 |
+
upper_bound: float = 1.0,
|
| 13 |
+
lower_bound: float = 0.5,
|
| 14 |
+
max_points: Optional[int] = None,
|
| 15 |
+
extra_fields: list[str] = [],
|
| 16 |
+
) -> Data:
|
| 17 |
+
r"""Subsamples the point cloud to a random number of points within the
|
| 18 |
+
range :obj:`[lower_bound, upper_bound]` (functional name: :obj:`subsample`).
|
| 19 |
+
"""
|
| 20 |
+
if data.pos.size(0) == 0:
|
| 21 |
+
return data
|
| 22 |
+
num_points = int(random.uniform(lower_bound, upper_bound) * data.pos.size(0))
|
| 23 |
+
if max_points is not None:
|
| 24 |
+
num_points = min(num_points, max_points)
|
| 25 |
+
idx = torch.randperm(data.pos.size(0))[:num_points]
|
| 26 |
+
data.pos = data.pos[idx]
|
| 27 |
+
for field in extra_fields:
|
| 28 |
+
if hasattr(data, field):
|
| 29 |
+
setattr(data, field, getattr(data, field)[idx])
|
| 30 |
+
return data
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def numpy_normalize_and_scale(xyz: np.ndarray) -> tuple[np.ndarray, float, float]:
|
| 34 |
+
r"""Normalizes the point cloud in such a way that the points are centered
|
| 35 |
+
around the origin and are within the interval :math:`[-1, 1]` (functional
|
| 36 |
+
name: :obj:`normalize`).
|
| 37 |
+
"""
|
| 38 |
+
center = xyz.mean(0)
|
| 39 |
+
scale = (1 / np.max(np.abs(xyz - center))) * 0.999999
|
| 40 |
+
xyz = numpy_normalize_and_scale_with_params(xyz, center, scale)
|
| 41 |
+
return xyz, center, scale
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def numpy_normalize_and_scale_with_params(
|
| 45 |
+
xyz: np.ndarray, center: np.ndarray, scale: float
|
| 46 |
+
) -> np.ndarray:
|
| 47 |
+
r"""Normalizes the point cloud in such a way that the points are centered
|
| 48 |
+
around the origin and are within the interval :math:`[-1, 1]` (functional
|
| 49 |
+
name: :obj:`normalize`).
|
| 50 |
+
"""
|
| 51 |
+
if xyz.size == 0:
|
| 52 |
+
return xyz
|
| 53 |
+
shape = xyz.shape
|
| 54 |
+
return ((xyz.reshape(-1, shape[-1]) - center) * scale).reshape(shape)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def normalize_and_scale(data: Data, extra_fields: list[str] = []) -> Data:
|
| 58 |
+
r"""Centers and normalizes the given fields to the interval :math:`[-1, 1]`
|
| 59 |
+
(functional name: :obj:`normalize_scale`).
|
| 60 |
+
"""
|
| 61 |
+
if data.pos.size(0) == 0:
|
| 62 |
+
data.center = torch.empty(0)
|
| 63 |
+
data.scale = torch.empty(0)
|
| 64 |
+
return data
|
| 65 |
+
# center the pos points
|
| 66 |
+
center = data.pos.mean(dim=-2, keepdim=True)
|
| 67 |
+
# scale the pos points
|
| 68 |
+
scale = (1 / (data.pos - center).abs().max()) * 0.999999
|
| 69 |
+
|
| 70 |
+
return normalize_and_scale_with_params(data, center, scale, extra_fields)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def reverse_normalize_and_scale(data: Data, extra_fields: list[str] = []) -> Data:
|
| 74 |
+
r"""Reverses the centering and normalization of the given fields
|
| 75 |
+
(functional name: :obj:`reverse_normalize_scale`).
|
| 76 |
+
"""
|
| 77 |
+
assert hasattr(data, "center") and hasattr(
|
| 78 |
+
data, "scale"
|
| 79 |
+
), "Data object does not contain the center and scale attributes."
|
| 80 |
+
return reverse_normalize_and_scale_with_params(
|
| 81 |
+
data, data.center, data.scale, extra_fields
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def normalize_and_scale_with_params(
|
| 86 |
+
data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
|
| 87 |
+
) -> Data:
|
| 88 |
+
if data.pos.size(0) == 0:
|
| 89 |
+
data.center = torch.empty(0)
|
| 90 |
+
data.scale = torch.empty(0)
|
| 91 |
+
return data
|
| 92 |
+
data.pos = (data.pos - center) * scale
|
| 93 |
+
for field in extra_fields:
|
| 94 |
+
if hasattr(data, field):
|
| 95 |
+
shape = getattr(data, field).size()
|
| 96 |
+
setattr(
|
| 97 |
+
data,
|
| 98 |
+
field,
|
| 99 |
+
(getattr(data, field).reshape(-1, shape[-1]) - center) * scale,
|
| 100 |
+
)
|
| 101 |
+
setattr(data, field, getattr(data, field).reshape(shape))
|
| 102 |
+
data.center = center
|
| 103 |
+
data.scale = scale
|
| 104 |
+
return data
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def reverse_normalize_and_scale_with_params(
|
| 108 |
+
data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
|
| 109 |
+
) -> Data:
|
| 110 |
+
r"""Reverses the centering and normalization of the given fields
|
| 111 |
+
(functional name: :obj:`reverse_normalize_scale`).
|
| 112 |
+
"""
|
| 113 |
+
# Reverse the scaling and centering of the pos points
|
| 114 |
+
data.pos = data.pos / scale + center
|
| 115 |
+
|
| 116 |
+
for field in extra_fields:
|
| 117 |
+
if hasattr(data, field):
|
| 118 |
+
shape = getattr(data, field).size()
|
| 119 |
+
setattr(
|
| 120 |
+
data,
|
| 121 |
+
field,
|
| 122 |
+
(getattr(data, field).reshape(-1, shape[-1]) / scale) + center,
|
| 123 |
+
)
|
| 124 |
+
setattr(data, field, getattr(data, field).reshape(shape))
|
| 125 |
+
data.center = torch.empty(0)
|
| 126 |
+
data.scale = torch.empty(0)
|
| 127 |
+
return data
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def reverse_normalize_and_scale_with_params(
|
| 131 |
+
data: Data, center: torch.Tensor, scale: torch.Tensor, extra_fields: list[str] = []
|
| 132 |
+
) -> Data:
|
| 133 |
+
r"""Reverses the centering and normalization of the given fields
|
| 134 |
+
(functional name: :obj:`reverse_normalize_scale`).
|
| 135 |
+
"""
|
| 136 |
+
# Reverse the scaling and centering of the pos points
|
| 137 |
+
data.pos = data.pos / scale + center
|
| 138 |
+
|
| 139 |
+
for field in extra_fields:
|
| 140 |
+
if hasattr(data, field):
|
| 141 |
+
shape = getattr(data, field).size()
|
| 142 |
+
setattr(
|
| 143 |
+
data,
|
| 144 |
+
field,
|
| 145 |
+
(getattr(data, field).reshape(-1, shape[-1]) / scale) + center,
|
| 146 |
+
)
|
| 147 |
+
setattr(data, field, getattr(data, field).reshape(shape))
|
| 148 |
+
data.center = torch.empty(0)
|
| 149 |
+
data.scale = torch.empty(0)
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def random_rotate(
|
| 154 |
+
data: Data, degrees: float, axis: int, extra_fields: list[str] = []
|
| 155 |
+
) -> Data:
|
| 156 |
+
r"""Rotates the object around the origin by a random angle within the
|
| 157 |
+
range :obj:`[-degrees, degrees]` (functional name: :obj:`random_rotate
|
| 158 |
+
`).
|
| 159 |
+
"""
|
| 160 |
+
if data.pos.size(0) == 0:
|
| 161 |
+
return data
|
| 162 |
+
return rotate_with_params(
|
| 163 |
+
data, random.uniform(-degrees, degrees), axis, extra_fields
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def rotate_with_params(
|
| 168 |
+
data: Data, degrees: float, axis: int = 0, extra_fields: list[str] = []
|
| 169 |
+
) -> Data:
|
| 170 |
+
r"""Rotates the object around the origin by a given angle
|
| 171 |
+
(functional name: :obj:`rotate`).
|
| 172 |
+
"""
|
| 173 |
+
angle = math.pi * degrees / 180.0
|
| 174 |
+
if data.pos.size(0) == 0:
|
| 175 |
+
return data
|
| 176 |
+
sin, cos = math.sin(angle), math.cos(angle)
|
| 177 |
+
if data.pos.size(-1) == 2:
|
| 178 |
+
matrix = torch.tensor([[cos, sin], [-sin, cos]])
|
| 179 |
+
else:
|
| 180 |
+
if axis == 0:
|
| 181 |
+
matrix = torch.tensor([[1, 0, 0], [0, cos, sin], [0, -sin, cos]])
|
| 182 |
+
elif axis == 1:
|
| 183 |
+
matrix = torch.tensor([[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]])
|
| 184 |
+
else:
|
| 185 |
+
matrix = torch.tensor([[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]])
|
| 186 |
+
|
| 187 |
+
matrix_dtype = matrix.to(data.pos.dtype)
|
| 188 |
+
matrix = matrix.to(matrix_dtype)
|
| 189 |
+
|
| 190 |
+
data.pos = data.pos @ matrix.t()
|
| 191 |
+
for field in extra_fields:
|
| 192 |
+
if hasattr(data, field):
|
| 193 |
+
shape = getattr(data, field).size()
|
| 194 |
+
# get dtype of field
|
| 195 |
+
dtype = getattr(data, field).dtype
|
| 196 |
+
|
| 197 |
+
matrix_dtype = matrix.to(dtype)
|
| 198 |
+
setattr(data, field, getattr(data, field) @ matrix_dtype.t())
|
| 199 |
+
setattr(data, field, getattr(data, field).reshape(shape))
|
| 200 |
+
setattr(data, f"rotated_{axis}", degrees)
|
| 201 |
+
return data
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def reverse_rotate(data: Data, axis: int = 0, extra_fields: list[str] = []) -> Data:
|
| 205 |
+
r"""Reverses the rotation of the object around the origin
|
| 206 |
+
(functional name: :obj:`reverse_rotate`).
|
| 207 |
+
"""
|
| 208 |
+
if not hasattr(data, f"rotated_{axis}"):
|
| 209 |
+
return data
|
| 210 |
+
return rotate_with_params(
|
| 211 |
+
data, -getattr(data, f"rotated_{axis}"), axis, extra_fields
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def add_noise(data: Data, std: float) -> Data:
|
| 216 |
+
r"""Adds Gaussian noise to the node features (functional name:
|
| 217 |
+
:obj:`add_noise`).
|
| 218 |
+
"""
|
| 219 |
+
if data.pos.size(0) == 0:
|
| 220 |
+
return data
|
| 221 |
+
noise = torch.randn_like(data.pos) * std
|
| 222 |
+
data.pos = data.pos + noise
|
| 223 |
+
data.noise = noise
|
| 224 |
+
return data
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def remove_noise(data: Data) -> Data:
|
| 228 |
+
r"""Removes the noise from the node features (functional name:
|
| 229 |
+
:obj:`remove_noise`).
|
| 230 |
+
"""
|
| 231 |
+
assert hasattr(data, "noise"), "Data object does not contain the noise attribute."
|
| 232 |
+
data.pos = data.pos - data.noise
|
| 233 |
+
del data.noise
|
| 234 |
+
return data
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def custom_normalize_and_scale(
|
| 238 |
+
data: Data, p1: torch.Tensor, p2: torch.Tensor, extra_fields: list[str] = []
|
| 239 |
+
) -> Data:
|
| 240 |
+
r"""Normalizes the point cloud in such a way that after the transformation
|
| 241 |
+
p1 is at (0,0,0) and p2 at (1,1,1)` (functional name:
|
| 242 |
+
:obj:`normalize`).
|
| 243 |
+
"""
|
| 244 |
+
assert p1.size() == p2.size() == (3,), "Invalid interval."
|
| 245 |
+
if data.pos.size(0) == 0:
|
| 246 |
+
return data
|
| 247 |
+
data.pos = (data.pos - p1) / (p2 - p1)
|
| 248 |
+
for field in extra_fields:
|
| 249 |
+
if hasattr(data, field):
|
| 250 |
+
shape = getattr(data, field).size()
|
| 251 |
+
setattr(
|
| 252 |
+
data,
|
| 253 |
+
field,
|
| 254 |
+
(getattr(data, field).reshape(-1, shape[-1]) - p1) / (p2 - p1),
|
| 255 |
+
)
|
| 256 |
+
setattr(data, field, getattr(data, field).reshape(shape))
|
| 257 |
+
data.p1 = p1
|
| 258 |
+
data.p2 = p2
|
| 259 |
+
return data
|
pi3detr/dataset/utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
import open3d as o3d
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_xyz_file(file_path: str, column_idxs: list[int] = [0, 1, 2]) -> np.ndarray:
|
| 7 |
+
"""Reads a point cloud from a .xyz file.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
file_path (str): Path to the .xyz file.
|
| 11 |
+
column_idxs (list[int], optional): Indices of the columns to read. Defaults to [0,1,2].
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
np.ndarray: Point cloud as a numpy array.
|
| 15 |
+
"""
|
| 16 |
+
return np.loadtxt(file_path, usecols=column_idxs)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_curve_file(file_path: str) -> tuple[np.ndarray]:
|
| 20 |
+
with open(file_path, "r") as f:
|
| 21 |
+
data = json.load(f)
|
| 22 |
+
return np.array(data["linear"]), np.array(data["bezier"])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_polyline_file(file_path: str, sep: str = ",") -> np.ndarray:
|
| 26 |
+
polylines = []
|
| 27 |
+
with open(file_path, "r") as f:
|
| 28 |
+
polyline = []
|
| 29 |
+
for line in f:
|
| 30 |
+
if line == "\n":
|
| 31 |
+
polylines.append(polyline)
|
| 32 |
+
polyline = []
|
| 33 |
+
else:
|
| 34 |
+
point = [float(x) for x in line.split(sep)]
|
| 35 |
+
polyline.append(point)
|
| 36 |
+
if polyline:
|
| 37 |
+
polylines.append(polyline)
|
| 38 |
+
return np.array(polylines)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def voxel_down_sample(xyz: np.ndarray, voxel_size: float) -> np.ndarray:
|
| 42 |
+
"""Downsamples a point cloud using voxel grid downsampling.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
xyz (np.ndarray): Point cloud.
|
| 46 |
+
voxel_size (float): Voxel size.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
np.ndarray: Downsampled point cloud.
|
| 50 |
+
"""
|
| 51 |
+
pcd = o3d.geometry.PointCloud()
|
| 52 |
+
pcd.points = o3d.utility.Vector3dVector(xyz)
|
| 53 |
+
downpcd = pcd.voxel_down_sample(voxel_size=voxel_size)
|
| 54 |
+
return np.asarray(downpcd.points)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def filter_normals(
|
| 58 |
+
xyz: np.ndarray, radius: float, max_nn: float, threshold: float
|
| 59 |
+
) -> np.ndarray:
|
| 60 |
+
"""Filters normals of a point cloud.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
xyz (np.ndarray): Point cloud.
|
| 64 |
+
threshold (float, optional): Threshold for filtering normals.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
np.ndarray: Filtered point cloud.
|
| 68 |
+
"""
|
| 69 |
+
pcd = o3d.geometry.PointCloud()
|
| 70 |
+
pcd.points = o3d.utility.Vector3dVector(xyz)
|
| 71 |
+
pcd.estimate_normals(
|
| 72 |
+
search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=radius, max_nn=max_nn)
|
| 73 |
+
)
|
| 74 |
+
new_pts = np.asarray(pcd.points)[np.abs(np.asarray(pcd.normals)[:, 2]) < threshold]
|
| 75 |
+
return new_pts
|
pi3detr/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .abc_metrics import ChamferMAP, ChamferIntervalMetric
|
pi3detr/evaluation/abc_metrics.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torchmetrics import Metric
|
| 4 |
+
from torchmetrics.functional import average_precision
|
| 5 |
+
from scipy.spatial import KDTree
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def calc_chamfer_distance(pred_points, gt_points):
|
| 9 |
+
"""Calculate chamfer distance and bidirectional hausdorff distance."""
|
| 10 |
+
if len(pred_points) == 0 or len(gt_points) == 0:
|
| 11 |
+
return float("inf"), float("inf")
|
| 12 |
+
|
| 13 |
+
tree_pred = KDTree(pred_points)
|
| 14 |
+
tree_gt = KDTree(gt_points)
|
| 15 |
+
|
| 16 |
+
dist_pred2gt, _ = tree_gt.query(pred_points)
|
| 17 |
+
dist_gt2pred, _ = tree_pred.query(gt_points)
|
| 18 |
+
|
| 19 |
+
chamfer_dist = np.mean(dist_pred2gt**2) + np.mean(dist_gt2pred**2)
|
| 20 |
+
bhaussdorf_dist = (dist_pred2gt.max() + dist_gt2pred.max()) / 2
|
| 21 |
+
|
| 22 |
+
return chamfer_dist, bhaussdorf_dist
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ChamferMAP(Metric):
|
| 26 |
+
def __init__(self, chamfer_thresh=0.05, dist_sync_on_step=False):
|
| 27 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
| 28 |
+
self.chamfer_thresh = chamfer_thresh
|
| 29 |
+
self.class_names = {
|
| 30 |
+
1: "mAP_bspline",
|
| 31 |
+
2: "mAP_line",
|
| 32 |
+
3: "mAP_circle",
|
| 33 |
+
4: "mAP_arc",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
self.add_state("all_scores", default=[], dist_reduce_fx="cat")
|
| 37 |
+
self.add_state("all_matches", default=[], dist_reduce_fx="cat")
|
| 38 |
+
self.add_state("all_classes", default=[], dist_reduce_fx="cat")
|
| 39 |
+
|
| 40 |
+
def pairwise_chamfer_distance_batch(self, pred, gt):
|
| 41 |
+
"""
|
| 42 |
+
pred: [P, 64, 3]
|
| 43 |
+
gt: [G, 64, 3]
|
| 44 |
+
returns: [P, G] chamfer distances
|
| 45 |
+
"""
|
| 46 |
+
P, G = pred.size(0), gt.size(0)
|
| 47 |
+
|
| 48 |
+
# Reshape for pairwise comparison
|
| 49 |
+
pred_exp = pred.unsqueeze(1) # [P, 1, 64, 3]
|
| 50 |
+
gt_exp = gt.unsqueeze(0) # [1, G, 64, 3]
|
| 51 |
+
|
| 52 |
+
# Compute pairwise distances between points
|
| 53 |
+
dists = torch.cdist(pred_exp, gt_exp, p=2) # [P, G, 64, 64]
|
| 54 |
+
|
| 55 |
+
a2b = dists.min(dim=3).values.mean(dim=2) # [P, G]
|
| 56 |
+
b2a = dists.min(dim=2).values.mean(dim=2) # [P, G]
|
| 57 |
+
|
| 58 |
+
return a2b + b2a # [P, G]
|
| 59 |
+
|
| 60 |
+
def update(self, outputs, batch):
|
| 61 |
+
B = outputs["pred_class"].shape[0]
|
| 62 |
+
y_curves = batch.y_curve_64 # [total_gt, 64, 3]
|
| 63 |
+
y_cls = batch.y_cls # [total_gt]
|
| 64 |
+
num_curves_per_batch = batch.num_curves.tolist()
|
| 65 |
+
|
| 66 |
+
gt_splits = torch.split(y_curves, num_curves_per_batch, dim=0)
|
| 67 |
+
cls_splits = torch.split(y_cls, num_curves_per_batch, dim=0)
|
| 68 |
+
|
| 69 |
+
pred_classes_all = outputs["pred_class"].softmax(dim=-1) # [B, N, C]
|
| 70 |
+
for b in range(B):
|
| 71 |
+
pred_classes = pred_classes_all[b] # [N, C]
|
| 72 |
+
|
| 73 |
+
preds_all = {
|
| 74 |
+
1: outputs["pred_bspline_points"][b], # [N, 64, 3]
|
| 75 |
+
2: outputs["pred_line_points"][b],
|
| 76 |
+
3: outputs["pred_circle_points"][b],
|
| 77 |
+
4: outputs["pred_arc_points"][b],
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
for cls in self.class_names.keys():
|
| 81 |
+
pred_points = preds_all[cls] # [P, 64, 3]
|
| 82 |
+
scores = pred_classes[:, cls] # [P]
|
| 83 |
+
|
| 84 |
+
gt_points = gt_splits[b][cls_splits[b] == cls] # [G, 64, 3]
|
| 85 |
+
if gt_points.size(0) == 0:
|
| 86 |
+
self.all_scores.append(scores)
|
| 87 |
+
self.all_matches.append(torch.zeros_like(scores))
|
| 88 |
+
self.all_classes.append(torch.full_like(scores, cls))
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
chamfer = self.pairwise_chamfer_distance_batch(pred_points, gt_points)
|
| 92 |
+
used_gt = torch.zeros(
|
| 93 |
+
gt_points.size(0), dtype=torch.bool, device=pred_points.device
|
| 94 |
+
)
|
| 95 |
+
matches = torch.zeros(pred_points.size(0), device=pred_points.device)
|
| 96 |
+
|
| 97 |
+
for i in range(pred_points.size(0)):
|
| 98 |
+
dists = chamfer[i]
|
| 99 |
+
min_dist, min_idx = dists.min(0)
|
| 100 |
+
if min_dist < self.chamfer_thresh and not used_gt[min_idx]:
|
| 101 |
+
matches[i] = 1.0
|
| 102 |
+
used_gt[min_idx] = True
|
| 103 |
+
|
| 104 |
+
self.all_scores.append(scores)
|
| 105 |
+
self.all_matches.append(matches)
|
| 106 |
+
self.all_classes.append(torch.full_like(matches, cls))
|
| 107 |
+
|
| 108 |
+
def compute(self):
|
| 109 |
+
if not self.all_scores:
|
| 110 |
+
return {cls_name: 0.0 for cls_name in self.class_names.values()}
|
| 111 |
+
|
| 112 |
+
scores = torch.cat(self.all_scores)
|
| 113 |
+
matches = torch.cat(self.all_matches)
|
| 114 |
+
classes = torch.cat(self.all_classes)
|
| 115 |
+
|
| 116 |
+
result = {}
|
| 117 |
+
ap_values = []
|
| 118 |
+
|
| 119 |
+
for cls in self.class_names.keys():
|
| 120 |
+
mask = classes == cls
|
| 121 |
+
if mask.sum() == 0 or torch.sum(matches[mask]) == 0:
|
| 122 |
+
ap = torch.tensor(0.0, device=self.device)
|
| 123 |
+
else:
|
| 124 |
+
ap = average_precision(
|
| 125 |
+
scores[mask], matches[mask].to(torch.int32), task="binary"
|
| 126 |
+
)
|
| 127 |
+
result[self.class_names[cls]] = ap.item()
|
| 128 |
+
ap_values.append(ap)
|
| 129 |
+
|
| 130 |
+
result["mAP"] = torch.stack(ap_values).mean().item()
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class ChamferIntervalMetric(Metric):
|
| 135 |
+
|
| 136 |
+
def __init__(self, interval=0.01, map_cd_thresh=0.005, dist_sync_on_step=False):
|
| 137 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
| 138 |
+
self.interval = interval
|
| 139 |
+
self.add_state("total_cd", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 140 |
+
self.add_state("total_cd_sq", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 141 |
+
self.add_state("total_bhd", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 142 |
+
self.add_state("total_bhd_sq", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 143 |
+
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 144 |
+
self.add_state("valid_count", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 145 |
+
|
| 146 |
+
self.map_cd_thresh = map_cd_thresh
|
| 147 |
+
self.map_cls_names = {
|
| 148 |
+
1: "mAP_bspline",
|
| 149 |
+
2: "mAP_line",
|
| 150 |
+
3: "mAP_circle",
|
| 151 |
+
4: "mAP_arc",
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
self.add_state("map_all_scores", default=[], dist_reduce_fx="cat")
|
| 155 |
+
self.add_state("map_all_matches", default=[], dist_reduce_fx="cat")
|
| 156 |
+
self.add_state("map_all_classes", default=[], dist_reduce_fx="cat")
|
| 157 |
+
|
| 158 |
+
def sample_curve_by_interval(self, points, interval, force_last=False):
|
| 159 |
+
"""Sample points along the curve at fixed length `interval`."""
|
| 160 |
+
if len(points) < 2:
|
| 161 |
+
return points
|
| 162 |
+
|
| 163 |
+
edges = np.array([[j, j + 1] for j in range(len(points) - 1)])
|
| 164 |
+
edge_lengths = np.linalg.norm(points[edges[:, 1]] - points[edges[:, 0]], axis=1)
|
| 165 |
+
|
| 166 |
+
samples = [points[0]]
|
| 167 |
+
distance_accum = 0.0
|
| 168 |
+
next_sample_dist = interval
|
| 169 |
+
edge_index = 0
|
| 170 |
+
|
| 171 |
+
while edge_index < len(edges):
|
| 172 |
+
p0 = points[edges[edge_index, 0]]
|
| 173 |
+
p1 = points[edges[edge_index, 1]]
|
| 174 |
+
edge_vec = p1 - p0
|
| 175 |
+
edge_len = np.linalg.norm(edge_vec)
|
| 176 |
+
if edge_len == 0:
|
| 177 |
+
edge_index += 1
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
while distance_accum + edge_len >= next_sample_dist:
|
| 181 |
+
t = (next_sample_dist - distance_accum) / edge_len
|
| 182 |
+
sample = p0 + t * edge_vec
|
| 183 |
+
samples.append(sample)
|
| 184 |
+
next_sample_dist += interval
|
| 185 |
+
|
| 186 |
+
distance_accum += edge_len
|
| 187 |
+
edge_index += 1
|
| 188 |
+
|
| 189 |
+
if force_last and not np.allclose(samples[-1], points[-1]):
|
| 190 |
+
samples.append(points[-1])
|
| 191 |
+
|
| 192 |
+
return np.array(samples)
|
| 193 |
+
|
| 194 |
+
def update(self, data, batch):
|
| 195 |
+
|
| 196 |
+
# Get ground truth curves
|
| 197 |
+
y_curves = batch.y_curve_64.cpu().numpy() # [total_gt, 64, 3]
|
| 198 |
+
num_curves_per_batch = batch.num_curves.tolist()
|
| 199 |
+
|
| 200 |
+
# Since batch size is 1 in your case
|
| 201 |
+
B = 1
|
| 202 |
+
|
| 203 |
+
# Sample ground truth curves
|
| 204 |
+
gt_points_list = []
|
| 205 |
+
gt_cls_list = []
|
| 206 |
+
for i, gt_curve in enumerate(y_curves):
|
| 207 |
+
if len(gt_curve) < 2 or np.any(np.isnan(gt_curve)):
|
| 208 |
+
continue
|
| 209 |
+
sampled_gt = self.sample_curve_by_interval(
|
| 210 |
+
gt_curve, self.interval, force_last=True
|
| 211 |
+
)
|
| 212 |
+
if len(sampled_gt) > 0 and np.all(np.isfinite(sampled_gt)):
|
| 213 |
+
gt_points_list.append(sampled_gt)
|
| 214 |
+
gt_cls_list.append(batch.y_cls[i].cpu().item())
|
| 215 |
+
|
| 216 |
+
# Sample predicted curves from post-processed data
|
| 217 |
+
pred_points_list = []
|
| 218 |
+
pred_cls_list = []
|
| 219 |
+
pred_score_list = []
|
| 220 |
+
for polyline, cls in zip(
|
| 221 |
+
data.polylines.cpu().numpy(), data.polyline_class.cpu().numpy()
|
| 222 |
+
):
|
| 223 |
+
if cls == 0: # Skip background class
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
if len(polyline) < 2 or np.any(np.isnan(polyline)):
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
sampled_pred = self.sample_curve_by_interval(
|
| 230 |
+
polyline, self.interval, force_last=True
|
| 231 |
+
)
|
| 232 |
+
if len(sampled_pred) > 0 and np.all(np.isfinite(sampled_pred)):
|
| 233 |
+
pred_points_list.append(sampled_pred)
|
| 234 |
+
pred_cls_list.append(int(cls))
|
| 235 |
+
pred_score_list.append(data.polyline_score[i].cpu().item())
|
| 236 |
+
|
| 237 |
+
if len(gt_points_list) == 0 and len(pred_points_list) == 0:
|
| 238 |
+
# No ground truth and no predictions, no penalty
|
| 239 |
+
self.count += 1
|
| 240 |
+
return
|
| 241 |
+
elif len(gt_points_list) == 0:
|
| 242 |
+
# Penalize no ground truth
|
| 243 |
+
self.count += 1
|
| 244 |
+
scores = torch.tensor(pred_score_list)
|
| 245 |
+
self.map_all_scores.append(scores)
|
| 246 |
+
self.map_all_matches.append(torch.zeros_like(scores))
|
| 247 |
+
self.map_all_classes.append(torch.tensor(pred_cls_list))
|
| 248 |
+
return
|
| 249 |
+
elif len(pred_points_list) == 0:
|
| 250 |
+
# Penalize no predictions
|
| 251 |
+
self.count += 1
|
| 252 |
+
cls_list = torch.tensor(gt_cls_list, dtype=torch.float32)
|
| 253 |
+
self.map_all_scores.append(torch.zeros_like(cls_list))
|
| 254 |
+
self.map_all_matches.append(torch.ones_like(cls_list))
|
| 255 |
+
self.map_all_classes.append(torch.tensor(cls_list))
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
# calculate mAP
|
| 259 |
+
for cls in self.map_cls_names.keys():
|
| 260 |
+
mask = torch.tensor(pred_cls_list) == cls
|
| 261 |
+
pred_curves = [curve for i, curve in enumerate(pred_points_list) if mask[i]]
|
| 262 |
+
pred_scores = torch.tensor(pred_score_list)[mask]
|
| 263 |
+
gt_curves = [
|
| 264 |
+
curve for i, curve in enumerate(gt_points_list) if cls == gt_cls_list[i]
|
| 265 |
+
]
|
| 266 |
+
if len(pred_curves) == 0 and len(gt_curves) != 0:
|
| 267 |
+
scores = torch.zeros(len(gt_curves))
|
| 268 |
+
self.map_all_scores.append(scores)
|
| 269 |
+
self.map_all_matches.append(torch.zeros_like(scores))
|
| 270 |
+
self.map_all_classes.append(torch.full_like(scores, cls))
|
| 271 |
+
continue
|
| 272 |
+
if len(gt_curves) == 0:
|
| 273 |
+
self.map_all_scores.append(pred_scores)
|
| 274 |
+
self.map_all_matches.append(torch.zeros_like(pred_scores))
|
| 275 |
+
self.map_all_classes.append(torch.full_like(pred_scores, cls))
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
# get [P, G] matrix of chamfer distances
|
| 279 |
+
cd_matrix = torch.ones((len(pred_curves), len(gt_curves))) * float("inf")
|
| 280 |
+
for i, pred_curve in enumerate(pred_curves):
|
| 281 |
+
for j, gt_curve in enumerate(gt_curves):
|
| 282 |
+
cd_matrix[i, j] = calc_chamfer_distance(pred_curve, gt_curve)[0]
|
| 283 |
+
|
| 284 |
+
used_gt = set()
|
| 285 |
+
matches = torch.zeros(len(pred_curves))
|
| 286 |
+
for i in range(len(pred_curves)):
|
| 287 |
+
dists = cd_matrix[i]
|
| 288 |
+
min_dist, min_idx = dists.min(0)
|
| 289 |
+
if min_dist < self.map_cd_thresh and min_idx not in used_gt:
|
| 290 |
+
matches[i] = 1.0
|
| 291 |
+
used_gt.add(min_idx)
|
| 292 |
+
|
| 293 |
+
self.map_all_scores.append(pred_scores)
|
| 294 |
+
self.map_all_matches.append(matches)
|
| 295 |
+
self.map_all_classes.append(torch.full_like(pred_scores, cls))
|
| 296 |
+
|
| 297 |
+
pred_points = np.concatenate(pred_points_list, axis=0)
|
| 298 |
+
gt_points = np.concatenate(gt_points_list, axis=0)
|
| 299 |
+
# Calculate distances
|
| 300 |
+
cd, bhd = calc_chamfer_distance(pred_points, gt_points)
|
| 301 |
+
|
| 302 |
+
self.total_cd += torch.tensor(cd)
|
| 303 |
+
self.total_cd_sq += torch.tensor(cd**2)
|
| 304 |
+
self.total_bhd += torch.tensor(bhd)
|
| 305 |
+
self.total_bhd_sq += torch.tensor(bhd**2)
|
| 306 |
+
self.count += 1
|
| 307 |
+
self.valid_count += 1 if len(pred_points) > 0 else 0
|
| 308 |
+
|
| 309 |
+
def compute(self):
|
| 310 |
+
if not self.map_all_scores:
|
| 311 |
+
return {cls_name: 0.0 for cls_name in self.map_cls_names.values()}
|
| 312 |
+
scores = torch.cat(self.map_all_scores)
|
| 313 |
+
matches = torch.cat(self.map_all_matches)
|
| 314 |
+
classes = torch.cat(self.map_all_classes)
|
| 315 |
+
map_result = {}
|
| 316 |
+
ap_values = []
|
| 317 |
+
|
| 318 |
+
for cls in self.map_cls_names.keys():
|
| 319 |
+
mask = classes == cls
|
| 320 |
+
if mask.sum() == 0 or torch.sum(matches[mask]) == 0:
|
| 321 |
+
ap = torch.tensor(0.0, device=self.device)
|
| 322 |
+
else:
|
| 323 |
+
ap = average_precision(
|
| 324 |
+
scores[mask], matches[mask].to(torch.int32), task="binary"
|
| 325 |
+
)
|
| 326 |
+
map_result[self.map_cls_names[cls]] = ap.item()
|
| 327 |
+
ap_values.append(ap)
|
| 328 |
+
|
| 329 |
+
map_result["mAP"] = torch.stack(ap_values).mean().item()
|
| 330 |
+
|
| 331 |
+
if self.count == 0:
|
| 332 |
+
return {"chamfer_distance": 0.0, "bidirectional_hausdorff": 0.0}
|
| 333 |
+
|
| 334 |
+
mean_cd = (self.total_cd / self.valid_count).item()
|
| 335 |
+
mean_bhd = (self.total_bhd / self.valid_count).item()
|
| 336 |
+
results = {
|
| 337 |
+
"chamfer_distance": mean_cd,
|
| 338 |
+
"chamfer_distance_std": (self.total_cd_sq / self.valid_count - (mean_cd**2))
|
| 339 |
+
.sqrt()
|
| 340 |
+
.item(),
|
| 341 |
+
"bidirectional_hausdorff": mean_bhd,
|
| 342 |
+
"bidirectional_hausdorff_std": (
|
| 343 |
+
self.total_bhd_sq / self.valid_count - (mean_bhd**2)
|
| 344 |
+
)
|
| 345 |
+
.sqrt()
|
| 346 |
+
.item(),
|
| 347 |
+
}
|
| 348 |
+
results.update(map_result)
|
| 349 |
+
return results
|
pi3detr/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model_config import ModelConfig
|
| 2 |
+
from .pi3detr import PI3DETR
|
pi3detr/models/losses/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .losses import (
|
| 2 |
+
chamfer_distance_batch,
|
| 3 |
+
LossParams,
|
| 4 |
+
ParametricLoss,
|
| 5 |
+
)
|
pi3detr/models/losses/losses.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch_geometric.data.data import Data
|
| 7 |
+
from kornia.losses import focal_loss
|
| 8 |
+
from .matcher import *
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def chamfer_distance_batch(pts1: torch.Tensor, pts2: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
assert len(pts1.shape) == 3 and len(pts2.shape) == 3
|
| 13 |
+
if pts1.nelement() == 0 or pts2.nelement() == 0:
|
| 14 |
+
return torch.tensor(0.0, device=pts1.device, requires_grad=True)
|
| 15 |
+
dist_matrix = torch.cdist(
|
| 16 |
+
pts1, pts2, p=2
|
| 17 |
+
) # shape: (batch_size, num_points, num_points)
|
| 18 |
+
dist1 = dist_matrix.min(dim=2).values.mean(dim=1) # min over pts2, mean over pts1
|
| 19 |
+
dist2 = dist_matrix.min(dim=1).values.mean(dim=1) # min over pts1, mean over pts2
|
| 20 |
+
return (dist1 + dist2) / 2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def accuracy(output, target, topk=(1,)):
|
| 25 |
+
"""Computes the precision@k for the specified values of k"""
|
| 26 |
+
if target.numel() == 0:
|
| 27 |
+
return [torch.zeros([], device=output.device)]
|
| 28 |
+
maxk = max(topk)
|
| 29 |
+
batch_size = target.size(0)
|
| 30 |
+
|
| 31 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 32 |
+
pred = pred.t()
|
| 33 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 34 |
+
|
| 35 |
+
res = []
|
| 36 |
+
for k in topk:
|
| 37 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
| 38 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def f1_score(output, target, threshold=0.5):
|
| 44 |
+
output = output > threshold
|
| 45 |
+
target = target > threshold
|
| 46 |
+
tp = (output & target).sum()
|
| 47 |
+
tn = (~output & ~target).sum()
|
| 48 |
+
fp = (output & ~target).sum()
|
| 49 |
+
fn = (~output & target).sum()
|
| 50 |
+
precision = tp / (tp + fp + 1e-8)
|
| 51 |
+
recall = tp / (tp + fn + 1e-8)
|
| 52 |
+
return 2 * (precision * recall) / (precision + recall + 1e-8), precision, recall
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class LossParams:
|
| 57 |
+
num_classes: int
|
| 58 |
+
cost_class: int = 1
|
| 59 |
+
cost_curve: int = 1
|
| 60 |
+
class_loss_type: str = "cross_entropy" # or "focal"
|
| 61 |
+
class_loss_weights: list[float] = field(
|
| 62 |
+
default_factory=lambda: [
|
| 63 |
+
0.04834912,
|
| 64 |
+
0.40329467,
|
| 65 |
+
0.09588135,
|
| 66 |
+
0.23071379,
|
| 67 |
+
0.22176106,
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
# NOTE: Weights calculated based on the dataset
|
| 71 |
+
# bezier, line, circle, arc, empty
|
| 72 |
+
# counts = np.array([11347, 200751, 34672, 37528])
|
| 73 |
+
# counts = np.append(counts, total_pred - counts.sum())
|
| 74 |
+
# weights = 1 / np.sqrt(counts)
|
| 75 |
+
# weights = weights / weights.sum()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Loss(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(self, params: LossParams) -> None:
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.matcher = ParametricMatcher(params.cost_class, params.cost_curve)
|
| 83 |
+
self.num_classes = params.num_classes
|
| 84 |
+
self.class_loss_type = params.class_loss_type
|
| 85 |
+
class_weights = torch.tensor(
|
| 86 |
+
params.class_loss_weights,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.register_buffer("class_weights", class_weights)
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self, outputs: dict[str, torch.Tensor], data: Data
|
| 93 |
+
) -> dict[str, torch.Tensor]:
|
| 94 |
+
indices = self.matcher(outputs, data)
|
| 95 |
+
losses = {}
|
| 96 |
+
losses.update(self._loss_class(outputs, data, indices))
|
| 97 |
+
losses.update(self._loss_polyline(outputs, data, indices))
|
| 98 |
+
# In case of auxiliary losses, we repeat this process with the output
|
| 99 |
+
# of each intermediate layer.
|
| 100 |
+
if "aux_outputs" in outputs:
|
| 101 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
| 102 |
+
indices = self.matcher(aux_outputs, data)
|
| 103 |
+
l_dict = self._loss_class(aux_outputs, data, indices, False)
|
| 104 |
+
l_dict.update(self._loss_polyline(aux_outputs, data, indices))
|
| 105 |
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
| 106 |
+
losses.update(l_dict)
|
| 107 |
+
return losses
|
| 108 |
+
|
| 109 |
+
@abstractmethod
|
| 110 |
+
def _loss_polyline(
|
| 111 |
+
self,
|
| 112 |
+
outputs: dict[str, torch.Tensor],
|
| 113 |
+
data: Data,
|
| 114 |
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
| 115 |
+
) -> dict[str, torch.Tensor]:
|
| 116 |
+
"""Compute the polyline loss."""
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
def _loss_class(
|
| 120 |
+
self,
|
| 121 |
+
outputs: dict[str, torch.Tensor],
|
| 122 |
+
data: Data,
|
| 123 |
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
| 124 |
+
log: bool = True,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
num_targets = (
|
| 127 |
+
data.num_polylines.tolist()
|
| 128 |
+
if hasattr(data, "num_polylines")
|
| 129 |
+
else data.num_curves.tolist()
|
| 130 |
+
)
|
| 131 |
+
src_logits = outputs["pred_class"]
|
| 132 |
+
idx = self._get_src_permutation_idx(indices)
|
| 133 |
+
target_classes_o = torch.cat(
|
| 134 |
+
[
|
| 135 |
+
target[J]
|
| 136 |
+
for target, (_, J) in zip(
|
| 137 |
+
data.y_cls.split_with_sizes(num_targets), indices
|
| 138 |
+
)
|
| 139 |
+
]
|
| 140 |
+
)
|
| 141 |
+
target_classes = torch.full(
|
| 142 |
+
src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device
|
| 143 |
+
) # 0: empty class
|
| 144 |
+
target_classes[idx] = target_classes_o
|
| 145 |
+
losses = {}
|
| 146 |
+
|
| 147 |
+
if self.class_loss_type == "cross_entropy":
|
| 148 |
+
loss_class = F.cross_entropy(
|
| 149 |
+
src_logits.transpose(1, 2),
|
| 150 |
+
target_classes,
|
| 151 |
+
weight=self.class_weights.to(src_logits.device),
|
| 152 |
+
reduction="mean",
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
loss_class = focal_loss(
|
| 156 |
+
src_logits.transpose(1, 2),
|
| 157 |
+
target_classes,
|
| 158 |
+
alpha=0.25,
|
| 159 |
+
gamma=2.0,
|
| 160 |
+
weight=self.class_weights.to(src_logits.device),
|
| 161 |
+
reduction="mean",
|
| 162 |
+
)
|
| 163 |
+
losses["loss_class"] = loss_class
|
| 164 |
+
if log:
|
| 165 |
+
losses["class_error"] = (
|
| 166 |
+
100
|
| 167 |
+
- accuracy(
|
| 168 |
+
src_logits.reshape(-1, src_logits.size(-1)),
|
| 169 |
+
target_classes.flatten(),
|
| 170 |
+
)[0]
|
| 171 |
+
)
|
| 172 |
+
f1, _, _ = f1_score(
|
| 173 |
+
src_logits.reshape(-1, src_logits.size(-1)).softmax(-1).argmax(-1),
|
| 174 |
+
target_classes.flatten(),
|
| 175 |
+
threshold=0.5,
|
| 176 |
+
)
|
| 177 |
+
losses["class_f1_score"] = f1
|
| 178 |
+
|
| 179 |
+
return losses
|
| 180 |
+
|
| 181 |
+
def _get_src_permutation_idx(
|
| 182 |
+
self, indices: list[tuple[torch.Tensor, torch.Tensor]]
|
| 183 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 184 |
+
# permute predictions following indices
|
| 185 |
+
batch_idx = torch.cat(
|
| 186 |
+
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
|
| 187 |
+
)
|
| 188 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
| 189 |
+
return batch_idx, src_idx
|
| 190 |
+
|
| 191 |
+
def _get_tgt_permutation_idx(
|
| 192 |
+
self, indices: list[tuple[torch.Tensor, torch.Tensor]]
|
| 193 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 194 |
+
# permute targets following indices
|
| 195 |
+
batch_idx = torch.cat(
|
| 196 |
+
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
|
| 197 |
+
)
|
| 198 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| 199 |
+
return batch_idx, tgt_idx
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class ParametricLoss(Loss):
|
| 203 |
+
def __init__(self, params: LossParams) -> None:
|
| 204 |
+
super().__init__(params)
|
| 205 |
+
|
| 206 |
+
def _loss_polyline(
|
| 207 |
+
self,
|
| 208 |
+
outputs: dict[str, torch.Tensor],
|
| 209 |
+
data: Data,
|
| 210 |
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
| 211 |
+
) -> dict[str, torch.Tensor]:
|
| 212 |
+
idx = self._get_src_permutation_idx(indices)
|
| 213 |
+
src_bspline_params = outputs["pred_bspline_params"][idx]
|
| 214 |
+
src_bspline_points = outputs["pred_bspline_points"][idx]
|
| 215 |
+
src_line_params = outputs["pred_line_params"][idx]
|
| 216 |
+
src_line_length = outputs["pred_line_length"][idx]
|
| 217 |
+
src_line_points = outputs["pred_line_points"][idx]
|
| 218 |
+
src_circle_params = outputs["pred_circle_params"][idx]
|
| 219 |
+
src_circle_radius = outputs["pred_circle_radius"][idx]
|
| 220 |
+
src_circle_points = outputs["pred_circle_points"][idx]
|
| 221 |
+
src_arc_params = outputs["pred_arc_params"][idx]
|
| 222 |
+
src_arc_points = outputs["pred_arc_points"][idx]
|
| 223 |
+
target_params = torch.cat(
|
| 224 |
+
[
|
| 225 |
+
target[J]
|
| 226 |
+
for target, (_, J) in zip(
|
| 227 |
+
data.y_params.split_with_sizes(data.num_curves.tolist()), indices
|
| 228 |
+
)
|
| 229 |
+
]
|
| 230 |
+
)
|
| 231 |
+
target_classes = torch.cat(
|
| 232 |
+
[
|
| 233 |
+
target[J]
|
| 234 |
+
for target, (_, J) in zip(
|
| 235 |
+
data.y_cls.split_with_sizes(data.num_curves.tolist()), indices
|
| 236 |
+
)
|
| 237 |
+
]
|
| 238 |
+
)
|
| 239 |
+
target_curves = torch.cat(
|
| 240 |
+
[
|
| 241 |
+
target[J]
|
| 242 |
+
for target, (_, J) in zip(
|
| 243 |
+
data.y_curve_64.split_with_sizes(data.num_curves.tolist()), indices
|
| 244 |
+
)
|
| 245 |
+
]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
losses = {}
|
| 249 |
+
|
| 250 |
+
# Filter indices for each class
|
| 251 |
+
bspline_mask = target_classes == 1 # B-spline
|
| 252 |
+
line_mask = target_classes == 2 # Line
|
| 253 |
+
circle_mask = target_classes == 3 # Circle
|
| 254 |
+
arc_mask = target_classes == 4 # Arc
|
| 255 |
+
|
| 256 |
+
# Compute loss for B-splines
|
| 257 |
+
if bspline_mask.any():
|
| 258 |
+
bspline_order_l1 = torch.min(
|
| 259 |
+
F.l1_loss(
|
| 260 |
+
src_bspline_params[bspline_mask].flatten(-2, -1),
|
| 261 |
+
target_params[bspline_mask],
|
| 262 |
+
reduction="none",
|
| 263 |
+
).mean(-1),
|
| 264 |
+
F.l1_loss(
|
| 265 |
+
src_bspline_params[bspline_mask].flip([1]).flatten(-2, -1),
|
| 266 |
+
target_params[bspline_mask],
|
| 267 |
+
reduction="none",
|
| 268 |
+
).mean(-1),
|
| 269 |
+
).mean()
|
| 270 |
+
losses["loss_bspline"] = bspline_order_l1
|
| 271 |
+
bspline_chamfer = chamfer_distance_batch(
|
| 272 |
+
src_bspline_points[bspline_mask], target_curves[bspline_mask]
|
| 273 |
+
)
|
| 274 |
+
losses["loss_bspline_chamfer"] = bspline_chamfer.mean()
|
| 275 |
+
else:
|
| 276 |
+
losses["loss_bspline"] = torch.tensor(0.0, device=src_bspline_params.device)
|
| 277 |
+
losses["loss_bspline_chamfer"] = torch.tensor(
|
| 278 |
+
0.0, device=src_bspline_points.device
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Compute loss for Lines
|
| 282 |
+
if line_mask.any():
|
| 283 |
+
line_position_l1 = torch.min(
|
| 284 |
+
F.l1_loss(
|
| 285 |
+
src_line_params[line_mask].flatten(-2, -1),
|
| 286 |
+
target_params[line_mask, :6],
|
| 287 |
+
reduction="none",
|
| 288 |
+
).mean(-1),
|
| 289 |
+
# also consider the negative direction
|
| 290 |
+
F.l1_loss(
|
| 291 |
+
(
|
| 292 |
+
src_line_params[line_mask]
|
| 293 |
+
* torch.tensor([1.0, -1.0])
|
| 294 |
+
.view(1, 2, 1)
|
| 295 |
+
.to(src_line_params.device)
|
| 296 |
+
).flatten(-2, -1),
|
| 297 |
+
target_params[line_mask, :6],
|
| 298 |
+
reduction="none",
|
| 299 |
+
).mean(-1),
|
| 300 |
+
).mean()
|
| 301 |
+
line_length_loss = F.l1_loss(
|
| 302 |
+
src_line_length[line_mask],
|
| 303 |
+
target_params[line_mask, 6].unsqueeze(-1),
|
| 304 |
+
)
|
| 305 |
+
losses["loss_line_position"] = line_position_l1
|
| 306 |
+
losses["loss_line_length"] = line_length_loss
|
| 307 |
+
line_chamfer = chamfer_distance_batch(
|
| 308 |
+
src_line_points[line_mask], target_curves[line_mask]
|
| 309 |
+
)
|
| 310 |
+
losses["loss_line_chamfer"] = line_chamfer.mean()
|
| 311 |
+
else:
|
| 312 |
+
losses["loss_line_position"] = torch.tensor(
|
| 313 |
+
0.0, device=src_line_params.device
|
| 314 |
+
)
|
| 315 |
+
losses["loss_line_length"] = torch.tensor(
|
| 316 |
+
0.0, device=src_line_length.device
|
| 317 |
+
)
|
| 318 |
+
losses["loss_line_chamfer"] = torch.tensor(
|
| 319 |
+
0.0, device=src_line_points.device
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Compute loss for Circles
|
| 323 |
+
if circle_mask.any():
|
| 324 |
+
circle_position_l1 = torch.min(
|
| 325 |
+
F.l1_loss(
|
| 326 |
+
src_circle_params[circle_mask].flatten(-2, -1),
|
| 327 |
+
target_params[circle_mask, :6],
|
| 328 |
+
reduction="none",
|
| 329 |
+
).mean(-1),
|
| 330 |
+
# also consider the negative direction
|
| 331 |
+
F.l1_loss(
|
| 332 |
+
(
|
| 333 |
+
src_circle_params[circle_mask]
|
| 334 |
+
* torch.tensor([1.0, -1.0])
|
| 335 |
+
.view(1, 2, 1)
|
| 336 |
+
.to(src_circle_params.device)
|
| 337 |
+
).flatten(-2, -1),
|
| 338 |
+
target_params[circle_mask, :6],
|
| 339 |
+
reduction="none",
|
| 340 |
+
).mean(-1),
|
| 341 |
+
).mean()
|
| 342 |
+
radius_loss = F.l1_loss(
|
| 343 |
+
src_circle_radius[circle_mask],
|
| 344 |
+
target_params[circle_mask, 6].unsqueeze(-1),
|
| 345 |
+
)
|
| 346 |
+
losses["loss_circle_position"] = circle_position_l1
|
| 347 |
+
losses["loss_circle_radius"] = radius_loss
|
| 348 |
+
circle_chamfer = chamfer_distance_batch(
|
| 349 |
+
src_circle_points[circle_mask], target_curves[circle_mask]
|
| 350 |
+
)
|
| 351 |
+
losses["loss_circle_chamfer"] = circle_chamfer.mean()
|
| 352 |
+
else:
|
| 353 |
+
losses["loss_circle_position"] = torch.tensor(
|
| 354 |
+
0.0, device=src_circle_params.device
|
| 355 |
+
)
|
| 356 |
+
losses["loss_circle_radius"] = torch.tensor(
|
| 357 |
+
0.0, device=src_circle_radius.device
|
| 358 |
+
)
|
| 359 |
+
losses["loss_circle_chamfer"] = torch.tensor(
|
| 360 |
+
0.0, device=src_circle_points.device
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Compute loss for Arcs
|
| 364 |
+
if arc_mask.any():
|
| 365 |
+
arc_order_l1 = torch.min(
|
| 366 |
+
F.l1_loss(
|
| 367 |
+
src_arc_params[arc_mask].flatten(-2, -1),
|
| 368 |
+
target_params[arc_mask, :9],
|
| 369 |
+
reduction="none",
|
| 370 |
+
).mean(-1),
|
| 371 |
+
F.l1_loss(
|
| 372 |
+
src_arc_params[arc_mask][:, [0, 2, 1]].flatten(-2, -1),
|
| 373 |
+
target_params[arc_mask, :9],
|
| 374 |
+
reduction="none",
|
| 375 |
+
).mean(-1),
|
| 376 |
+
).mean()
|
| 377 |
+
losses["loss_arc"] = arc_order_l1
|
| 378 |
+
arc_chamfer = chamfer_distance_batch(
|
| 379 |
+
src_arc_points[arc_mask], target_curves[arc_mask]
|
| 380 |
+
)
|
| 381 |
+
losses["loss_arc_chamfer"] = arc_chamfer.mean()
|
| 382 |
+
else:
|
| 383 |
+
losses["loss_arc"] = torch.tensor(0.0, device=src_arc_params.device)
|
| 384 |
+
losses["loss_arc_chamfer"] = torch.tensor(0.0, device=src_arc_points.device)
|
| 385 |
+
|
| 386 |
+
losses["total_curve"] = (
|
| 387 |
+
losses["loss_bspline"]
|
| 388 |
+
+ losses["loss_line_position"]
|
| 389 |
+
+ losses["loss_line_length"]
|
| 390 |
+
+ losses["loss_circle_position"]
|
| 391 |
+
+ losses["loss_circle_radius"]
|
| 392 |
+
+ losses["loss_line_chamfer"]
|
| 393 |
+
+ losses["loss_circle_chamfer"]
|
| 394 |
+
+ losses["loss_bspline_chamfer"]
|
| 395 |
+
+ losses["loss_arc"]
|
| 396 |
+
+ losses["loss_arc_chamfer"]
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return losses
|
pi3detr/models/losses/matcher.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from scipy.optimize import linear_sum_assignment
|
| 4 |
+
from torch_geometric.data.data import Data
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ParametricMatcher(nn.Module):
|
| 8 |
+
def __init__(self, cost_class: int = 1, cost_curve: int = 1) -> None:
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.cost_class = cost_class
|
| 11 |
+
self.cost_curve = cost_curve
|
| 12 |
+
|
| 13 |
+
@torch.no_grad()
|
| 14 |
+
def forward(
|
| 15 |
+
self, outputs: dict[str, torch.Tensor], data: Data
|
| 16 |
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
| 17 |
+
"""
|
| 18 |
+
Compute the matching indices based on class costs and Chamfer distance.
|
| 19 |
+
"""
|
| 20 |
+
bs, num_queries = outputs["pred_class"].shape[:2]
|
| 21 |
+
|
| 22 |
+
# Compute the classification cost
|
| 23 |
+
out_prob = (
|
| 24 |
+
outputs["pred_class"].flatten(0, 1).softmax(-1)
|
| 25 |
+
) # [batch_size * num_queries, num_classes]
|
| 26 |
+
cost_class = -out_prob[:, data.y_cls]
|
| 27 |
+
|
| 28 |
+
pred_bspline_params = outputs["pred_bspline_params"].flatten(0, 1)
|
| 29 |
+
pred_line_params = outputs["pred_line_params"].flatten(0, 1)
|
| 30 |
+
pred_line_length = outputs["pred_line_length"].flatten(0, 1)
|
| 31 |
+
pred_circle_params = outputs["pred_circle_params"].flatten(0, 1)
|
| 32 |
+
pred_circle_radius = outputs["pred_circle_radius"].flatten(0, 1)
|
| 33 |
+
pred_arc_params = outputs["pred_arc_params"].flatten(0, 1)
|
| 34 |
+
|
| 35 |
+
# classes -> 1: bspline, 2: line, 3: circle, 4: arc
|
| 36 |
+
# NOTE: scaling done assuming points are in [-1, 1] range
|
| 37 |
+
bspline_costs = torch.min(
|
| 38 |
+
torch.cdist(
|
| 39 |
+
pred_bspline_params.flatten(-2, -1),
|
| 40 |
+
data.bspline_params.flatten(-2, -1),
|
| 41 |
+
p=1,
|
| 42 |
+
),
|
| 43 |
+
torch.cdist(
|
| 44 |
+
pred_bspline_params.flip([1]).flatten(-2, -1),
|
| 45 |
+
data.bspline_params.flatten(-2, -1),
|
| 46 |
+
p=1,
|
| 47 |
+
),
|
| 48 |
+
) # [batch_size * num_queries, num_curves]
|
| 49 |
+
line_costs = torch.min(
|
| 50 |
+
torch.cdist(
|
| 51 |
+
pred_line_params.flatten(-2, -1),
|
| 52 |
+
data.line_params.flatten(-2, -1),
|
| 53 |
+
p=1,
|
| 54 |
+
),
|
| 55 |
+
torch.cdist(
|
| 56 |
+
(
|
| 57 |
+
pred_line_params
|
| 58 |
+
* torch.tensor([1.0, -1.0])
|
| 59 |
+
.view(1, 2, 1)
|
| 60 |
+
.to(pred_line_params.device)
|
| 61 |
+
).flatten(-2, -1),
|
| 62 |
+
data.line_params.flatten(-2, -1),
|
| 63 |
+
p=1,
|
| 64 |
+
),
|
| 65 |
+
) + torch.cdist(
|
| 66 |
+
pred_line_length,
|
| 67 |
+
data.line_length.unsqueeze(-1),
|
| 68 |
+
p=1,
|
| 69 |
+
) # [batch_size * num_queries, num_curves]
|
| 70 |
+
circle_costs = torch.min(
|
| 71 |
+
torch.cdist(
|
| 72 |
+
pred_circle_params.flatten(-2, -1),
|
| 73 |
+
data.circle_params.flatten(-2, -1),
|
| 74 |
+
p=1,
|
| 75 |
+
),
|
| 76 |
+
torch.cdist(
|
| 77 |
+
(
|
| 78 |
+
pred_circle_params
|
| 79 |
+
* torch.tensor([1.0, -1.0])
|
| 80 |
+
.view(1, 2, 1)
|
| 81 |
+
.to(pred_circle_params.device)
|
| 82 |
+
).flatten(-2, -1),
|
| 83 |
+
data.circle_params.flatten(-2, -1),
|
| 84 |
+
p=1,
|
| 85 |
+
),
|
| 86 |
+
) + torch.cdist(
|
| 87 |
+
pred_circle_radius,
|
| 88 |
+
data.circle_radius.unsqueeze(-1),
|
| 89 |
+
p=1,
|
| 90 |
+
) # [batch_size * num_queries, num_curves]
|
| 91 |
+
arc_costs = torch.min(
|
| 92 |
+
torch.cdist(
|
| 93 |
+
pred_arc_params.flatten(-2, -1),
|
| 94 |
+
data.arc_params.flatten(-2, -1),
|
| 95 |
+
p=1,
|
| 96 |
+
),
|
| 97 |
+
# mid, start, end | start and end can be swapped
|
| 98 |
+
torch.cdist(
|
| 99 |
+
pred_arc_params[:, [0, 2, 1], :].flatten(-2, -1),
|
| 100 |
+
data.arc_params.flatten(-2, -1),
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
cost_params = torch.stack(
|
| 105 |
+
[
|
| 106 |
+
torch.zeros_like(line_costs),
|
| 107 |
+
bspline_costs,
|
| 108 |
+
line_costs,
|
| 109 |
+
circle_costs,
|
| 110 |
+
arc_costs,
|
| 111 |
+
],
|
| 112 |
+
dim=-1,
|
| 113 |
+
)
|
| 114 |
+
cost_params = cost_params[
|
| 115 |
+
torch.arange(cost_params.size(0))[:, None],
|
| 116 |
+
torch.arange(cost_params.size(1)),
|
| 117 |
+
data.y_cls,
|
| 118 |
+
] # [num_queries, num_curves]
|
| 119 |
+
|
| 120 |
+
# Combine costs
|
| 121 |
+
C = self.cost_class * cost_class + self.cost_curve * cost_params
|
| 122 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 123 |
+
|
| 124 |
+
# Perform Hungarian matching
|
| 125 |
+
indices = [
|
| 126 |
+
linear_sum_assignment(c[i])
|
| 127 |
+
for i, c in enumerate(C.split(data.num_curves.cpu().tolist(), -1))
|
| 128 |
+
]
|
| 129 |
+
return [
|
| 130 |
+
(
|
| 131 |
+
torch.as_tensor(i, dtype=torch.int64),
|
| 132 |
+
torch.as_tensor(j, dtype=torch.int64),
|
| 133 |
+
)
|
| 134 |
+
for i, j in indices
|
| 135 |
+
]
|
pi3detr/models/model_config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from torch_geometric.nn import (
|
| 4 |
+
MLP,
|
| 5 |
+
)
|
| 6 |
+
from .pointnetpp import SAModule
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModelConfig:
|
| 11 |
+
model: str
|
| 12 |
+
num_features: int
|
| 13 |
+
epochs: int = 1700
|
| 14 |
+
lr: float = 1e-4
|
| 15 |
+
lr_warmup_epochs: int = 15
|
| 16 |
+
lr_warmup_start_factor: float = 1e-6
|
| 17 |
+
lr_step: int = 1230
|
| 18 |
+
batch_size: int = 8
|
| 19 |
+
batch_size_val: int = 8
|
| 20 |
+
loss_weights: Optional[dict[str, float]] = None
|
| 21 |
+
num_curve_points: Optional[int] = 64
|
| 22 |
+
num_curve_points_val: Optional[int] = 256
|
| 23 |
+
preencoder_type: Optional[str] = "samodule"
|
| 24 |
+
preencoder_lr: Optional[float] = 1e-4
|
| 25 |
+
freeze_backbone: bool = False
|
| 26 |
+
encoder_dim: Optional[int] = 768
|
| 27 |
+
decoder_dim: Optional[int] = 768
|
| 28 |
+
num_encoder_layers: Optional[int] = 3
|
| 29 |
+
num_decoder_layers: Optional[int] = 9
|
| 30 |
+
encoder_dropout: float = 0.1
|
| 31 |
+
decoder_dropout: float = 0.1
|
| 32 |
+
num_attn_heads: Optional[int] = 8
|
| 33 |
+
enc_dim_feedforward: Optional[int] = 2048
|
| 34 |
+
dec_dim_feedforward: Optional[int] = 2048
|
| 35 |
+
mlp_dropout: float = 0.0
|
| 36 |
+
num_preds: Optional[int] = 128
|
| 37 |
+
num_classes: Optional[int] = 5
|
| 38 |
+
cost_weights: Optional[dict[str, float]] = None
|
| 39 |
+
auxiliary_loss: bool = True
|
| 40 |
+
max_points_in_param: Optional[int] = 4
|
| 41 |
+
num_transformer_points: Optional[int] = 2048
|
| 42 |
+
query_type: str = "point_fps"
|
| 43 |
+
pos_embed_type: str = "sine"
|
| 44 |
+
class_loss_type: str = "cross_entropy" # or "focal"
|
| 45 |
+
class_loss_weights: list[float] = field(
|
| 46 |
+
default_factory=lambda: [
|
| 47 |
+
0.04834912,
|
| 48 |
+
0.40329467,
|
| 49 |
+
0.09588135,
|
| 50 |
+
0.23071379,
|
| 51 |
+
0.22176106,
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def get_preencoder(self):
|
| 56 |
+
preencoder_type = self.preencoder_type
|
| 57 |
+
preencoder = None
|
| 58 |
+
if preencoder_type == "samodule":
|
| 59 |
+
preencoder = SAModule(
|
| 60 |
+
MLP([self.num_features + 3, 64, 128, self.encoder_dim]),
|
| 61 |
+
num_out_points=self.num_transformer_points,
|
| 62 |
+
)
|
| 63 |
+
preencoder.out_channels = self.encoder_dim
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unknown preencoder type: {self.preencoder_type}.")
|
| 66 |
+
return preencoder
|
pi3detr/models/pi3detr.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from pytorch_lightning.utilities.types import OptimizerLRScheduler
|
| 5 |
+
from torch_geometric.data.data import Data
|
| 6 |
+
from torch_geometric.nn import MLP
|
| 7 |
+
from torch_geometric.utils import to_dense_batch
|
| 8 |
+
from .model_config import ModelConfig
|
| 9 |
+
from .losses import LossParams, ParametricLoss
|
| 10 |
+
from .transformer import Transformer
|
| 11 |
+
from .positional_embedding import PositionEmbeddingCoordsSine
|
| 12 |
+
from .query_engine import build_query_engine
|
| 13 |
+
from pi3detr.dataset import reverse_normalize_and_scale
|
| 14 |
+
from ..utils.curve_fitter import (
|
| 15 |
+
torch_bezier_curve,
|
| 16 |
+
torch_line_points,
|
| 17 |
+
generate_points_on_circle_torch,
|
| 18 |
+
torch_arc_points,
|
| 19 |
+
)
|
| 20 |
+
from ..utils.postprocessing import (
|
| 21 |
+
snap_and_fit_curves,
|
| 22 |
+
filter_predictions,
|
| 23 |
+
iou_filter_point_based,
|
| 24 |
+
iou_filter_predictions,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from pi3detr.evaluation.abc_metrics import (
|
| 28 |
+
ChamferMAP,
|
| 29 |
+
)
|
| 30 |
+
from torchmetrics.classification import (
|
| 31 |
+
BinaryJaccardIndex,
|
| 32 |
+
BinaryPrecision,
|
| 33 |
+
BinaryRecall,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PI3DETR(pl.LightningModule):
|
| 38 |
+
def __init__(self, config: ModelConfig):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.config = config
|
| 41 |
+
self.pc_preencoder = config.get_preencoder()
|
| 42 |
+
self.enc_dim = config.encoder_dim
|
| 43 |
+
self.dec_dim = config.decoder_dim
|
| 44 |
+
self.num_preds = config.num_preds
|
| 45 |
+
self.num_curve_points = config.num_curve_points
|
| 46 |
+
self.num_curve_points_val = config.num_curve_points_val
|
| 47 |
+
self.num_classes = config.num_classes
|
| 48 |
+
self.max_points_in_param = config.max_points_in_param
|
| 49 |
+
self.preenc_to_enc_proj = MLP(
|
| 50 |
+
[self.pc_preencoder.out_channels, self.enc_dim, self.enc_dim],
|
| 51 |
+
act="relu",
|
| 52 |
+
norm="layer_norm",
|
| 53 |
+
)
|
| 54 |
+
num_decoder_layers = config.num_decoder_layers
|
| 55 |
+
self.transformer = Transformer(
|
| 56 |
+
self.enc_dim,
|
| 57 |
+
self.dec_dim,
|
| 58 |
+
nhead=config.num_attn_heads,
|
| 59 |
+
num_encoder_layers=config.num_encoder_layers,
|
| 60 |
+
num_decoder_layers=num_decoder_layers,
|
| 61 |
+
enc_dim_feedforward=config.enc_dim_feedforward,
|
| 62 |
+
dec_dim_feedforward=config.dec_dim_feedforward,
|
| 63 |
+
enc_dropout=config.encoder_dropout,
|
| 64 |
+
dec_dropout=config.decoder_dropout,
|
| 65 |
+
return_intermediate_dec=True,
|
| 66 |
+
)
|
| 67 |
+
self.positional_embedding = PositionEmbeddingCoordsSine(
|
| 68 |
+
d_pos=self.dec_dim, pos_type=self.config.pos_embed_type
|
| 69 |
+
)
|
| 70 |
+
self.pos_embed_proj = MLP(
|
| 71 |
+
[self.dec_dim, self.dec_dim, self.dec_dim], act="relu", norm="layer_norm"
|
| 72 |
+
)
|
| 73 |
+
self.query_type = config.query_type
|
| 74 |
+
self.query_engine = build_query_engine(
|
| 75 |
+
self.query_type,
|
| 76 |
+
self.positional_embedding,
|
| 77 |
+
self.dec_dim,
|
| 78 |
+
self.max_points_in_param,
|
| 79 |
+
self.num_preds,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def make_mlp(out_dim, layers=4, base_dim=None, bias_last=True):
|
| 83 |
+
base_dim = base_dim or self.dec_dim
|
| 84 |
+
n_layers = layers - 1
|
| 85 |
+
return MLP(
|
| 86 |
+
channel_list=[base_dim] * n_layers + [out_dim],
|
| 87 |
+
bias=[False] * (n_layers - 1) + [bias_last],
|
| 88 |
+
dropout=self.config.mlp_dropout,
|
| 89 |
+
act="relu",
|
| 90 |
+
norm="layer_norm",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.class_head = make_mlp(self.num_classes)
|
| 94 |
+
self.bspline_param_head = make_mlp(4 * 3)
|
| 95 |
+
self.line_param_head = make_mlp(2 * 3)
|
| 96 |
+
self.line_length_head = make_mlp(1, layers=3)
|
| 97 |
+
self.circle_param_head = make_mlp(2 * 3)
|
| 98 |
+
self.circle_radius_head = make_mlp(1, layers=3)
|
| 99 |
+
self.arc_param_head = make_mlp(3 * 3)
|
| 100 |
+
|
| 101 |
+
self.loss = ParametricLoss(
|
| 102 |
+
LossParams(
|
| 103 |
+
num_classes=self.num_classes - 1, # -1 for the EOS token
|
| 104 |
+
cost_class=config.cost_weights["cost_class"],
|
| 105 |
+
cost_curve=config.cost_weights["cost_curve"],
|
| 106 |
+
class_loss_type=config.class_loss_type,
|
| 107 |
+
class_loss_weights=config.class_loss_weights,
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
self.auxiliary_loss = self.config.auxiliary_loss
|
| 111 |
+
self.weight_dict = {
|
| 112 |
+
"loss_class": config.loss_weights["loss_class"],
|
| 113 |
+
"loss_bspline": config.loss_weights["loss_bspline"],
|
| 114 |
+
"loss_bspline_chamfer": config.loss_weights["loss_bspline_chamfer"],
|
| 115 |
+
"loss_line_position": config.loss_weights["loss_line_position"],
|
| 116 |
+
"loss_line_length": config.loss_weights["loss_line_length"],
|
| 117 |
+
"loss_line_chamfer": config.loss_weights["loss_line_chamfer"],
|
| 118 |
+
"loss_circle_position": config.loss_weights["loss_circle_position"],
|
| 119 |
+
"loss_circle_radius": config.loss_weights["loss_circle_radius"],
|
| 120 |
+
"loss_circle_chamfer": config.loss_weights["loss_circle_chamfer"],
|
| 121 |
+
"loss_arc": config.loss_weights["loss_arc"],
|
| 122 |
+
"loss_arc_chamfer": config.loss_weights["loss_arc_chamfer"],
|
| 123 |
+
}
|
| 124 |
+
# TODO this is a hack
|
| 125 |
+
self.aux_weight_dict = {}
|
| 126 |
+
if self.auxiliary_loss:
|
| 127 |
+
for i in range(num_decoder_layers - 1):
|
| 128 |
+
self.aux_weight_dict.update(
|
| 129 |
+
{k + f"_{i}": v for k, v in self.weight_dict.items()}
|
| 130 |
+
)
|
| 131 |
+
self.weight_dict.update(self.aux_weight_dict)
|
| 132 |
+
|
| 133 |
+
self.chamfer_map = ChamferMAP(chamfer_thresh=0.05)
|
| 134 |
+
|
| 135 |
+
# Torchmetrics for segmentation
|
| 136 |
+
self.seg_iou = BinaryJaccardIndex()
|
| 137 |
+
self.seg_precision = BinaryPrecision()
|
| 138 |
+
self.seg_recall = BinaryRecall()
|
| 139 |
+
|
| 140 |
+
def forward(self, data: Data) -> dict[str, Tensor]:
|
| 141 |
+
x, pos, batch = self.pc_preencoder(data)[-1]
|
| 142 |
+
x = self.preenc_to_enc_proj(x)
|
| 143 |
+
x, mask = to_dense_batch(x, batch)
|
| 144 |
+
pos_dense_batch, _ = to_dense_batch(pos, batch)
|
| 145 |
+
pos_embed = self.pos_embed_proj(
|
| 146 |
+
self.positional_embedding(
|
| 147 |
+
pos_dense_batch, num_channels=self.dec_dim
|
| 148 |
+
).permute(0, 2, 1)
|
| 149 |
+
).permute(0, 2, 1)
|
| 150 |
+
query_xyz, query_embed = self.query_engine(Data(pos=pos, batch=batch))
|
| 151 |
+
x = self.transformer(
|
| 152 |
+
x, # [batch_size, num_points, enc_dim]
|
| 153 |
+
# transformer expects 1s to be masked
|
| 154 |
+
~mask if not torch.all(mask) else None, # [batch_size, num_points]
|
| 155 |
+
query_embed, # [batch_size, dec_dim, num_queries]
|
| 156 |
+
pos_embed, # [batch_size, dec_dim, num_points]
|
| 157 |
+
)
|
| 158 |
+
output_class = self.class_head(x)
|
| 159 |
+
output_bspline_params = self.bspline_param_head(x)
|
| 160 |
+
output_line_params = self.line_param_head(x)
|
| 161 |
+
output_line_length = self.line_length_head(x)
|
| 162 |
+
output_circle_params = self.circle_param_head(x)
|
| 163 |
+
output_circle_radius = self.circle_radius_head(x)
|
| 164 |
+
output_arc_params = self.arc_param_head(x)
|
| 165 |
+
|
| 166 |
+
pred_bspline_params = (
|
| 167 |
+
output_bspline_params[-1].reshape(data.batch_size, self.num_preds, 4, 3)
|
| 168 |
+
+ query_xyz
|
| 169 |
+
)
|
| 170 |
+
pred_line_params = output_line_params[-1].reshape(
|
| 171 |
+
data.batch_size, self.num_preds, 2, 3
|
| 172 |
+
)
|
| 173 |
+
pred_line_params[:, :, 0, :] = (
|
| 174 |
+
pred_line_params[:, :, 0, :] + query_xyz[:, :, 0, :]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
pred_circle_params = output_circle_params[-1].reshape(
|
| 178 |
+
data.batch_size, self.num_preds, 2, 3
|
| 179 |
+
)
|
| 180 |
+
pred_circle_params[:, :, 0, :] = (
|
| 181 |
+
pred_circle_params[:, :, 0, :] + query_xyz[:, :, 0, :]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
pred_arc_params = (
|
| 185 |
+
output_arc_params[-1].reshape(data.batch_size, self.num_preds, 3, 3)
|
| 186 |
+
+ query_xyz[:, :, :3, :]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
out = {
|
| 190 |
+
"pred_class": output_class[-1],
|
| 191 |
+
"pred_bspline_params": pred_bspline_params,
|
| 192 |
+
"pred_line_params": pred_line_params,
|
| 193 |
+
"pred_line_length": output_line_length[-1],
|
| 194 |
+
"pred_circle_params": pred_circle_params,
|
| 195 |
+
"pred_circle_radius": output_circle_radius[-1],
|
| 196 |
+
"pred_arc_params": pred_arc_params,
|
| 197 |
+
"query_xyz": query_xyz,
|
| 198 |
+
}
|
| 199 |
+
if self.auxiliary_loss and self.training:
|
| 200 |
+
out["aux_outputs"] = self._set_aux_loss(
|
| 201 |
+
output_bspline_params,
|
| 202 |
+
output_line_params,
|
| 203 |
+
output_line_length,
|
| 204 |
+
output_circle_params,
|
| 205 |
+
output_circle_radius,
|
| 206 |
+
output_arc_params,
|
| 207 |
+
query_xyz,
|
| 208 |
+
output_class,
|
| 209 |
+
)
|
| 210 |
+
return out
|
| 211 |
+
|
| 212 |
+
@torch.jit.unused
|
| 213 |
+
def _set_aux_loss(
|
| 214 |
+
self,
|
| 215 |
+
output_bspline_params: torch.Tensor,
|
| 216 |
+
output_line_params: torch.Tensor,
|
| 217 |
+
output_line_length: torch.Tensor,
|
| 218 |
+
output_circle_params: torch.Tensor,
|
| 219 |
+
output_circle_radius: torch.Tensor,
|
| 220 |
+
output_arc_params: torch.Tensor,
|
| 221 |
+
query_xyz: torch.Tensor,
|
| 222 |
+
output_class: torch.Tensor,
|
| 223 |
+
) -> list[dict[str, torch.Tensor]]:
|
| 224 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 225 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 226 |
+
# as a dict having both a Tensor and a list.
|
| 227 |
+
out_aux = []
|
| 228 |
+
for b, l, ll, c, cr, a, cl in zip(
|
| 229 |
+
output_bspline_params[:-1],
|
| 230 |
+
output_line_params[:-1],
|
| 231 |
+
output_line_length[:-1],
|
| 232 |
+
output_circle_params[:-1],
|
| 233 |
+
output_circle_radius[:-1],
|
| 234 |
+
output_arc_params[:-1],
|
| 235 |
+
output_class[:-1],
|
| 236 |
+
):
|
| 237 |
+
pred_bspline_params = b.reshape(*b.shape[:2], 4, 3) + query_xyz
|
| 238 |
+
# second point is the direction vector
|
| 239 |
+
pred_line_params = l.reshape(*l.shape[:2], 2, 3)
|
| 240 |
+
pred_line_params_adjusted = pred_line_params.clone()
|
| 241 |
+
pred_line_params_adjusted[:, :, 0, :] = (
|
| 242 |
+
pred_line_params[:, :, 0, :] + query_xyz[:, :, 0, :]
|
| 243 |
+
)
|
| 244 |
+
pred_circle_params = c.reshape(*c.shape[:2], 2, 3)
|
| 245 |
+
pred_circle_params_adjusted = pred_circle_params.clone()
|
| 246 |
+
pred_circle_params_adjusted[:, :, 0, :] = (
|
| 247 |
+
pred_circle_params[:, :, 0, :] + query_xyz[:, :, 0, :]
|
| 248 |
+
)
|
| 249 |
+
pred_arc_params = a.reshape(*a.shape[:2], 3, 3) + query_xyz[:, :, :3, :]
|
| 250 |
+
|
| 251 |
+
layer_out = {
|
| 252 |
+
"pred_bspline_params": pred_bspline_params,
|
| 253 |
+
"pred_line_params": pred_line_params_adjusted,
|
| 254 |
+
"pred_line_length": ll,
|
| 255 |
+
"pred_circle_params": pred_circle_params_adjusted,
|
| 256 |
+
"pred_circle_radius": cr,
|
| 257 |
+
"pred_arc_params": pred_arc_params,
|
| 258 |
+
"pred_class": cl,
|
| 259 |
+
}
|
| 260 |
+
layer_out.update(
|
| 261 |
+
self._sample_curve_points(layer_out, self.num_curve_points)
|
| 262 |
+
)
|
| 263 |
+
out_aux.append(layer_out)
|
| 264 |
+
|
| 265 |
+
return out_aux
|
| 266 |
+
|
| 267 |
+
def _sample_curve_points(
|
| 268 |
+
self, out: dict[str, Tensor], num_points: int
|
| 269 |
+
) -> dict[str, Tensor]:
|
| 270 |
+
batch_size, num_preds = out["pred_bspline_params"].shape[:2]
|
| 271 |
+
pred_line_params = out["pred_line_params"]
|
| 272 |
+
pred_line_length = out["pred_line_length"]
|
| 273 |
+
pred_line_start = (
|
| 274 |
+
pred_line_params[:, :, 0, :]
|
| 275 |
+
- pred_line_params[:, :, 1, :] * pred_line_length / 2.0
|
| 276 |
+
)
|
| 277 |
+
pred_line_end = (
|
| 278 |
+
pred_line_params[:, :, 0, :]
|
| 279 |
+
+ pred_line_params[:, :, 1, :] * pred_line_length / 2.0
|
| 280 |
+
)
|
| 281 |
+
curves = {}
|
| 282 |
+
curves["pred_bspline_points"] = torch_bezier_curve(
|
| 283 |
+
out["pred_bspline_params"].reshape(-1, 4, 3), num_points
|
| 284 |
+
).reshape(batch_size, num_preds, -1, 3)
|
| 285 |
+
curves["pred_line_points"] = torch_line_points(
|
| 286 |
+
pred_line_start.reshape(-1, 3),
|
| 287 |
+
pred_line_end.reshape(-1, 3),
|
| 288 |
+
num_points,
|
| 289 |
+
).reshape(batch_size, num_preds, -1, 3)
|
| 290 |
+
curves["pred_circle_points"] = generate_points_on_circle_torch(
|
| 291 |
+
out["pred_circle_params"].reshape(-1, 2, 3)[:, 0],
|
| 292 |
+
out["pred_circle_params"].reshape(-1, 2, 3)[:, 1],
|
| 293 |
+
out["pred_circle_radius"].reshape(-1),
|
| 294 |
+
num_points,
|
| 295 |
+
).reshape(batch_size, num_preds, -1, 3)
|
| 296 |
+
curves["pred_arc_points"] = torch_arc_points(
|
| 297 |
+
out["pred_arc_params"][:, :, 1, :].reshape(-1, 3),
|
| 298 |
+
out["pred_arc_params"][:, :, 0, :].reshape(-1, 3),
|
| 299 |
+
out["pred_arc_params"][:, :, 2, :].reshape(-1, 3),
|
| 300 |
+
num_points,
|
| 301 |
+
).reshape(batch_size, num_preds, -1, 3)
|
| 302 |
+
return curves
|
| 303 |
+
|
| 304 |
+
def predict_step(
|
| 305 |
+
self,
|
| 306 |
+
batch: Data,
|
| 307 |
+
reverse_norm: bool = True,
|
| 308 |
+
thresholds: list[float] = None,
|
| 309 |
+
snap_and_fit: bool = True,
|
| 310 |
+
iou_filter: bool = False,
|
| 311 |
+
) -> list[Data]:
|
| 312 |
+
preds = self(batch)
|
| 313 |
+
preds.update(self._sample_curve_points(preds, self.num_curve_points_val))
|
| 314 |
+
|
| 315 |
+
outputs = self.decode_predictions(batch, preds, reverse_norm)
|
| 316 |
+
|
| 317 |
+
if thresholds:
|
| 318 |
+
outputs = [filter_predictions(data, thresholds) for data in outputs]
|
| 319 |
+
|
| 320 |
+
if snap_and_fit:
|
| 321 |
+
outputs = [snap_and_fit_curves(data.clone()) for data in outputs]
|
| 322 |
+
|
| 323 |
+
if iou_filter:
|
| 324 |
+
outputs = [iou_filter_predictions(data) for data in outputs]
|
| 325 |
+
|
| 326 |
+
return outputs
|
| 327 |
+
|
| 328 |
+
def training_step(self, batch: Data, batch_idx: int) -> Tensor:
|
| 329 |
+
outputs = self(batch)
|
| 330 |
+
outputs.update(self._sample_curve_points(outputs, self.num_curve_points))
|
| 331 |
+
loss_dict = self.loss(outputs, batch)
|
| 332 |
+
for k, v in loss_dict.items():
|
| 333 |
+
weight = self.weight_dict[k] if "loss" in k else 1
|
| 334 |
+
self._default_log(k, v * weight)
|
| 335 |
+
# weigh losses and sum them for backpropagation
|
| 336 |
+
weighted_loss_dict = {
|
| 337 |
+
k: loss_dict[k] * self.weight_dict[k] for k in self.weight_dict.keys()
|
| 338 |
+
}
|
| 339 |
+
total_loss = sum(weighted_loss_dict.values())
|
| 340 |
+
self._default_log("loss_train", total_loss)
|
| 341 |
+
return total_loss
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def validation_step(self, batch: Data, batch_idx: int) -> None:
|
| 345 |
+
outputs = self(batch)
|
| 346 |
+
# sample the training curve points for loss computation
|
| 347 |
+
outputs.update(self._sample_curve_points(outputs, self.num_curve_points))
|
| 348 |
+
loss_dict = self.loss(outputs, batch)
|
| 349 |
+
for k, v in loss_dict.items():
|
| 350 |
+
weight = self.weight_dict[k] if "loss" in k else 1
|
| 351 |
+
self._default_log(f"val_{k}", v * weight)
|
| 352 |
+
without_aux = {
|
| 353 |
+
k for k in self.weight_dict.keys() if k not in self.aux_weight_dict
|
| 354 |
+
}
|
| 355 |
+
self._default_log(
|
| 356 |
+
"loss_val", sum(loss_dict[k] * self.weight_dict[k] for k in without_aux)
|
| 357 |
+
)
|
| 358 |
+
# Sample curve points for validation
|
| 359 |
+
outputs.update(self._sample_curve_points(outputs, self.num_curve_points_val))
|
| 360 |
+
self._compute_metrics(batch, outputs)
|
| 361 |
+
|
| 362 |
+
@torch.no_grad()
|
| 363 |
+
def on_validation_epoch_end(self) -> None:
|
| 364 |
+
metrics = self.chamfer_map.compute()
|
| 365 |
+
for key, value in metrics.items():
|
| 366 |
+
self._default_log(f"val_{key}", value)
|
| 367 |
+
self.chamfer_map.reset()
|
| 368 |
+
|
| 369 |
+
# Log segmentation metrics at epoch end
|
| 370 |
+
self._default_log("val_seg_iou", self.seg_iou.compute())
|
| 371 |
+
self._default_log(
|
| 372 |
+
"val_seg_precision",
|
| 373 |
+
self.seg_precision.compute(),
|
| 374 |
+
)
|
| 375 |
+
self._default_log("val_seg_recall", self.seg_recall.compute())
|
| 376 |
+
self.seg_iou.reset()
|
| 377 |
+
self.seg_precision.reset()
|
| 378 |
+
self.seg_recall.reset()
|
| 379 |
+
|
| 380 |
+
def test_step(self, batch: Data, batch_idx: int) -> None:
|
| 381 |
+
outputs = self(batch)
|
| 382 |
+
outputs.update(self._sample_curve_points(outputs, self.num_curve_points_val))
|
| 383 |
+
self._compute_metrics(batch, outputs)
|
| 384 |
+
|
| 385 |
+
def on_test_epoch_end(self) -> None:
|
| 386 |
+
metrics = self.chamfer_map.compute()
|
| 387 |
+
self.chamfer_map.reset()
|
| 388 |
+
for key, value in metrics.items():
|
| 389 |
+
self.log(f"test_{key}", value, prog_bar=False)
|
| 390 |
+
|
| 391 |
+
# Log segmentation metrics at epoch end
|
| 392 |
+
self.log("test_seg_iou", self.seg_iou.compute(), prog_bar=False)
|
| 393 |
+
self.log("test_seg_precision", self.seg_precision.compute(), prog_bar=False)
|
| 394 |
+
self.log("test_seg_recall", self.seg_recall.compute(), prog_bar=False)
|
| 395 |
+
self.seg_iou.reset()
|
| 396 |
+
self.seg_precision.reset()
|
| 397 |
+
self.seg_recall.reset()
|
| 398 |
+
|
| 399 |
+
def _compute_metrics(self, batch: Data, preds: dict):
|
| 400 |
+
# segmentation metrics
|
| 401 |
+
outputs = self.decode_predictions(batch, preds, reverse_norm=True)
|
| 402 |
+
for i, output in enumerate(outputs):
|
| 403 |
+
self.seg_iou.update(output.segmentation, output.y_seg)
|
| 404 |
+
self.seg_precision.update(output.segmentation, output.y_seg)
|
| 405 |
+
self.seg_recall.update(output.segmentation, output.y_seg)
|
| 406 |
+
# chamfer metrics
|
| 407 |
+
self.chamfer_map.update(preds, batch)
|
| 408 |
+
|
| 409 |
+
def set_num_preds(self, num_preds: int) -> None:
|
| 410 |
+
if num_preds == self.num_preds:
|
| 411 |
+
return
|
| 412 |
+
self.num_preds = num_preds
|
| 413 |
+
old_state = (
|
| 414 |
+
self.query_engine.state_dict()
|
| 415 |
+
if isinstance(self.query_engine, nn.Module)
|
| 416 |
+
else None
|
| 417 |
+
)
|
| 418 |
+
new_engine = build_query_engine(
|
| 419 |
+
self.query_type,
|
| 420 |
+
self.positional_embedding,
|
| 421 |
+
self.dec_dim,
|
| 422 |
+
self.max_points_in_param,
|
| 423 |
+
self.num_preds,
|
| 424 |
+
)
|
| 425 |
+
if old_state is not None:
|
| 426 |
+
new_state = new_engine.state_dict()
|
| 427 |
+
for k, v in old_state.items():
|
| 428 |
+
assert k in new_state, f"Missing parameter in new query engine: {k}"
|
| 429 |
+
nv = new_state[k]
|
| 430 |
+
assert (
|
| 431 |
+
v.shape == nv.shape
|
| 432 |
+
), f"Shape mismatch for {k}: {v.shape} != {nv.shape}"
|
| 433 |
+
nv.copy_(v.to(nv.device))
|
| 434 |
+
new_engine.load_state_dict(new_state, strict=True)
|
| 435 |
+
self.query_engine = new_engine.to(self.device)
|
| 436 |
+
|
| 437 |
+
@torch.no_grad()
|
| 438 |
+
def decode_predictions(
|
| 439 |
+
self, batch: Data, preds: Data, reverse_norm: bool = True
|
| 440 |
+
) -> list[Data]:
|
| 441 |
+
outputs = []
|
| 442 |
+
|
| 443 |
+
# Vectorized class prediction and score
|
| 444 |
+
preds_class = preds["pred_class"].softmax(-1)
|
| 445 |
+
polyline_class = preds_class.argmax(-1) # (batch_size, num_preds)
|
| 446 |
+
polyline_score = preds_class.max(-1).values # (batch_size, num_preds)
|
| 447 |
+
|
| 448 |
+
# Prepare all possible polylines: (batch_size, num_preds, num_polypoints, 3)
|
| 449 |
+
bspline_points = preds["pred_bspline_points"]
|
| 450 |
+
line_points = preds["pred_line_points"]
|
| 451 |
+
circle_points = preds["pred_circle_points"]
|
| 452 |
+
arc_points = preds["pred_arc_points"]
|
| 453 |
+
zeros_points = torch.zeros_like(bspline_points) # EOS/empty
|
| 454 |
+
|
| 455 |
+
# Stack all types: (batch_size, num_preds, 4, num_polypoints, 3)
|
| 456 |
+
all_polylines = torch.stack(
|
| 457 |
+
[zeros_points, bspline_points, line_points, circle_points, arc_points],
|
| 458 |
+
dim=2,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Gather correct polyline for each prediction
|
| 462 |
+
# polyline_class: (batch_size, num_preds)
|
| 463 |
+
# Need to expand to match all_polylines shape for gather
|
| 464 |
+
idx = (
|
| 465 |
+
polyline_class.unsqueeze(-1)
|
| 466 |
+
.unsqueeze(-1)
|
| 467 |
+
.unsqueeze(2) # shape: (batch_size, num_preds, 1, num_polypoints, 3)
|
| 468 |
+
.expand(-1, -1, 1, self.num_curve_points_val, 3)
|
| 469 |
+
)
|
| 470 |
+
polylines = torch.gather(all_polylines, 2, idx).squeeze(
|
| 471 |
+
2
|
| 472 |
+
) # (batch_size, num_preds, num_polypoints, 3)
|
| 473 |
+
|
| 474 |
+
batch_size = batch.batch_size
|
| 475 |
+
device = batch.pos.device
|
| 476 |
+
segmentations = []
|
| 477 |
+
for i in range(batch_size):
|
| 478 |
+
# If all predicted classes are zero (EOS), segmentation should be all zeros
|
| 479 |
+
if torch.all(polyline_class[i] == 0):
|
| 480 |
+
pc_pts = batch.pos[batch.batch == i]
|
| 481 |
+
segmentation = torch.zeros(
|
| 482 |
+
pc_pts.shape[0], dtype=torch.long, device=device
|
| 483 |
+
)
|
| 484 |
+
segmentations.append(segmentation)
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
poly_pts = polylines[i, polyline_class[i] != 0].reshape(-1, 3)
|
| 488 |
+
pc_pts = batch.pos[batch.batch == i] # (num_points_in_cloud, 3)
|
| 489 |
+
dists = torch.cdist(poly_pts, pc_pts)
|
| 490 |
+
closest_idx = dists.argmin(dim=1)
|
| 491 |
+
segmentation = torch.zeros(pc_pts.shape[0], dtype=torch.long, device=device)
|
| 492 |
+
segmentation[closest_idx.unique()] = 1
|
| 493 |
+
segmentations.append(segmentation)
|
| 494 |
+
|
| 495 |
+
for i in range(batch.batch_size):
|
| 496 |
+
output = Data(
|
| 497 |
+
pos=batch.pos[batch.batch == i].clone(), # point cloud
|
| 498 |
+
bspline_points=bspline_points[i], # prediction of B-spline head
|
| 499 |
+
line_points=line_points[i], # prediction of line heads
|
| 500 |
+
circle_points=circle_points[i], # prediction of circle heads
|
| 501 |
+
arc_points=arc_points[i], # prediction of arc head
|
| 502 |
+
polyline_class=polyline_class[i], # class of each polyline
|
| 503 |
+
polyline_score=polyline_score[i], # score of polyline class
|
| 504 |
+
polylines=polylines[i], # polyline that matches polyline_class
|
| 505 |
+
segmentation=segmentations[
|
| 506 |
+
i
|
| 507 |
+
], # curve segmentation for whole point cloud
|
| 508 |
+
query_xyz=preds["query_xyz"][i], # query points for the transformer
|
| 509 |
+
)
|
| 510 |
+
if hasattr(batch, "y_seg"):
|
| 511 |
+
output.y_seg = batch.y_seg[batch.batch == i]
|
| 512 |
+
|
| 513 |
+
if reverse_norm:
|
| 514 |
+
output.center = batch.center[i]
|
| 515 |
+
output.scale = batch.scale[i]
|
| 516 |
+
|
| 517 |
+
output = reverse_normalize_and_scale(
|
| 518 |
+
output,
|
| 519 |
+
extra_fields=[
|
| 520 |
+
"polylines",
|
| 521 |
+
"bspline_points",
|
| 522 |
+
"line_points",
|
| 523 |
+
"circle_points",
|
| 524 |
+
"arc_points",
|
| 525 |
+
"query_xyz",
|
| 526 |
+
],
|
| 527 |
+
)
|
| 528 |
+
outputs.append(output)
|
| 529 |
+
|
| 530 |
+
return outputs
|
| 531 |
+
|
| 532 |
+
def _default_log(self, name: str, value: Tensor) -> None:
|
| 533 |
+
batch_size = (
|
| 534 |
+
self.config.batch_size if self.training else self.config.batch_size_val
|
| 535 |
+
)
|
| 536 |
+
self.log(
|
| 537 |
+
name,
|
| 538 |
+
value,
|
| 539 |
+
prog_bar=True,
|
| 540 |
+
on_epoch=True,
|
| 541 |
+
on_step=False,
|
| 542 |
+
sync_dist=True,
|
| 543 |
+
batch_size=batch_size,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
| 547 |
+
param_dict = None
|
| 548 |
+
config = self.config
|
| 549 |
+
if config.freeze_backbone:
|
| 550 |
+
for param in self.pc_preencoder.parameters():
|
| 551 |
+
param.requires_grad = False
|
| 552 |
+
param_dict = self.parameters()
|
| 553 |
+
elif config.lr != config.preencoder_lr:
|
| 554 |
+
param_dict = [
|
| 555 |
+
{
|
| 556 |
+
"params": [
|
| 557 |
+
p for n, p in self.named_parameters() if "pc_encoder" not in n
|
| 558 |
+
]
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"params": [
|
| 562 |
+
p for n, p in self.named_parameters() if "pc_encoder" in n
|
| 563 |
+
],
|
| 564 |
+
"lr": self.config.preencoder_lr,
|
| 565 |
+
},
|
| 566 |
+
]
|
| 567 |
+
else:
|
| 568 |
+
param_dict = self.parameters()
|
| 569 |
+
# ----- OPTIMIZER -----
|
| 570 |
+
optimizer = torch.optim.AdamW(param_dict, lr=self.config.lr)
|
| 571 |
+
# ----- WARMUP SCHEDULER -----
|
| 572 |
+
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 573 |
+
optimizer,
|
| 574 |
+
start_factor=config.lr_warmup_start_factor, # start near zero
|
| 575 |
+
end_factor=1.0, # ramp up to base LR
|
| 576 |
+
total_iters=config.lr_warmup_epochs,
|
| 577 |
+
)
|
| 578 |
+
# ----- STEP SCHEDULER -----
|
| 579 |
+
# Drop LR by factor after (step_epoch - warmup_epochs) epochs
|
| 580 |
+
step_scheduler = torch.optim.lr_scheduler.StepLR(
|
| 581 |
+
optimizer,
|
| 582 |
+
step_size=config.lr_step - config.lr_warmup_epochs,
|
| 583 |
+
gamma=0.1, # drop LR to 10%
|
| 584 |
+
last_epoch=config.epochs,
|
| 585 |
+
)
|
| 586 |
+
# ----- COMBINE -----
|
| 587 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 588 |
+
optimizer,
|
| 589 |
+
schedulers=[warmup_scheduler, step_scheduler],
|
| 590 |
+
milestones=[config.lr_warmup_epochs],
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
return [optimizer], {"scheduler": scheduler, "interval": "epoch"}
|
pi3detr/models/pointnetpp.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.spatial import cKDTree
|
| 6 |
+
from torch_geometric.nn import (
|
| 7 |
+
MLP,
|
| 8 |
+
PointNetConv,
|
| 9 |
+
fps,
|
| 10 |
+
knn_interpolate,
|
| 11 |
+
radius,
|
| 12 |
+
global_max_pool,
|
| 13 |
+
)
|
| 14 |
+
from torch_geometric.data.data import BaseData
|
| 15 |
+
|
| 16 |
+
TensorTriple = Tuple[Tensor, Tensor, Tensor]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def radius_cpu(
|
| 20 |
+
x: torch.Tensor,
|
| 21 |
+
y: torch.Tensor,
|
| 22 |
+
r: float,
|
| 23 |
+
batch_x: Optional[torch.Tensor] = None,
|
| 24 |
+
batch_y: Optional[torch.Tensor] = None,
|
| 25 |
+
max_num_neighbors: Optional[int] = None,
|
| 26 |
+
loop: bool = False,
|
| 27 |
+
sort_by_distance: bool = True,
|
| 28 |
+
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
| 29 |
+
"""
|
| 30 |
+
CPU replacement for torch_cluster.radius / torch_geometric.radius.
|
| 31 |
+
|
| 32 |
+
Semantics (matching torch_geometric.radius):
|
| 33 |
+
Returns (row, col) where
|
| 34 |
+
row: indices into `y` (centers) in the range [0, y.size(0))
|
| 35 |
+
col: indices into `x` (neighbors) in the range [0, x.size(0))
|
| 36 |
+
|
| 37 |
+
Thus, for y = x[idx]:
|
| 38 |
+
edge_index = torch.stack([col, row], dim=0)
|
| 39 |
+
edge_index[0] indexes the full set (source/neighbor),
|
| 40 |
+
edge_index[1] indexes the sampled centers.
|
| 41 |
+
"""
|
| 42 |
+
# Basic checks
|
| 43 |
+
if x.device.type != "cpu" or y.device.type != "cpu":
|
| 44 |
+
raise ValueError("radius_cpu expects x and y to be on CPU.")
|
| 45 |
+
if x.ndim != 2 or y.ndim != 2:
|
| 46 |
+
raise ValueError("x and y must be 2D (N, D).")
|
| 47 |
+
if x.shape[1] != y.shape[1]:
|
| 48 |
+
raise ValueError("x and y must have same dimensionality D.")
|
| 49 |
+
|
| 50 |
+
N_x = x.shape[0]
|
| 51 |
+
N_y = y.shape[0]
|
| 52 |
+
if N_x == 0 or N_y == 0:
|
| 53 |
+
return torch.empty((0,), dtype=torch.long), torch.empty((0,), dtype=torch.long)
|
| 54 |
+
|
| 55 |
+
x_np = np.asarray(x)
|
| 56 |
+
y_np = np.asarray(y)
|
| 57 |
+
|
| 58 |
+
if batch_x is None:
|
| 59 |
+
batch_x = torch.zeros(N_x, dtype=torch.long)
|
| 60 |
+
else:
|
| 61 |
+
if batch_x.device.type != "cpu":
|
| 62 |
+
batch_x = batch_x.cpu()
|
| 63 |
+
batch_x = batch_x.long()
|
| 64 |
+
|
| 65 |
+
if batch_y is None:
|
| 66 |
+
batch_y = torch.zeros(N_y, dtype=torch.long)
|
| 67 |
+
else:
|
| 68 |
+
if batch_y.device.type != "cpu":
|
| 69 |
+
batch_y = batch_y.cpu()
|
| 70 |
+
batch_y = batch_y.long()
|
| 71 |
+
|
| 72 |
+
rows = []
|
| 73 |
+
cols = []
|
| 74 |
+
|
| 75 |
+
unique_batches = torch.unique(torch.cat([batch_x, batch_y])).tolist()
|
| 76 |
+
# iterate only over batches actually present in y to avoid unnecessary work
|
| 77 |
+
unique_batches = sorted(set(batch_y.tolist()))
|
| 78 |
+
|
| 79 |
+
for b in unique_batches:
|
| 80 |
+
# mask and maps from local->global indices
|
| 81 |
+
mask_x = (batch_x == b).numpy()
|
| 82 |
+
mask_y = (batch_y == b).numpy()
|
| 83 |
+
idxs_x = np.nonzero(mask_x)[0] # global indices in x
|
| 84 |
+
idxs_y = np.nonzero(mask_y)[0] # global indices in y
|
| 85 |
+
|
| 86 |
+
if idxs_y.size == 0 or idxs_x.size == 0:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
pts_x = x_np[mask_x]
|
| 90 |
+
pts_y = y_np[mask_y]
|
| 91 |
+
|
| 92 |
+
# build tree on source points (x) and query for each center in y
|
| 93 |
+
tree = cKDTree(pts_x)
|
| 94 |
+
# neighbors_list: for each center (local), a list of local indices into pts_x
|
| 95 |
+
neighbors_list = tree.query_ball_point(pts_y, r)
|
| 96 |
+
|
| 97 |
+
for local_center, neigh_locals in enumerate(neighbors_list):
|
| 98 |
+
if len(neigh_locals) == 0:
|
| 99 |
+
continue
|
| 100 |
+
neigh_locals = np.array(neigh_locals, dtype=int)
|
| 101 |
+
|
| 102 |
+
# remove self if requested AND x and y are the same set at same global indices
|
| 103 |
+
if not loop:
|
| 104 |
+
# If x and y refer to the same global indices and same coords, remove self-match
|
| 105 |
+
# we detect self by checking whether global index equals center global index
|
| 106 |
+
center_global = idxs_y[local_center]
|
| 107 |
+
# compute global neighbor indices
|
| 108 |
+
neigh_globals = idxs_x[neigh_locals]
|
| 109 |
+
# boolean mask for neighbors that are not self
|
| 110 |
+
not_self_mask = neigh_globals != center_global
|
| 111 |
+
neigh_locals = neigh_locals[not_self_mask]
|
| 112 |
+
|
| 113 |
+
if neigh_locals.size == 0:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
# apply max_num_neighbors: keep closest ones by distance if requested
|
| 117 |
+
if max_num_neighbors is not None and neigh_locals.size > max_num_neighbors:
|
| 118 |
+
if sort_by_distance:
|
| 119 |
+
dists = np.linalg.norm(
|
| 120 |
+
pts_x[neigh_locals] - pts_y[local_center], axis=1
|
| 121 |
+
)
|
| 122 |
+
order = np.argsort(dists)[:max_num_neighbors]
|
| 123 |
+
neigh_locals = neigh_locals[order]
|
| 124 |
+
else:
|
| 125 |
+
neigh_locals = np.sort(neigh_locals)[:max_num_neighbors]
|
| 126 |
+
|
| 127 |
+
# optionally sort by distance
|
| 128 |
+
if sort_by_distance and neigh_locals.size > 0:
|
| 129 |
+
dists = np.linalg.norm(
|
| 130 |
+
pts_x[neigh_locals] - pts_y[local_center], axis=1
|
| 131 |
+
)
|
| 132 |
+
order = np.argsort(dists)
|
| 133 |
+
neigh_locals = neigh_locals[order]
|
| 134 |
+
|
| 135 |
+
# convert to global indices and append
|
| 136 |
+
neigh_globals = idxs_x[neigh_locals].tolist()
|
| 137 |
+
center_global = int(idxs_y[local_center])
|
| 138 |
+
rows.extend(neigh_globals) # neighbor indices into x (row)
|
| 139 |
+
cols.extend(
|
| 140 |
+
[center_global] * len(neigh_globals)
|
| 141 |
+
) # center indices into y (col)
|
| 142 |
+
|
| 143 |
+
if len(rows) == 0:
|
| 144 |
+
return torch.empty((0,), dtype=torch.long), torch.empty((0,), dtype=torch.long)
|
| 145 |
+
|
| 146 |
+
row_t = torch.tensor(rows, dtype=torch.long) # currently neighbors (x)
|
| 147 |
+
col_t = torch.tensor(cols, dtype=torch.long) # currently centers (y)
|
| 148 |
+
|
| 149 |
+
# Swap to enforce (row=center_indices_in_y, col=neighbor_indices_in_x)
|
| 150 |
+
return col_t, row_t
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class SAModuleRatio(torch.nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self, ratio: float, r: float, nn: nn.Module, max_num_neighbors: int = 64
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.ratio = ratio
|
| 159 |
+
self.r = r
|
| 160 |
+
self.conv = PointNetConv(nn, add_self_loops=False)
|
| 161 |
+
self.max_num_neighbors = max_num_neighbors
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor):
|
| 164 |
+
idx = fps(pos, batch, ratio=self.ratio)
|
| 165 |
+
row, col = radius(
|
| 166 |
+
pos,
|
| 167 |
+
pos[idx],
|
| 168 |
+
self.r,
|
| 169 |
+
batch,
|
| 170 |
+
batch[idx],
|
| 171 |
+
max_num_neighbors=self.max_num_neighbors,
|
| 172 |
+
)
|
| 173 |
+
edge_index = torch.stack([col, row], dim=0)
|
| 174 |
+
x_dst = None if x is None else x[idx]
|
| 175 |
+
x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
|
| 176 |
+
pos, batch = pos[idx], batch[idx]
|
| 177 |
+
return x, pos, batch
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class SAModule(torch.nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
nn: nn.Module,
|
| 184 |
+
num_out_points: float = 2048,
|
| 185 |
+
r: float = 0.2,
|
| 186 |
+
max_num_neighbors: int = 64,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.num_out_points = num_out_points
|
| 190 |
+
self.r = r
|
| 191 |
+
self.conv = PointNetConv(nn, add_self_loops=False)
|
| 192 |
+
self.max_num_neighbors = max_num_neighbors
|
| 193 |
+
|
| 194 |
+
def forward(self, data: BaseData) -> list[tuple[TensorTriple]]:
|
| 195 |
+
x, pos, batch = data.x, data.pos, data.batch
|
| 196 |
+
num_points_per_batch = torch.bincount(batch)
|
| 197 |
+
max_ratio = self.num_out_points / num_points_per_batch.min().item()
|
| 198 |
+
fps_idx = fps(pos, batch, ratio=max_ratio)
|
| 199 |
+
fps_batch = batch[fps_idx]
|
| 200 |
+
idx = torch.cat(
|
| 201 |
+
[
|
| 202 |
+
fps_idx[fps_batch == i][: self.num_out_points]
|
| 203 |
+
for i in range(batch.max().item() + 1)
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
if pos.device == torch.device("cpu"):
|
| 207 |
+
row, col = radius_cpu(
|
| 208 |
+
pos,
|
| 209 |
+
pos[idx],
|
| 210 |
+
self.r,
|
| 211 |
+
batch,
|
| 212 |
+
batch[idx],
|
| 213 |
+
max_num_neighbors=self.max_num_neighbors,
|
| 214 |
+
sort_by_distance=False,
|
| 215 |
+
)
|
| 216 |
+
else: # GPU
|
| 217 |
+
row, col = radius(
|
| 218 |
+
pos,
|
| 219 |
+
pos[idx],
|
| 220 |
+
self.r,
|
| 221 |
+
batch,
|
| 222 |
+
batch[idx],
|
| 223 |
+
max_num_neighbors=self.max_num_neighbors,
|
| 224 |
+
)
|
| 225 |
+
edge_index = torch.stack([col, row], dim=0)
|
| 226 |
+
x_dst = None if x is None else x[idx]
|
| 227 |
+
x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
|
| 228 |
+
pos, batch = pos[idx], batch[idx]
|
| 229 |
+
return [(x, pos, batch)]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class GlobalSAModule(torch.nn.Module):
|
| 233 |
+
def __init__(self, nn):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.nn = nn
|
| 236 |
+
|
| 237 |
+
def forward(self, x: torch.Tensor, pos: torch.Tensor, batch: torch.Tensor):
|
| 238 |
+
x = self.nn(torch.cat([x, pos], dim=1))
|
| 239 |
+
x = global_max_pool(x, batch)
|
| 240 |
+
pos = pos.new_zeros((x.size(0), 3))
|
| 241 |
+
batch = torch.arange(x.size(0), device=batch.device)
|
| 242 |
+
return x, pos, batch
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class FPModule(nn.Module):
|
| 246 |
+
def __init__(self, k, nn):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.k = k
|
| 249 |
+
self.nn = nn
|
| 250 |
+
|
| 251 |
+
def forward(
|
| 252 |
+
self,
|
| 253 |
+
x: torch.Tensor,
|
| 254 |
+
pos: torch.Tensor,
|
| 255 |
+
batch: torch.Tensor,
|
| 256 |
+
x_skip: torch.Tensor,
|
| 257 |
+
pos_skip: torch.Tensor,
|
| 258 |
+
batch_skip: torch.Tensor,
|
| 259 |
+
):
|
| 260 |
+
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
| 261 |
+
if x_skip is not None:
|
| 262 |
+
x = torch.cat([x, x_skip], dim=1)
|
| 263 |
+
x = self.nn(x)
|
| 264 |
+
return x, pos_skip, batch_skip
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class PointNetPPEncoder(nn.Module):
|
| 268 |
+
def __init__(self, num_features: int = 3, out_channels: int = 512):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.out_channels = out_channels
|
| 271 |
+
# Input channels account for both `pos` and node features.
|
| 272 |
+
self.sa1_module = SAModuleRatio(
|
| 273 |
+
0.5, 0.05, MLP([num_features + 3, 32, 32, 64]), 32
|
| 274 |
+
)
|
| 275 |
+
self.sa2_module = SAModuleRatio(0.5, 0.1, MLP([64 + 3, 64, 64, 128]), 32)
|
| 276 |
+
self.sa3_module = SAModuleRatio(0.5, 0.2, MLP([128 + 3, 128, 128, 256]), 32)
|
| 277 |
+
self.sa4_module = SAModuleRatio(
|
| 278 |
+
0.5, 0.4, MLP([256 + 3, 256, 256, self.out_channels]), 32
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def forward(self, data: BaseData) -> list[Tensor]:
|
| 282 |
+
sa0_out = (data.x, data.pos, data.batch)
|
| 283 |
+
sa1_out = self.sa1_module(*sa0_out)
|
| 284 |
+
sa2_out = self.sa2_module(*sa1_out)
|
| 285 |
+
sa3_out = self.sa3_module(*sa2_out)
|
| 286 |
+
sa4_out = self.sa4_module(*sa3_out)
|
| 287 |
+
return [sa0_out, sa1_out, sa2_out, sa3_out, sa4_out]
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class PointNetPPDecoder(nn.Module):
|
| 291 |
+
def __init__(self, num_features: int = 3, out_channels: int = 256):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.out_channels = out_channels
|
| 294 |
+
self.fp4_module = FPModule(1, MLP([512 + 256, 256, 256]))
|
| 295 |
+
self.fp3_module = FPModule(3, MLP([256 + 128, 256, 256]))
|
| 296 |
+
self.fp2_module = FPModule(3, MLP([256 + 64, 256, 128]))
|
| 297 |
+
self.fp1_module = FPModule(3, MLP([128 + num_features, 128, self.out_channels]))
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
sa0_out: TensorTriple,
|
| 302 |
+
sa1_out: TensorTriple,
|
| 303 |
+
sa2_out: TensorTriple,
|
| 304 |
+
sa3_out: TensorTriple,
|
| 305 |
+
sa4_out: TensorTriple,
|
| 306 |
+
) -> TensorTriple:
|
| 307 |
+
fp4_out = self.fp4_module(*sa4_out, *sa3_out)
|
| 308 |
+
fp3_out = self.fp3_module(*fp4_out, *sa2_out)
|
| 309 |
+
fp2_out = self.fp2_module(*fp3_out, *sa1_out)
|
| 310 |
+
x, pos, batch = self.fp1_module(*fp2_out, *sa0_out)
|
| 311 |
+
return [x, pos, batch]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class PointNetPP(nn.Module):
|
| 315 |
+
def __init__(self, num_features: int, dec_out_channels: int = 256):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.encoder = PointNetPPEncoder(num_features)
|
| 318 |
+
self.decoder = PointNetPPDecoder(num_features, dec_out_channels)
|
| 319 |
+
self.out_channels = self.decoder.out_channels
|
| 320 |
+
|
| 321 |
+
def forward(self, data: BaseData) -> TensorTriple:
|
| 322 |
+
x = self.encoder(data)
|
| 323 |
+
x, pos, batch = self.decoder(*x)
|
| 324 |
+
return [(x, pos, batch)]
|
pi3detr/models/positional_embedding.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Various positional encodings for the transformer.
|
| 8 |
+
"""
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class PositionEmbeddingCoordsSine(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
temperature=10000,
|
| 18 |
+
normalize=False,
|
| 19 |
+
scale=None,
|
| 20 |
+
pos_type="fourier",
|
| 21 |
+
d_pos=None,
|
| 22 |
+
d_in=3,
|
| 23 |
+
gauss_scale=1.0,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.temperature = temperature
|
| 27 |
+
self.normalize = normalize
|
| 28 |
+
if scale is not None and normalize is False:
|
| 29 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 30 |
+
if scale is None:
|
| 31 |
+
scale = 2 * math.pi
|
| 32 |
+
assert pos_type in ["sine", "fourier"]
|
| 33 |
+
self.pos_type = pos_type
|
| 34 |
+
self.scale = scale
|
| 35 |
+
if pos_type == "fourier":
|
| 36 |
+
assert d_pos is not None
|
| 37 |
+
assert d_pos % 2 == 0
|
| 38 |
+
# define a gaussian matrix input_ch -> output_ch
|
| 39 |
+
B = torch.empty((d_in, d_pos // 2)).normal_()
|
| 40 |
+
B *= gauss_scale
|
| 41 |
+
self.register_buffer("gauss_B", B)
|
| 42 |
+
self.d_pos = d_pos
|
| 43 |
+
|
| 44 |
+
def get_sine_embeddings(self, xyz, num_channels):
|
| 45 |
+
# clone coords so that shift/scale operations do not affect original tensor
|
| 46 |
+
orig_xyz = xyz
|
| 47 |
+
xyz = orig_xyz.clone()
|
| 48 |
+
|
| 49 |
+
ndim = num_channels // xyz.shape[2]
|
| 50 |
+
if ndim % 2 != 0:
|
| 51 |
+
ndim -= 1
|
| 52 |
+
# automatically handle remainder by assiging it to the first dim
|
| 53 |
+
rems = num_channels - (ndim * xyz.shape[2])
|
| 54 |
+
|
| 55 |
+
assert (
|
| 56 |
+
ndim % 2 == 0
|
| 57 |
+
), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}"
|
| 58 |
+
|
| 59 |
+
final_embeds = []
|
| 60 |
+
prev_dim = 0
|
| 61 |
+
|
| 62 |
+
for d in range(xyz.shape[2]):
|
| 63 |
+
cdim = ndim
|
| 64 |
+
if rems > 0:
|
| 65 |
+
# add remainder in increments of two to maintain even size
|
| 66 |
+
cdim += 2
|
| 67 |
+
rems -= 2
|
| 68 |
+
|
| 69 |
+
if cdim != prev_dim:
|
| 70 |
+
dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device)
|
| 71 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / cdim)
|
| 72 |
+
|
| 73 |
+
# create batch x cdim x nccords embedding
|
| 74 |
+
raw_pos = xyz[:, :, d]
|
| 75 |
+
if self.scale:
|
| 76 |
+
raw_pos *= self.scale
|
| 77 |
+
pos = raw_pos[:, :, None] / dim_t
|
| 78 |
+
pos = torch.stack(
|
| 79 |
+
(pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3
|
| 80 |
+
).flatten(2)
|
| 81 |
+
final_embeds.append(pos)
|
| 82 |
+
prev_dim = cdim
|
| 83 |
+
|
| 84 |
+
final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
|
| 85 |
+
return final_embeds
|
| 86 |
+
|
| 87 |
+
def get_fourier_embeddings(self, xyz, num_channels=None):
|
| 88 |
+
# Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
|
| 89 |
+
|
| 90 |
+
if num_channels is None:
|
| 91 |
+
num_channels = self.gauss_B.shape[1] * 2
|
| 92 |
+
|
| 93 |
+
bsize, npoints = xyz.shape[0], xyz.shape[1]
|
| 94 |
+
assert num_channels > 0 and num_channels % 2 == 0
|
| 95 |
+
d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1]
|
| 96 |
+
d_out = num_channels // 2
|
| 97 |
+
assert d_out <= max_d_out
|
| 98 |
+
assert d_in == xyz.shape[-1]
|
| 99 |
+
|
| 100 |
+
# clone coords so that shift/scale operations do not affect original tensor
|
| 101 |
+
orig_xyz = xyz
|
| 102 |
+
xyz = orig_xyz.clone()
|
| 103 |
+
|
| 104 |
+
xyz *= 2 * np.pi
|
| 105 |
+
xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view(
|
| 106 |
+
bsize, npoints, d_out
|
| 107 |
+
)
|
| 108 |
+
final_embeds = [xyz_proj.sin(), xyz_proj.cos()]
|
| 109 |
+
|
| 110 |
+
# return batch x d_pos x npoints embedding
|
| 111 |
+
final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1)
|
| 112 |
+
return final_embeds
|
| 113 |
+
|
| 114 |
+
def forward(self, xyz, num_channels=None):
|
| 115 |
+
assert isinstance(xyz, torch.Tensor)
|
| 116 |
+
assert xyz.ndim == 3
|
| 117 |
+
# xyz is batch x npoints x 3
|
| 118 |
+
if self.pos_type == "sine":
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
return self.get_sine_embeddings(xyz, num_channels)
|
| 121 |
+
elif self.pos_type == "fourier":
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
return self.get_fourier_embeddings(xyz, num_channels)
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unknown {self.pos_type}")
|
| 126 |
+
|
| 127 |
+
def extra_repr(self):
|
| 128 |
+
st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}"
|
| 129 |
+
if hasattr(self, "gauss_B"):
|
| 130 |
+
st += (
|
| 131 |
+
f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}"
|
| 132 |
+
)
|
| 133 |
+
return st
|
pi3detr/models/query_engine.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, Tensor
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from torch_geometric.nn import MLP, fps, knn
|
| 5 |
+
from torch_geometric.data.data import Data
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class QueryEngine(nn.Module, ABC):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
pos_embedder: Optional[nn.Module],
|
| 13 |
+
feat_dim: int,
|
| 14 |
+
max_points_in_param: int,
|
| 15 |
+
num_queries: int,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.pos_embedder = pos_embedder
|
| 19 |
+
self.feat_dim = feat_dim
|
| 20 |
+
self.max_points_in_param = max_points_in_param
|
| 21 |
+
self.num_queries = num_queries
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def forward(self, data: Data) -> tuple[Tensor]:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PointFPSQueryEngine(QueryEngine):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
pos_embedder: nn.Module,
|
| 32 |
+
feat_dim: int,
|
| 33 |
+
max_points_in_param: int,
|
| 34 |
+
num_queries: int,
|
| 35 |
+
):
|
| 36 |
+
super().__init__(pos_embedder, feat_dim, max_points_in_param, num_queries)
|
| 37 |
+
self.num_queries = num_queries
|
| 38 |
+
self.query_proj = MLP(
|
| 39 |
+
[self.feat_dim, self.feat_dim, self.feat_dim],
|
| 40 |
+
bias=False,
|
| 41 |
+
act="relu",
|
| 42 |
+
norm="layer_norm",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, data: Data) -> tuple[Tensor]:
|
| 46 |
+
num_points_per_batch = torch.bincount(data.batch)
|
| 47 |
+
max_ratio = self.num_queries / num_points_per_batch.min().item()
|
| 48 |
+
fps_idx = fps(data.pos, data.batch, ratio=max_ratio)
|
| 49 |
+
fps_batch = data.batch[fps_idx]
|
| 50 |
+
query_xyz = torch.stack(
|
| 51 |
+
[
|
| 52 |
+
data.pos[fps_idx[fps_batch == i][: self.num_queries]]
|
| 53 |
+
for i in range(data.batch.max().item() + 1)
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
query_pos = self.pos_embedder(query_xyz, num_channels=self.feat_dim)
|
| 57 |
+
query_embed = self.query_proj(query_pos.permute(0, 2, 1))[
|
| 58 |
+
:, : self.num_queries, :
|
| 59 |
+
].permute(0, 2, 1)
|
| 60 |
+
return (
|
| 61 |
+
query_xyz.unsqueeze(2).expand(-1, -1, self.max_points_in_param, -1),
|
| 62 |
+
query_embed,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LearnedQueryEngine(QueryEngine):
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
pos_embedder: Optional[nn.Module],
|
| 71 |
+
feat_dim: int,
|
| 72 |
+
max_points_in_param: int,
|
| 73 |
+
num_queries: int,
|
| 74 |
+
):
|
| 75 |
+
super().__init__(None, feat_dim, max_points_in_param, num_queries)
|
| 76 |
+
self.query_embed = nn.Embedding(self.num_queries, feat_dim)
|
| 77 |
+
|
| 78 |
+
def forward(self, data: Data) -> tuple[Tensor]:
|
| 79 |
+
return (
|
| 80 |
+
torch.zeros(
|
| 81 |
+
data.batch_size,
|
| 82 |
+
self.num_queries,
|
| 83 |
+
self.max_points_in_param,
|
| 84 |
+
3,
|
| 85 |
+
device=data.pos.device,
|
| 86 |
+
requires_grad=False,
|
| 87 |
+
),
|
| 88 |
+
self.query_embed.weight.unsqueeze(0)
|
| 89 |
+
.expand(data.batch_size, -1, -1)
|
| 90 |
+
.permute(0, 2, 1),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def build_query_engine(
|
| 95 |
+
query_type: str,
|
| 96 |
+
pos_embedder: Optional[nn.Module],
|
| 97 |
+
feat_dim: int,
|
| 98 |
+
max_points_in_param: int,
|
| 99 |
+
num_queries: int,
|
| 100 |
+
) -> QueryEngine:
|
| 101 |
+
if query_type == "point_fps":
|
| 102 |
+
return PointFPSQueryEngine(
|
| 103 |
+
pos_embedder, feat_dim, max_points_in_param, num_queries
|
| 104 |
+
)
|
| 105 |
+
elif query_type == "learned":
|
| 106 |
+
return LearnedQueryEngine(None, feat_dim, max_points_in_param, num_queries)
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"Unknown query type {query_type}")
|
pi3detr/models/transformer.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted code from Meta's DETR.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Transformer(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
enc_dim: int = 256,
|
| 17 |
+
dec_dim: int = 256,
|
| 18 |
+
nhead: int = 8,
|
| 19 |
+
num_encoder_layers: int = 6,
|
| 20 |
+
num_decoder_layers: int = 6,
|
| 21 |
+
enc_dim_feedforward: int = 2048,
|
| 22 |
+
dec_dim_feedforward: int = 2048,
|
| 23 |
+
enc_dropout: float = 0.1,
|
| 24 |
+
dec_dropout: float = 0.1,
|
| 25 |
+
activation: str = "relu",
|
| 26 |
+
normalize_before: bool = False,
|
| 27 |
+
return_intermediate_dec: bool = False,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
encoder_layer = TransformerEncoderLayer(
|
| 31 |
+
enc_dim,
|
| 32 |
+
nhead,
|
| 33 |
+
enc_dim_feedforward,
|
| 34 |
+
enc_dropout,
|
| 35 |
+
activation,
|
| 36 |
+
normalize_before,
|
| 37 |
+
)
|
| 38 |
+
encoder_norm = nn.LayerNorm(enc_dim) if normalize_before else None
|
| 39 |
+
self.encoder = TransformerEncoder(
|
| 40 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
| 41 |
+
)
|
| 42 |
+
if enc_dim != dec_dim:
|
| 43 |
+
self.enc_to_dec_proj = nn.Linear(enc_dim, dec_dim)
|
| 44 |
+
else:
|
| 45 |
+
self.enc_to_dec_proj = nn.Identity()
|
| 46 |
+
decoder_layer = TransformerDecoderLayer(
|
| 47 |
+
dec_dim,
|
| 48 |
+
nhead,
|
| 49 |
+
dec_dim_feedforward,
|
| 50 |
+
dec_dropout,
|
| 51 |
+
activation,
|
| 52 |
+
normalize_before,
|
| 53 |
+
)
|
| 54 |
+
decoder_norm = nn.LayerNorm(dec_dim)
|
| 55 |
+
self.decoder = TransformerDecoder(
|
| 56 |
+
decoder_layer,
|
| 57 |
+
num_decoder_layers,
|
| 58 |
+
decoder_norm,
|
| 59 |
+
return_intermediate=return_intermediate_dec,
|
| 60 |
+
)
|
| 61 |
+
self._reset_parameters()
|
| 62 |
+
self.d_model = dec_dim
|
| 63 |
+
self.nhead = nhead
|
| 64 |
+
|
| 65 |
+
def _reset_parameters(self):
|
| 66 |
+
for p in self.parameters():
|
| 67 |
+
if p.dim() > 1:
|
| 68 |
+
nn.init.xavier_uniform_(p)
|
| 69 |
+
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
src: Tensor,
|
| 73 |
+
mask: Optional[Tensor],
|
| 74 |
+
query_embed: Tensor,
|
| 75 |
+
pos_embed: Tensor = None,
|
| 76 |
+
) -> Tensor:
|
| 77 |
+
bs, _, _ = src.shape
|
| 78 |
+
|
| 79 |
+
src = src.permute(1, 0, 2) # (bs, seq, feat) -> (seq, bs, feat)
|
| 80 |
+
if pos_embed is not None:
|
| 81 |
+
pos_embed = pos_embed.permute(2, 0, 1)
|
| 82 |
+
|
| 83 |
+
memory = self.encoder(src, mask=None, src_key_padding_mask=mask, pos=None)
|
| 84 |
+
memory = self.enc_to_dec_proj(memory)
|
| 85 |
+
|
| 86 |
+
query_embed = query_embed.permute(2, 0, 1)
|
| 87 |
+
tgt = torch.zeros_like(query_embed)
|
| 88 |
+
hs = self.decoder(
|
| 89 |
+
tgt,
|
| 90 |
+
memory,
|
| 91 |
+
tgt_mask=None,
|
| 92 |
+
memory_mask=None,
|
| 93 |
+
tgt_key_padding_mask=None,
|
| 94 |
+
memory_key_padding_mask=mask,
|
| 95 |
+
pos=pos_embed,
|
| 96 |
+
query_pos=query_embed,
|
| 97 |
+
)
|
| 98 |
+
return hs.transpose(1, 2)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TransformerEncoder(nn.Module):
|
| 102 |
+
|
| 103 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 106 |
+
self.num_layers = num_layers
|
| 107 |
+
self.norm = norm
|
| 108 |
+
|
| 109 |
+
def forward(
|
| 110 |
+
self,
|
| 111 |
+
src,
|
| 112 |
+
mask: Optional[Tensor] = None,
|
| 113 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 114 |
+
pos: Optional[Tensor] = None,
|
| 115 |
+
):
|
| 116 |
+
output = src
|
| 117 |
+
|
| 118 |
+
for layer in self.layers:
|
| 119 |
+
output = layer(
|
| 120 |
+
output,
|
| 121 |
+
src_mask=mask,
|
| 122 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 123 |
+
pos=pos,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if self.norm is not None:
|
| 127 |
+
output = self.norm(output)
|
| 128 |
+
|
| 129 |
+
return output
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class TransformerEncoderLayer(nn.Module):
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
d_model,
|
| 137 |
+
nhead,
|
| 138 |
+
dim_feedforward=2048,
|
| 139 |
+
dropout=0.1,
|
| 140 |
+
activation="relu",
|
| 141 |
+
normalize_before=False,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 145 |
+
# Implementation of Feedforward model
|
| 146 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 147 |
+
self.dropout = nn.Dropout(dropout)
|
| 148 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 149 |
+
|
| 150 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 151 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 152 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 153 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 154 |
+
|
| 155 |
+
self.activation = _get_activation_fn(activation)
|
| 156 |
+
self.normalize_before = normalize_before
|
| 157 |
+
|
| 158 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 159 |
+
return tensor if pos is None else tensor + pos
|
| 160 |
+
|
| 161 |
+
def forward_post(
|
| 162 |
+
self,
|
| 163 |
+
src,
|
| 164 |
+
src_mask: Optional[Tensor] = None,
|
| 165 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 166 |
+
pos: Optional[Tensor] = None,
|
| 167 |
+
):
|
| 168 |
+
q = k = self.with_pos_embed(src, pos)
|
| 169 |
+
src2 = self.self_attn(
|
| 170 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
| 171 |
+
)[0]
|
| 172 |
+
src = src + self.dropout1(src2)
|
| 173 |
+
src = self.norm1(src)
|
| 174 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 175 |
+
src = src + self.dropout2(src2)
|
| 176 |
+
src = self.norm2(src)
|
| 177 |
+
return src
|
| 178 |
+
|
| 179 |
+
def forward_pre(
|
| 180 |
+
self,
|
| 181 |
+
src,
|
| 182 |
+
src_mask: Optional[Tensor] = None,
|
| 183 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 184 |
+
pos: Optional[Tensor] = None,
|
| 185 |
+
):
|
| 186 |
+
src2 = self.norm1(src)
|
| 187 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 188 |
+
src2 = self.self_attn(
|
| 189 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
| 190 |
+
)[0]
|
| 191 |
+
src = src + self.dropout1(src2)
|
| 192 |
+
src2 = self.norm2(src)
|
| 193 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 194 |
+
src = src + self.dropout2(src2)
|
| 195 |
+
return src
|
| 196 |
+
|
| 197 |
+
def forward(
|
| 198 |
+
self,
|
| 199 |
+
src,
|
| 200 |
+
src_mask: Optional[Tensor] = None,
|
| 201 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 202 |
+
pos: Optional[Tensor] = None,
|
| 203 |
+
):
|
| 204 |
+
if self.normalize_before:
|
| 205 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 206 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TransformerDecoder(nn.Module):
|
| 210 |
+
|
| 211 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 214 |
+
self.num_layers = num_layers
|
| 215 |
+
self.norm = norm
|
| 216 |
+
self.return_intermediate = return_intermediate
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
tgt,
|
| 221 |
+
memory,
|
| 222 |
+
tgt_mask: Optional[Tensor] = None,
|
| 223 |
+
memory_mask: Optional[Tensor] = None,
|
| 224 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 225 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 226 |
+
pos: Optional[Tensor] = None,
|
| 227 |
+
query_pos: Optional[Tensor] = None,
|
| 228 |
+
):
|
| 229 |
+
output = tgt
|
| 230 |
+
|
| 231 |
+
intermediate = []
|
| 232 |
+
|
| 233 |
+
for layer in self.layers:
|
| 234 |
+
output = layer(
|
| 235 |
+
output,
|
| 236 |
+
memory,
|
| 237 |
+
tgt_mask=tgt_mask,
|
| 238 |
+
memory_mask=memory_mask,
|
| 239 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 240 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 241 |
+
pos=pos,
|
| 242 |
+
query_pos=query_pos,
|
| 243 |
+
)
|
| 244 |
+
if self.return_intermediate:
|
| 245 |
+
intermediate.append(self.norm(output))
|
| 246 |
+
|
| 247 |
+
if self.norm is not None:
|
| 248 |
+
output = self.norm(output)
|
| 249 |
+
if self.return_intermediate:
|
| 250 |
+
intermediate.pop()
|
| 251 |
+
intermediate.append(output)
|
| 252 |
+
|
| 253 |
+
if self.return_intermediate:
|
| 254 |
+
return torch.stack(intermediate)
|
| 255 |
+
|
| 256 |
+
return output.unsqueeze(0)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TransformerDecoderLayer(nn.Module):
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
d_model,
|
| 264 |
+
nhead,
|
| 265 |
+
dim_feedforward=2048,
|
| 266 |
+
dropout=0.1,
|
| 267 |
+
activation="relu",
|
| 268 |
+
normalize_before=False,
|
| 269 |
+
):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 272 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 273 |
+
# Implementation of Feedforward model
|
| 274 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 275 |
+
self.dropout = nn.Dropout(dropout)
|
| 276 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 277 |
+
|
| 278 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 279 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 280 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 281 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 282 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 283 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 284 |
+
|
| 285 |
+
self.activation = _get_activation_fn(activation)
|
| 286 |
+
self.normalize_before = normalize_before
|
| 287 |
+
|
| 288 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 289 |
+
return tensor if pos is None else tensor + pos
|
| 290 |
+
|
| 291 |
+
def forward_post(
|
| 292 |
+
self,
|
| 293 |
+
tgt,
|
| 294 |
+
memory,
|
| 295 |
+
tgt_mask: Optional[Tensor] = None,
|
| 296 |
+
memory_mask: Optional[Tensor] = None,
|
| 297 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 298 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 299 |
+
pos: Optional[Tensor] = None,
|
| 300 |
+
query_pos: Optional[Tensor] = None,
|
| 301 |
+
):
|
| 302 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 303 |
+
tgt2 = self.self_attn(
|
| 304 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
| 305 |
+
)[0]
|
| 306 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 307 |
+
tgt = self.norm1(tgt)
|
| 308 |
+
tgt2 = self.multihead_attn(
|
| 309 |
+
query=self.with_pos_embed(tgt, query_pos),
|
| 310 |
+
key=self.with_pos_embed(memory, pos),
|
| 311 |
+
value=memory,
|
| 312 |
+
attn_mask=memory_mask,
|
| 313 |
+
key_padding_mask=memory_key_padding_mask,
|
| 314 |
+
)[0]
|
| 315 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 316 |
+
tgt = self.norm2(tgt)
|
| 317 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 318 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 319 |
+
tgt = self.norm3(tgt)
|
| 320 |
+
return tgt
|
| 321 |
+
|
| 322 |
+
def forward_pre(
|
| 323 |
+
self,
|
| 324 |
+
tgt,
|
| 325 |
+
memory,
|
| 326 |
+
tgt_mask: Optional[Tensor] = None,
|
| 327 |
+
memory_mask: Optional[Tensor] = None,
|
| 328 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 329 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 330 |
+
pos: Optional[Tensor] = None,
|
| 331 |
+
query_pos: Optional[Tensor] = None,
|
| 332 |
+
):
|
| 333 |
+
tgt2 = self.norm1(tgt)
|
| 334 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 335 |
+
tgt2 = self.self_attn(
|
| 336 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
| 337 |
+
)[0]
|
| 338 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 339 |
+
tgt2 = self.norm2(tgt)
|
| 340 |
+
tgt2 = self.multihead_attn(
|
| 341 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
| 342 |
+
key=self.with_pos_embed(memory, pos),
|
| 343 |
+
value=memory,
|
| 344 |
+
attn_mask=memory_mask,
|
| 345 |
+
key_padding_mask=memory_key_padding_mask,
|
| 346 |
+
)[0]
|
| 347 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 348 |
+
tgt2 = self.norm3(tgt)
|
| 349 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 350 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 351 |
+
return tgt
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
tgt,
|
| 356 |
+
memory,
|
| 357 |
+
tgt_mask: Optional[Tensor] = None,
|
| 358 |
+
memory_mask: Optional[Tensor] = None,
|
| 359 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 360 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 361 |
+
pos: Optional[Tensor] = None,
|
| 362 |
+
query_pos: Optional[Tensor] = None,
|
| 363 |
+
):
|
| 364 |
+
if self.normalize_before:
|
| 365 |
+
return self.forward_pre(
|
| 366 |
+
tgt,
|
| 367 |
+
memory,
|
| 368 |
+
tgt_mask,
|
| 369 |
+
memory_mask,
|
| 370 |
+
tgt_key_padding_mask,
|
| 371 |
+
memory_key_padding_mask,
|
| 372 |
+
pos,
|
| 373 |
+
query_pos,
|
| 374 |
+
)
|
| 375 |
+
return self.forward_post(
|
| 376 |
+
tgt,
|
| 377 |
+
memory,
|
| 378 |
+
tgt_mask,
|
| 379 |
+
memory_mask,
|
| 380 |
+
tgt_key_padding_mask,
|
| 381 |
+
memory_key_padding_mask,
|
| 382 |
+
pos,
|
| 383 |
+
query_pos,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def _get_clones(module, N):
|
| 388 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _get_activation_fn(activation):
|
| 392 |
+
"""Return an activation function given a string"""
|
| 393 |
+
if activation == "relu":
|
| 394 |
+
return F.relu
|
| 395 |
+
if activation == "gelu":
|
| 396 |
+
return F.gelu
|
| 397 |
+
if activation == "glu":
|
| 398 |
+
return F.glu
|
| 399 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
pi3detr/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config_reader import load_args
|
| 2 |
+
from .layer_utils import no_grad, load_weights
|
| 3 |
+
from .curve_fitter import torch_bezier_curve
|
pi3detr/utils/config_reader.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from argparse import Namespace
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_yaml(file_path: str) -> yaml.YAMLObject:
|
| 8 |
+
with open(file_path, "r") as file:
|
| 9 |
+
try:
|
| 10 |
+
return yaml.safe_load(file)
|
| 11 |
+
except yaml.YAMLError as exc:
|
| 12 |
+
print(exc)
|
| 13 |
+
yaml.YAMLError("error reading yaml file")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_args(config: str, parsed_args: Optional[Namespace] = None) -> SimpleNamespace:
|
| 17 |
+
args = load_yaml(config)
|
| 18 |
+
parsed_args = vars(parsed_args) if parsed_args else {}
|
| 19 |
+
args = args | parsed_args
|
| 20 |
+
args = SimpleNamespace(**args)
|
| 21 |
+
return args
|
pi3detr/utils/curve_fitter.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def torch_arc_points(start, mid, end, num_points=100):
|
| 7 |
+
"""
|
| 8 |
+
Sample points along a circular arc defined by 3 points in 3D, batched.
|
| 9 |
+
Inputs:
|
| 10 |
+
start, mid, end: tensors of shape [B, 3]
|
| 11 |
+
num_points: number of points sampled along the arc
|
| 12 |
+
Returns:
|
| 13 |
+
arc_points: tensor of shape [B, num_points, 3]
|
| 14 |
+
"""
|
| 15 |
+
B = start.shape[0]
|
| 16 |
+
|
| 17 |
+
# 1. Compute circle center and normal vector for each batch
|
| 18 |
+
v1 = mid - start # [B,3]
|
| 19 |
+
v2 = end - start # [B,3]
|
| 20 |
+
|
| 21 |
+
normal = torch.cross(v1, v2, dim=1) # [B,3]
|
| 22 |
+
normal_norm = normal.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
| 23 |
+
normal = normal / normal_norm # normalize
|
| 24 |
+
|
| 25 |
+
mid1 = (start + mid) / 2 # [B,3]
|
| 26 |
+
mid2 = (start + end) / 2 # [B,3]
|
| 27 |
+
|
| 28 |
+
# perpendicular directions in the plane
|
| 29 |
+
perp1 = torch.cross(normal, v1, dim=1) # [B,3]
|
| 30 |
+
perp2 = torch.cross(normal, v2, dim=1) # [B,3]
|
| 31 |
+
|
| 32 |
+
# Solve line intersection for each batch:
|
| 33 |
+
# Line 1: point mid1, direction perp1
|
| 34 |
+
# Line 2: point mid2, direction perp2
|
| 35 |
+
# Solve for t in mid1 + t * perp1 = mid2 + s * perp2
|
| 36 |
+
|
| 37 |
+
# Construct matrix A and vector b for least squares
|
| 38 |
+
A = torch.stack([perp1, -perp2], dim=2) # [B,3,2]
|
| 39 |
+
b = (mid2 - mid1).unsqueeze(2) # [B,3,1]
|
| 40 |
+
|
| 41 |
+
# Use torch.linalg.lstsq if available, fallback to pinv:
|
| 42 |
+
try:
|
| 43 |
+
t_s = torch.linalg.lstsq(A, b).solution # [B,2,1]
|
| 44 |
+
except:
|
| 45 |
+
# fallback
|
| 46 |
+
At = A.transpose(1, 2) # [B,2,3]
|
| 47 |
+
pinv = torch.linalg.pinv(A) # [B,2,3]
|
| 48 |
+
t_s = torch.bmm(pinv, b) # [B,2,1]
|
| 49 |
+
|
| 50 |
+
t = t_s[:, 0, 0] # [B]
|
| 51 |
+
|
| 52 |
+
center = mid1 + (perp1 * t.unsqueeze(1)) # [B,3]
|
| 53 |
+
|
| 54 |
+
radius = (start - center).norm(dim=1, keepdim=True) # [B,1]
|
| 55 |
+
|
| 56 |
+
# 2. Define local basis in the arc plane
|
| 57 |
+
x_axis = (start - center) / radius # [B,3]
|
| 58 |
+
y_axis = torch.cross(normal, x_axis, dim=1) # [B,3]
|
| 59 |
+
|
| 60 |
+
# 3. Compute angles function
|
| 61 |
+
def angle_from_vector(v):
|
| 62 |
+
x = (v * x_axis).sum(dim=1) # [B]
|
| 63 |
+
y = (v * y_axis).sum(dim=1) # [B]
|
| 64 |
+
angles = torch.atan2(y, x) # [-pi, pi]
|
| 65 |
+
angles = angles % (2 * math.pi)
|
| 66 |
+
return angles
|
| 67 |
+
|
| 68 |
+
theta_start = torch.zeros(B, device=start.device) # [B], 0 since x_axis is ref
|
| 69 |
+
theta_end = angle_from_vector(end - center) # [B]
|
| 70 |
+
theta_mid = angle_from_vector(mid - center) # [B]
|
| 71 |
+
|
| 72 |
+
# 4. Ensure arc goes the correct way (shortest arc through mid)
|
| 73 |
+
# Helper function vectorized:
|
| 74 |
+
def between(a, b, c):
|
| 75 |
+
# returns bool tensor if b is between a and c going CCW mod 2pi
|
| 76 |
+
return ((a < b) & (b < c)) | ((c < a) & ((a < b) | (b < c)))
|
| 77 |
+
|
| 78 |
+
cond = between(theta_start, theta_mid, theta_end)
|
| 79 |
+
|
| 80 |
+
# If not cond, swap start/end angles by adding 2pi to one side
|
| 81 |
+
# We'll add 2pi to whichever angle is smaller to preserve direction
|
| 82 |
+
theta_start_new = torch.where(
|
| 83 |
+
cond,
|
| 84 |
+
theta_start,
|
| 85 |
+
torch.where(theta_start < theta_end, theta_start, theta_start + 2 * math.pi),
|
| 86 |
+
)
|
| 87 |
+
theta_end_new = torch.where(
|
| 88 |
+
cond,
|
| 89 |
+
theta_end,
|
| 90 |
+
torch.where(theta_end < theta_start, theta_end + 2 * math.pi, theta_end),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# 5. Sample angles
|
| 94 |
+
t_lin = (
|
| 95 |
+
torch.linspace(0, 1, steps=num_points, device=start.device)
|
| 96 |
+
.unsqueeze(0)
|
| 97 |
+
.repeat(B, 1)
|
| 98 |
+
) # [B, num_points]
|
| 99 |
+
|
| 100 |
+
angles = theta_start_new.unsqueeze(1) + t_lin * (
|
| 101 |
+
theta_end_new - theta_start_new
|
| 102 |
+
).unsqueeze(
|
| 103 |
+
1
|
| 104 |
+
) # [B, num_points]
|
| 105 |
+
angles = angles % (2 * math.pi)
|
| 106 |
+
|
| 107 |
+
# 6. Map back to 3D
|
| 108 |
+
cos_a = torch.cos(angles).unsqueeze(2) # [B, num_points, 1]
|
| 109 |
+
sin_a = torch.sin(angles).unsqueeze(2) # [B, num_points, 1]
|
| 110 |
+
|
| 111 |
+
points = center.unsqueeze(1) + radius.unsqueeze(1) * (
|
| 112 |
+
cos_a * x_axis.unsqueeze(1) + sin_a * y_axis.unsqueeze(1)
|
| 113 |
+
) # [B, num_points, 3]
|
| 114 |
+
|
| 115 |
+
return points
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def torch_circle_fitter(
|
| 119 |
+
points: torch.Tensor,
|
| 120 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 121 |
+
"""
|
| 122 |
+
Fits a circle to an arbitrary number of 3D points using least squares.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
points: Tensor of shape (B, N, 3), where B = batch size, N = number of points per batch.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
center_3d: (B, 3) tensor of circle centers in 3D
|
| 129 |
+
normal: (B, 3) tensor of normal vectors to the circle's plane
|
| 130 |
+
radius: (B,) tensor of circle radii
|
| 131 |
+
"""
|
| 132 |
+
B, N, _ = points.shape
|
| 133 |
+
mean = points.mean(dim=1, keepdim=True)
|
| 134 |
+
centered = points - mean
|
| 135 |
+
|
| 136 |
+
# PCA via SVD
|
| 137 |
+
U, S, Vh = torch.linalg.svd(centered)
|
| 138 |
+
normal = Vh[
|
| 139 |
+
:, -1, :
|
| 140 |
+
] # last singular vector corresponds to the smallest variance (plane normal)
|
| 141 |
+
|
| 142 |
+
# Project to plane
|
| 143 |
+
x_axis = Vh[:, 0, :]
|
| 144 |
+
y_axis = Vh[:, 1, :]
|
| 145 |
+
X = torch.einsum("bij,bj->bi", centered, x_axis) # (B, N)
|
| 146 |
+
Y = torch.einsum("bij,bj->bi", centered, y_axis) # (B, N)
|
| 147 |
+
|
| 148 |
+
# Fit circle in 2D: (x - xc)^2 + (y - yc)^2 = r^2
|
| 149 |
+
A = torch.stack([2 * X, 2 * Y, torch.ones_like(X)], dim=-1) # (B, N, 3)
|
| 150 |
+
b = (X**2 + Y**2).unsqueeze(-1) # (B, N, 1)
|
| 151 |
+
|
| 152 |
+
# Solve the least squares system: A @ [xc, yc, c] = b
|
| 153 |
+
AtA = A.transpose(1, 2) @ A
|
| 154 |
+
Atb = A.transpose(1, 2) @ b
|
| 155 |
+
sol = torch.linalg.solve(AtA, Atb).squeeze(-1) # (B, 3)
|
| 156 |
+
|
| 157 |
+
xc, yc, c = sol[:, 0], sol[:, 1], sol[:, 2]
|
| 158 |
+
radius = torch.sqrt(xc**2 + yc**2 + c)
|
| 159 |
+
|
| 160 |
+
# Reconstruct center in 3D
|
| 161 |
+
center_3d = mean.squeeze(1) + xc.unsqueeze(1) * x_axis + yc.unsqueeze(1) * y_axis
|
| 162 |
+
|
| 163 |
+
return center_3d, normal, radius
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def generate_points_on_circle(center, normal, radius, num_points=100):
|
| 167 |
+
# Normalize the normal vector
|
| 168 |
+
normal = normal / np.linalg.norm(normal)
|
| 169 |
+
|
| 170 |
+
# Find two orthogonal vectors in the plane of the circle
|
| 171 |
+
if np.allclose(normal, [0, 0, 1]):
|
| 172 |
+
u = np.array([1, 0, 0])
|
| 173 |
+
else:
|
| 174 |
+
u = np.cross(normal, [0, 0, 1])
|
| 175 |
+
u = u / np.linalg.norm(u)
|
| 176 |
+
v = np.cross(normal, u)
|
| 177 |
+
|
| 178 |
+
# Generate points on the circle in the plane
|
| 179 |
+
theta = np.linspace(0, 2 * np.pi, num_points)
|
| 180 |
+
circle_points = (
|
| 181 |
+
center
|
| 182 |
+
+ radius * np.outer(np.cos(theta), u)
|
| 183 |
+
+ radius * np.outer(np.sin(theta), v)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return circle_points
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def generate_points_on_circle_torch(
|
| 190 |
+
center, normal, radius, num_points=100
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
"""
|
| 193 |
+
Generate points on a circle in 3D space using PyTorch, supporting batching.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
center: Tensor of shape (B, 3), circle centers.
|
| 197 |
+
normal: Tensor of shape (B, 3), normal vectors to the circle's plane.
|
| 198 |
+
radius: Tensor of shape (B,), radii of the circles.
|
| 199 |
+
num_points: Number of points to generate per circle.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Tensor of shape (B, num_points, 3), points on the circles.
|
| 203 |
+
"""
|
| 204 |
+
B = center.shape[0]
|
| 205 |
+
normal = normal / torch.norm(normal, dim=1, keepdim=True) # Normalize normals
|
| 206 |
+
|
| 207 |
+
# Find two orthogonal vectors in the plane of the circle
|
| 208 |
+
u = torch.linalg.cross(
|
| 209 |
+
normal,
|
| 210 |
+
torch.tensor([0, 0, 1], dtype=normal.dtype, device=normal.device).expand_as(
|
| 211 |
+
normal
|
| 212 |
+
),
|
| 213 |
+
)
|
| 214 |
+
u = torch.where(
|
| 215 |
+
torch.norm(u, dim=1, keepdim=True) > 1e-6,
|
| 216 |
+
u,
|
| 217 |
+
torch.tensor([1, 0, 0], dtype=normal.dtype, device=normal.device).expand_as(
|
| 218 |
+
normal
|
| 219 |
+
),
|
| 220 |
+
)
|
| 221 |
+
u = u / torch.norm(u, dim=1, keepdim=True)
|
| 222 |
+
v = torch.linalg.cross(normal, u)
|
| 223 |
+
|
| 224 |
+
# Generate points on the circle in the plane
|
| 225 |
+
theta = (
|
| 226 |
+
torch.linspace(0, 2 * torch.pi, num_points, device=center.device)
|
| 227 |
+
.unsqueeze(0)
|
| 228 |
+
.repeat(B, 1)
|
| 229 |
+
)
|
| 230 |
+
circle_points = (
|
| 231 |
+
center.unsqueeze(1)
|
| 232 |
+
+ radius.unsqueeze(1).unsqueeze(2)
|
| 233 |
+
* torch.cos(theta).unsqueeze(2)
|
| 234 |
+
* u.unsqueeze(1)
|
| 235 |
+
+ radius.unsqueeze(1).unsqueeze(2)
|
| 236 |
+
* torch.sin(theta).unsqueeze(2)
|
| 237 |
+
* v.unsqueeze(1)
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return circle_points
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def torch_bezier_curve(
|
| 244 |
+
control_points: torch.Tensor, num_points: int = 100
|
| 245 |
+
) -> torch.Tensor:
|
| 246 |
+
control_points = control_points.float()
|
| 247 |
+
t = (torch.linspace(0, 1, num_points).unsqueeze(-1).unsqueeze(-1)).to(
|
| 248 |
+
control_points.device
|
| 249 |
+
) # shape [1, num_points, 1]
|
| 250 |
+
B = (
|
| 251 |
+
(1 - t) ** 3 * control_points[:, 0]
|
| 252 |
+
+ 3 * (1 - t) ** 2 * t * control_points[:, 1]
|
| 253 |
+
+ 3 * (1 - t) * t**2 * control_points[:, 2]
|
| 254 |
+
+ t**3 * control_points[:, 3]
|
| 255 |
+
)
|
| 256 |
+
# Transpose the first two dimensions to get the shape (batch_size, num_points, 3)
|
| 257 |
+
B = B.transpose(0, 1)
|
| 258 |
+
|
| 259 |
+
return B
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def torch_line_points(
|
| 263 |
+
start_points: torch.Tensor, end_points: torch.Tensor, num_points: int = 100
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
weights = (
|
| 266 |
+
torch.linspace(0, 1, num_points)
|
| 267 |
+
.unsqueeze(0)
|
| 268 |
+
.unsqueeze(-1)
|
| 269 |
+
.to(start_points.device)
|
| 270 |
+
)
|
| 271 |
+
line_points = (1 - weights) * start_points.unsqueeze(
|
| 272 |
+
1
|
| 273 |
+
) + weights * end_points.unsqueeze(1)
|
| 274 |
+
return line_points
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def fit_line(points: torch.Tensor, K: int = 100) -> torch.Tensor:
|
| 278 |
+
"""
|
| 279 |
+
Fit a line to 3D points and sample K points along it.
|
| 280 |
+
"""
|
| 281 |
+
assert points.ndim == 2 and points.shape[1] == 3, "Input must be [N, 3]"
|
| 282 |
+
|
| 283 |
+
# Step 1: Center the points
|
| 284 |
+
mean = points.mean(dim=0, keepdim=True)
|
| 285 |
+
centered = points - mean
|
| 286 |
+
|
| 287 |
+
# Step 2: SVD
|
| 288 |
+
U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
|
| 289 |
+
direction = Vh[0] # First principal component
|
| 290 |
+
|
| 291 |
+
# Step 3: Project points onto the line to get min/max
|
| 292 |
+
projections = torch.matmul(centered, direction)
|
| 293 |
+
t_min, t_max = projections.min(), projections.max()
|
| 294 |
+
|
| 295 |
+
# Step 4: Sample along the line
|
| 296 |
+
t_vals = torch.linspace(t_min, t_max, K).to(points.device)
|
| 297 |
+
fitted_points = mean + t_vals[:, None] * direction
|
| 298 |
+
|
| 299 |
+
return fitted_points
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def fit_cubic_bezier(points_3d: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
"""
|
| 304 |
+
Fit a cubic Bézier curve to 3D points while fixing the start and end points.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
points_3d: (N, 3) Tensor of 3D arc points.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
bezier_pts: Tensor of 4 control points (P0, P1, P2, P3), shape (4, 3)
|
| 311 |
+
"""
|
| 312 |
+
if not isinstance(points_3d, torch.Tensor):
|
| 313 |
+
points_3d = torch.tensor(points_3d, dtype=torch.float32)
|
| 314 |
+
|
| 315 |
+
n = len(points_3d)
|
| 316 |
+
|
| 317 |
+
if n < 4:
|
| 318 |
+
raise ValueError("At least 4 points are required to fit a cubic Bézier curve.")
|
| 319 |
+
|
| 320 |
+
device = points_3d.device
|
| 321 |
+
|
| 322 |
+
# Fixed start and end points
|
| 323 |
+
P0 = points_3d[0]
|
| 324 |
+
P3 = points_3d[-1]
|
| 325 |
+
|
| 326 |
+
# Normalize parameter t
|
| 327 |
+
t = torch.linspace(0, 1, n, device=device)
|
| 328 |
+
|
| 329 |
+
# Bernstein basis functions for cubic Bézier
|
| 330 |
+
def bernstein(t):
|
| 331 |
+
b0 = (1 - t) ** 3
|
| 332 |
+
b1 = 3 * (1 - t) ** 2 * t
|
| 333 |
+
b2 = 3 * (1 - t) * t**2
|
| 334 |
+
b3 = t**3
|
| 335 |
+
return torch.stack([b0, b1, b2, b3], dim=1) # (n, 4)
|
| 336 |
+
|
| 337 |
+
B = bernstein(t)
|
| 338 |
+
|
| 339 |
+
# Initial guess for P1 and P2 (based on tangents)
|
| 340 |
+
P1_init = P0 + (points_3d[1] - P0) * 1.5
|
| 341 |
+
P2_init = P3 + (points_3d[-2] - P3) * 1.5
|
| 342 |
+
|
| 343 |
+
# Optimization parameters - make them require gradients
|
| 344 |
+
P1 = P1_init.clone().detach().requires_grad_(True)
|
| 345 |
+
P2 = P2_init.clone().detach().requires_grad_(True)
|
| 346 |
+
|
| 347 |
+
# Optimizer
|
| 348 |
+
optimizer = torch.optim.LBFGS([P1, P2], max_iter=100, line_search_fn="strong_wolfe")
|
| 349 |
+
|
| 350 |
+
def closure():
|
| 351 |
+
optimizer.zero_grad()
|
| 352 |
+
|
| 353 |
+
# Compute Bézier curve
|
| 354 |
+
curve = (
|
| 355 |
+
B[:, 0].unsqueeze(1) * P0
|
| 356 |
+
+ B[:, 1].unsqueeze(1) * P1
|
| 357 |
+
+ B[:, 2].unsqueeze(1) * P2
|
| 358 |
+
+ B[:, 3].unsqueeze(1) * P3
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Compute loss (mean squared error)
|
| 362 |
+
loss = torch.mean((curve - points_3d) ** 2)
|
| 363 |
+
loss.backward()
|
| 364 |
+
return loss
|
| 365 |
+
|
| 366 |
+
# Optimize
|
| 367 |
+
optimizer.step(closure)
|
| 368 |
+
|
| 369 |
+
# Return control points
|
| 370 |
+
with torch.no_grad():
|
| 371 |
+
return torch.stack([P0, P1, P2, P3])
|
pi3detr/utils/layer_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_weights(
|
| 7 |
+
model: nn.Module, ckpt_path: str, dont_load: Optional[list[str]] = []
|
| 8 |
+
) -> None:
|
| 9 |
+
ckpt = torch.load(ckpt_path, weights_only=False)
|
| 10 |
+
state_dict = {}
|
| 11 |
+
for k, v in ckpt["state_dict"].items():
|
| 12 |
+
if not any([dl in k for dl in dont_load]):
|
| 13 |
+
state_dict[k] = v
|
| 14 |
+
else:
|
| 15 |
+
print(f"Didn't load {k}")
|
| 16 |
+
model.load_state_dict(state_dict, strict=False)
|
| 17 |
+
print(f"Loaded checkpoint: {ckpt_path}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def no_grad(model: nn.Module) -> None:
|
| 21 |
+
for param in model.parameters():
|
| 22 |
+
param.requires_grad = False
|
pi3detr/utils/postprocessing.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from torch_geometric.data import Data
|
| 5 |
+
from scipy.spatial import cKDTree # faster than KDTree
|
| 6 |
+
|
| 7 |
+
from .curve_fitter import (
|
| 8 |
+
fit_cubic_bezier,
|
| 9 |
+
torch_bezier_curve,
|
| 10 |
+
torch_circle_fitter,
|
| 11 |
+
generate_points_on_circle_torch,
|
| 12 |
+
torch_arc_points,
|
| 13 |
+
fit_line,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def snap_and_fit_curves(
|
| 18 |
+
data: Data,
|
| 19 |
+
) -> Data:
|
| 20 |
+
"""
|
| 21 |
+
Snap polylines to nearest point cloud points and fit geometric curves based on predicted classes.
|
| 22 |
+
|
| 23 |
+
This function performs two main operations:
|
| 24 |
+
1. Snaps each polyline vertex to its nearest neighbor in the point cloud
|
| 25 |
+
2. Fits the appropriate geometric curve (line, circle, arc, or B-spline) based on the predicted class
|
| 26 |
+
|
| 27 |
+
Class mapping:
|
| 28 |
+
0: Background (no processing, kept as-is)
|
| 29 |
+
1: B-spline (cubic Bezier curve fitting)
|
| 30 |
+
2: Line (linear regression fitting)
|
| 31 |
+
3: Circle (3D circle fitting)
|
| 32 |
+
4: Arc (circular arc through 3 points)
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
data (Data): PyTorch Geometric Data object containing:
|
| 36 |
+
- pos (Tensor): Point cloud coordinates [P, 3]
|
| 37 |
+
- polylines (Tensor): Raw polyline predictions [M, K, 3]
|
| 38 |
+
- polyline_class (Tensor): Predicted classes for each polyline [M]
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Data: Cloned Data object with fitted polylines replacing the original polylines.
|
| 42 |
+
All other attributes remain unchanged.
|
| 43 |
+
|
| 44 |
+
Note:
|
| 45 |
+
- Robust error handling: falls back to original polyline if fitting fails
|
| 46 |
+
- Validates output shapes and numerical stability (NaN/Inf checks)
|
| 47 |
+
- Requires minimum points per curve type (e.g., 3 for circles, 4 for B-splines)
|
| 48 |
+
"""
|
| 49 |
+
point_cloud = data.pos
|
| 50 |
+
polylines = data.polylines
|
| 51 |
+
polyline_classes = data.polyline_class
|
| 52 |
+
M, K, _ = polylines.shape
|
| 53 |
+
snapped_and_fitted = torch.zeros_like(polylines)
|
| 54 |
+
|
| 55 |
+
for i, cls in enumerate(polyline_classes):
|
| 56 |
+
if cls == 0:
|
| 57 |
+
snapped_and_fitted[i] = polylines[i] # Keep original for class 0
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Snap the polyline to the nearest point in the point cloud
|
| 62 |
+
distances = torch.cdist(polylines[i], point_cloud)
|
| 63 |
+
nearest_idx = distances.argmin(dim=1)
|
| 64 |
+
nn_points = point_cloud[nearest_idx]
|
| 65 |
+
|
| 66 |
+
# Safety check: ensure we have valid points
|
| 67 |
+
if (
|
| 68 |
+
len(nn_points) == 0
|
| 69 |
+
or torch.any(torch.isnan(nn_points))
|
| 70 |
+
or torch.any(torch.isinf(nn_points))
|
| 71 |
+
):
|
| 72 |
+
snapped_and_fitted[i] = polylines[i]
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
new_curve = None
|
| 76 |
+
|
| 77 |
+
if cls == 1: # BSpline
|
| 78 |
+
try:
|
| 79 |
+
if len(nn_points) < 4:
|
| 80 |
+
# Not enough points for cubic Bezier, fallback to original
|
| 81 |
+
new_curve = polylines[i]
|
| 82 |
+
else:
|
| 83 |
+
bezier_pts = fit_cubic_bezier(nn_points)
|
| 84 |
+
new_curve = torch_bezier_curve(
|
| 85 |
+
bezier_pts.unsqueeze(0), K
|
| 86 |
+
).squeeze(0)
|
| 87 |
+
|
| 88 |
+
# Validate output shape and values
|
| 89 |
+
if (
|
| 90 |
+
new_curve.shape != (K, 3)
|
| 91 |
+
or torch.any(torch.isnan(new_curve))
|
| 92 |
+
or torch.any(torch.isinf(new_curve))
|
| 93 |
+
):
|
| 94 |
+
new_curve = polylines[i]
|
| 95 |
+
|
| 96 |
+
except Exception:
|
| 97 |
+
new_curve = polylines[i]
|
| 98 |
+
|
| 99 |
+
elif cls == 2: # Line
|
| 100 |
+
try:
|
| 101 |
+
if len(nn_points) < 2:
|
| 102 |
+
new_curve = polylines[i]
|
| 103 |
+
else:
|
| 104 |
+
new_curve = fit_line(nn_points, K)
|
| 105 |
+
|
| 106 |
+
# Validate output shape and values
|
| 107 |
+
if (
|
| 108 |
+
new_curve.shape != (K, 3)
|
| 109 |
+
or torch.any(torch.isnan(new_curve))
|
| 110 |
+
or torch.any(torch.isinf(new_curve))
|
| 111 |
+
):
|
| 112 |
+
new_curve = polylines[i]
|
| 113 |
+
except Exception:
|
| 114 |
+
new_curve = polylines[i]
|
| 115 |
+
|
| 116 |
+
elif cls == 3: # Circle
|
| 117 |
+
try:
|
| 118 |
+
# Check if we have enough unique points for circle fitting
|
| 119 |
+
unique_points = torch.unique(nn_points, dim=0)
|
| 120 |
+
if len(unique_points) < 3:
|
| 121 |
+
new_curve = polylines[i]
|
| 122 |
+
else:
|
| 123 |
+
center, normal, radius = torch_circle_fitter(
|
| 124 |
+
nn_points.unsqueeze(0)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Validate circle parameters
|
| 128 |
+
if (
|
| 129 |
+
torch.any(torch.isnan(center))
|
| 130 |
+
or torch.any(torch.isnan(normal))
|
| 131 |
+
or torch.any(torch.isnan(radius))
|
| 132 |
+
or torch.any(torch.isinf(center))
|
| 133 |
+
or torch.any(torch.isinf(normal))
|
| 134 |
+
or torch.any(torch.isinf(radius))
|
| 135 |
+
or radius <= 0
|
| 136 |
+
):
|
| 137 |
+
new_curve = polylines[i]
|
| 138 |
+
else:
|
| 139 |
+
new_curve = generate_points_on_circle_torch(
|
| 140 |
+
center, normal, radius, K
|
| 141 |
+
).squeeze(0)
|
| 142 |
+
|
| 143 |
+
# Validate output shape and values
|
| 144 |
+
if (
|
| 145 |
+
new_curve.shape != (K, 3)
|
| 146 |
+
or torch.any(torch.isnan(new_curve))
|
| 147 |
+
or torch.any(torch.isinf(new_curve))
|
| 148 |
+
):
|
| 149 |
+
new_curve = polylines[i]
|
| 150 |
+
except Exception:
|
| 151 |
+
new_curve = polylines[i]
|
| 152 |
+
|
| 153 |
+
elif cls == 4: # Arc
|
| 154 |
+
try:
|
| 155 |
+
if len(nn_points) < 3:
|
| 156 |
+
new_curve = polylines[i]
|
| 157 |
+
else:
|
| 158 |
+
start_pt = nn_points[0].unsqueeze(0)
|
| 159 |
+
mid_pt = nn_points[len(nn_points) // 2].unsqueeze(0)
|
| 160 |
+
end_pt = nn_points[-1].unsqueeze(0)
|
| 161 |
+
|
| 162 |
+
new_curve = torch_arc_points(
|
| 163 |
+
start_pt, mid_pt, end_pt, K
|
| 164 |
+
).squeeze(0)
|
| 165 |
+
|
| 166 |
+
# Validate output shape and values
|
| 167 |
+
if (
|
| 168 |
+
new_curve.shape != (K, 3)
|
| 169 |
+
or torch.any(torch.isnan(new_curve))
|
| 170 |
+
or torch.any(torch.isinf(new_curve))
|
| 171 |
+
):
|
| 172 |
+
new_curve = polylines[i]
|
| 173 |
+
except Exception:
|
| 174 |
+
new_curve = polylines[i]
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
# Unknown class, keep original
|
| 178 |
+
new_curve = polylines[i]
|
| 179 |
+
|
| 180 |
+
# Final safety check
|
| 181 |
+
if new_curve is not None and new_curve.shape == (K, 3):
|
| 182 |
+
snapped_and_fitted[i] = new_curve
|
| 183 |
+
else:
|
| 184 |
+
snapped_and_fitted[i] = polylines[i]
|
| 185 |
+
|
| 186 |
+
except Exception:
|
| 187 |
+
# If anything goes wrong, fallback to original polyline
|
| 188 |
+
snapped_and_fitted[i] = polylines[i]
|
| 189 |
+
|
| 190 |
+
output = data.clone()
|
| 191 |
+
output.polylines = snapped_and_fitted
|
| 192 |
+
return output
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def filter_predictions(pred_data: Data, thresholds: list[float]) -> Data:
|
| 196 |
+
"""
|
| 197 |
+
Filter predictions based on class-specific confidence thresholds.
|
| 198 |
+
|
| 199 |
+
Removes polylines whose confidence scores fall below the specified threshold
|
| 200 |
+
for their predicted class. This is typically used as a post-processing step
|
| 201 |
+
to remove low-confidence predictions before further analysis.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
pred_data (Data): PyTorch Geometric Data object containing:
|
| 205 |
+
- pos (Tensor): Point cloud coordinates [P, 3]
|
| 206 |
+
- polyline_class (Tensor): Predicted classes [N]
|
| 207 |
+
- polyline_score (Tensor): Confidence scores [N]
|
| 208 |
+
- polylines (Tensor): Polyline coordinates [N, K, 3]
|
| 209 |
+
- query_xyz (Tensor, optional): Query coordinates [N, 3]
|
| 210 |
+
thresholds (list[float]): Confidence thresholds for each class.
|
| 211 |
+
Length must match the number of classes.
|
| 212 |
+
thresholds[i] is the minimum confidence for class i.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Data: Filtered Data object containing only polylines that meet their
|
| 216 |
+
class-specific confidence thresholds. Maintains the same structure
|
| 217 |
+
as input but with potentially fewer polylines.
|
| 218 |
+
|
| 219 |
+
Example:
|
| 220 |
+
# Keep only polylines with confidence > 0.5 for class 0, > 0.7 for class 1, etc.
|
| 221 |
+
filtered = filter_predictions(data, [0.5, 0.7, 0.6, 0.8])
|
| 222 |
+
"""
|
| 223 |
+
mask = (
|
| 224 |
+
pred_data.polyline_score
|
| 225 |
+
>= torch.tensor(thresholds, device=pred_data.pos.device)[
|
| 226 |
+
pred_data.polyline_class
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
+
filtered_data = Data(
|
| 230 |
+
pos=pred_data.pos,
|
| 231 |
+
polyline_class=pred_data.polyline_class[mask],
|
| 232 |
+
polyline_score=pred_data.polyline_score[mask],
|
| 233 |
+
polylines=pred_data.polylines[mask],
|
| 234 |
+
query_xyz=(
|
| 235 |
+
pred_data.query_xyz[mask] if hasattr(pred_data, "query_xyz") else None
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return filtered_data
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def iou_filter_point_based(
|
| 243 |
+
pred_data,
|
| 244 |
+
iou_threshold: float = 0.6,
|
| 245 |
+
background_class: int = 0,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Efficient per-class Non-Maximum Suppression using IoU computed on point cloud indices.
|
| 249 |
+
|
| 250 |
+
This optimized NMS implementation:
|
| 251 |
+
1. Snaps all polyline vertices to nearest point cloud neighbors
|
| 252 |
+
2. Computes IoU based on overlapping point cloud indices (not 3D distances)
|
| 253 |
+
3. Applies greedy NMS within each class, keeping highest-scoring polylines
|
| 254 |
+
4. Uses optimized data structures (cKDTree, sorted arrays) for speed
|
| 255 |
+
|
| 256 |
+
Algorithm details:
|
| 257 |
+
- Single batched nearest neighbor query for all valid vertices
|
| 258 |
+
- IoU = |intersection| / |union| of snapped point indices
|
| 259 |
+
- Polylines ordered by: score (desc) → #snapped_points (desc) → index (asc)
|
| 260 |
+
- Background class polylines are never removed
|
| 261 |
+
- Polylines with no valid snapped points are dropped
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
pred_data (Data): PyTorch Geometric Data object with polyline predictions
|
| 265 |
+
iou_threshold (float, optional): IoU threshold for suppression. Default: 0.6
|
| 266 |
+
background_class (int, optional): Class ID to exclude from NMS. Default: 0
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Data: Filtered Data object with overlapping polylines removed per class.
|
| 270 |
+
Maintains same structure with potentially fewer polylines.
|
| 271 |
+
|
| 272 |
+
Performance:
|
| 273 |
+
Significantly faster than distance-based methods due to:
|
| 274 |
+
- Batched spatial queries (cKDTree)
|
| 275 |
+
- Integer set operations (np.intersect1d)
|
| 276 |
+
- Minimal Python loops
|
| 277 |
+
"""
|
| 278 |
+
data = pred_data.clone()
|
| 279 |
+
|
| 280 |
+
polylines: torch.Tensor = data.polylines # (N, M, 3)
|
| 281 |
+
classes: torch.Tensor = data.polyline_class # (N,)
|
| 282 |
+
pc: torch.Tensor = data.pos # (P, 3)
|
| 283 |
+
scores = getattr(data, "polyline_score", None)
|
| 284 |
+
|
| 285 |
+
device = polylines.device
|
| 286 |
+
N = polylines.shape[0]
|
| 287 |
+
if N == 0 or pc.shape[0] == 0:
|
| 288 |
+
return data
|
| 289 |
+
|
| 290 |
+
# ---- helpers ----
|
| 291 |
+
def valid_mask(pts_t: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
finite = torch.isfinite(pts_t).all(dim=-1)
|
| 293 |
+
non_zero = pts_t.abs().sum(dim=-1) > 0
|
| 294 |
+
return finite & non_zero
|
| 295 |
+
|
| 296 |
+
# ---- gather all valid vertices once (batched) ----
|
| 297 |
+
# We'll collect (poly_idx, vertex_xyz) over all non-background curves.
|
| 298 |
+
poly_indices_list = []
|
| 299 |
+
all_vertices = []
|
| 300 |
+
|
| 301 |
+
bg = int(background_class)
|
| 302 |
+
for i in range(N):
|
| 303 |
+
if int(classes[i].item()) == bg:
|
| 304 |
+
continue
|
| 305 |
+
vm = valid_mask(polylines[i])
|
| 306 |
+
if vm.any():
|
| 307 |
+
pts = polylines[i][vm].detach().cpu().numpy()
|
| 308 |
+
if pts.size > 0:
|
| 309 |
+
all_vertices.append(pts)
|
| 310 |
+
poly_indices_list.append(np.full((pts.shape[0],), i, dtype=np.int32))
|
| 311 |
+
|
| 312 |
+
if len(all_vertices) == 0:
|
| 313 |
+
# nothing to snap; everything gets dropped
|
| 314 |
+
keep_mask = torch.zeros(N, dtype=torch.bool, device=device)
|
| 315 |
+
data.polylines = data.polylines[keep_mask]
|
| 316 |
+
data.polyline_class = data.polyline_class[keep_mask]
|
| 317 |
+
if hasattr(data, "polyline_score") and data.polyline_score is not None:
|
| 318 |
+
data.polyline_score = data.polyline_score[keep_mask]
|
| 319 |
+
if hasattr(data, "query_xyz") and data.query_xyz is not None:
|
| 320 |
+
data.query_xyz = data.query_xyz[keep_mask]
|
| 321 |
+
return data
|
| 322 |
+
|
| 323 |
+
all_vertices = np.concatenate(all_vertices, axis=0) # (T, 3)
|
| 324 |
+
owner_poly = np.concatenate(poly_indices_list, axis=0) # (T,)
|
| 325 |
+
|
| 326 |
+
# ---- one cKDTree query for all vertices ----
|
| 327 |
+
pc_np = pc.detach().cpu().numpy()
|
| 328 |
+
tree = cKDTree(pc_np)
|
| 329 |
+
# Use parallel workers if SciPy supports it (falls back silently otherwise)
|
| 330 |
+
nn_idx = tree.query(all_vertices, workers=-1)[1].astype(np.int64) # (T,)
|
| 331 |
+
|
| 332 |
+
# ---- split back to per-curve snapped unique index arrays (sorted) ----
|
| 333 |
+
snapped_arrays = [None] * N
|
| 334 |
+
set_sizes = torch.zeros(N, dtype=torch.long, device=device)
|
| 335 |
+
|
| 336 |
+
# group indices by polyline using numpy argsort
|
| 337 |
+
order = np.argsort(owner_poly, kind="mergesort")
|
| 338 |
+
owner_sorted = owner_poly[order]
|
| 339 |
+
nn_sorted = nn_idx[order]
|
| 340 |
+
|
| 341 |
+
# find segment starts for each unique polyline id
|
| 342 |
+
unique_ids, starts = np.unique(owner_sorted, return_index=True)
|
| 343 |
+
# append end sentinel
|
| 344 |
+
starts = np.append(starts, owner_sorted.shape[0])
|
| 345 |
+
|
| 346 |
+
for k in range(len(unique_ids)):
|
| 347 |
+
i = int(unique_ids[k])
|
| 348 |
+
seg = nn_sorted[starts[k] : starts[k + 1]]
|
| 349 |
+
if seg.size == 0:
|
| 350 |
+
snapped_arrays[i] = np.empty((0,), dtype=np.int64)
|
| 351 |
+
continue
|
| 352 |
+
uniq = np.unique(seg) # already sorted
|
| 353 |
+
snapped_arrays[i] = uniq
|
| 354 |
+
set_sizes[i] = uniq.size
|
| 355 |
+
|
| 356 |
+
# For background curves or curves with no valid vertices, ensure empty arrays
|
| 357 |
+
for i in range(N):
|
| 358 |
+
if snapped_arrays[i] is None:
|
| 359 |
+
snapped_arrays[i] = np.empty((0,), dtype=np.int64)
|
| 360 |
+
|
| 361 |
+
# fallback scores: prefer more snapped support
|
| 362 |
+
if scores is None:
|
| 363 |
+
scores = set_sizes.to(torch.float)
|
| 364 |
+
|
| 365 |
+
keep_mask = torch.ones(N, dtype=torch.bool, device=device)
|
| 366 |
+
|
| 367 |
+
# ---- per-class greedy NMS (IoU via fast array intersection) ----
|
| 368 |
+
target_classes = torch.unique(classes[classes != background_class]).tolist()
|
| 369 |
+
for cls in target_classes:
|
| 370 |
+
cls_inds = torch.where(classes == cls)[0].tolist()
|
| 371 |
+
if not cls_inds:
|
| 372 |
+
continue
|
| 373 |
+
|
| 374 |
+
# order by (score desc, size desc, index asc)
|
| 375 |
+
cls_order = sorted(
|
| 376 |
+
cls_inds,
|
| 377 |
+
key=lambda idx: (
|
| 378 |
+
-float(scores[idx].item()),
|
| 379 |
+
-int(set_sizes[idx].item()),
|
| 380 |
+
idx,
|
| 381 |
+
),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
suppressed = set()
|
| 385 |
+
for i_idx in cls_order:
|
| 386 |
+
if i_idx in suppressed:
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
A = snapped_arrays[i_idx]
|
| 390 |
+
if A.size == 0:
|
| 391 |
+
suppressed.add(i_idx)
|
| 392 |
+
continue
|
| 393 |
+
|
| 394 |
+
lenA = A.size
|
| 395 |
+
for j_idx in cls_order:
|
| 396 |
+
if j_idx <= i_idx or j_idx in suppressed:
|
| 397 |
+
continue
|
| 398 |
+
B = snapped_arrays[j_idx]
|
| 399 |
+
if B.size == 0:
|
| 400 |
+
suppressed.add(j_idx)
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
# fast intersection of two sorted int arrays
|
| 404 |
+
inter = np.intersect1d(A, B, assume_unique=True).size
|
| 405 |
+
union = lenA + B.size - inter
|
| 406 |
+
if union == 0:
|
| 407 |
+
continue
|
| 408 |
+
if (inter / union) > iou_threshold:
|
| 409 |
+
suppressed.add(j_idx)
|
| 410 |
+
|
| 411 |
+
if suppressed:
|
| 412 |
+
keep_mask[list(suppressed)] = False
|
| 413 |
+
|
| 414 |
+
# ---- filter aligned fields ----
|
| 415 |
+
data.polylines = data.polylines[keep_mask]
|
| 416 |
+
data.polyline_class = data.polyline_class[keep_mask]
|
| 417 |
+
if hasattr(data, "polyline_score") and data.polyline_score is not None:
|
| 418 |
+
data.polyline_score = data.polyline_score[keep_mask]
|
| 419 |
+
if hasattr(data, "query_xyz") and data.query_xyz is not None:
|
| 420 |
+
data.query_xyz = data.query_xyz[keep_mask]
|
| 421 |
+
|
| 422 |
+
return data
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def iou_filter_predictions(
|
| 426 |
+
data: Data,
|
| 427 |
+
iou_threshold: float = 0.6,
|
| 428 |
+
tol: float = 1e-2,
|
| 429 |
+
) -> Data:
|
| 430 |
+
"""
|
| 431 |
+
Remove overlapping polylines within each class using point-to-point distance IoU.
|
| 432 |
+
|
| 433 |
+
Performs class-wise Non-Maximum Suppression to eliminate redundant predictions:
|
| 434 |
+
1. Filters out invalid points (NaN, Inf, near-zero)
|
| 435 |
+
2. Computes pairwise point distances between polylines of the same class
|
| 436 |
+
3. Calculates IoU based on points within distance tolerance
|
| 437 |
+
4. Removes lower-scoring polylines when IoU exceeds threshold
|
| 438 |
+
5. Protects "lonely" polylines (minimal overlap) from removal
|
| 439 |
+
|
| 440 |
+
IoU Calculation:
|
| 441 |
+
- overlap_i = number of points in polyline_i within tolerance of polyline_j
|
| 442 |
+
- overlap_j = number of points in polyline_j within tolerance of polyline_i
|
| 443 |
+
- intersection = min(overlap_i, overlap_j)
|
| 444 |
+
- union = len(polyline_i) + len(polyline_j) - intersection
|
| 445 |
+
- IoU = intersection / union
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
data (Data): PyTorch Geometric Data object containing:
|
| 449 |
+
- polylines (Tensor): Polyline coordinates [N, P, 3]
|
| 450 |
+
- polyline_class (Tensor): Class predictions [N]
|
| 451 |
+
- polyline_score (Tensor): Confidence scores [N]
|
| 452 |
+
- query_xyz (Tensor, optional): Query coordinates [N, 3]
|
| 453 |
+
iou_threshold (float, optional): IoU threshold for duplicate removal. Default: 0.6
|
| 454 |
+
tol (float, optional): Distance tolerance for point overlap detection. Default: 1e-2
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
Data: Filtered Data object with overlapping polylines removed.
|
| 458 |
+
Background class (0) polylines are never removed.
|
| 459 |
+
|
| 460 |
+
Note:
|
| 461 |
+
- Processes polylines in descending score order for deterministic results
|
| 462 |
+
- Requires significant overlap (≥2 points, ≥10% of smaller polyline) before considering removal
|
| 463 |
+
- More computationally expensive than index-based methods but handles arbitrary point clouds
|
| 464 |
+
"""
|
| 465 |
+
polylines = data.polylines
|
| 466 |
+
polyline_class = data.polyline_class
|
| 467 |
+
scores = data.polyline_score
|
| 468 |
+
|
| 469 |
+
# Precompute valid points for all polylines
|
| 470 |
+
valid_pts = []
|
| 471 |
+
for poly in polylines:
|
| 472 |
+
mask = ~torch.isnan(poly).any(dim=1) & (
|
| 473 |
+
~torch.isinf(poly).any(dim=1) & (torch.norm(poly, dim=1) > 1e-6)
|
| 474 |
+
)
|
| 475 |
+
valid_pts.append(poly[mask])
|
| 476 |
+
|
| 477 |
+
remove_set = set()
|
| 478 |
+
|
| 479 |
+
# Process each class independently
|
| 480 |
+
for cls in torch.unique(polyline_class):
|
| 481 |
+
if cls == 0: # Skip background
|
| 482 |
+
continue
|
| 483 |
+
# Get indices for this class
|
| 484 |
+
class_mask = polyline_class == cls
|
| 485 |
+
class_indices = torch.where(class_mask)[0]
|
| 486 |
+
if len(class_indices) < 2:
|
| 487 |
+
continue
|
| 488 |
+
|
| 489 |
+
# Sort by score descending, then index ascending for determinism
|
| 490 |
+
sorted_indices = sorted(
|
| 491 |
+
class_indices.tolist(), key=lambda idx: (-scores[idx].item(), idx)
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Compare pairs in sorted order
|
| 495 |
+
for i, idx_i in enumerate(sorted_indices):
|
| 496 |
+
if idx_i in remove_set:
|
| 497 |
+
continue
|
| 498 |
+
pts_i = valid_pts[idx_i]
|
| 499 |
+
if len(pts_i) == 0:
|
| 500 |
+
continue
|
| 501 |
+
for j in range(i + 1, len(sorted_indices)):
|
| 502 |
+
idx_j = sorted_indices[j]
|
| 503 |
+
if idx_j in remove_set:
|
| 504 |
+
continue
|
| 505 |
+
pts_j = valid_pts[idx_j]
|
| 506 |
+
if len(pts_j) == 0:
|
| 507 |
+
continue
|
| 508 |
+
|
| 509 |
+
# Compute point-wise distances
|
| 510 |
+
dists = torch.cdist(pts_i, pts_j)
|
| 511 |
+
# Calculate overlaps
|
| 512 |
+
overlap_i = (dists.min(dim=1).values < tol).sum().item()
|
| 513 |
+
overlap_j = (dists.min(dim=0).values < tol).sum().item()
|
| 514 |
+
min_points = min(len(pts_i), len(pts_j))
|
| 515 |
+
# Skip if not significant overlap
|
| 516 |
+
if (
|
| 517 |
+
overlap_i < 2
|
| 518 |
+
or overlap_j < 2
|
| 519 |
+
or min(overlap_i, overlap_j) < 0.1 * min_points
|
| 520 |
+
):
|
| 521 |
+
continue
|
| 522 |
+
|
| 523 |
+
# Calculate IoU
|
| 524 |
+
intersection = min(overlap_i, overlap_j)
|
| 525 |
+
union = len(pts_i) + len(pts_j) - intersection
|
| 526 |
+
iou = intersection / union if union > 0 else 0.0
|
| 527 |
+
if iou > iou_threshold:
|
| 528 |
+
# Always remove lower-scoring polyline
|
| 529 |
+
remove_set.add(idx_j)
|
| 530 |
+
|
| 531 |
+
# Create keep mask (protects lonely lines)
|
| 532 |
+
keep_mask = torch.ones(len(polylines), dtype=torch.bool)
|
| 533 |
+
for idx in remove_set:
|
| 534 |
+
keep_mask[idx] = False
|
| 535 |
+
|
| 536 |
+
# Apply filtering
|
| 537 |
+
data.polylines = polylines[keep_mask]
|
| 538 |
+
data.polyline_class = polyline_class[keep_mask]
|
| 539 |
+
data.polyline_score = scores[keep_mask]
|
| 540 |
+
if hasattr(data, "query_xyz"):
|
| 541 |
+
data.query_xyz = data.query_xyz[keep_mask]
|
| 542 |
+
|
| 543 |
+
return data
|
pi3detr/utils/viz.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def hex_to_rgb(hex_color: str) -> tuple:
|
| 2 |
+
"""
|
| 3 |
+
Convert hex color string to RGB tuple.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
hex_color: Hex color string (e.g., "#FF5733")
|
| 7 |
+
|
| 8 |
+
Returns:
|
| 9 |
+
Tuple of RGB values (0-1 range)
|
| 10 |
+
"""
|
| 11 |
+
hex_color = hex_color.lstrip("#")
|
| 12 |
+
return tuple(int(hex_color[i : i + 2], 16) / 255.0 for i in (0, 2, 4))
|
requirements.txt
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==== Repos for CUDA 12.1 PyTorch + PyG wheels ====
|
| 2 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 3 |
+
--extra-index-url https://pypi.org/simple
|
| 4 |
+
--find-links https://data.pyg.org/whl/torch-2.5.1%2Bcu121.html
|
| 5 |
+
|
| 6 |
+
# ==== Core: PyTorch CUDA 12.1 ====
|
| 7 |
+
torch==2.5.1+cu121
|
| 8 |
+
|
| 9 |
+
# ==== PyTorch Geometric stack (built for torch 2.5 + cu121) ====
|
| 10 |
+
pyg-lib==0.4.0+pt25cu121
|
| 11 |
+
torch-scatter==2.1.2+pt25cu121
|
| 12 |
+
torch-sparse==0.6.18+pt25cu121
|
| 13 |
+
torch-cluster==1.6.3+pt25cu121
|
| 14 |
+
torch-spline-conv==1.2.2+pt25cu121
|
| 15 |
+
torch-geometric==2.6.1
|
| 16 |
+
|
| 17 |
+
# ==== Vision / geometry extras ====
|
| 18 |
+
kornia==0.8.0
|
| 19 |
+
opencv-python==4.11.0.86
|
| 20 |
+
open3d==0.19.0
|
| 21 |
+
polyscope==2.3.0
|
| 22 |
+
trimesh==4.6.8
|
| 23 |
+
timm==1.0.14
|
| 24 |
+
spconv-cu121==2.3.8
|
| 25 |
+
fpsample==0.3.3
|
| 26 |
+
|
| 27 |
+
# ==== SciPy stack ====
|
| 28 |
+
# Use NumPy >=2.0 for SciPy 1.15.x; leave upper bound open for compatibility with CUDA wheels.
|
| 29 |
+
numpy>=2.0
|
| 30 |
+
scipy==1.15.2
|
| 31 |
+
scikit-learn==1.6.1
|
| 32 |
+
matplotlib==3.10.0
|
| 33 |
+
|
| 34 |
+
# ==== Jupyter / developer tools ====
|
| 35 |
+
ipython>=8.20
|
| 36 |
+
ipykernel==6.29.5
|
| 37 |
+
ipywidgets==8.1.5
|
| 38 |
+
black==25.1.0
|
| 39 |
+
tqdm>=4.66
|
| 40 |
+
tensorboard==2.19.0
|
| 41 |
+
pytorch-lightning==2.5.0.post0
|
| 42 |
+
|
| 43 |
+
# ==== Hugging Face Space ====
|
| 44 |
+
gradio==5.49.1
|
| 45 |
+
plotly==6.3.1
|
| 46 |
+
plyfile==1.1.2
|