Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import argparse | |
| import sqlite3 | |
| import shutil | |
| import uuid | |
| from datasets import Dataset, concatenate_datasets | |
| import gradio as gr | |
| import torch | |
| from storing.createdb import create_db | |
| from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor | |
| from loading.serialization import JsonSerializer | |
| from utils import nest_list, is_google_colab | |
| from datapipeline import create_hardcoded_data_pipeline | |
| from threadeddatapipeline import ThreadedDataPipeline | |
| from dataset.hf_dataset import HFDataset | |
| from huggingface_hub import DatasetCard | |
| NUM_THREADS = 1 | |
| # Detect if code is running in Colab | |
| is_colab = is_google_colab() | |
| colab_instruction = "" if is_colab else """ | |
| <p>You can skip the queue using Colab: | |
| <a href="https://colab.research.google.com/drive/1zNRnX1lXjlGtBMW8U8S9t4eY1cA0D6lm?usp=sharing"> | |
| <img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>""" | |
| device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶" | |
| def numvideos_type(x): | |
| x = int(x) | |
| if x > 12: | |
| raise argparse.ArgumentTypeError("Maximum number of videos is 12") | |
| if x < 1: | |
| raise argparse.ArgumentTypeError("Minimum number of videos is 12") | |
| return x | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(usage="[arguments] --channel_name --num_videos", | |
| description="Program to transcribe YouTube videos.") | |
| parser.add_argument("--channel_name", | |
| type=str, | |
| required=True, | |
| help="Name of the channel from where the videos will be transcribed") | |
| parser.add_argument("--num_videos", | |
| type=numvideos_type, | |
| required=True, | |
| help="Number of videos (min. 1 - max. 12) to transcribe from --channel_name") | |
| parser.add_argument("--hf_token", | |
| type=str, | |
| required=True, | |
| help="Token of your HF account. You need a HF account to upload the dataset") | |
| parser.add_argument("--hf_dataset_identifier", | |
| type=str, | |
| required=True, | |
| help="The ID of the repository to push to in the following format: <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.") | |
| parser.add_argument("--whisper_model", | |
| type=str, | |
| required=True, | |
| help="Select one of the available whispers models", | |
| choices=["tiny", "base", "small", "medium", "large"]) | |
| args = parser.parse_args() | |
| return args | |
| def transcribe(mode: str, | |
| channel_name: str, | |
| num_videos: int, | |
| hf_token: str, | |
| hf_dataset_identifier: str, | |
| whisper_model: str) -> str: | |
| # Create a unique name for the database | |
| unique_filename = str(uuid.uuid4()) | |
| database_name = unique_filename +".db" | |
| create_db(database_name) | |
| # Create necessary resources | |
| yt_video_processor = YoutubeVideoPreprocessor(mode=mode, | |
| serializer=JsonSerializer()) # TODO: Let user select serializer | |
| hf_dataset = HFDataset(hf_dataset_identifier) | |
| videos_downloaded = hf_dataset.list_of_ids | |
| paths, dataset_folder = yt_video_processor.preprocess(channel_name, | |
| num_videos, | |
| videos_downloaded) | |
| nested_listed_length = math.ceil(len(paths) / NUM_THREADS) | |
| nested_paths = nest_list(paths, nested_listed_length) | |
| data_pipelines = [create_hardcoded_data_pipeline(database_name, whisper_model) for i in range(NUM_THREADS)] | |
| # Run pipelines in multiple threads | |
| threads = [] | |
| for data_pipeline, thread_paths in zip(data_pipelines, nested_paths): | |
| threads.append(ThreadedDataPipeline(data_pipeline, thread_paths)) | |
| for thread in threads: | |
| thread.start() | |
| for thread in threads: | |
| thread.join() | |
| # Fetch entries and print them | |
| connection = sqlite3.connect(database_name) | |
| cursor = connection.cursor() | |
| cursor.execute("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO") | |
| videos = cursor.fetchall() | |
| num_new_videos = len(videos) | |
| dataset = Dataset.from_sql("SELECT CHANNEL_NAME, URL, TITLE, DESCRIPTION, TRANSCRIPTION, SEGMENTS FROM VIDEO", connection) | |
| if (hf_dataset.exist==True) and (hf_dataset.is_empty==False): | |
| dataset_to_upload = concatenate_datasets([hf_dataset.dataset["train"], dataset]) | |
| else: | |
| dataset_to_upload = dataset | |
| dataset_to_upload.push_to_hub(hf_dataset_identifier, token=hf_token) | |
| card = DatasetCard.load(hf_dataset_identifier) | |
| card.data.tags = ["whisper", "whispering", whisper_model] | |
| card.data.task_categories = ["automatic-speech-recognition"] | |
| card.push_to_hub(hf_dataset_identifier, token=hf_token) | |
| # Close connection | |
| connection.close() | |
| # Remove db | |
| os.remove(database_name) | |
| try: | |
| shutil.rmtree(dataset_folder) | |
| except OSError as e: | |
| print("Error: %s : %s" % (dataset_folder, e.strerror)) | |
| return f"Dataset created or updated at {hf_dataset_identifier}. {num_new_videos} samples were added" | |
| with gr.Blocks() as demo: | |
| md = """# Use Whisper to create a HF dataset from YouTube videos | |
| This space will let you create a HF dataset by transcribing videos from YouTube. | |
| Enter the name of the YouTube channel or the URL of a YouTube playlist (in the form https://www.youtube.com/playlist?list=****), | |
| and the repo_id of the dataset (you need a HuggingFace account). | |
| If the dataset already exists, it will only transcribe videos that are not in the dataset. | |
| If it does not exists, it will create the dataset. For using this demo, you need a | |
| [Hugging Face token](https://huggingface.co/settings/tokens) with write role. Learn more about [tokens](https://huggingface.co/docs/hub/security-tokens). | |
| """ | |
| gr.Markdown(md) | |
| gr.HTML( | |
| f""" | |
| <p style="margin-bottom: 10px; font-size: 94%"> | |
| Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")} | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| whisper_model = gr.Radio([ | |
| "tiny", "base", "small", "medium", "large" | |
| ], label="Whisper model", value="base") | |
| mode = gr.Radio([ | |
| "channel_name", "playlist" | |
| ], label="Get the videos from:", value="channel_name") | |
| channel_name = gr.Textbox(label="YouTube Channel or Playlist URL", | |
| placeholder="Enter the name of the YouTube channel or the URL of the playlist") | |
| num_videos = gr.Slider(1, 20000, value=4, step=1, label="Number of videos") | |
| hf_token = gr.Textbox(placeholder="Your HF write access token", type="password") | |
| hf_dataset_identifier = gr.Textbox(label = 'Dataset Name', | |
| placeholder = "Enter in the format <username>/<repo_name>") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| output = gr.Text() | |
| submit_btn.click(fn=transcribe, inputs=[mode, | |
| channel_name, | |
| num_videos, | |
| hf_token, | |
| hf_dataset_identifier, | |
| whisper_model], outputs=[output]) | |
| gr.Markdown(''' | |
|  | |
| ''') | |
| if not is_colab: | |
| demo.queue(concurrency_count=1) | |
| demo.launch(debug=True, share=is_colab) |