ChinnaVemareddy23 commited on
Commit
a25ea49
·
verified ·
1 Parent(s): 4972899

Update src/visual_cues.py

Browse files
Files changed (1) hide show
  1. src/visual_cues.py +145 -42
src/visual_cues.py CHANGED
@@ -1,4 +1,95 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import io
3
  import base64
4
  from typing import List, Dict, Tuple
@@ -12,78 +103,90 @@ from src.config import LOGO_DETECTION_MODEL
12
  # --------------------------------------------------
13
  # MODEL INITIALIZATION (LOAD ONCE)
14
  # --------------------------------------------------
15
- # Object detection pipeline for logo / seal detection
16
  detector = pipeline(
17
  task="object-detection",
18
  model=LOGO_DETECTION_MODEL,
19
- device=-1 # CPU
20
  )
21
 
22
 
23
  # --------------------------------------------------
24
- # LOGO DETECTION
25
  # --------------------------------------------------
26
  def detect_logos_from_bytes(
27
  image_bytes: bytes,
28
  resize: Tuple[int, int] = (1024, 1024),
29
- max_logos: int = 3
 
30
  ) -> List[Dict[str, str | float]]:
31
  """
32
  Detect logos or visual emblems from raw image bytes.
33
 
34
- The function resizes the image for faster inference,
35
- detects logo regions, crops them, and returns the
36
- cropped logo images encoded in base64 along with
37
- confidence scores.
38
-
39
- Parameters
40
- ----------
41
- image_bytes : bytes
42
- Raw image data.
43
- resize : tuple[int, int], optional
44
- Maximum image size for inference (default: 1024x1024).
45
- max_logos : int, optional
46
- Maximum number of detected logos to return.
47
-
48
- Returns
49
- -------
50
- list[dict]
51
- List of detected logos with:
52
- - confidence: float
53
- - image_base64: str
54
  """
55
 
