#!/usr/bin/env python import copy import tempfile import gradio as gr from huggingface_hub import CommitOperationAdd, HfApi import pandas as pd from papers import PaperList REPO_ID = "CVPR2024/CVPR2024-papers" FILENAME = "data.csv" api = HfApi() paper_list = PaperList() path = api.hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="dataset") actual_df = pd.read_csv(path) paper_id_to_index = {str(row["id"]): i for i, row in actual_df.iterrows()} with gr.Blocks() as demo_search: with gr.Group(): search_title = gr.Textbox(label="Search title") search_author = gr.Textbox(label="Search author") df = gr.Dataframe( value=paper_list.df_prettified, datatype=paper_list.get_column_datatypes(paper_list.get_column_names()), type="pandas", row_count=(0, "dynamic"), interactive=False, height=1000, elem_id="table", wrap=True, ) inputs = [ search_title, search_author, ] gr.on( triggers=[ search_title.submit, search_author.submit, ], fn=paper_list.search, inputs=inputs, outputs=df, queue=False, api_name=False, ) demo_search.load( fn=paper_list.search, inputs=inputs, outputs=df, queue=False, api_name=False, ) def load_data(paper_id: str) -> tuple[str, str, str, str, str, str, str, str, str]: try: index = paper_id_to_index[paper_id] except KeyError: raise gr.Error(f"Paper ID {paper_id} not found.") paper = actual_df.iloc[index] return ( paper["id"], paper["title"], paper["authors"], paper["arxiv_id"], "\n".join([PaperList.create_link("GitHub", url) for url in paper["GitHub"]] if paper["GitHub"]!="[]" else " "), "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}") for repo_id in paper["Space"] ] if paper["Space"] != "[]" else [" "]), "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in paper["Model"]] if paper["Model"] != "[]" else [" "]), "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") for repo_id in paper["Dataset"] ] if paper["Dataset"] != "[]" else [" "] ) ) def split_and_strip(s: str) -> list[str]: return [x.strip() for x in s.split("\n") if x.strip()] def create_pr( paper_id: str, title: str, authors: str, arxiv_id: str, project_page: str, github_links: str, space_ids: str, model_ids: str, dataset_ids: str, oauth_token: gr.OAuthToken | None, ) -> str: if oauth_token is None: return "Please log in first." try: index = paper_id_to_index[paper_id] except KeyError: raise gr.Error(f"Paper ID {paper_id} not found.") data = copy.deepcopy(df) data[index]["title"] = title.strip() data[index]["authors"] = authors data[index]["arxiv_id"] = arxiv_id.strip() data[index]["GitHub"] = github_links data[index]["Space"] = space_ids data[index]["Model"] = model_ids data[index]["Dataset"] = dataset_ids with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: data.to_csv(f) commit = CommitOperationAdd(FILENAME, f.name) res = api.create_commit( repo_id=REPO_ID, operations=[commit], commit_message=f"Update {paper_id}", repo_type="dataset", create_pr=True, token=oauth_token.token, ) return res.pr_url with gr.Blocks() as demo_edit: with gr.Group(): paper_id = gr.Textbox(label="ID", max_lines=1) load_button = gr.Button("Load") with gr.Group(): title = gr.Textbox(label="Title", max_lines=1) authors = gr.Textbox(label="Authors", lines=5) arxiv_id = gr.Textbox(label="arXiv ID", max_lines=1, placeholder="2404.00000") github_links = gr.Textbox( label="GitHub links", lines=5, placeholder="https://github.com/aaa/bbb\nhttps://github.com/ccc/ddd", ) space_ids = gr.Textbox(label="Space IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") model_ids = gr.Textbox(label="Model IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") dataset_ids = gr.Textbox( label="Dataset IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2" ) create_pr_button = gr.Button("Create PR") result = gr.Textbox(label="Result", max_lines=1) gr.on( triggers=[ paper_id.submit, load_button.click, ], fn=load_data, inputs=paper_id, outputs=[ paper_id, title, authors, arxiv_id, github_links, space_ids, model_ids, dataset_ids, ], queue=False, api_name=False, ) create_pr_button.click( fn=create_pr, inputs=[ paper_id, title, authors, arxiv_id, github_links, space_ids, model_ids, dataset_ids, ], outputs=result, queue=False, api_name=False, ) with gr.Blocks(css="style.css") as demo: gr.Markdown( "You can create PRs to update the CSV files in the [CVPR2024-papers repo](https://huggingface.co/datasets/CVPR2024/CVPR2024-papers) with this Space." ) with gr.Tabs(): with gr.Tab(label="Step 1: Login"): gr.Markdown("To create a PR, you first need to log in. Please press the login button below.") gr.LoginButton() with gr.Tab(label="Step 2: Search for paper ID"): gr.Markdown("Search for the paper you would like to update and find its paper ID.") demo_search.render() with gr.Tab(label="Step 3: Edit and create PR"): gr.Markdown("Enter the paper ID in the field below and press the Load button.") gr.Markdown("After making the necessary changes, press the Create PR button.") demo_edit.render() if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False, debug=True)