Yanlin Zhang commited on
Commit
ccd869c
·
1 Parent(s): a5e9f68

use sam3 pipeline

Browse files
Files changed (1) hide show
  1. app.py +113 -68
app.py CHANGED
@@ -20,7 +20,7 @@ import gradio as gr
20
  import numpy as np
21
  from PIL import Image
22
  import torch
23
- from transformers import AutoImageProcessor, AutoModel
24
 
25
  # -----------------------------------------------------------------------------
26
  # Configuration
@@ -39,15 +39,18 @@ CLASS_COLORS: Dict[str, Tuple[int, int, int]] = {
39
  }
40
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
43
 
44
  # -----------------------------------------------------------------------------
45
  # Model + processor
46
  # -----------------------------------------------------------------------------
47
 
48
- processor = AutoImageProcessor.from_pretrained(MODEL_ID)
49
- model = AutoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE)
50
- model.eval()
 
 
 
 
51
 
52
 
53
  # -----------------------------------------------------------------------------
@@ -63,77 +66,119 @@ class Track:
63
  score: float | None
64
 
65
 
66
- def _post_process(outputs, height: int, width: int):
67
- target_sizes = [(height, width)]
68
-
69
- if hasattr(processor, "post_process_instance_segmentation"):
70
- return processor.post_process_instance_segmentation(
71
- outputs=outputs,
72
- target_sizes=target_sizes,
73
- threshold=0.35,
74
- mask_threshold=0.4,
75
- overlap_mask_area_threshold=0.5,
76
- )[0]
77
-
78
- if hasattr(processor, "post_process_semantic_segmentation"):
79
- segmentation = processor.post_process_semantic_segmentation(
80
- outputs=outputs,
81
- target_sizes=target_sizes,
82
- )[0]
83
- return {
84
- "masks": segmentation.unsqueeze(0),
85
- "scores": torch.ones(1),
86
- "labels": torch.zeros(1, dtype=torch.int64),
87
- }
88
-
89
- raise gr.Error(
90
- "This version of transformers does not expose SAM3 post-processing helpers. "
91
- "Please ensure transformers>=4.46.0 is installed."
92
- )
93
-
94
-
95
  def _extract_detections(frame_rgb: np.ndarray) -> List[Dict]:
96
  pil_image = Image.fromarray(frame_rgb)
97
  detections: List[Dict] = []
98
 
99
  for label in TEXT_PROMPTS:
100
- inputs = processor(images=pil_image, text=label, return_tensors="pt")
101
- inputs = {
102
- k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v)
103
- for k, v in inputs.items()
104
- }
105
-
106
- with torch.inference_mode():
107
- outputs = model(**inputs)
108
-
109
- processed = _post_process(outputs, pil_image.height, pil_image.width)
110
- masks = processed.get("masks", [])
111
- scores = processed.get("scores", [None] * len(masks))
112
-
113
- for mask_tensor, score in zip(masks, scores):
114
- mask_np = mask_tensor.squeeze().detach().cpu().numpy()
115
- if mask_np.ndim == 3:
116
- mask_np = mask_np[0]
117
-
118
- binary_mask = mask_np > 0.5
119
- area = int(binary_mask.sum())
120
- if area < MIN_MASK_PIXELS:
121
- continue
122
 
123
- ys, xs = np.nonzero(binary_mask)
124
- if len(xs) == 0:
 
 
 
 
 
 
 
 
 
125
  continue
126
 
127
- centroid = (float(xs.mean()), float(ys.mean()))
128
- detections.append(
129
- {
130
- "label": label,
131
- "mask": binary_mask,
132
- "score": float(score) if score is not None else None,
133
- "centroid": centroid,
134
- "area": area,
135
- }
136
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  return detections
139
 
 
20
  import numpy as np
21
  from PIL import Image
22
  import torch
23
+ from transformers import pipeline
24
 
25
  # -----------------------------------------------------------------------------
26
  # Configuration
 
39
  }
40
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
42
 
43
  # -----------------------------------------------------------------------------
44
  # Model + processor
45
  # -----------------------------------------------------------------------------
46
 
47
+ # Use pipeline as shown in Hugging Face guidance
48
+ # Then extract model and processor for text-prompt support
49
+ mask_pipe = pipeline("mask-generation", model=MODEL_ID, device=0 if DEVICE == "cuda" else -1)
50
+
51
+ # Extract model and processor from pipeline for direct text prompt usage
52
+ model = mask_pipe.model
53
+ processor = mask_pipe.feature_extractor if hasattr(mask_pipe, 'feature_extractor') else mask_pipe.image_processor
54
 