56
- # Load image from bytes
57
- image: Image.Image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
58
-
59
- # Resize image for performance optimization
60
- image.thumbnail(resize)
61
-
62
- # Run object detection
63
- detections = detector(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  results: List[Dict[str, str | float]] = []
66
 
67
- # Process top detections only
 
 
68
  for det in detections[:max_logos]:
69
  box = det["box"]
70
- score: float = float(det["score"])
 
 
 
 
 
71
 
72
- xmin: int = int(box["xmin"])
73
- ymin: int = int(box["ymin"])
74
- xmax: int = int(box["xmax"])
75
- ymax: int = int(box["ymax"])
76
 
77
- # Crop detected logo region
78
  cropped = image.crop((xmin, ymin, xmax, ymax))
79
 
80
- # Convert cropped logo to base64
81
  buffer = io.BytesIO()
82
  cropped.save(buffer, format="PNG")
83
 
84
  results.append({
85
  "confidence": round(score, 3),
86
- "image_base64": base64.b64encode(buffer.getvalue()).decode()
87
  })
88
 
89
- return results
 
1
 
2
+ # import io
3
+ # import base64
4
+ # from typing import List, Dict, Tuple
5
+
6
+ # from PIL import Image
7
+ # from transformers import pipeline
8
+
9
+ # from src.config import LOGO_DETECTION_MODEL
10
+
11
+
12
+ # # --------------------------------------------------
13
+ # # MODEL INITIALIZATION (LOAD ONCE)
14
+ # # --------------------------------------------------
15
+ # # Object detection pipeline for logo / seal detection
16
+ # detector = pipeline(
17
+ # task="object-detection",
18
+ # model=LOGO_DETECTION_MODEL,
19
+ # device=-1 # CPU
20
+ # )
21
+
22
+
23
+ # # --------------------------------------------------
24
+ # # LOGO DETECTION
25
+ # # --------------------------------------------------
26
+ # def detect_logos_from_bytes(
27
+ # image_bytes: bytes,
28
+ # resize: Tuple[int, int] = (1024, 1024),
29
+ # max_logos: int = 3
30
+ # ) -> List[Dict[str, str | float]]:
31
+ # """
32
+ # Detect logos or visual emblems from raw image bytes.
33
+
34
+ # The function resizes the image for faster inference,
35
+ # detects logo regions, crops them, and returns the
36
+ # cropped logo images encoded in base64 along with
37
+ # confidence scores.
38
+
39
+ # Parameters
40
+ # ----------
41
+ # image_bytes : bytes
42
+ # Raw image data.
43
+ # resize : tuple[int, int], optional
44
+ # Maximum image size for inference (default: 1024x1024).
45
+ # max_logos : int, optional
46
+ # Maximum number of detected logos to return.
47
+
48
+ # Returns
49
+ # -------
50
+ # list[dict]
51
+ # List of detected logos with:
52
+ # - confidence: float
53
+ # - image_base64: str
54
+ # """
55
+
56
+ # # Load image from bytes
57
+ # image: Image.Image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
58
+
59
+ # # Resize image for performance optimization
60
+ # image.thumbnail(resize)
61
+
62
+ # # Run object detection
63
+ # detections = detector(image)
64
+
65
+ # results: List[Dict[str, str | float]] = []
66
+
67
+ # # Process top detections only
68
+ # for det in detections[:max_logos]:
69
+ # box = det["box"]
70
+ # score: float = float(det["score"])
71
+
72
+ # xmin: int = int(box["xmin"])
73
+ # ymin: int = int(box["ymin"])
74
+ # xmax: int = int(box["xmax"])
75
+ # ymax: int = int(box["ymax"])
76
+
77
+ # # Crop detected logo region
78
+ # cropped = image.crop((xmin, ymin, xmax, ymax))
79
+
80
+ # # Convert cropped logo to base64
81
+ # buffer = io.BytesIO()
82
+ # cropped.save(buffer, format="PNG")
83
+
84
+ # results.append({
85
+ # "confidence": round(score, 3),
86
+ # "image_base64": base64.b64encode(buffer.getvalue()).decode()
87
+ # })
88
+
89
+ # return results
90
+
91
+
92
+
93
  import io
94
  import base64
95
  from typing import List, Dict, Tuple
 
103
  # --------------------------------------------------
104
  # MODEL INITIALIZATION (LOAD ONCE)
105
  # --------------------------------------------------
 
106
  detector = pipeline(
107
  task="object-detection",
108
  model=LOGO_DETECTION_MODEL,
109
+ device=-1 # CPU (HF Spaces safe)
110
  )
111
 
112
 
113
  # --------------------------------------------------
114
+ # LOGO DETECTION FUNCTION
115
  # --------------------------------------------------
116
  def detect_logos_from_bytes(
117
  image_bytes: bytes,
118
  resize: Tuple[int, int] = (1024, 1024),
119
+ max_logos: int = 4,
120
+ threshold: float = 0.2
121
  ) -> List[Dict[str, str | float]]:
122
  """
123
  Detect logos or visual emblems from raw image bytes.
124
 
125
+ Returns cropped logo images (base64) with confidence scores.
126
+ Works consistently on local & Hugging Face Spaces.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  """
128
 
129
+ # -------------------------------
130
+ # Load image (deterministic)
131
+ # -------------------------------
132
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
133
+
134
+ # Deterministic resize (NO thumbnail)
135
+ image = image.resize(
136
+ (
137
+ min(image.width, resize[0]),
138
+ min(image.height, resize[1])
139
+ )
140
+ )
141
+
142
+ # -------------------------------
143
+ # Object detection (EXPLICIT threshold)
144
+ # -------------------------------
145
+ detections = detector(
146
+ image,
147
+ threshold=threshold
148
+ )
149
+
150
+ if not detections:
151
+ return []
152
+
153
+ # -------------------------------
154
+ # Sort by confidence (IMPORTANT)
155
+ # -------------------------------
156
+ detections = sorted(
157
+ detections,
158
+ key=lambda x: x["score"],
159
+ reverse=True
160
+ )
161
 
162
  results: List[Dict[str, str | float]] = []
163
 
164
+ # -------------------------------
165
+ # Process top detections
166
+ # -------------------------------
167
  for det in detections[:max_logos]:
168
  box = det["box"]
169
+ score = float(det["score"])
170
+
171
+ xmin = max(0, int(box["xmin"]))
172
+ ymin = max(0, int(box["ymin"]))
173
+ xmax = min(image.width, int(box["xmax"]))
174
+ ymax = min(image.height, int(box["ymax"]))
175
 
176
+ # Safety check
177
+ if xmax <= xmin or ymax <= ymin:
178
+ continue
 
179
 
180
+ # Crop logo region
181
  cropped = image.crop((xmin, ymin, xmax, ymax))
182
 
183
+ # Encode cropped logo to base64
184
  buffer = io.BytesIO()
185
  cropped.save(buffer, format="PNG")
186
 
187
  results.append({
188
  "confidence": round(score, 3),
189
+ "image_base64": base64.b64encode(buffer.getvalue()).decode("utf-8")
190
  })
191
 
192
+ return results