abdullahalmunem's picture
model added
f81cfe2
import os
import re
from fastapi import FastAPI, Request
from pydantic import BaseModel
from inference_onnx import get_transcription
import torch
import onnxruntime as ort
from config import *
from contextlib import asynccontextmanager
# Global session object (attached to app.state)
@asynccontextmanager
async def lifespan(app: FastAPI):
print("🔧 Loading model...")
app.state.device = torch.device('cpu')
app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased")
app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3]
onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx"
providers = ['CPUExecutionProvider']
# providers = ["CUDAExecutionProvider"]
# providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
sess_options = ort.SessionOptions()
app.state.session = ort.InferenceSession(onnx_model_path, providers=providers)
print("✅ ONNX model loaded into memory.")
yield
print("🧹 Shutting down...")
app = FastAPI(lifespan=lifespan)
punc_dict = {
'!': 'EXCLAMATION',
'?': 'QUESTION',
',': 'COMMA',
';': 'SEMICOLON',
':': 'COLON',
'-': 'HYPHEN',
'।': 'DARI',
}
allowed_punctuations = set(punc_dict.keys())
def clean_and_normalize_text(text, remove_punctuations=False):
"""Clean and normalize Bangla text with correct spacing"""
if remove_punctuations:
# Remove all allowed punctuations
cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text)
# Normalize spaces
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
return cleaned_text
else:
# Keep only allowed punctuations and Bangla letters/digits
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
filtered_chunks = []
for chunk in chunks:
if chunk in allowed_punctuations:
filtered_chunks.append(chunk)
else:
# Clean text and preserve word boundaries
clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk)
clean_chunk = re.sub(r'\s+', ' ', clean_chunk) # Normalize internal spacing
clean_chunk = clean_chunk.strip()
if clean_chunk:
filtered_chunks.append(' ' + clean_chunk) # Add space before word chunks
# Join and clean up spacing
result = ''.join(filtered_chunks)
result = re.sub(r'\s+', ' ', result).strip()
return result
class TextInput(BaseModel):
text: str
@app.post("/punctuate")
async def punctuate_text(data: TextInput):
input_normalized = clean_and_normalize_text(data.text)
input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True)
restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style)
return {"restored_text": restored_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1)