File size: 9,212 Bytes
b197fed
860b6bb
3f878f1
860b6bb
98a6067
 
 
 
 
b197fed
 
 
 
 
6dbaaa5
 
517d88f
 
 
6dbaaa5
 
e8ee9a7
 
5ed1241
 
 
 
e8ee9a7
 
 
 
517d88f
e8ee9a7
 
98a6067
6dbaaa5
517d88f
6dbaaa5
d1326a6
6dbaaa5
 
 
bde2566
6dbaaa5
98a6067
517d88f
b197fed
 
 
601fcc2
517d88f
 
 
98a6067
873f3d3
d1326a6
860b6bb
 
b197fed
 
 
3f878f1
 
d1326a6
9fe1d46
d1326a6
517d88f
3f878f1
6dbaaa5
3f878f1
b197fed
873f3d3
1ba0470
860b6bb
 
3f878f1
 
 
 
6dbaaa5
bde2566
6dbaaa5
 
 
bde2566
6dbaaa5
 
 
 
 
 
473c4c8
517d88f
 
 
6dbaaa5
873f3d3
 
6dbaaa5
 
 
 
 
1ba0470
 
6dbaaa5
 
1ba0470
 
6dbaaa5
 
 
 
 
 
517d88f
6dbaaa5
 
 
873f3d3
6dbaaa5
 
 
d1326a6
6dbaaa5
9fe1d46
12f777c
517d88f
6dbaaa5
 
d1326a6
6dbaaa5
 
d1326a6
 
 
 
 
 
 
517d88f
d1326a6
 
 
873f3d3
3f878f1
6dbaaa5
3f878f1
517d88f
 
 
 
73c4973
3f878f1
6dbaaa5
 
 
d1326a6
473c4c8
6dbaaa5
 
d1326a6
3f878f1
d1326a6
5ed1241
 
 
 
3f878f1
 
517d88f
 
6dbaaa5
3f878f1
 
517d88f
d1326a6
517d88f
 
 
d1326a6
6dbaaa5
 
3f878f1
98a6067
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import logging
from importlib.metadata import version
from timeit import default_timer as timer

import gradio as gr
import numpy as np

import onnx_asr

logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.info("onnx_asr version: %s", version("onnx_asr"))

vad = onnx_asr.load_vad("silero")

models_multilang = {name: onnx_asr.load_model(name) for name in ["whisper-base"]} | {
    name: onnx_asr.load_model(name, quantization="int8") for name in ["nemo-parakeet-tdt-0.6b-v3", "nemo-canary-1b-v2"]
}

models_ru = {
    name: onnx_asr.load_model(name)
    for name in [
        "gigaam-v3-ctc",
        "gigaam-v3-rnnt",
        "gigaam-v3-e2e-ctc",
        "gigaam-v3-e2e-rnnt",
        "nemo-fastconformer-ru-ctc",
        "nemo-fastconformer-ru-rnnt",
        "alphacep/vosk-model-ru",
        "alphacep/vosk-model-small-ru",
        "t-tech/t-one",
    ]
}

models_en = {
    name: onnx_asr.load_model(name, quantization="int8")
    for name in [
        "nemo-parakeet-tdt-0.6b-v2",
    ]
}

models_vad = models_multilang | models_ru | models_en


def recognize(audio: tuple[int, np.ndarray], models, language: str):
    if audio is None:
        return None

    valid_res = gr.validators.is_audio_correct_length(audio, min_length=1, max_length=30)
    if not valid_res["is_valid"]:
        raise gr.Error(valid_res["message"])

    sample_rate, waveform = audio
    length = waveform.shape[0] / sample_rate
    logger.debug("recognize: length %.3f, sample_rate %s, waveform.shape %s.", length, sample_rate, waveform.shape)
    try:
        waveform = waveform.astype(np.float32) / 2 ** (8 * waveform.itemsize - 1)
        if waveform.ndim == 2:
            waveform = waveform.mean(axis=1)

        results = []
        for name, model in models.items():
            if length > 20 and name == "alphacep/vosk-model-small-ru":
                gr.Warning(f"Model {name} only supports audio no longer than 20 s.")
                continue

            start = timer()
            result = model.recognize(waveform, sample_rate=sample_rate, language=language)
            time = timer() - start
            logger.debug("recognized by %s: result '%s', time %.3f s.", name, result, time)
            results.append([name, result])

    except Exception as e:
        raise gr.Error(f"{e} Audio: sample_rate: {sample_rate}, waveform.shape: {waveform.shape}.") from e
    else:
        return results


