Spaces:
Sleeping
Sleeping
Fix: Use keyword-based explanation instead of gradient (avoid CUDA errors)
Browse files- app/services/ml_service.py +31 -45
app/services/ml_service.py
CHANGED
|
@@ -249,7 +249,7 @@ class MLPredictionService:
|
|
| 249 |
def predict_with_explanation(self, text: str) -> Dict[str, Any]:
|
| 250 |
"""
|
| 251 |
Predict rating with explanation (word importance scores)
|
| 252 |
-
Uses
|
| 253 |
"""
|
| 254 |
# Lazy load model on first request
|
| 255 |
self._load_model()
|
|
@@ -259,7 +259,6 @@ class MLPredictionService:
|
|
| 259 |
|
| 260 |
# 1. Vietnamese preprocessing
|
| 261 |
processed_text = self.preprocess(text)
|
| 262 |
-
words = processed_text.split()
|
| 263 |
|
| 264 |
# 2. Tokenize
|
| 265 |
encoded = self.tokenizer(
|
|
@@ -273,64 +272,51 @@ class MLPredictionService:
|
|
| 273 |
# Move tensors to device
|
| 274 |
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
| 275 |
|
| 276 |
-
# 3.
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
# 4. Forward pass with embeddings
|
| 281 |
-
with torch.enable_grad():
|
| 282 |
-
outputs = self.model.roberta.encoder(embeddings)
|
| 283 |
-
sequence_output = outputs.last_hidden_state
|
| 284 |
-
logits = self.model.classifier(sequence_output)
|
| 285 |
probs = F.softmax(logits, dim=1)
|
| 286 |
|
| 287 |
# Get predicted class
|
| 288 |
predicted_class = torch.argmax(probs, dim=1).item()
|
| 289 |
confidence = probs[0][predicted_class].item()
|
| 290 |
-
|
| 291 |
-
# Compute gradient for the predicted class
|
| 292 |
-
target_score = probs[0][predicted_class]
|
| 293 |
-
target_score.backward()
|
| 294 |
-
|
| 295 |
-
# Get gradient-based importance
|
| 296 |
-
gradients = embeddings.grad
|
| 297 |
-
importance = gradients.abs().sum(dim=-1).squeeze()
|
| 298 |
|
| 299 |
-
#
|
| 300 |
tokens = self.tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
| 301 |
-
token_importance = importance.detach().cpu().numpy()
|
| 302 |
|
| 303 |
-
#
|
| 304 |
-
if token_importance.max() > 0:
|
| 305 |
-
token_importance = token_importance / token_importance.max()
|
| 306 |
-
|
| 307 |
-
# Map to original words (simplified approach)
|
| 308 |
word_importance = []
|
| 309 |
for i, token in enumerate(tokens):
|
| 310 |
-
if token not in ['<s>', '</s>', '<pad>']:
|
| 311 |
-
#
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
elif keyword_analysis['negative_count'] > 0:
|
| 321 |
-
word_importance.append({
|
| 322 |
-
'word': token,
|
| 323 |
-
'score': -base_score # Negative contribution
|
| 324 |
-
})
|
| 325 |
else:
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
rating = predicted_class + 1
|
| 332 |
|
| 333 |
-
#
|
| 334 |
keyword_analysis = self.keyword_analyzer.analyze(text)
|
| 335 |
|
| 336 |
return {
|
|
|
|
| 249 |
def predict_with_explanation(self, text: str) -> Dict[str, Any]:
|
| 250 |
"""
|
| 251 |
Predict rating with explanation (word importance scores)
|
| 252 |
+
Uses keyword-based importance for interpretability (safer than gradients)
|
| 253 |
"""
|
| 254 |
# Lazy load model on first request
|
| 255 |
self._load_model()
|
|
|
|
| 259 |
|
| 260 |
# 1. Vietnamese preprocessing
|
| 261 |
processed_text = self.preprocess(text)
|
|
|
|
| 262 |
|
| 263 |
# 2. Tokenize
|
| 264 |
encoded = self.tokenizer(
|
|
|
|
| 272 |
# Move tensors to device
|
| 273 |
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
| 274 |
|
| 275 |
+
# 3. Standard inference (no gradients needed)
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
outputs = self.model(**encoded)
|
| 278 |
+
logits = outputs.logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
probs = F.softmax(logits, dim=1)
|
| 280 |
|
| 281 |
# Get predicted class
|
| 282 |
predicted_class = torch.argmax(probs, dim=1).item()
|
| 283 |
confidence = probs[0][predicted_class].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
# 4. Keyword-based importance (more reliable than gradient-based)
|
| 286 |
tokens = self.tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
|
|
|
| 287 |
|
| 288 |
+
# Calculate importance based on keyword presence and position
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
word_importance = []
|
| 290 |
for i, token in enumerate(tokens):
|
| 291 |
+
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
|
| 292 |
+
# Clean token (remove BPE markers)
|
| 293 |
+
clean_token = token.replace('@@', '').replace('▁', '').strip()
|
| 294 |
+
if not clean_token:
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
# Check if token is a keyword
|
| 298 |
+
is_positive = any(kw in clean_token.lower() or clean_token.lower() in kw
|
| 299 |
+
for kw in self.keyword_analyzer.positive_words)
|
| 300 |
+
is_negative = any(kw in clean_token.lower() or clean_token.lower() in kw
|
| 301 |
+
for kw in self.keyword_analyzer.negative_words)
|
| 302 |
|
| 303 |
+
# Assign importance score
|
| 304 |
+
if is_positive:
|
| 305 |
+
score = 0.8 + (0.2 * (1 - i / len(tokens))) # Decay by position
|
| 306 |
+
elif is_negative:
|
| 307 |
+
score = -(0.8 + (0.2 * (1 - i / len(tokens))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
else:
|
| 309 |
+
# Neutral words get small score based on prediction
|
| 310 |
+
score = 0.2 if predicted_class >= 2 else -0.2
|
| 311 |
+
|
| 312 |
+
word_importance.append({
|
| 313 |
+
'word': clean_token,
|
| 314 |
+
'score': round(score, 3)
|
| 315 |
+
})
|
| 316 |
|
| 317 |
rating = predicted_class + 1
|
| 318 |
|
| 319 |
+
# Get keyword analysis for the full text
|
| 320 |
keyword_analysis = self.keyword_analyzer.analyze(text)
|
| 321 |
|
| 322 |
return {
|