daneigh commited on
Commit
eec9b07
Β·
verified Β·
1 Parent(s): a8f735c

Update test_mode.py

Browse files
Files changed (1) hide show
  1. test_mode.py +433 -433
test_mode.py CHANGED
@@ -1,434 +1,434 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import models, transforms
4
- import torch.nn.functional as F
5
- import math
6
- from transformers import AutoModel, AutoTokenizer
7
- from PIL import Image
8
- import matplotlib.pyplot as plt
9
- import easyocr
10
- import numpy as np
11
- import re
12
- import os
13
- import io
14
- import cv2
15
-
16
-
17
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
- MODEL_PATH = os.path.join(BASE_DIR, "model", "best_multimodal_v3.pth")
19
-
20
- # =========================
21
- # 1. Text Preprocessing
22
- # =========================
23
- def preprocess_text(text):
24
- emoji_pattern = re.compile(
25
- "["
26
- "\U0001F600-\U0001F64F" # emoticons
27
- "\U0001F300-\U0001F5FF" # symbols & pictographs
28
- "\U0001F680-\U0001F6FF" # transport & map symbols
29
- "\U0001F1E0-\U0001F1FF" # flags
30
- "\U00002700-\U000027BF" # dingbats
31
- "\U0001F900-\U0001F9FF" # supplemental symbols
32
- "\U00002600-\U000026FF" # misc symbols
33
- "\U00002B00-\U00002BFF" # arrows, etc.
34
- "\U0001FA70-\U0001FAFF" # extended symbols
35
- "]+",
36
- flags=re.UNICODE
37
- )
38
- # Remove emojis
39
- text = emoji_pattern.sub(r'', text)
40
- # Lowercase and strip
41
- text = text.lower().strip()
42
- # Keep letters (including accented), and spaces
43
- text = re.sub(r'[^a-zñÑéíóúü\s]', '', text)
44
- # Normalize whitespace
45
- text = re.sub(r'\s+', ' ', text)
46
-
47
- return text
48
-
49
- # =========================
50
- # 2. OCR Extraction
51
- # =========================
52
- def ocr_extract_text(image_path, confidence_threshold=0.6):
53
- reader = easyocr.Reader(['en', 'tl'], gpu=torch.cuda.is_available())
54
- # # preprocess image for ocr
55
- # image = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
56
- # # image = cv2.GaussianBlur(image,(5,5),0)
57
-
58
- # result = reader.readtext(image, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
59
-
60
- # # Extract text and confidence scores
61
- # texts = []
62
- # confidences = []
63
-
64
- # for detection in result:
65
- # bbox, text, confidence = detection
66
- # texts.append(text)
67
- # confidences.append(confidence)
68
- # final_text = " ".join(texts)
69
- # preprocess_txt = preprocess_text(final_text)
70
- # avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
71
- # return final_text, preprocess_txt, avg_confidence
72
-
73
- # Convert to grayscale
74
- gray = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
75
-
76
- # First pass: OCR on raw grayscale
77
- result = reader.readtext(gray, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
78
- texts, confidences = [], []
79
-
80
- for detection in result:
81
- if len(detection) == 3:
82
- _, text, conf = detection
83
- else:
84
- text, conf = detection
85
-
86
- if isinstance(text, list):
87
- text = " ".join([str(t) for t in text if isinstance(t, str)])
88
-
89
- texts.append(text)
90
- try:
91
- confidences.append(float(conf))
92
- except (ValueError, TypeError):
93
- confidences.append(0.0)
94
-
95
- final_text = " ".join(texts)
96
- avg_conf = sum(confidences)/len(confidences) if confidences else 0.0
97
-
98
- # If confidence is low, retry with Gaussian blur
99
- if avg_conf < confidence_threshold:
100
- texts, confidences = [], []
101
- gauss_img = cv2.GaussianBlur(gray, (5,5), 0)
102
- result = reader.readtext(gauss_img, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
103
-
104
- for detection in result:
105
- if len(detection) == 3:
106
- _, text, conf = detection
107
- else:
108
- text, conf = detection
109
-
110
- if isinstance(text, list):
111
- text = " ".join([str(t) for t in text if isinstance(t, str)])
112
-
113
- texts.append(text)
114
- try:
115
- confidences.append(float(conf))
116
- except (ValueError, TypeError):
117
- confidences.append(0.0)
118
-
119
- final_text_gauss = " ".join(texts)
120
- avg_conf_gauss = sum(confidences)/len(confidences) if confidences else 0.0
121
-
122
- # Keep the version with higher confidence
123
- if avg_conf_gauss > avg_conf:
124
- final_text, avg_conf = final_text_gauss, avg_conf_gauss
125
-
126
- if not final_text:
127
- return "", "", 0.0
128
-
129
- preprocess_txt = preprocess_text(final_text)
130
- return final_text, preprocess_txt, avg_conf
131
-
132
-
133
- # =========================
134
- # 3. Image Preprocessing
135
- # =========================
136
- def resize_normalize_image(image_path, target_size=(224, 224)):
137
-
138
- preprocess_image = transforms.Compose([
139
- transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
140
- transforms.ToTensor(),
141
- transforms.Normalize(
142
- mean=[0.485, 0.456, 0.406],
143
- std=[0.229, 0.224, 0.225]
144
- )
145
- ])
146
-
147
- image_tensor = preprocess_image(image_path).unsqueeze(0) # Add batch dimension
148
- return image_tensor
149
-
150
-
151
- # =========================
152
- # 4. Model Definitions
153
- # =========================
154
- class CrossAttentionModule(nn.Module):
155
- def __init__(self, query_dim, key_value_dim, hidden_dim=256, num_heads=8, dropout=0.1):
156
- super(CrossAttentionModule, self).__init__()
157
-
158
- self.hidden_dim = hidden_dim
159
- self.num_heads = num_heads
160
- self.head_dim = hidden_dim // num_heads
161
- self.scale = math.sqrt(self.head_dim) # √dk
162
-
163
- assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
164
-
165
- # Query projection for H (image features)
166
- self.query_proj = nn.Linear(query_dim, hidden_dim)
167
-
168
- # Key and Value projections for G (text features)
169
- self.key_proj = nn.Linear(key_value_dim, hidden_dim)
170
- self.value_proj = nn.Linear(key_value_dim, hidden_dim)
171
-
172
- # Output projection WO
173
- self.out_proj = nn.Linear(hidden_dim, query_dim)
174
-
175
- # Layer normalization
176
- self.norm1 = nn.LayerNorm(query_dim)
177
- self.norm2 = nn.LayerNorm(query_dim)
178
-
179
- # MLP for final transformation
180
- self.mlp = nn.Sequential(
181
- nn.Linear(query_dim, query_dim * 4),
182
- nn.ReLU(),
183
- nn.Dropout(dropout),
184
- nn.Linear(query_dim * 4, query_dim),
185
- nn.Dropout(dropout)
186
- )
187
-
188
- self.dropout = nn.Dropout(dropout)
189
-
190
- def forward(self, H, G):
191
- """
192
- Args:
193
- H: Query features [batch_size, seq_len_h, query_dim] (e.g., image patches)
194
- G: Key/Value features [batch_size, seq_len_g, key_value_dim] (e.g., text tokens)
195
- """
196
- batch_size, seq_len_h, _ = H.shape
197
- seq_len_g = G.shape[1]
198
-
199
- # Step 1: Project to Q, K, V
200
- Q = self.query_proj(H) # WiQ H
201
- K = self.key_proj(G) # WiK G
202
- V = self.value_proj(G) # WiV G
203
-
204
- # Step 2: Reshape for multi-head attention
205
- Q = Q.view(batch_size, seq_len_h, self.num_heads, self.head_dim).transpose(1, 2)
206
- K = K.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
207
- V = V.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
208
-
209
- # Step 3: Compute attention ATTi(H,G) = softmax((WiQ H)T(WiK G)/√dk)(WiV G)T
210
- attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
211
- attention_weights = F.softmax(attention_scores, dim=-1)
212
- attention_weights = self.dropout(attention_weights)
213
- attention_output = torch.matmul(attention_weights, V)
214
-
215
- # Step 4: Concatenate heads and apply output projection
216
- attention_output = attention_output.transpose(1, 2).contiguous().view(
217
- batch_size, seq_len_h, self.hidden_dim
218
- )
219
-
220
- # MATT(H,G) = [ATT1...ATTh]WO
221
- matt_output = self.out_proj(attention_output)
222
-
223
- # Step 5: Z = LN(H + MATT(H,G))
224
- Z = self.norm1(H + matt_output)
225
-
226
- # Step 6: TIM(H,G) = LN(Z + MLP(Z))
227
- mlp_output = self.mlp(Z)
228
- tim_output = self.norm2(Z + mlp_output)
229
-
230
- return tim_output
231
-
232
- class MultimodalClassifier(nn.Module):
233
- def __init__(self, num_classes=2, model_name='jcblaise/roberta-tagalog-base'):
234
- super(MultimodalClassifier, self).__init__()
235
-
236
- # Image encoder (ResNet-18, keep spatial features)
237
- resnet = models.resnet18(pretrained=True)
238
- modules = list(resnet.children())[:-2] # keep until last conv (before avgpool)
239
- self.image_encoder = nn.Sequential(*modules) # output: (B, 512, 7, 7)
240
-
241
- # Text encoder
242
- self.text_encoder = AutoModel.from_pretrained(model_name)
243
-
244
- # Cross-attention using paper formula
245
- # Image attends to text
246
- self.img_to_text_attention = CrossAttentionModule(
247
- query_dim=512,
248
- key_value_dim=self.text_encoder.config.hidden_size,
249
- hidden_dim=256,
250
- num_heads=8
251
- )
252
-
253
- # Text attends to image
254
- self.text_to_img_attention = CrossAttentionModule(
255
- query_dim=self.text_encoder.config.hidden_size,
256
- key_value_dim=512,
257
- hidden_dim=256,
258
- num_heads=8
259
- )
260
-
261
- # Fusion & classifier
262
- self.fusion = nn.Sequential(
263
- nn.Linear(512 + self.text_encoder.config.hidden_size, 512),
264
- nn.ReLU(),
265
- nn.Dropout(0.3),
266
- nn.Linear(512, 128),
267
- nn.ReLU(),
268
- nn.Dropout(0.3),
269
- nn.Linear(128, num_classes)
270
- )
271
-
272
- def forward(self, images, input_ids, attention_mask):
273
- # Extract image features
274
- batch_size = images.size(0)
275
- img_feats = self.image_encoder(images) # (B, 512, 7, 7)
276
- img_feats = img_feats.flatten(2).permute(0, 2, 1) # (B, 49, 512)
277
-
278
- # Extract text features
279
- text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
280
- txt_feats = text_outputs.last_hidden_state # (B, seq_len, H)
281
-
282
- # Cross-attention following paper formula
283
- # TIM(img_feats, txt_feats) and TIM(txt_feats, img_feats)
284
- attended_img = self.img_to_text_attention(img_feats, txt_feats)
285
- attended_txt = self.text_to_img_attention(txt_feats, img_feats)
286
-
287
- # Pool attended outputs
288
- img_repr = attended_img.mean(dim=1) # (B, 512)
289
- txt_repr = attended_txt[:, 0, :] # CLS token (B, hidden_size)
290
-
291
- # Fusion
292
- fused = torch.cat([img_repr, txt_repr], dim=1)
293
- return self.fusion(fused)
294
-
295
- # =========================
296
- # 5. Load Model & Tokenizer
297
- # =========================
298
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
299
-
300
- model = MultimodalClassifier(num_classes=2)
301
- model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
302
- model.to(device)
303
- model.eval()
304
-
305
- tokenizer = AutoTokenizer.from_pretrained("jcblaise/roberta-tagalog-base")
306
-
307
- # =========================
308
- # 6. Inference Function
309
- # =========================
310
- def run_inference(image_path):
311
- # Convert bytes β†’ PIL image
312
- if isinstance(image_path, (bytes, bytearray)):
313
- pil_img = Image.open(io.BytesIO(image_path)).convert("RGB")
314
- elif isinstance(image_path, str):
315
- pil_img = Image.open(image_path).convert("RGB")
316
- elif isinstance(image_path, Image.Image):
317
- pil_img = image_path.convert("RGB")
318
- else:
319
- raise TypeError(f"Unsupported input type: {type(image_path)}")
320
-
321
- # OCR
322
- np_image= np.array(pil_img)
323
- raw_text, clean_text, confidence= ocr_extract_text(np_image)
324
-
325
- if clean_text == "":
326
- return {
327
- "error": "This is not a meme. Upload a valid meme image with text.",
328
- }
329
-
330
- elif len(clean_text.split()) < 3:
331
- return {
332
- "error": "Insufficient text detected in the meme. Please upload a meme with more text. Minimum is 3 words.",
333
- "clean_text": clean_text,
334
- "raw_text": raw_text,
335
- "confidence": confidence
336
- }
337
-
338
- # Image
339
- img_tensor = resize_normalize_image(pil_img).to(device)
340
-
341
- # Tokenize text
342
- encoding = tokenizer(
343
- clean_text, return_tensors='pt',
344
- padding=True, truncation=True, max_length=128
345
- )
346
- input_ids = encoding['input_ids'].to(device)
347
- attention_mask = encoding['attention_mask'].to(device)
348
-
349
-
350
- # Forward pass
351
- with torch.no_grad():
352
- logits = model(img_tensor, input_ids, attention_mask)
353
- probs = torch.softmax(logits, dim=1)
354
- pred_class = torch.argmax(probs, dim=1).item()
355
- pred_class = 'sexual' if pred_class == 1 else 'non-sexual'
356
-
357
- return {
358
- 'original_size': pil_img.size,
359
- 'prediction': pred_class,
360
- 'probabilities': probs.cpu().numpy().tolist(),
361
- 'raw_text': raw_text,
362
- 'clean_text': clean_text,
363
- 'confidence': confidence
364
- }
365
-
366
-
367
- # =========================
368
- # 7. Run as main
369
- # =========================
370
- # if __name__ == "__main__":
371
- # # Example: load image from path
372
- # IMAGE_PATH = "backend/OIP (1).jfif"
373
-
374
- # # test_dimension_sensitivity(IMAGE_PATH)
375
-
376
- # result = run_inference(IMAGE_PATH)
377
-
378
- # # Print results
379
- # print("Original Image Size:", result['original_size'])
380
- # print("Prediction:", result['prediction'])
381
- # print("Probabilities:", result['probabilities'])
382
- # print("Raw OCR Text:", result['raw_text'])
383
- # print("Clean OCR Text:", result['clean_text'])
384
- # print("OCR Confidence:", result['confidence'])
385
-
386
-
387
- # # Preprocess image
388
- # pil_img = Image.open(IMAGE_PATH).convert("RGB")
389
- # img_tensor = resize_normalize_image(pil_img).to(device)
390
-
391
- # # -----------------------------
392
- # # Generate ResNet heatmap
393
- # # -----------------------------
394
- # features = {}
395
- # def hook_fn(module, input, output):
396
- # features['value'] = output.detach()
397
-
398
- # last_conv = model.image_encoder[-1]
399
- # hook_handle = last_conv.register_forward_hook(hook_fn)
400
-
401
- # with torch.no_grad():
402
- # _ = model(img_tensor,
403
- # input_ids=torch.zeros(1,1, dtype=torch.long, device=device),
404
- # attention_mask=torch.ones(1,1, dtype=torch.long, device=device))
405
-
406
- # hook_handle.remove()
407
-
408
- # feat_tensor = features['value']
409
- # heatmap_grid = feat_tensor[0].mean(dim=0).cpu().numpy()
410
- # heatmap_grid = (heatmap_grid - heatmap_grid.min()) / (heatmap_grid.max() - heatmap_grid.min())
411
- # heatmap_resized = np.uint8(255 * heatmap_grid)
412
- # heatmap_resized = Image.fromarray(heatmap_resized).resize(pil_img.size, Image.NEAREST)
413
- # heatmap_resized = np.array(heatmap_resized)
414
-
415
- # probs = result['probabilities'][0]
416
- # prob_text = f"non-sexual: {probs[0]:.2f}, sexual: {probs[1]:.2f}"
417
-
418
- # # -----------------------------
419
- # # Plot everything in one figure
420
- # # -----------------------------
421
- # fig, ax = plt.subplots(figsize=(6,6))
422
-
423
- # ax.imshow(pil_img) # original image
424
- # ax.imshow(heatmap_resized, cmap='jet', alpha=0.4, interpolation='nearest') # overlay heatmap
425
- # ax.axis('off')
426
- # ax.set_title(f"{result['prediction']} ({prob_text})", fontsize=14, color='blue')
427
-
428
- # # Add colorbar
429
- # sm = plt.cm.ScalarMappable(cmap='jet', norm=plt.Normalize(vmin=0, vmax=1))
430
- # sm.set_array([])
431
- # cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
432
- # cbar.set_label('Feature Intensity')
433
-
434
  # plt.show()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import easyocr
10
+ import numpy as np
11
+ import re
12
+ import os
13
+ import io
14
+ import cv2
15
+
16
+
17
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ MODEL_PATH = os.path.join(BASE_DIR, "model", "best_multimodal_v4.pth")
19
+
20
+ # =========================
21
+ # 1. Text Preprocessing
22
+ # =========================
23
+ def preprocess_text(text):
24
+ emoji_pattern = re.compile(
25
+ "["
26
+ "\U0001F600-\U0001F64F" # emoticons
27
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
28
+ "\U0001F680-\U0001F6FF" # transport & map symbols
29
+ "\U0001F1E0-\U0001F1FF" # flags
30
+ "\U00002700-\U000027BF" # dingbats
31
+ "\U0001F900-\U0001F9FF" # supplemental symbols
32
+ "\U00002600-\U000026FF" # misc symbols
33
+ "\U00002B00-\U00002BFF" # arrows, etc.
34
+ "\U0001FA70-\U0001FAFF" # extended symbols
35
+ "]+",
36
+ flags=re.UNICODE
37
+ )
38
+ # Remove emojis
39
+ text = emoji_pattern.sub(r'', text)
40
+ # Lowercase and strip
41
+ text = text.lower().strip()
42
+ # Keep letters (including accented), and spaces
43
+ text = re.sub(r'[^a-zñÑéíóúü\s]', '', text)
44
+ # Normalize whitespace
45
+ text = re.sub(r'\s+', ' ', text)
46
+
47
+ return text
48
+
49
+ # =========================
50
+ # 2. OCR Extraction
51
+ # =========================
52
+ def ocr_extract_text(image_path, confidence_threshold=0.6):
53
+ reader = easyocr.Reader(['en', 'tl'], gpu=torch.cuda.is_available())
54
+ # # preprocess image for ocr
55
+ # image = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
56
+ # # image = cv2.GaussianBlur(image,(5,5),0)
57
+
58
+ # result = reader.readtext(image, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
59
+
60
+ # # Extract text and confidence scores
61
+ # texts = []
62
+ # confidences = []
63
+
64
+ # for detection in result:
65
+ # bbox, text, confidence = detection
66
+ # texts.append(text)
67
+ # confidences.append(confidence)
68
+ # final_text = " ".join(texts)
69
+ # preprocess_txt = preprocess_text(final_text)
70
+ # avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
71
+ # return final_text, preprocess_txt, avg_confidence
72
+
73
+ # Convert to grayscale
74
+ gray = cv2.cvtColor(image_path, cv2.COLOR_RGB2GRAY)
75
+
76
+ # First pass: OCR on raw grayscale
77
+ result = reader.readtext(gray, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
78
+ texts, confidences = [], []
79
+
80
+ for detection in result:
81
+ if len(detection) == 3:
82
+ _, text, conf = detection
83
+ else:
84
+ text, conf = detection
85
+
86
+ if isinstance(text, list):
87
+ text = " ".join([str(t) for t in text if isinstance(t, str)])
88
+
89
+ texts.append(text)
90
+ try:
91
+ confidences.append(float(conf))
92
+ except (ValueError, TypeError):
93
+ confidences.append(0.0)
94
+
95
+ final_text = " ".join(texts)
96
+ avg_conf = sum(confidences)/len(confidences) if confidences else 0.0
97
+
98
+ # If confidence is low, retry with Gaussian blur
99
+ if avg_conf < confidence_threshold:
100
+ texts, confidences = [], []
101
+ gauss_img = cv2.GaussianBlur(gray, (5,5), 0)
102
+ result = reader.readtext(gauss_img, detail=1, paragraph=False, width_ths=0.7, height_ths=0.7)
103
+
104
+ for detection in result:
105
+ if len(detection) == 3:
106
+ _, text, conf = detection
107
+ else:
108
+ text, conf = detection
109
+
110
+ if isinstance(text, list):
111
+ text = " ".join([str(t) for t in text if isinstance(t, str)])
112
+
113
+ texts.append(text)
114
+ try:
115
+ confidences.append(float(conf))
116
+ except (ValueError, TypeError):
117
+ confidences.append(0.0)
118
+
119
+ final_text_gauss = " ".join(texts)
120
+ avg_conf_gauss = sum(confidences)/len(confidences) if confidences else 0.0
121
+
122
+ # Keep the version with higher confidence
123
+ if avg_conf_gauss > avg_conf:
124
+ final_text, avg_conf = final_text_gauss, avg_conf_gauss
125
+
126
+ if not final_text:
127
+ return "", "", 0.0
128
+
129
+ preprocess_txt = preprocess_text(final_text)
130
+ return final_text, preprocess_txt, avg_conf
131
+
132
+
133
+ # =========================
134
+ # 3. Image Preprocessing
135
+ # =========================
136
+ def resize_normalize_image(image_path, target_size=(224, 224)):
137
+
138
+ preprocess_image = transforms.Compose([
139
+ transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR),
140
+ transforms.ToTensor(),
141
+ transforms.Normalize(
142
+ mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225]
144
+ )
145
+ ])
146
+
147
+ image_tensor = preprocess_image(image_path).unsqueeze(0) # Add batch dimension
148
+ return image_tensor
149
+
150
+
151
+ # =========================
152
+ # 4. Model Definitions
153
+ # =========================
154
+ class CrossAttentionModule(nn.Module):
155
+ def __init__(self, query_dim, key_value_dim, hidden_dim=256, num_heads=8, dropout=0.1):
156
+ super(CrossAttentionModule, self).__init__()
157
+
158
+ self.hidden_dim = hidden_dim
159
+ self.num_heads = num_heads
160
+ self.head_dim = hidden_dim // num_heads
161
+ self.scale = math.sqrt(self.head_dim) # √dk
162
+
163
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
164
+
165
+ # Query projection for H (image features)
166
+ self.query_proj = nn.Linear(query_dim, hidden_dim)
167
+
168
+ # Key and Value projections for G (text features)
169
+ self.key_proj = nn.Linear(key_value_dim, hidden_dim)
170
+ self.value_proj = nn.Linear(key_value_dim, hidden_dim)
171
+
172
+ # Output projection WO
173
+ self.out_proj = nn.Linear(hidden_dim, query_dim)
174
+
175
+ # Layer normalization
176
+ self.norm1 = nn.LayerNorm(query_dim)
177
+ self.norm2 = nn.LayerNorm(query_dim)
178
+
179
+ # MLP for final transformation
180
+ self.mlp = nn.Sequential(
181
+ nn.Linear(query_dim, query_dim * 4),
182
+ nn.ReLU(),
183
+ nn.Dropout(dropout),
184
+ nn.Linear(query_dim * 4, query_dim),
185
+ nn.Dropout(dropout)
186
+ )
187
+
188
+ self.dropout = nn.Dropout(dropout)
189
+
190
+ def forward(self, H, G):
191
+ """
192
+ Args:
193
+ H: Query features [batch_size, seq_len_h, query_dim] (e.g., image patches)
194
+ G: Key/Value features [batch_size, seq_len_g, key_value_dim] (e.g., text tokens)
195
+ """
196
+ batch_size, seq_len_h, _ = H.shape
197
+ seq_len_g = G.shape[1]
198
+
199
+ # Step 1: Project to Q, K, V
200
+ Q = self.query_proj(H) # WiQ H
201
+ K = self.key_proj(G) # WiK G
202
+ V = self.value_proj(G) # WiV G
203
+
204
+ # Step 2: Reshape for multi-head attention
205
+ Q = Q.view(batch_size, seq_len_h, self.num_heads, self.head_dim).transpose(1, 2)
206
+ K = K.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
207
+ V = V.view(batch_size, seq_len_g, self.num_heads, self.head_dim).transpose(1, 2)
208
+
209
+ # Step 3: Compute attention ATTi(H,G) = softmax((WiQ H)T(WiK G)/√dk)(WiV G)T
210
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
211
+ attention_weights = F.softmax(attention_scores, dim=-1)
212
+ attention_weights = self.dropout(attention_weights)
213
+ attention_output = torch.matmul(attention_weights, V)
214
+
215
+ # Step 4: Concatenate heads and apply output projection
216
+ attention_output = attention_output.transpose(1, 2).contiguous().view(
217
+ batch_size, seq_len_h, self.hidden_dim
218
+ )
219
+
220
+ # MATT(H,G) = [ATT1...ATTh]WO
221
+ matt_output = self.out_proj(attention_output)
222
+
223
+ # Step 5: Z = LN(H + MATT(H,G))
224
+ Z = self.norm1(H + matt_output)
225
+
226
+ # Step 6: TIM(H,G) = LN(Z + MLP(Z))
227
+ mlp_output = self.mlp(Z)
228
+ tim_output = self.norm2(Z + mlp_output)
229
+
230
+ return tim_output
231
+
232
+ class MultimodalClassifier(nn.Module):
233
+ def __init__(self, num_classes=2, model_name='jcblaise/roberta-tagalog-base'):
234
+ super(MultimodalClassifier, self).__init__()
235
+
236
+ # Image encoder (ResNet-18, keep spatial features)
237
+ resnet = models.resnet18(pretrained=True)
238
+ modules = list(resnet.children())[:-2] # keep until last conv (before avgpool)
239
+ self.image_encoder = nn.Sequential(*modules) # output: (B, 512, 7, 7)
240
+
241
+ # Text encoder
242
+ self.text_encoder = AutoModel.from_pretrained(model_name)
243
+
244
+ # Cross-attention using paper formula
245
+ # Image attends to text
246
+ self.img_to_text_attention = CrossAttentionModule(
247
+ query_dim=512,
248
+ key_value_dim=self.text_encoder.config.hidden_size,
249
+ hidden_dim=256,
250
+ num_heads=8
251
+ )
252
+
253
+ # Text attends to image
254
+ self.text_to_img_attention = CrossAttentionModule(
255
+ query_dim=self.text_encoder.config.hidden_size,
256
+ key_value_dim=512,
257
+ hidden_dim=256,
258
+ num_heads=8
259
+ )
260
+
261
+ # Fusion & classifier
262
+ self.fusion = nn.Sequential(
263
+ nn.Linear(512 + self.text_encoder.config.hidden_size, 512),
264
+ nn.ReLU(),
265
+ nn.Dropout(0.3),
266
+ nn.Linear(512, 128),
267
+ nn.ReLU(),
268
+ nn.Dropout(0.3),
269
+ nn.Linear(128, num_classes)
270
+ )
271
+
272
+ def forward(self, images, input_ids, attention_mask):
273
+ # Extract image features
274
+ batch_size = images.size(0)
275
+ img_feats = self.image_encoder(images) # (B, 512, 7, 7)
276
+ img_feats = img_feats.flatten(2).permute(0, 2, 1) # (B, 49, 512)
277
+
278
+ # Extract text features
279
+ text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
280
+ txt_feats = text_outputs.last_hidden_state # (B, seq_len, H)
281
+
282
+ # Cross-attention following paper formula
283
+ # TIM(img_feats, txt_feats) and TIM(txt_feats, img_feats)
284
+ attended_img = self.img_to_text_attention(img_feats, txt_feats)
285
+ attended_txt = self.text_to_img_attention(txt_feats, img_feats)
286
+
287
+ # Pool attended outputs
288
+ img_repr = attended_img.mean(dim=1) # (B, 512)
289
+ txt_repr = attended_txt[:, 0, :] # CLS token (B, hidden_size)
290
+
291
+ # Fusion
292
+ fused = torch.cat([img_repr, txt_repr], dim=1)
293
+ return self.fusion(fused)
294
+
295
+ # =========================
296
+ # 5. Load Model & Tokenizer
297
+ # =========================
298
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
299
+
300
+ model = MultimodalClassifier(num_classes=2)
301
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
302
+ model.to(device)
303
+ model.eval()
304
+
305
+ tokenizer = AutoTokenizer.from_pretrained("jcblaise/roberta-tagalog-base")
306
+
307
+ # =========================
308
+ # 6. Inference Function
309
+ # =========================
310
+ def run_inference(image_path):
311
+ # Convert bytes β†’ PIL image
312
+ if isinstance(image_path, (bytes, bytearray)):
313
+ pil_img = Image.open(io.BytesIO(image_path)).convert("RGB")
314
+ elif isinstance(image_path, str):
315
+ pil_img = Image.open(image_path).convert("RGB")
316
+ elif isinstance(image_path, Image.Image):
317
+ pil_img = image_path.convert("RGB")
318
+ else:
319
+ raise TypeError(f"Unsupported input type: {type(image_path)}")
320
+
321
+ # OCR
322
+ np_image= np.array(pil_img)
323
+ raw_text, clean_text, confidence= ocr_extract_text(np_image)
324
+
325
+ if clean_text == "":
326
+ return {
327
+ "error": "This is not a meme. Upload a valid meme image with text.",
328
+ }
329
+
330
+ elif len(clean_text.split()) < 3:
331
+ return {
332
+ "error": "Insufficient text detected in the meme. Please upload a meme with more text. Minimum is 3 words.",
333
+ "clean_text": clean_text,
334
+ "raw_text": raw_text,
335
+ "confidence": confidence
336
+ }
337
+
338
+ # Image
339
+ img_tensor = resize_normalize_image(pil_img).to(device)
340
+
341
+ # Tokenize text
342
+ encoding = tokenizer(
343
+ clean_text, return_tensors='pt',
344
+ padding=True, truncation=True, max_length=128
345
+ )
346
+ input_ids = encoding['input_ids'].to(device)
347
+ attention_mask = encoding['attention_mask'].to(device)
348
+
349
+
350
+ # Forward pass
351
+ with torch.no_grad():
352
+ logits = model(img_tensor, input_ids, attention_mask)
353
+ probs = torch.softmax(logits, dim=1)
354
+ pred_class = torch.argmax(probs, dim=1).item()
355
+ pred_class = 'sexual' if pred_class == 1 else 'non-sexual'
356
+
357
+ return {
358
+ 'original_size': pil_img.size,
359
+ 'prediction': pred_class,
360
+ 'probabilities': probs.cpu().numpy().tolist(),
361
+ 'raw_text': raw_text,
362
+ 'clean_text': clean_text,
363
+ 'confidence': confidence
364
+ }
365
+
366
+
367
+ # =========================
368
+ # 7. Run as main
369
+ # =========================
370
+ # if __name__ == "__main__":
371
+ # # Example: load image from path
372
+ # IMAGE_PATH = "backend/OIP (1).jfif"
373
+
374
+ # # test_dimension_sensitivity(IMAGE_PATH)
375
+
376
+ # result = run_inference(IMAGE_PATH)
377
+
378
+ # # Print results
379
+ # print("Original Image Size:", result['original_size'])
380
+ # print("Prediction:", result['prediction'])
381
+ # print("Probabilities:", result['probabilities'])
382
+ # print("Raw OCR Text:", result['raw_text'])
383
+ # print("Clean OCR Text:", result['clean_text'])
384
+ # print("OCR Confidence:", result['confidence'])
385
+
386
+
387
+ # # Preprocess image
388
+ # pil_img = Image.open(IMAGE_PATH).convert("RGB")
389
+ # img_tensor = resize_normalize_image(pil_img).to(device)
390
+
391
+ # # -----------------------------
392
+ # # Generate ResNet heatmap
393
+ # # -----------------------------
394
+ # features = {}
395
+ # def hook_fn(module, input, output):
396
+ # features['value'] = output.detach()
397
+
398
+ # last_conv = model.image_encoder[-1]
399
+ # hook_handle = last_conv.register_forward_hook(hook_fn)
400
+
401
+ # with torch.no_grad():
402
+ # _ = model(img_tensor,
403
+ # input_ids=torch.zeros(1,1, dtype=torch.long, device=device),
404
+ # attention_mask=torch.ones(1,1, dtype=torch.long, device=device))
405
+
406
+ # hook_handle.remove()
407
+
408
+ # feat_tensor = features['value']
409
+ # heatmap_grid = feat_tensor[0].mean(dim=0).cpu().numpy()
410
+ # heatmap_grid = (heatmap_grid - heatmap_grid.min()) / (heatmap_grid.max() - heatmap_grid.min())
411
+ # heatmap_resized = np.uint8(255 * heatmap_grid)
412
+ # heatmap_resized = Image.fromarray(heatmap_resized).resize(pil_img.size, Image.NEAREST)
413
+ # heatmap_resized = np.array(heatmap_resized)
414
+
415
+ # probs = result['probabilities'][0]
416
+ # prob_text = f"non-sexual: {probs[0]:.2f}, sexual: {probs[1]:.2f}"
417
+
418
+ # # -----------------------------
419
+ # # Plot everything in one figure
420
+ # # -----------------------------
421
+ # fig, ax = plt.subplots(figsize=(6,6))
422
+
423
+ # ax.imshow(pil_img) # original image
424
+ # ax.imshow(heatmap_resized, cmap='jet', alpha=0.4, interpolation='nearest') # overlay heatmap
425
+ # ax.axis('off')
426
+ # ax.set_title(f"{result['prediction']} ({prob_text})", fontsize=14, color='blue')
427
+
428
+ # # Add colorbar
429
+ # sm = plt.cm.ScalarMappable(cmap='jet', norm=plt.Normalize(vmin=0, vmax=1))
430
+ # sm.set_array([])
431
+ # cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
432
+ # cbar.set_label('Feature Intensity')
433
+
434
  # plt.show()