fafraob commited on
Commit
925407b
Β·
1 Parent(s): 9948732

add links & remove post-processing selection

Browse files
Files changed (1) hide show
  1. app.py +53 -43
app.py CHANGED
@@ -53,9 +53,7 @@ DEMO_POINTCLOUDS = {
53
  }
54
 
55
 
56
- def initialize_model(
57
- checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"
58
- ):
59
  """Initialize the model at startup and store it in the global cache."""
60
  global PI3DETR_MODEL, MODEL_STATUS
61
  try:
@@ -255,7 +253,7 @@ def make_figure(
255
  y=points[:, 1],
256
  z=points[:, 2],
257
  mode="lines",
258
- line=dict(color=color, width=5),
259
  name=f"{curve_type} #{curve_id} ({score:.2f})",
260
  visible=curve.get("visible_state", True),
261
  hoverinfo="text",
@@ -399,11 +397,9 @@ def run_model_inference(
399
  points: np.ndarray,
400
  max_points: int = 32768,
401
  sample_mode: str = "fps",
402
- num_queries: int = 256, # NEW
403
- snap_and_fit: bool = False, # NEW
404
- iou_filter: bool = False, # NEW
405
  ) -> list:
406
- """Run model inference on the given point cloud (extended with num_queries, snap_and_fit, iou_filter)."""
407
  global PI3DETR_MODEL
408
  if model is None:
409
  model = PI3DETR_MODEL
@@ -423,8 +419,6 @@ def run_model_inference(
423
  data,
424
  reverse_norm=True,
425
  thresholds=None,
426
- snap_and_fit=snap_and_fit, # CHANGED
427
- iou_filter=iou_filter, # CHANGED
428
  )
429
  result = output[0]
430
  curves = process_model_predictions(result)
@@ -509,9 +503,7 @@ def run_model_prediction(
509
  th_circle: float,
510
  th_arc: float,
511
  num_queries: int = 256,
512
- snap_and_fit: bool = False,
513
- iou_filter: bool = False,
514
- ): # CHANGED: removed reduction
515
  # NOTE: display points now handled outside; keep signature (called before adding display pts state)
516
  # (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version)
517
  return run_model_prediction_unified( # type: ignore
@@ -527,8 +519,6 @@ def run_model_prediction(
527
  th_arc,
528
  "",
529
  num_queries,
530
- snap_and_fit,
531
- iou_filter,
532
  )
533
 
534
 
@@ -545,8 +535,6 @@ def run_model_prediction_unified(
545
  th_arc: float,
546
  file_name: str = "",
547
  num_queries: int = 256,
548
- snap_and_fit: bool = False,
549
- iou_filter: bool = False,
550
  ):
551
  """
552
  Run model inference and apply initial threshold-based coloring.
@@ -570,9 +558,7 @@ def run_model_prediction_unified(
570
  model_pts,
571
  max_points=model_max_points,
572
  sample_mode=sample_mode,
573
- num_queries=num_queries, # NEW
574
- snap_and_fit=snap_and_fit, # NEW
575
- iou_filter=iou_filter, # NEW
576
  )
577
  except Exception:
578
  pass
@@ -788,9 +774,7 @@ def run_model_with_display(
788
  th_arc: float,
789
  file_name: str = "",
790
  num_queries: int = 256,
791
- snap_and_fit: bool = False,
792
- iou_filter: bool = False,
793
- ): # CHANGED: removed reduction
794
  """
795
  Run inference (if model_pts present) then immediately apply current display
796
  (max_points/point_size/show_axes) and thresholds. Returns:
@@ -813,10 +797,8 @@ def run_model_with_display(
813
  th_circle,
814
  th_arc,
815
  file_name,
816
- num_queries, # NEW
817
- snap_and_fit, # NEW
818
- iou_filter, # NEW
819
- ) # CHANGED: removed reduction
820
 
821
  # Now apply current display settings & thresholds without re-inference
822
  fig_final, display_pts = apply_pointcloud_display_settings(
@@ -836,9 +818,43 @@ def run_model_with_display(
836
 
837
  with gr.Blocks(title="PI3DETR") as demo:
838
  gr.Markdown(
839
- "# πŸ₯§ PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW]\n"
840
- "An end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.\n"
841
- "Upload a `.xyz`, `.ply`, or `.obj` file to explore curve detection."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  )
843
 
844
  with gr.Row():
@@ -857,7 +873,7 @@ with gr.Blocks(title="PI3DETR") as demo:
857
  with gr.Column():
858
  gr.Markdown(
859
  "### 🎯 Confidence Thresholds\n"
860
- "- Hover to inspect scores\n."
861
  "- Filter curves by **class confidence** interactively"
862
  )
863
  with gr.Row():
@@ -866,15 +882,14 @@ with gr.Blocks(title="PI3DETR") as demo:
866
  "### 🧠 Model Settings\n"
867
  "- **Sampling Mode:** Choose downsampling strategy.\n"
868
  "- **Model Input Size:** Number of model input points.\n"
869
- "- **Queries:** Transformer decoder queries (max. output curves).\n"
870
- "- Optional: *Snap&Fit* / *IOU-Filter* post-processing."
871
  )
872
  with gr.Column():
873
  gr.Markdown(
874
  "### ⚑ Performance Notes\n"
875
  "- Trained on **human-made objects**.\n"
876
  "- Optimized for **GPU**; this demo runs on **CPU**.\n"
877
- "- For full qualitative performance: \n"
878
  "[GitHub β†’ PI3DETR](https://github.com/fafraob/pi3detr)"
879
  )
880
  with gr.Column():
@@ -927,16 +942,13 @@ with gr.Blocks(title="PI3DETR") as demo:
927
  step=1,
928
  label="Number of Queries",
929
  )
930
- with gr.Row():
931
- snap_and_fit_chk = gr.Checkbox(value=True, label="Snap&Fit")
932
- iou_filter_chk = gr.Checkbox(value=False, label="IOU-Filter")
933
 
934
  # Threshold sliders (no auto-change triggers)
935
  gr.Markdown("#### Confidence Thresholds (per class)")
936
- th_bspline = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="BSpline β‰₯")
937
- th_line = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Line β‰₯")
938
- th_circle = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Circle β‰₯")
939
- th_arc = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Arc β‰₯")
940
 
941
  with gr.Column(scale=1):
942
  main_plot = gr.Plot(
@@ -973,8 +985,6 @@ with gr.Blocks(title="PI3DETR") as demo:
973
  th_arc,
974
  file_name_state,
975
  num_queries,
976
- snap_and_fit_chk,
977
- iou_filter_chk,
978
  ],
979
  outputs=[main_plot, curves_state, display_pts_state],
980
  )
 