def recognize_ru(audio: tuple[int, np.ndarray]):
    return recognize(audio, models_ru | models_multilang, "ru")


def recognize_en(audio: tuple[int, np.ndarray]):
    return recognize(audio, models_en | models_multilang, "en")


def recognize_with_vad(audio: tuple[int, np.ndarray], name: str):
    if audio is None:
        return None

    valid_res = gr.validators.is_audio_correct_length(audio, min_length=1, max_length=600)
    if not valid_res["is_valid"]:
        raise gr.Error(valid_res["message"])

    sample_rate, waveform = audio
    length = waveform.shape[0] / sample_rate
    logger.debug("recognize: length %.3f, sample_rate %s, waveform.shape %s.", length, sample_rate, waveform.shape)
    try:
        waveform = waveform.astype(np.float32) / 2 ** (8 * waveform.itemsize - 1)
        if waveform.ndim == 2:
            waveform = waveform.mean(axis=1)

        model = models_vad[name].with_vad(vad, batch_size=1)
        results = ""
        for res in model.recognize(waveform, sample_rate=sample_rate):
            logger.debug("recognized by %s: result '%s'.", name, res)
            results += f"[{res.start:5.1f}, {res.end:5.1f}]: {res.text}\n"
            yield results

    except Exception as e:
        raise gr.Error(f"{e} Audio: sample_rate: {sample_rate}, waveform.shape: {waveform.shape}.") from e


with gr.Blocks() as recognize_short:
    audio = gr.Audio()
    with gr.Row():
        btn_ru = gr.Button("Recognize (ru)", variant="primary")
        btn_en = gr.Button("Recognize (en)", variant="primary")
    output = gr.Dataframe(headers=["model", "result"], wrap=True)
    btn_ru.click(fn=recognize_ru, inputs=audio, outputs=output)
    btn_en.click(fn=recognize_en, inputs=audio, outputs=output)


with gr.Blocks() as recognize_long:
    gr.Markdown("The default VAD parameters are used. For best results, you should adjust the VAD parameters in your app.")
    name = gr.Dropdown(sorted(models_vad.keys()), value="nemo-parakeet-tdt-0.6b-v3", label="Model")
    audio = gr.Audio()
    with gr.Row():
        btn = gr.Button("Recognize", variant="primary")
    output = gr.TextArea(label="result")
    btn.click(fn=recognize_with_vad, inputs=[audio, name], outputs=output)

    def on_model_change(name: str):
        if name in models_ru:
            label = f"Model {name} support only Russian language"
        elif name in models_en:
            label = f"Model {name} support only English language"
        else:
            label = None
        return gr.Audio(label=label)

    name.change(on_model_change, inputs=name, outputs=audio)

