File size: 3,160 Bytes
f81cfe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import re
from fastapi import FastAPI, Request
from pydantic import BaseModel
from inference_onnx import get_transcription
import torch
import onnxruntime as ort
from config import *
from contextlib import asynccontextmanager

# Global session object (attached to app.state)
@asynccontextmanager
async def lifespan(app: FastAPI):
    print("🔧 Loading model...")

    app.state.device = torch.device('cpu')
    app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased")
    app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3]

    onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx"
    providers = ['CPUExecutionProvider']
    # providers = ["CUDAExecutionProvider"]
    # providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
    sess_options = ort.SessionOptions()
    app.state.session = ort.InferenceSession(onnx_model_path, providers=providers)

    print("✅ ONNX model loaded into memory.")
    yield
    print("🧹 Shutting down...")

app = FastAPI(lifespan=lifespan)

punc_dict = {
    '!': 'EXCLAMATION',
    '?': 'QUESTION',
    ',': 'COMMA',
    ';': 'SEMICOLON',
    ':': 'COLON',
    '-': 'HYPHEN',
    '।': 'DARI',
}
allowed_punctuations = set(punc_dict.keys())

def clean_and_normalize_text(text, remove_punctuations=False):
    """Clean and normalize Bangla text with correct spacing"""
    if remove_punctuations:
        # Remove all allowed punctuations
        cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text)
        # Normalize spaces
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
        return cleaned_text
    else:
        # Keep only allowed punctuations and Bangla letters/digits
        chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
        filtered_chunks = []

        for chunk in chunks:
            if chunk in allowed_punctuations:
                filtered_chunks.append(chunk)
            else:
                # Clean text and preserve word boundaries
                clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk)
                clean_chunk = re.sub(r'\s+', ' ', clean_chunk)  # Normalize internal spacing
                clean_chunk = clean_chunk.strip()
                if clean_chunk:
                    filtered_chunks.append(' ' + clean_chunk)  # Add space before word chunks

        # Join and clean up spacing
        result = ''.join(filtered_chunks)
        result = re.sub(r'\s+', ' ', result).strip()
        return result

class TextInput(BaseModel):
    text: str

@app.post("/punctuate")
async def punctuate_text(data: TextInput):
    input_normalized = clean_and_normalize_text(data.text)
    input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True)
    restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style)
    return {"restored_text": restored_text}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1)