53
  }
54
 
55
 
56
+ def initialize_model(checkpoint_path="model.ckpt", config_path="configs/pi3detr.yaml"):
 
 
57
  """Initialize the model at startup and store it in the global cache."""
58
  global PI3DETR_MODEL, MODEL_STATUS
59
  try:
 
253
  y=points[:, 1],
254
  z=points[:, 2],
255
  mode="lines",
256
+ line=dict(color=color, width=8), # CHANGED: increased from 5 to 8
257
  name=f"{curve_type} #{curve_id} ({score:.2f})",
258
  visible=curve.get("visible_state", True),
259
  hoverinfo="text",
 
397
  points: np.ndarray,
398
  max_points: int = 32768,
399
  sample_mode: str = "fps",
400
+ num_queries: int = 256,
 
 
401
  ) -> list:
402
+ """Run model inference on the given point cloud."""
403
  global PI3DETR_MODEL
404
  if model is None:
405
  model = PI3DETR_MODEL
 
419
  data,
420
  reverse_norm=True,
421
  thresholds=None,
 
 
422
  )
423
  result = output[0]
424
  curves = process_model_predictions(result)
 
503
  th_circle: float,
504
  th_arc: float,
505
  num_queries: int = 256,
506
+ ):
 
 
507
  # NOTE: display points now handled outside; keep signature (called before adding display pts state)
508
  # (This wrapper kept for backwards compatibility if needed – we adapt below in new unified version)
509
  return run_model_prediction_unified( # type: ignore
 
519
  th_arc,
520
  "",
521
  num_queries,
 
 
522
  )
523
 
524
 
 
535
  th_arc: float,
536
  file_name: str = "",
