|
|
import os |
|
|
import re |
|
|
from fastapi import FastAPI, Request |
|
|
from pydantic import BaseModel |
|
|
from inference_onnx import get_transcription |
|
|
import torch |
|
|
import onnxruntime as ort |
|
|
from config import * |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
print("🔧 Loading model...") |
|
|
|
|
|
app.state.device = torch.device('cpu') |
|
|
app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased") |
|
|
app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3] |
|
|
|
|
|
onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx" |
|
|
providers = ['CPUExecutionProvider'] |
|
|
|
|
|
|
|
|
sess_options = ort.SessionOptions() |
|
|
app.state.session = ort.InferenceSession(onnx_model_path, providers=providers) |
|
|
|
|
|
print("✅ ONNX model loaded into memory.") |
|
|
yield |
|
|
print("🧹 Shutting down...") |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
punc_dict = { |
|
|
'!': 'EXCLAMATION', |
|
|
'?': 'QUESTION', |
|
|
',': 'COMMA', |
|
|
';': 'SEMICOLON', |
|
|
':': 'COLON', |
|
|
'-': 'HYPHEN', |
|
|
'।': 'DARI', |
|
|
} |
|
|
allowed_punctuations = set(punc_dict.keys()) |
|
|
|
|
|
def clean_and_normalize_text(text, remove_punctuations=False): |
|
|
"""Clean and normalize Bangla text with correct spacing""" |
|
|
if remove_punctuations: |
|
|
|
|
|
cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) |
|
|
|
|
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() |
|
|
return cleaned_text |
|
|
else: |
|
|
|
|
|
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
|
|
filtered_chunks = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
if chunk in allowed_punctuations: |
|
|
filtered_chunks.append(chunk) |
|
|
else: |
|
|
|
|
|
clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) |
|
|
clean_chunk = re.sub(r'\s+', ' ', clean_chunk) |
|
|
clean_chunk = clean_chunk.strip() |
|
|
if clean_chunk: |
|
|
filtered_chunks.append(' ' + clean_chunk) |
|
|
|
|
|
|
|
|
result = ''.join(filtered_chunks) |
|
|
result = re.sub(r'\s+', ' ', result).strip() |
|
|
return result |
|
|
|
|
|
class TextInput(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.post("/punctuate") |
|
|
async def punctuate_text(data: TextInput): |
|
|
input_normalized = clean_and_normalize_text(data.text) |
|
|
input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True) |
|
|
restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style) |
|
|
return {"restored_text": restored_text} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1) |
|
|
|