Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| import pandas as pd | |
| import duckdb | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns # Import Seaborn | |
| import plotly.express as px # Added for Plotly | |
| import plotly.graph_objects as go # Added for Plotly error figure | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import login | |
| from datetime import datetime, timedelta | |
| import sys # Added for error logging | |
| # Get token from environment variable | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| if not HF_TOKEN: | |
| raise ValueError("Please set the HF_TOKEN environment variable") | |
| # Login to Hugging Face | |
| login(token=HF_TOKEN) | |
| # Apply Seaborn theme and context globally | |
| sns.set_theme(style="whitegrid") | |
| sns.set_context("notebook") | |
| # Load dataset once at startup | |
| try: | |
| dataset = load_dataset("reach-vb/trending-repos", split="models") | |
| df = dataset.to_pandas() | |
| # Register the pandas DataFrame as a DuckDB table named 'models' | |
| # This allows the SQL query to use 'FROM models' | |
| duckdb.register('models', df) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| raise | |
| def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame: | |
| try: | |
| # The input start_date and end_date are already strings in YYYY-MM-DD format. | |
| # We can pass them directly to DuckDB if the SQL column is DATE. | |
| query = """ | |
| WITH model_presence AS ( | |
| SELECT | |
| id AS model_id, | |
| collected_at::DATE AS collection_day | |
| FROM models | |
| ), | |
| daily_model_counts AS ( | |
| SELECT | |
| collection_day, | |
| COUNT(*) AS total_models_today | |
| FROM model_presence | |
| GROUP BY collection_day | |
| ), | |
| retained_models AS ( | |
| SELECT | |
| a.collection_day, | |
| COUNT(*) AS previously_existed_count | |
| FROM model_presence a | |
| JOIN model_presence b | |
| ON a.model_id = b.model_id | |
| AND a.collection_day = b.collection_day + INTERVAL '1 day' | |
| GROUP BY a.collection_day | |
| ) | |
| SELECT | |
| d.collection_day, | |
| d.total_models_today, | |
| COALESCE(r.previously_existed_count, 0) AS carried_over_models, | |
| CASE | |
| WHEN d.total_models_today = 0 THEN NULL | |
| ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2) | |
| END AS percent_retained | |
| FROM daily_model_counts d | |
| LEFT JOIN retained_models r ON d.collection_day = r.collection_day | |
| WHERE d.collection_day BETWEEN ? AND ? | |
| ORDER BY d.collection_day | |
| """ | |
| # Pass the string dates directly to the query, using the 'params' keyword argument. | |
| result = duckdb.query(query, params=[start_date, end_date]).to_df() | |
| print("SQL Query Result:") # Log the result | |
| print(result) # Log the result | |
| return result | |
| except Exception as e: | |
| # Log the error to standard error | |
| print(f"Error in get_retention_data: {e}", file=sys.stderr) | |
| # Return empty DataFrame with error message | |
| return pd.DataFrame({"Error": [str(e)]}) | |
| def plot_retention_data(dataframe: pd.DataFrame): | |
| print("DataFrame received by plot_retention_data (first 5 rows):") | |
| print(dataframe.head()) | |
| print("\nData types in plot_retention_data before any conversion:") | |
| print(dataframe.dtypes) | |
| # Check if the DataFrame itself is an error signal from the previous function | |
| if "Error" in dataframe.columns and not dataframe.empty: | |
| error_message = dataframe['Error'].iloc[0] | |
| print(f"Error DataFrame received: {error_message}", file=sys.stderr) | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"Error from data generation: {error_message}", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=16) | |
| ) | |
| return fig | |
| try: | |
| # Ensure 'percent_retained' column exists | |
| if 'percent_retained' not in dataframe.columns: | |
| raise ValueError("'percent_retained' column is missing from the DataFrame.") | |
| if 'collection_day' not in dataframe.columns: | |
| raise ValueError("'collection_day' column is missing from the DataFrame.") | |
| # Explicitly convert 'percent_retained' to numeric. | |
| # Ensure 'percent_retained' is numeric and 'collection_day' is datetime for Plotly | |
| dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce') | |
| dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day']) | |
| # Drop rows where 'percent_retained' could not be converted (became NaT) | |
| dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True) | |
| print("\n'percent_retained' column after pd.to_numeric (first 5 values):") | |
| print(dataframe['percent_retained'].head()) | |
| print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype) | |
| print("\n'collection_day' column after pd.to_datetime (first 5 values):") | |
| print(dataframe['collection_day'].head()) | |
| print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype) | |
| if dataframe.empty: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No data available to plot after processing.", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=16) | |
| ) | |
| return fig | |
| # Create Plotly bar chart | |
| fig = px.bar( | |
| dataframe, | |
| x='collection_day', | |
| y='percent_retained', | |
| title='Previous Day Top 200 Trending Model Retention %', | |
| labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'}, | |
| text='percent_retained' # Use the column directly for hover/text | |
| ) | |
| # Format the text on bars | |
| fig.update_traces( | |
| texttemplate='%{text:.2f}%', | |
| textposition='inside', | |
| insidetextanchor='middle', # Anchor text to the middle of the bar | |
| textfont_color='white', | |
| textfont_size=10, # Adjusted size for better fit | |
| hovertemplate='<b>Date</b>: %{x|%Y-%m-%d}<br>' + | |
| '<b>Retention</b>: %{y:.2f}%<extra></extra>' # Custom hover | |
| ) | |
| # Calculate and plot the average retention line | |
| if not dataframe['percent_retained'].empty: | |
| average_retention = dataframe['percent_retained'].mean() | |
| fig.add_hline( | |
| y=average_retention, | |
| line_dash="dash", | |
| line_color="red", | |
| annotation_text=f"Average: {average_retention:.2f}%", | |
| annotation_position="bottom right" | |
| ) | |
| fig.update_xaxes(tickangle=45) | |
| fig.update_layout( | |
| title_x=0.5, # Center title | |
| xaxis_title="Date", | |
| yaxis_title="Retention Rate (%)", | |
| plot_bgcolor='white', # Set plot background to white like seaborn whitegrid | |
| bargap=0.2 # Gap between bars of different categories | |
| ) | |
| return fig | |
| except Exception as e: | |
| print(f"Error during plot_retention_data: {e}", file=sys.stderr) | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"Plotting Error: {str(e)}", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=16) | |
| ) | |
| return fig | |
| def interface_fn(start_date, end_date): | |
| result = get_retention_data(start_date, end_date) | |
| return plot_retention_data(result) | |
| # Get min and max dates from the dataset | |
| min_date = datetime.fromisoformat(df['collected_at'].min()).date() | |
| max_date = datetime.fromisoformat(df['collected_at'].max()).date() | |
| iface = gr.Interface( | |
| fn=interface_fn, | |
| inputs=[ | |
| gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")), | |
| gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d")) | |
| ], | |
| outputs=gr.Plot(label="Model Retention Visualization"), | |
| title="Model Retention Analysis", | |
| description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format." | |
| ) | |
| iface.launch() |