537
  num_queries: int = 256,
 
 
538
  ):
539
  """
540
  Run model inference and apply initial threshold-based coloring.
 
558
  model_pts,
559
  max_points=model_max_points,
560
  sample_mode=sample_mode,
561
+ num_queries=num_queries,
 
 
562
  )
563
  except Exception:
564
  pass
 
774
  th_arc: float,
775
  file_name: str = "",
776
  num_queries: int = 256,
777
+ ):
 
 
778
  """
779
  Run inference (if model_pts present) then immediately apply current display
780
  (max_points/point_size/show_axes) and thresholds. Returns:
 
797
  th_circle,
798
  th_arc,
799
  file_name,
800
+ num_queries,
801
+ )
 
 
802
 
803
  # Now apply current display settings & thresholds without re-inference
804
  fig_final, display_pts = apply_pointcloud_display_settings(
 
818
 
819
  with gr.Blocks(title="PI3DETR") as demo:
820
  gr.Markdown(
821
+ """
822
+ # πŸ₯§ PI3DETR: Detection of Sharp 3D CAD Edges [CPU-PREVIEW]
823
+
824
+ A novel end-to-end deep learning model for **parametric curve inference** in **3D point clouds** and **meshes**.
825
+
826
+ <div style="margin-top: 10px;">
827
+ <a href="https://arxiv.org/pdf/2509.03262" target="_blank" style="
828
+ display: inline-block;
829
+ background-color: #4CAF50;
830
+ color: white;
831
+ padding: 8px 16px;
832
+ text-decoration: none;
833
+ border-radius: 5px;
834
+ margin-right: 8px;
835
+ font-weight: bold;
836
+ ">πŸ“„ Paper</a>
837
+ <a href="https://fafraob.github.io/pi3detr/" target="_blank" style="
838
+ display: inline-block;
839
+ background-color: #2196F3;
840
+ color: white;
841
+ padding: 8px 16px;
842
+ text-decoration: none;
843
+ border-radius: 5px;
844
+ margin-right: 8px;
845
+ font-weight: bold;
846
+ ">🌐 Website</a>
847
+ <a href="https://github.com/fafraob/pi3detr" target="_blank" style="
848
+ display: inline-block;
849
+ background-color: #333;
850
+ color: white;
851
+ padding: 8px 16px;
852
+ text-decoration: none;
853
+ border-radius: 5px;
854
+ font-weight: bold;
855
+ ">πŸ™ GitHub</a>
856
+ </div>
857
+ """
858
  )
859
 
860
  with gr.Row():
 
873
  with gr.Column():
874
  gr.Markdown(
875
  "### 🎯 Confidence Thresholds\n"
876
+ "- Hover to inspect scores.\n"
877
  "- Filter curves by **class confidence** interactively"
878
  )
879
  with gr.Row():
 
882
  "### 🧠 Model Settings\n"
883
  "- **Sampling Mode:** Choose downsampling strategy.\n"
884
  "- **Model Input Size:** Number of model input points.\n"
885
+ "- **Queries:** Transformer decoder queries (max. output curves)."
 
886
  )
887
  with gr.Column():
888
  gr.Markdown(
889
  "### ⚑ Performance Notes\n"
890
  "- Trained on **human-made objects**.\n"
891
  "- Optimized for **GPU**; this demo runs on **CPU**.\n"
892
+ "- For **full qualitative performance**: \n"
893
  "[GitHub β†’ PI3DETR](https://github.com/fafraob/pi3detr)"
894
  )
895
  with gr.Column():
 
942
  step=1,
943
  label="Number of Queries",
944
  )
 
 
 
945
 
946
  # Threshold sliders (no auto-change triggers)
947
  gr.Markdown("#### Confidence Thresholds (per class)")
948
+ th_bspline = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="BSpline β‰₯")
949
+ th_line = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Line β‰₯")
950
+ th_circle = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Circle β‰₯")
951
+ th_arc = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Arc β‰₯")
952
 
953
  with gr.Column(scale=1):
954
  main_plot = gr.Plot(
 
985
  th_arc,
986
  file_name_state,
987
  num_queries,
 
 
988
  ],
989
  outputs=[main_plot, curves_state, display_pts_state],
990
  )