vtdung23 commited on
Commit
dced78a
·
1 Parent(s): 92db116

Fix: Use keyword-based explanation instead of gradient (avoid CUDA errors)

Browse files
Files changed (1) hide show
  1. app/services/ml_service.py +31 -45
app/services/ml_service.py CHANGED
@@ -249,7 +249,7 @@ class MLPredictionService:
249
  def predict_with_explanation(self, text: str) -> Dict[str, Any]:
250
  """
251
  Predict rating with explanation (word importance scores)
252
- Uses gradient-based attribution for interpretability
253
  """
254
  # Lazy load model on first request
255
  self._load_model()
@@ -259,7 +259,6 @@ class MLPredictionService:
259
 
260
  # 1. Vietnamese preprocessing
261
  processed_text = self.preprocess(text)
262
- words = processed_text.split()
263
 
264
  # 2. Tokenize
265
  encoded = self.tokenizer(
@@ -273,64 +272,51 @@ class MLPredictionService:
273
  # Move tensors to device
274
  encoded = {k: v.to(self.device) for k, v in encoded.items()}
275
 
276
- # 3. Get embeddings and enable gradient computation
277
- embeddings = self.model.roberta.embeddings(encoded['input_ids'])
278
- embeddings.requires_grad_(True)
279
-
280
- # 4. Forward pass with embeddings
281
- with torch.enable_grad():
282
- outputs = self.model.roberta.encoder(embeddings)
283
- sequence_output = outputs.last_hidden_state
284
- logits = self.model.classifier(sequence_output)
285
  probs = F.softmax(logits, dim=1)
286
 
287
  # Get predicted class
288
  predicted_class = torch.argmax(probs, dim=1).item()
289
  confidence = probs[0][predicted_class].item()
290
-
291
- # Compute gradient for the predicted class
292
- target_score = probs[0][predicted_class]
293
- target_score.backward()
294
-
295
- # Get gradient-based importance
296
- gradients = embeddings.grad
297
- importance = gradients.abs().sum(dim=-1).squeeze()
298
 
299
- # 5. Map importance to words (simplified - using tokenizer alignment)
300
  tokens = self.tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
301
- token_importance = importance.detach().cpu().numpy()
302
 
303
- # Normalize importance scores
304
- if token_importance.max() > 0:
305
- token_importance = token_importance / token_importance.max()
306
-
307
- # Map to original words (simplified approach)
308
  word_importance = []
309
  for i, token in enumerate(tokens):
310
- if token not in ['<s>', '</s>', '<pad>']:
311
- # Assign positive/negative based on keyword analysis
312
- keyword_analysis = self.keyword_analyzer.analyze(token)
313
- base_score = float(token_importance[i])
 
 
 
 
 
 
 
314
 
315
- if keyword_analysis['positive_count'] > 0:
316
- word_importance.append({
317
- 'word': token,
318
- 'score': base_score # Positive contribution
319
- })
320
- elif keyword_analysis['negative_count'] > 0:
321
- word_importance.append({
322
- 'word': token,
323
- 'score': -base_score # Negative contribution
324
- })
325
  else:
326
- word_importance.append({
327
- 'word': token,
328
- 'score': base_score * (0.5 if predicted_class >= 2 else -0.5)
329
- })
 
 
 
330
 
331
  rating = predicted_class + 1
332
 
333
- # Also get keyword analysis for the full text
334
  keyword_analysis = self.keyword_analyzer.analyze(text)
335
 
336
  return {
 
249
  def predict_with_explanation(self, text: str) -> Dict[str, Any]:
250
  """
251
  Predict rating with explanation (word importance scores)
252
+ Uses keyword-based importance for interpretability (safer than gradients)
253
  """
254
  # Lazy load model on first request
255
  self._load_model()
 
259
 
260
  # 1. Vietnamese preprocessing
261
  processed_text = self.preprocess(text)
 
262
 
263
  # 2. Tokenize
264
  encoded = self.tokenizer(
 
272
  # Move tensors to device
273
  encoded = {k: v.to(self.device) for k, v in encoded.items()}
274
 
275
+ # 3. Standard inference (no gradients needed)
276
+ with torch.no_grad():
277
+ outputs = self.model(**encoded)
278
+ logits = outputs.logits
 
 
 
 
 
279
  probs = F.softmax(logits, dim=1)
280
 
281
  # Get predicted class
282
  predicted_class = torch.argmax(probs, dim=1).item()
283
  confidence = probs[0][predicted_class].item()
 
 
 
 
 
 
 
 
284
 
285
+ # 4. Keyword-based importance (more reliable than gradient-based)
286
  tokens = self.tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
 
287
 
288
+ # Calculate importance based on keyword presence and position
 
 
 
 
289
  word_importance = []
290
  for i, token in enumerate(tokens):
291
+ if token not in ['<s>', '</s>', '<pad>', '<unk>']:
292
+ # Clean token (remove BPE markers)
293
+ clean_token = token.replace('@@', '').replace('▁', '').strip()
294
+ if not clean_token:
295
+ continue
296
+
297
+ # Check if token is a keyword
298
+ is_positive = any(kw in clean_token.lower() or clean_token.lower() in kw
299
+ for kw in self.keyword_analyzer.positive_words)
300
+ is_negative = any(kw in clean_token.lower() or clean_token.lower() in kw
301
+ for kw in self.keyword_analyzer.negative_words)
302
 
303
+ # Assign importance score
304
+ if is_positive:
305
+ score = 0.8 + (0.2 * (1 - i / len(tokens))) # Decay by position
306
+ elif is_negative:
307
+ score = -(0.8 + (0.2 * (1 - i / len(tokens))))
 
 
 
 
 
308
  else:
309
+ # Neutral words get small score based on prediction
310
+ score = 0.2 if predicted_class >= 2 else -0.2
311
+
312
+ word_importance.append({
313
+ 'word': clean_token,
314
+ 'score': round(score, 3)
315
+ })
316
 
317
  rating = predicted_class + 1
318
 
319
+ # Get keyword analysis for the full text
320
  keyword_analysis = self.keyword_analyzer.analyze(text)
321
 
322
  return {