with gr.Blocks(title="onnx-asr demo") as demo:
    gr.Markdown("""
    # ASR demo using onnx-asr
    **[onnx-asr](https://github.com/istupakov/onnx-asr)** is a Python package for Automatic Speech Recognition using ONNX models.
    It's written in pure Python with minimal dependencies (no PyTorch, Transformers, or FFmpeg required).

    Supports **Parakeet v2 (En) / v3 (Multilingual)**, **Canary v2 (Multilingual)** and **GigaAM v2/v3 (Ru)** models
    (and many other modern [models](https://github.com/istupakov/onnx-asr?tab=readme-ov-file#supported-model-names)).   
    You can also use it with your own model if it has a supported architecture.
    """)
    gr.TabbedInterface(
        [recognize_short, recognize_long],
        [
            "Recognition of a short phrase (up to 30 sec.)",
            "Recognition of a long phrase with VAD (up to 10 min.)",
        ],
    )
    with gr.Accordion("Models used in this demo:", open=False):
        gr.Markdown("""
        ## Russian ASR models
        * `gigaam-v3-ctc` - Sber GigaAM v3 CTC ([origin](https://huggingface.co/ai-sage/GigaAM-v3), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx))
        * `gigaam-v3-rnnt` - Sber GigaAM v3 RNN-T ([origin](https://huggingface.co/ai-sage/GigaAM-v3), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx))
        * `gigaam-v3-e2e-ctc` - Sber GigaAM v3 E2E CTC ([origin](https://huggingface.co/ai-sage/GigaAM-v3), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx))
        * `gigaam-v3-e2e-rnnt` - Sber GigaAM v3 E2E RNN-T ([origin](https://huggingface.co/ai-sage/GigaAM-v3), [onnx](https://huggingface.co/istupakov/gigaam-v3-onnx))
        * `nemo-fastconformer-ru-ctc` - Nvidia FastConformer-Hybrid Large (ru) with CTC decoder ([origin](https://huggingface.co/nvidia/stt_ru_fastconformer_hybrid_large_pc), [onnx](https://huggingface.co/istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx))
        * `nemo-fastconformer-ru-rnnt` - Nvidia FastConformer-Hybrid Large (ru) with RNN-T decoder ([origin](https://huggingface.co/nvidia/stt_ru_fastconformer_hybrid_large_pc), [onnx](https://huggingface.co/istupakov/stt_ru_fastconformer_hybrid_large_pc_onnx))
        * `nemo-parakeet-tdt-0.6b-v3` - Nvidia Parakeet TDT 0.6B v3 (multilingual) ([origin](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3), [onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx))
        * `nemo-canary-1b-v2` - Nvidia Canary 1B v2 (multilingual) ([origin](https://huggingface.co/nvidia/canary-1b-v2), [onnx](https://huggingface.co/istupakov/canary-1b-v2-onnx))
        * `whisper-base` - OpenAI Whisper Base exported with onnxruntime ([origin](https://huggingface.co/openai/whisper-base), [onnx](https://huggingface.co/istupakov/whisper-base-onnx))
        * `alphacep/vosk-model-ru` - Alpha Cephei Vosk 0.54-ru ([origin](https://huggingface.co/alphacep/vosk-model-ru))
        * `alphacep/vosk-model-small-ru` - Alpha Cephei Vosk 0.52-small-ru ([origin](https://huggingface.co/alphacep/vosk-model-small-ru))
        * `t-tech/t-one` - T-Tech T-one ([origin](https://huggingface.co/t-tech/T-one))
        ## English ASR models
        * `nemo-parakeet-tdt-0.6b-v2` - Nvidia Parakeet TDT 0.6B v2 (en) ([origin](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2), [onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v2-onnx))
        * `nemo-parakeet-tdt-0.6b-v3` - Nvidia Parakeet TDT 0.6B v3 (multilingual) ([origin](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3), [onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx))
        * `nemo-canary-1b-v2` - Nvidia Canary 1B v2 (multilingual) ([origin](https://huggingface.co/nvidia/canary-1b-v2), [onnx](https://huggingface.co/istupakov/canary-1b-v2-onnx))
        * `whisper-base` - OpenAI Whisper Base exported with onnxruntime ([origin](https://huggingface.co/openai/whisper-base), [onnx](https://huggingface.co/istupakov/whisper-base-onnx))
        ## VAD models
        * `silero` - Silero VAD ([origin](https://github.com/snakers4/silero-vad), [onnx](https://huggingface.co/onnx-community/silero-vad))
        """)

demo.launch()