File size: 12,528 Bytes
1602deb
05deaf1
 
 
 
1602deb
05deaf1
1602deb
dc3fba1
 
 
77085dc
1ea690b
 
05deaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b036e8
 
 
 
 
 
05deaf1
 
 
 
 
182177b
 
05deaf1
182177b
 
05deaf1
182177b
 
05deaf1
182177b
 
05deaf1
 
 
 
 
 
 
 
 
 
182177b
05deaf1
182177b
05deaf1
 
182177b
 
05deaf1
182177b
 
 
05deaf1
 
182177b
 
 
05deaf1
 
 
 
 
182177b
05deaf1
182177b
05deaf1
 
182177b
05deaf1
 
 
e0a7220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05deaf1
 
 
1602deb
 
77085dc
1ea690b
b028259
1602deb
05deaf1
 
 
 
 
 
 
dc3fba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05deaf1
 
 
 
 
 
 
 
 
182177b
 
05deaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
e0a7220
 
 
 
 
 
182177b
05deaf1
 
 
 
 
 
a966ebf
 
05deaf1
 
 
 
 
 
 
 
 
 
 
 
a966ebf
05deaf1
 
a966ebf
 
 
 
05deaf1
 
a966ebf
05deaf1
 
a966ebf
 
 
05deaf1
 
a966ebf
05deaf1
 
a966ebf
05deaf1
a966ebf
 
05deaf1
a966ebf
05deaf1
 
a966ebf
af19614
05deaf1
 
 
 
 
 
 
 
 
7b036e8
 
 
 
 
05deaf1
 
 
 
7b036e8
05deaf1
1602deb
 
 
 
dc3fba1
 
 
 
 
 
 
 
 
77085dc
dc3fba1
 
 
1602deb
77085dc
 
b028259
 
 
 
 
 
 
77085dc
 
 
dc3fba1
77085dc
 
dc3fba1
77085dc
 
 
 
 
 
 
 
 
dc3fba1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
from fastapi import FastAPI, Body, UploadFile, File
import torch
import os
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from pydantic import BaseModel
import tempfile
import hashlib
import json
from typing import Optional
import httpx  # Add this import for HTTP requests
from dotenv import load_dotenv
load_dotenv()

# Define input model


class TextInput(BaseModel):
    text: str


# Initialize FastAPI
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    # Vous pouvez restreindre ceci à votre frontend spécifique
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Get base directory
base_dir = Path(__file__).parent.absolute()

# Your Hugging Face Hub username
HF_USERNAME = "YassineJedidi"  # Replace with your actual username

# Définition des entités valides pour chaque type
entites_valides = {
    "Tâche": {"TITRE", "DELAI", "PRIORITE"},
    "Événement": {"TITRE", "DATE_HEURE"},
}

# Try to load models from Hugging Face Hub
try:
    print("Loading models from Hugging Face Hub")

    # Model repositories on Hugging Face
    ner_model_repo = f"{HF_USERNAME}/plangenieai-ner"
    type_model_repo = f"{HF_USERNAME}/plangenieai-type"

    print(f"Loading NER model (and tokenizer) from: {ner_model_repo}")
    print(f"Loading type model (and tokenizer) from: {type_model_repo}")

    # Load NER model and tokenizer from the same repo
    ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_repo)
    ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_repo)
    # Load type model and tokenizer from the same repo
    type_tokenizer = AutoTokenizer.from_pretrained(type_model_repo)
    type_model = AutoModelForSequenceClassification.from_pretrained(
        type_model_repo)

