Spaces:
Running
Running
| import pandas as pd | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional | |
| import json | |
| import math | |
| import os | |
| from typing import Dict | |
| class DataManager: | |
| def __init__(self, data_dir: str): | |
| self.data_dir = Path(data_dir) | |
| self.master_df = self._load_all_data() | |
| def _load_old_format_folder(self, period_dir: Path) -> pd.DataFrame: | |
| all_xlsx_data = [] | |
| period = period_dir.name | |
| for file_path in period_dir.iterdir(): | |
| if file_path.suffix != ".xlsx": | |
| continue | |
| model_group = file_path.stem | |
| xls = pd.read_excel(file_path, sheet_name=None) | |
| for sheet_name, df in xls.items(): | |
| if df.empty: | |
| continue | |
| df = self._clean_dataframe(df) | |
| df["Period"] = period | |
| df["Metric"] = sheet_name | |
| df["Model Group"] = model_group | |
| all_xlsx_data.append(df) | |
| return all_xlsx_data | |
| def _load_new_format_folder(self, period_dir: Path) -> pd.DataFrame: | |
| raw_records = [] | |
| period = period_dir.name | |
| for file_path in period_dir.rglob("*.json"): | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| full_path = data["model_name_or_path"] | |
| model_name = full_path.split("/")[-1].replace(".pth", "") | |
| params = data["parameters count"] | |
| data_path = data["data_path"] | |
| source_col = data_path.split("-")[-1] | |
| neg_log_prob = data["neg_log_prob_sum"] | |
| avg_char = data["avg character count"] | |
| avg_bytes = data["avg bytes"] | |
| metrics = {} | |
| metrics["cr"] = data["compression_rate"] | |
| metrics["bpc"] = (neg_log_prob / avg_char) * (1 / math.log(2)) | |
| metrics["bpb"] = (neg_log_prob / avg_bytes) * (1 / math.log(2)) | |
| for metric_type, value in metrics.items(): | |
| if value is not None: | |
| raw_records.append( | |
| { | |
| "Name": model_name, | |
| "Params (B)": params, | |
| "Period": period, | |
| "Metric": metric_type, | |
| "Model Group": "other", | |
| "Source": source_col, | |
| "Value": value, | |
| } | |
| ) | |
| if not raw_records: | |
| return [] | |
| df_long = pd.DataFrame(raw_records) | |
| df_wide = df_long.pivot_table( | |
| index=["Name", "Params (B)", "Period", "Metric", "Model Group"], | |
| columns="Source", | |
| values="Value", | |
| ).reset_index() | |
| df_wide.columns.name = None | |
| def assign_group(p): | |
| if p >= 12: | |
| return "14b" | |
| if p >= 9: | |
| return "9b" | |
| if p >= 6: | |
| return "7b" | |
| if p >= 2.5: | |
| return "3b" | |
| if p >= 1: | |
| return "1b5" | |
| return "other" | |
| df_wide["Model Group"] = df_wide["Params (B)"].apply(assign_group) | |
| metadata_cols = ["Name", "Params (B)", "Period", "Metric", "Model Group"] | |
| new_columns = {} | |
| for col in df_wide.columns: | |
| if col not in metadata_cols: | |
| new_columns[col] = col.replace("_", " ") | |
| df_wide = df_wide.rename(columns=new_columns) | |
| return [df_wide] | |
| def _load_all_data(self) -> pd.DataFrame: | |
| all_records = [] | |
| if not self.data_dir.exists(): | |
| print(f"Warning: Directory {self.data_dir} does not exist.") | |
| return pd.DataFrame() | |
| period_dirs = [d for d in self.data_dir.iterdir() if d.is_dir() and re.match(r"^\d{4}-\d{2}$", d.name)] | |
| for period_dir in period_dirs: | |
| if period_dir.name <= "2025-11": | |
| all_records.extend(self._load_old_format_folder(period_dir)) | |
| else: | |
| all_records.extend(self._load_new_format_folder(period_dir)) | |
| if not all_records: | |
| return pd.DataFrame() | |
| final_df = pd.concat(all_records, ignore_index=True) | |
| exclude_cols = ["Name", "Period", "Metric", "Model Group"] | |
| numeric_cols = [c for c in final_df.columns if c not in exclude_cols] | |
| for col in numeric_cols: | |
| final_df[col] = pd.to_numeric(final_df[col], errors="coerce") | |
| return final_df | |
| def _clean_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: | |
| df = df.dropna(axis=1, how="all") | |
| new_columns = [] | |
| for col in df.columns: | |
| col_str = str(col) | |
| if "Parameters" in col_str: | |
| new_columns.append("Params (B)") | |
| elif col_str == "Average (The lower the better)": | |
| new_columns.append("Average (lower=better)") | |
| else: | |
| new_columns.append(col_str.rsplit("_", maxsplit=1)[0].replace("\u200b", "")) # 去除_202xxxxx后缀 | |
| df.columns = new_columns | |
| column_mapping = {col: col.replace("_", " ") for col in df.columns} | |
| df = df.rename(columns=column_mapping) | |
| return df | |
| def get_available_periods(self) -> List[str]: | |
| """返回所有可用的时间周期,已排序,从旧到新""" | |
| if self.master_df.empty: | |
| return [] | |
| return sorted(self.master_df["Period"].unique().tolist()) | |
| def get_available_columns(self, period: str) -> List[str]: | |
| """获取特定时间段内的数据列(排除元数据列和全部为NaN的列)""" | |
| if self.master_df.empty: | |
| return [] | |
| subset = self.master_df[self.master_df["Period"] == period] | |
| if subset.empty: | |
| return [] | |
| metadata_cols = ["Name", "Params (B)", "Period", "Metric", "Model Group", "Average (lower=better)"] | |
| return [c for c in subset.columns if c not in metadata_cols and not subset[c].isna().all()] | |
| def query( | |
| self, | |
| period: str, | |
| metric_code: str, | |
| param_range: Tuple[float, float], | |
| model_groups: Optional[List[str]] = None, | |
| visible_columns: Optional[List[str]] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| 统一查询接口。 | |
| Args: | |
| period: 时间周期 (e.g. "2025-12") | |
| metric_display_name: bpc, bpb, cr | |
| param_range: 参数量范围元组 (min, max) | |
| model_groups: (可选) 文件名列表,如 ['14b', '7b'] | |
| visible_columns: (可选) 需要参与计算平均值的列名列表 | |
| """ | |
| mask = ( | |
| (self.master_df["Period"] == period) | |
| & (self.master_df["Metric"] == metric_code) | |
| & (self.master_df["Params (B)"].between(param_range[0], param_range[1])) | |
| ) | |
| if model_groups is not None: | |
| if len(model_groups) == 0: | |
| return pd.DataFrame() | |
| mask = mask & (self.master_df["Model Group"].isin(model_groups)) | |
| filtered_df = self.master_df.loc[mask].copy() | |
| if filtered_df.empty: | |
| return filtered_df | |
| # 始终排除的列 | |
| exclude_cols = ["Period", "Metric", "Model Group"] | |
| # 始终保留的列 | |
| metadata_cols = ["Name", "Params (B)", "Average (lower=better)"] | |
| if visible_columns is not None: | |
| valid_visible_cols = [c for c in visible_columns if c in filtered_df.columns] | |
| columns_to_keep = metadata_cols + valid_visible_cols | |
| cols_for_average = valid_visible_cols | |
| else: | |
| all_cols = [c for c in filtered_df.columns if c not in exclude_cols] | |
| columns_to_keep = all_cols | |
| cols_for_average = [c for c in all_cols if c not in metadata_cols] | |
| if not cols_for_average: | |
| return pd.DataFrame() | |
| filtered_df["Average (lower=better)"] = filtered_df[cols_for_average].mean(axis=1).round(3) | |
| columns_to_keep = [c for c in columns_to_keep if c not in exclude_cols] | |
| columns_to_keep = list(dict.fromkeys(columns_to_keep)) | |
| filtered_df = filtered_df[columns_to_keep] | |
| if "Name" in filtered_df.columns: | |
| filtered_df["Name"] = filtered_df["Name"].apply(lambda x: x.replace(".pth", "")) | |
| filtered_df = filtered_df.sort_values(by="Average (lower=better)", ascending=True, kind="mergesort", na_position="last").reset_index( | |
| drop=True | |
| ) | |
| fixed_cols = ["Name", "Params (B)", "Average (lower=better)"] | |
| column_priority = [ | |
| # 代码 (Code) | |
| "github cpp", | |
| "github python", | |
| "github javascript", | |
| # 科研 (Research) | |
| "arxiv physics", | |
| "arxiv cs", | |
| "arxiv math", | |
| # 写作 (Writing) | |
| "ao3 english", | |
| "github markdown", | |
| # 世界知识 (World Knowledge) | |
| "bbc news", | |
| "wikipedia english", | |
| ] | |
| existing_cols = filtered_df.columns.tolist() | |
| ordered_cols = [] | |
| for col in fixed_cols: | |
| if col in existing_cols: | |
| ordered_cols.append(col) | |
| for col in column_priority: | |
| if col in existing_cols and col not in ordered_cols: | |
| ordered_cols.append(col) | |
| for col in existing_cols: | |
| if col not in ordered_cols: | |
| ordered_cols.append(col) | |
| filtered_df = filtered_df[ordered_cols] | |
| return filtered_df | |
| class LongContextDataManager: | |
| def __init__(self, data_dir: str): | |
| self.data_dir = data_dir | |
| # { period: { "Display Name": "full_path" } } | |
| self.period_file_map = {} | |
| # { period: { "Model Name": ["path1", "path2"] } } | |
| self.period_model_map = {} | |
| # { period: { "Dataset Name": set(["path1", "path2"]) } } | |
| self.period_dataset_map = {} | |
| # { period: { "Model Name": { "Dataset Name": ["path1", "path2"] } } } | |
| self.period_model_dataset_map = {} | |
| self._scan_directories() | |
| def _scan_directories(self): | |
| for root, dirs, files in os.walk(self.data_dir): | |
| json_files = [f for f in files if f.endswith(".json")] | |
| if json_files: | |
| period = os.path.basename(root) | |
| if period not in self.period_file_map: | |
| self.period_file_map[period] = {} | |
| self.period_model_map[period] = {} | |
| self.period_dataset_map[period] = {} | |
| self.period_model_dataset_map[period] = {} | |
| for jf in json_files: | |
| full_path = os.path.join(root, jf) | |
| with open(full_path, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| model_name = meta.get("model_name_or_path", jf.replace(".json", "")).split("/")[-1].replace(".pth", "") | |
| data_path = meta.get("data_path", "").replace("UncheatableEval-", "") | |
| dataset_name = data_path.split("/")[-1] if data_path else "Unknown" | |
| file_display_label = f"{model_name}-{dataset_name}" | |
| print(file_display_label) | |
| self.period_file_map[period][file_display_label] = full_path | |
| if model_name not in self.period_model_map[period]: | |
| self.period_model_map[period][model_name] = [] | |
| self.period_model_map[period][model_name].append(full_path) | |
| if dataset_name not in self.period_dataset_map[period]: | |
| self.period_dataset_map[period][dataset_name] = [] | |
| self.period_dataset_map[period][dataset_name].append(full_path) | |
| if model_name not in self.period_model_dataset_map[period]: | |
| self.period_model_dataset_map[period][model_name] = {} | |
| if dataset_name not in self.period_model_dataset_map[period][model_name]: | |
| self.period_model_dataset_map[period][model_name][dataset_name] = [] | |
| self.period_model_dataset_map[period][model_name][dataset_name].append(full_path) | |
| def get_available_periods(self): | |
| return sorted(list(self.period_file_map.keys())) | |
| def get_file_choices(self, period): | |
| """返回 [(Display Name, Full Path), ...]""" | |
| if period not in self.period_file_map: | |
| return [] | |
| return [(k, v) for k, v in self.period_file_map[period].items()] | |
| def get_model_choices(self, period): | |
| """返回 [(Model Name, Model Name), ...]""" | |
| if period not in self.period_model_map: | |
| return [] | |
| return [(k, k) for k in self.period_model_map[period].keys()] | |
| def get_paths_for_model(self, period, model_name): | |
| return self.period_model_map.get(period, {}).get(model_name, []) | |
| def get_dataset_choices(self, period): | |
| """返回某个period下的所有数据集列表 [(Dataset Name, Dataset Name), ...]""" | |
| if period not in self.period_dataset_map: | |
| return [] | |
| return [(k, k) for k in sorted(self.period_dataset_map[period].keys())] | |
| def get_paths_for_model_and_datasets(self, period, model_name, dataset_names): | |
| """根据模型名称和数据集名称列表,返回对应的文件路径列表""" | |
| if period not in self.period_model_dataset_map: | |
| return [] | |
| if model_name not in self.period_model_dataset_map[period]: | |
| return [] | |
| paths = [] | |
| for dataset_name in dataset_names: | |
| if dataset_name in self.period_model_dataset_map[period][model_name]: | |
| paths.extend(self.period_model_dataset_map[period][model_name][dataset_name]) | |
| return paths | |
| if __name__ == "__main__": | |
| # dm = DataManager("data") | |
| # periods = dm.get_available_periods() | |
| # print(f"Total records loaded: {len(dm.master_df)}") | |
| # print(f"Available periods: {periods}") | |
| # print(f"Available columns: {dm.get_available_columns('2025-11')}") | |
| # result = dm.query( | |
| # period="2025-11", | |
| # metric_code="cr", | |
| # param_range=(0, 20), | |
| # model_groups=["7b"], | |
| # visible_columns=["wikipedia_english"], | |
| # ) | |
| # print(result.head(20)) | |
| lcm = LongContextDataManager("longctx_data") | |