55
 
56
  # -----------------------------------------------------------------------------
 
66
  score: float | None
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def _extract_detections(frame_rgb: np.ndarray) -> List[Dict]:
70
  pil_image = Image.fromarray(frame_rgb)
71
  detections: List[Dict] = []
72
 
73
  for label in TEXT_PROMPTS:
74
+ # Use processor and model directly with text prompt
75
+ try:
76
+ inputs = processor(images=pil_image, text=label, return_tensors="pt")
77
+ inputs = {
78
+ k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v)
79
+ for k, v in inputs.items()
80
+ }
81
+
82
+ with torch.inference_mode():
83
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Extract masks from outputs - SAM3 outputs structure may vary
86
+ if hasattr(outputs, "pred_masks"):
87
+ masks = outputs.pred_masks
88
+ elif hasattr(outputs, "masks"):
89
+ masks = outputs.masks
90
+ elif isinstance(outputs, dict):
91
+ masks = outputs.get("pred_masks") or outputs.get("masks")
92
+ else:
93
+ masks = outputs
94
+
95
+ if masks is None:
96
  continue
97
 
98
+ # Handle different mask formats
99
+ if isinstance(masks, torch.Tensor):
100
+ if masks.ndim == 4: # [batch, num_masks, H, W]
101
+ masks = masks[0] # Remove batch dimension
102
+ elif masks.ndim == 3: # [num_masks, H, W]
103
+ pass
104
+ else:
105
+ continue
106
+
107
+ for mask_tensor in masks:
108
+ mask_np = mask_tensor.squeeze().detach().cpu().numpy()
109
+ if mask_np.ndim == 3:
110
+ mask_np = mask_np[0]
111
+
112
+ binary_mask = mask_np > 0.5
113
+ area = int(binary_mask.sum())
114
+ if area < MIN_MASK_PIXELS:
115
+ continue
116
+
117
+ ys, xs = np.nonzero(binary_mask)
118
+ if len(xs) == 0:
119
+ continue
120
+
121
+ centroid = (float(xs.mean()), float(ys.mean()))
122
+ detections.append(
123
+ {
124
+ "label": label,
125
+ "mask": binary_mask,
126
+ "score": None,
127
+ "centroid": centroid,
128
+ "area": area,
129
+ }
130
+ )
131
+ except Exception as e:
132
+ # Fallback to pipeline if direct access fails
133
+ try:
134
+ results = mask_pipe(pil_image)
135
+ if not isinstance(results, list):
136
+ results = [results]
137
+
138
+ for result in results:
139
+ if isinstance(result, dict):
140
+ mask = result.get("mask")
141
+ score = result.get("score")
142
+ else:
143
+ mask = result
144
+ score = None
145
+
146
+ if isinstance(mask, Image.Image):
147
+ mask_np = np.array(mask.convert("L"))
148
+ elif isinstance(mask, torch.Tensor):
149
+ mask_np = mask.squeeze().detach().cpu().numpy()
150
+ elif isinstance(mask, np.ndarray):
151
+ mask_np = mask
152
+ else:
153
+ continue
154
+
155
+ if mask_np.ndim == 3:
156
+ mask_np = mask_np[:, :, 0] if mask_np.shape[2] == 1 else mask_np.max(axis=2)
157
+
158
+ if mask_np.max() > 1.0:
159
+ mask_np = mask_np / 255.0
160
+
161
+ binary_mask = mask_np > 0.5
162
+ area = int(binary_mask.sum())
163
+ if area < MIN_MASK_PIXELS:
164
+ continue
165
+
166
+ ys, xs = np.nonzero(binary_mask)
167
+ if len(xs) == 0:
168
+ continue
169
+
170
+ centroid = (float(xs.mean()), float(ys.mean()))
171
+ detections.append(
172
+ {
173
+ "label": label,
174
+ "mask": binary_mask,
175
+ "score": float(score) if score is not None else None,
176
+ "centroid": centroid,
177
+ "area": area,
178
+ }
179
+ )
180
+ except Exception as e2:
181
+ raise gr.Error(f"Both direct model access and pipeline failed: {e2}")
182
 
183
  return detections
184