except Exception as e:
    print(f"Error loading models from Hugging Face Hub: {e}")

    # Fallback to local files if available
    try:
        # Convert paths to strings with forward slashes
        ner_model_path = str(base_dir / "models" /
                             "plangenieai-ner").replace("\\", "/")
        type_model_path = str(base_dir / "models" /
                              "plangenieai-type").replace("\\", "/")

        print(f"Falling back to local models")
        print(f"Loading NER model (and tokenizer) from: {ner_model_path}")
        print(f"Loading type model (and tokenizer) from: {type_model_path}")

        # Load NER model and tokenizer from local files
        ner_tokenizer = AutoTokenizer.from_pretrained(
            ner_model_path, local_files_only=True)
        ner_model = AutoModelForTokenClassification.from_pretrained(
            ner_model_path, local_files_only=True)
        # Load type model and tokenizer from local files
        type_tokenizer = AutoTokenizer.from_pretrained(
            type_model_path, local_files_only=True)
        type_model = AutoModelForSequenceClassification.from_pretrained(
            type_model_path, local_files_only=True)

    except Exception as e:
        print(f"Error loading local models: {e}")
        # Fallback to base CamemBERT model from HuggingFace Hub
        print("Falling back to base CamemBERT model from HuggingFace Hub")
        ner_tokenizer = AutoTokenizer.from_pretrained("camembert-base")
        ner_model = AutoModelForTokenClassification.from_pretrained(
            "camembert-base")
        type_tokenizer = AutoTokenizer.from_pretrained("camembert-base")
        type_model = AutoModelForSequenceClassification.from_pretrained(
            "camembert-base")

# Helper functions for tokenization


def clean_text(text):
    if isinstance(text, str):
        return text.strip()
    return ""


def find_all_occurrences(text, substring):
    start_positions = []
    start = 0
    if not substring or not isinstance(substring, str):
        return start_positions
    text_lower = text.lower()
    substring_lower = substring.lower()
    while True:
        start = text_lower.find(substring_lower, start)
        if start == -1:
            break
        is_beginning = start == 0 or not text_lower[start-1].isalnum()
        is_ending = (start + len(substring_lower) == len(text_lower) or
                     not text_lower[start + len(substring_lower)].isalnum())
        if is_beginning and is_ending:
            original_substring = text[start:start + len(substring_lower)]
            start_positions.append(
                (start, start + len(substring_lower), original_substring))
        start += 1
    return start_positions


def tokenize_text_with_positions(text, tokenizer):
    """Tokenize text and return tokens with their positions"""
    # Use CamemBERT tokenizer
    tokens = tokenizer.tokenize(text)

    # Clean tokens and get positions
    clean_tokens = []
    token_positions = []
    current_pos = 0

    for token in tokens:
        # Clean the token (remove special characters from tokenizer)
        clean_token = token.replace('▁', '').replace('##', '')
        clean_tokens.append(clean_token)

        if clean_token:
            pos = text.find(clean_token, current_pos)
            if pos != -1:
                token_positions.append((pos, pos + len(clean_token)))
                current_pos = pos + len(clean_token)
            else:
                token_positions.append(
                    (current_pos, current_pos + len(clean_token)))
                current_pos += len(clean_token)
        else:
            token_positions.append((current_pos, current_pos))

    return clean_tokens, token_positions


# Set device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Add Groq API key and URL
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GROQ_API_URL = "https://api.groq.com/openai/v1/audio/transcriptions"

ner_model = ner_model.to(device)
type_model = type_model.to(device)

# Retrieve label mappings
id2label = ner_model.config.id2label
id2type = type_model.config.id2label

# Cache directory for transcriptions
CACHE_DIR = Path("transcription_cache")
CACHE_DIR.mkdir(exist_ok=True)


def get_cache_path(audio_data: bytes) -> Path:
    """Generate a cache file path based on the audio content hash."""
    hash_md5 = hashlib.md5(audio_data).hexdigest()
    return CACHE_DIR / f"{hash_md5}.json"


def get_cached_transcription(audio_data: bytes) -> Optional[str]:
    """Get cached transcription if it exists."""
    cache_path = get_cache_path(audio_data)
    if cache_path.exists():
        try:
            with open(cache_path, 'r') as f:
                return json.load(f)['transcription']
        except Exception:
            return None
    return None


def save_transcription_to_cache(audio_data: bytes, transcription: str):
    """Save transcription to cache."""
    cache_path = get_cache_path(audio_data)
    try:
        with open(cache_path, 'w') as f:
            json.dump({'transcription': transcription}, f)
    except Exception:
        pass  # Silently fail if cache write fails


@app.get("/")
def root():
    return {"message": "FastAPI NLP Model is running!"}


