Spaces:
Runtime error
Runtime error
| import nltk | |
| import whisper | |
| from pytube import YouTube | |
| import streamlit as st | |
| from sentence_transformers import SentenceTransformer, util | |
| nltk.download('punkt') | |
| def init_sentence_model(embedding_model): | |
| return SentenceTransformer(embedding_model) | |
| def init_whisper(whisper_size): | |
| return whisper.load_model(whisper_size) | |
| def inference(link): | |
| yt = YouTube(link) | |
| path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4") | |
| options = whisper.DecodingOptions(without_timestamps=True) | |
| results = whisper_model.transcribe(path) | |
| return results['segments'] | |
| def get_embeddings(segments): | |
| return model.encode(segments["text"]) | |
| def format_segments(segments, window=10): | |
| new_segments = dict() | |
| new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)] | |
| new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)] | |
| return new_segments | |
| st.markdown(""" | |
| # Youtube video transcription and search | |
| You can run it on colab GPU for faster performance: [Link](https://colab.research.google.com/drive/1-6Lmvlfwxd5JAnKOBKtdR1YiooIm-rJf?usp=sharing) | |
| """) | |
| with st.form("transcribe"): | |
| yt_link = st.text_input("Youtube link") | |
| whisper_size = st.selectbox("Whisper model size", ("small", "base", "large")) | |
| embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2') | |
| top_k = st.number_input("Number of query results", value=5) | |
| window = st.number_input("Number of segments per result", value=10) | |
| transcribe_submit = st.form_submit_button("Submit") | |
| if transcribe_submit and 'start_search' not in st.session_state: | |
| st.session_state.start_search = True | |
| if 'start_search' in st.session_state: | |
| model = init_sentence_model(embedding_model) | |
| whisper_model = init_whisper(whisper_size) | |
| segments = inference(yt_link) | |
| segments = format_segments(segments, window) | |
| embeddings = get_embeddings(segments) | |
| query = st.text_input('Enter a query') | |
| if query: | |
| query_embedding = model.encode(query) | |
| results = util.semantic_search(query_embedding, embeddings, top_k=top_k) | |
| st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True) |