cdancette commited on
Commit
b49f319
·
1 Parent(s): a0a2528

contour instead of overlay

Browse files
Files changed (2) hide show
  1. app.py +24 -8
  2. requirements.txt +1 -1
app.py CHANGED
@@ -19,6 +19,7 @@ import random
19
  from functools import lru_cache
20
  from typing import Any, Dict, List, Optional, Tuple, Union
21
 
 
22
  import gradio as gr
23
  import numpy as np
24
  import pandas as pd
@@ -29,7 +30,6 @@ from transformers import (
29
  AutoImageProcessor,
30
  AutoModelForImageClassification,
31
  )
32
- from torchvision.utils import draw_segmentation_masks
33
 
34
 
35
  HF_REPO_ID = "raidium/curia"
@@ -214,15 +214,31 @@ def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Te
214
  return stacked
215
 
216
 
217
- def apply_mask_overlay(image: np.ndarray, mask: Any) -> np.ndarray:
 
 
 
 
 
 
218
  height, width = image.shape[:2]
219
  mask_tensor = prepare_mask_tensor(mask, height, width)
220
  if mask_tensor is None:
221
  return image
222
 
223
- img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
224
- overlaid = draw_segmentation_masks(img_tensor, mask_tensor, colors=[(255, 0, 0)], alpha=0.4)
225
- output = overlaid.permute(1, 2, 0).mul(255).clamp(0, 255).byte().numpy()
 
 
 
 
 
 
 
 
 
 
226
  return output
227
 
228
 
@@ -232,7 +248,7 @@ def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarra
232
  return display, None
233
 
234
  try:
235
- overlaid = apply_mask_overlay(display, mask)
236
  return overlaid, ""
237
  except Exception:
238
  return display, "Mask provided but could not be visualised."
@@ -552,8 +568,8 @@ def build_demo() -> gr.Blocks:
552
 
553
  - Configure the `HF_TOKEN` secret in your Space to load private checkpoints
554
  and datasets from the `raidium` organisation.
555
- - When masks are available in the dataset sample, they are overlaid on the
556
- image for visual reference (courtesy of `torchvision.utils.draw_segmentation_masks`).
557
  - Uploaded images must be single-channel arrays. Multi-channel inputs are
558
  converted to grayscale automatically.
559
  """
 
19
  from functools import lru_cache
20
  from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
+ import cv2
23
  import gradio as gr
24
  import numpy as np
25
  import pandas as pd
 
30
  AutoImageProcessor,
31
  AutoModelForImageClassification,
32
  )
 
33
 
34
 
35
  HF_REPO_ID = "raidium/curia"
 
214
  return stacked
215
 
216
 
217
+ def apply_contour_overlay(
218
+ image: np.ndarray,
219
+ mask: Any,
220
+ thickness: int = 1,
221
+ color: Tuple[int, int, int] = (255, 0, 0),
222
+ ) -> np.ndarray:
223
+ """Draw only the contours of segmentation masks instead of filled masks."""
224
  height, width = image.shape[:2]
225
  mask_tensor = prepare_mask_tensor(mask, height, width)
226
  if mask_tensor is None:
227
  return image
228
 
229
+ # Work on a copy of the image
230
+ output = image.copy()
231
+
232
+ # Process each mask separately
233
+ for idx in range(mask_tensor.shape[0]):
234
+ mask_np = mask_tensor[idx].numpy().astype(np.uint8)
235
+
236
+ # Find contours
237
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
238
+
239
+ # Draw contours on the image
240
+ cv2.drawContours(output, contours, -1, color, thickness)
241
+
242
  return output
243
 
244
 
 
248
  return display, None
249
 
250
  try:
251
+ overlaid = apply_contour_overlay(display, mask)
252
  return overlaid, ""
253
  except Exception:
254
  return display, "Mask provided but could not be visualised."
 
568
 
569
  - Configure the `HF_TOKEN` secret in your Space to load private checkpoints
570
  and datasets from the `raidium` organisation.
571
+ - When masks are available in the dataset sample, their contours are drawn on the
572
+ image for visual reference using OpenCV.
573
  - Uploaded images must be single-channel arrays. Multi-channel inputs are
574
  converted to grayscale automatically.
575
  """
requirements.txt CHANGED
@@ -2,8 +2,8 @@ gradio>=4.44.0
2
  transformers>=4.41.0
3
  datasets>=2.19.0
4
  torch>=2.2.0
5
- torchvision>=0.17.0
6
  pandas>=2.2.0
7
  numpy>=1.26.0
8
  pillow>=10.2.0
 
9
 
 
2
  transformers>=4.41.0
3
  datasets>=2.19.0
4
  torch>=2.2.0
 
5
  pandas>=2.2.0
6
  numpy>=1.26.0
7
  pillow>=10.2.0
8
+ opencv-python>=4.8.0
9