@app.post("/predict-type/")
async def predict_type(input_data: TextInput):
    text = input_data.text
    inputs = type_tokenizer(text, return_tensors="pt",
                            truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = type_model(**inputs)

    predicted_class_id = outputs.logits.argmax().item()
    predicted_type = id2type[predicted_class_id]
    confidence = torch.softmax(outputs.logits, dim=1).max().item()

    return {"type": predicted_type, "confidence": confidence}


@app.post("/extract-entities/")
async def extract_entities(input_data: TextInput):
    text = input_data.text

    # Use the model's tokenizer for tokenization
    clean_tokens, token_positions = tokenize_text_with_positions(
        text, ner_tokenizer)

    # Tokenize for NER prediction
    inputs = ner_tokenizer(clean_tokens, is_split_into_words=True,
                           return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = ner_model(**inputs)

    predictions = outputs.logits.argmax(dim=2)
    entities = {}
    current_entity = None
    current_start = None
    current_end = None

    word_ids = inputs.word_ids(0)
    for idx, word_idx in enumerate(word_ids):
        if word_idx is None:
            continue
        if idx > 0 and word_ids[idx-1] == word_idx:
            continue

        prediction = predictions[0, idx].item()
        predicted_label = id2label[prediction]

        if predicted_label.startswith("B-"):
            if current_entity is not None:
                entity_type = current_entity[2:]
                if entity_type not in entities:
                    entities[entity_type] = [text[current_start:current_end]]
                current_entity = None
                current_start = None
                current_end = None

            current_entity = predicted_label
            current_start, current_end = token_positions[word_idx]

        elif predicted_label.startswith("I-") and current_entity and predicted_label[2:] == current_entity[2:]:
            # Extend the end position to include this token
            _, token_end = token_positions[word_idx]
            current_end = token_end

        else:
            if current_entity is not None:
                entity_type = current_entity[2:]
                if entity_type not in entities:
                    entities[entity_type] = [text[current_start:current_end]]
                current_entity = None
                current_start = None
                current_end = None

    if current_entity is not None:
        entity_type = current_entity[2:]
        if entity_type not in entities:
            entities[entity_type] = [text[current_start:current_end]]
        # Only keep the first detection, do nothing if already present

    return {"entities": entities}


@app.post("/analyze-text/")
async def analyze_text(input_data: TextInput):
    type_result = await predict_type(input_data)
    text_type = type_result["type"]
    confidence = type_result["confidence"]
    raw_entities = (await extract_entities(input_data))["entities"]

    # Filtrage des entités selon le type détecté
    allowed = entites_valides.get(text_type, set())
    filtered_entities = {k: v for k, v in raw_entities.items() if k in allowed}

    return {
        "type": text_type,
        "confidence": confidence,
        "entities": filtered_entities
    }


@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    try:
        # Read the file content
        audio_data = await file.read()

        # Check cache first
        cached_transcription = get_cached_transcription(audio_data)
        if cached_transcription:
            return {"transcription": cached_transcription, "cached": True}

        # Save audio to a temporary file (Groq expects multipart/form-data)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            tmp.write(audio_data)
            tmp_path = tmp.name

        # Prepare request to Groq API
        headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
        data = {
            "model": "whisper-large-v3-turbo",
            "response_format": "json"
        }
        files = {
            "file": (os.path.basename(tmp_path), open(tmp_path, "rb"), "audio/wav")
        }

        async with httpx.AsyncClient() as client:
            response = await client.post(GROQ_API_URL, headers=headers, data=data, files=files, timeout=60)

        # Clean up temp file
        os.remove(tmp_path)

        if response.status_code == 200:
            result = response.json()
            transcription = result.get("text", "")
            # Save to cache
            save_transcription_to_cache(audio_data, transcription)
            return {"transcription": transcription, "cached": False}
        else:
            print(f"Groq API error: {response.status_code} {response.text}")
            return {"error": "Transcription failed", "details": response.text}
    except Exception as e:
        print(f"Transcription error: {str(e)}")
        return {"error": "Transcription failed", "details": str(e)}