import pandas as pd import re from pathlib import Path from typing import List, Tuple, Optional import json import math import os from typing import Dict class DataManager: def __init__(self, data_dir: str): self.data_dir = Path(data_dir) self.master_df = self._load_all_data() def _load_old_format_folder(self, period_dir: Path) -> pd.DataFrame: all_xlsx_data = [] period = period_dir.name for file_path in period_dir.iterdir(): if file_path.suffix != ".xlsx": continue model_group = file_path.stem xls = pd.read_excel(file_path, sheet_name=None) for sheet_name, df in xls.items(): if df.empty: continue df = self._clean_dataframe(df) df["Period"] = period df["Metric"] = sheet_name df["Model Group"] = model_group all_xlsx_data.append(df) return all_xlsx_data def _load_new_format_folder(self, period_dir: Path) -> pd.DataFrame: raw_records = [] period = period_dir.name for file_path in period_dir.rglob("*.json"): with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) full_path = data["model_name_or_path"] model_name = full_path.split("/")[-1].replace(".pth", "") params = data["parameters count"] data_path = data["data_path"] source_col = data_path.split("-")[-1] neg_log_prob = data["neg_log_prob_sum"] avg_char = data["avg character count"] avg_bytes = data["avg bytes"] metrics = {} metrics["cr"] = data["compression_rate"] metrics["bpc"] = (neg_log_prob / avg_char) * (1 / math.log(2)) metrics["bpb"] = (neg_log_prob / avg_bytes) * (1 / math.log(2)) for metric_type, value in metrics.items(): if value is not None: raw_records.append( { "Name": model_name, "Params (B)": params, "Period": period, "Metric": metric_type, "Model Group": "other", "Source": source_col, "Value": value, } ) if not raw_records: return [] df_long = pd.DataFrame(raw_records) df_wide = df_long.pivot_table( index=["Name", "Params (B)", "Period", "Metric", "Model Group"], columns="Source", values="Value", ).reset_index() df_wide.columns.name = None def assign_group(p): if p >= 12: return "14b" if p >= 9: return "9b" if p >= 6: return "7b" if p >= 2.5: return "3b" if p >= 1: return "1b5" return "other" df_wide["Model Group"] = df_wide["Params (B)"].apply(assign_group) metadata_cols = ["Name", "Params (B)", "Period", "Metric", "Model Group"] new_columns = {} for col in df_wide.columns: if col not in metadata_cols: new_columns[col] = col.replace("_", " ") df_wide = df_wide.rename(columns=new_columns) return [df_wide] def _load_all_data(self) -> pd.DataFrame: all_records = [] if not self.data_dir.exists(): print(f"Warning: Directory {self.data_dir} does not exist.") return pd.DataFrame() period_dirs = [d for d in self.data_dir.iterdir() if d.is_dir() and re.match(r"^\d{4}-\d{2}$", d.name)] for period_dir in period_dirs: if period_dir.name <= "2025-11": all_records.extend(self._load_old_format_folder(period_dir)) else: all_records.extend(self._load_new_format_folder(period_dir)) if not all_records: return pd.DataFrame() final_df = pd.concat(all_records, ignore_index=True) exclude_cols = ["Name", "Period", "Metric", "Model Group"] numeric_cols = [c for c in final_df.columns if c not in exclude_cols] for col in numeric_cols: final_df[col] = pd.to_numeric(final_df[col], errors="coerce") return final_df def _clean_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: df = df.dropna(axis=1, how="all") new_columns = [] for col in df.columns: col_str = str(col) if "Parameters" in col_str: new_columns.append("Params (B)") elif col_str == "Average (The lower the better)": new_columns.append("Average (lower=better)") else: new_columns.append(col_str.rsplit("_", maxsplit=1)[0].replace("\u200b", "")) # 去除_202xxxxx后缀 df.columns = new_columns column_mapping = {col: col.replace("_", " ") for col in df.columns} df = df.rename(columns=column_mapping) return df def get_available_periods(self) -> List[str]: """返回所有可用的时间周期,已排序,从旧到新""" if self.master_df.empty: return [] return sorted(self.master_df["Period"].unique().tolist()) def get_available_columns(self, period: str) -> List[str]: """获取特定时间段内的数据列(排除元数据列和全部为NaN的列)""" if self.master_df.empty: return [] subset = self.master_df[self.master_df["Period"] == period] if subset.empty: return [] metadata_cols = ["Name", "Params (B)", "Period", "Metric", "Model Group", "Average (lower=better)"] return [c for c in subset.columns if c not in metadata_cols and not subset[c].isna().all()] def query( self, period: str, metric_code: str, param_range: Tuple[float, float], model_groups: Optional[List[str]] = None, visible_columns: Optional[List[str]] = None, ) -> pd.DataFrame: """ 统一查询接口。 Args: period: 时间周期 (e.g. "2025-12") metric_display_name: bpc, bpb, cr param_range: 参数量范围元组 (min, max) model_groups: (可选) 文件名列表,如 ['14b', '7b'] visible_columns: (可选) 需要参与计算平均值的列名列表 """ mask = ( (self.master_df["Period"] == period) & (self.master_df["Metric"] == metric_code) & (self.master_df["Params (B)"].between(param_range[0], param_range[1])) ) if model_groups is not None: if len(model_groups) == 0: return pd.DataFrame() mask = mask & (self.master_df["Model Group"].isin(model_groups)) filtered_df = self.master_df.loc[mask].copy() if filtered_df.empty: return filtered_df # 始终排除的列 exclude_cols = ["Period", "Metric", "Model Group"] # 始终保留的列 metadata_cols = ["Name", "Params (B)", "Average (lower=better)"] if visible_columns is not None: valid_visible_cols = [c for c in visible_columns if c in filtered_df.columns] columns_to_keep = metadata_cols + valid_visible_cols cols_for_average = valid_visible_cols else: all_cols = [c for c in filtered_df.columns if c not in exclude_cols] columns_to_keep = all_cols cols_for_average = [c for c in all_cols if c not in metadata_cols] if not cols_for_average: return pd.DataFrame() filtered_df["Average (lower=better)"] = filtered_df[cols_for_average].mean(axis=1).round(3) columns_to_keep = [c for c in columns_to_keep if c not in exclude_cols] columns_to_keep = list(dict.fromkeys(columns_to_keep)) filtered_df = filtered_df[columns_to_keep] if "Name" in filtered_df.columns: filtered_df["Name"] = filtered_df["Name"].apply(lambda x: x.replace(".pth", "")) filtered_df = filtered_df.sort_values(by="Average (lower=better)", ascending=True, kind="mergesort", na_position="last").reset_index( drop=True ) fixed_cols = ["Name", "Params (B)", "Average (lower=better)"] column_priority = [ # 代码 (Code) "github cpp", "github python", "github javascript", # 科研 (Research) "arxiv physics", "arxiv cs", "arxiv math", # 写作 (Writing) "ao3 english", "github markdown", # 世界知识 (World Knowledge) "bbc news", "wikipedia english", ] existing_cols = filtered_df.columns.tolist() ordered_cols = [] for col in fixed_cols: if col in existing_cols: ordered_cols.append(col) for col in column_priority: if col in existing_cols and col not in ordered_cols: ordered_cols.append(col) for col in existing_cols: if col not in ordered_cols: ordered_cols.append(col) filtered_df = filtered_df[ordered_cols] return filtered_df class LongContextDataManager: def __init__(self, data_dir: str): self.data_dir = data_dir # { period: { "Display Name": "full_path" } } self.period_file_map = {} # { period: { "Model Name": ["path1", "path2"] } } self.period_model_map = {} # { period: { "Dataset Name": set(["path1", "path2"]) } } self.period_dataset_map = {} # { period: { "Model Name": { "Dataset Name": ["path1", "path2"] } } } self.period_model_dataset_map = {} self._scan_directories() def _scan_directories(self): for root, dirs, files in os.walk(self.data_dir): json_files = [f for f in files if f.endswith(".json")] if json_files: period = os.path.basename(root) if period not in self.period_file_map: self.period_file_map[period] = {} self.period_model_map[period] = {} self.period_dataset_map[period] = {} self.period_model_dataset_map[period] = {} for jf in json_files: full_path = os.path.join(root, jf) with open(full_path, "r", encoding="utf-8") as f: meta = json.load(f) model_name = meta.get("model_name_or_path", jf.replace(".json", "")).split("/")[-1].replace(".pth", "") data_path = meta.get("data_path", "").replace("UncheatableEval-", "") dataset_name = data_path.split("/")[-1] if data_path else "Unknown" file_display_label = f"{model_name}-{dataset_name}" print(file_display_label) self.period_file_map[period][file_display_label] = full_path if model_name not in self.period_model_map[period]: self.period_model_map[period][model_name] = [] self.period_model_map[period][model_name].append(full_path) if dataset_name not in self.period_dataset_map[period]: self.period_dataset_map[period][dataset_name] = [] self.period_dataset_map[period][dataset_name].append(full_path) if model_name not in self.period_model_dataset_map[period]: self.period_model_dataset_map[period][model_name] = {} if dataset_name not in self.period_model_dataset_map[period][model_name]: self.period_model_dataset_map[period][model_name][dataset_name] = [] self.period_model_dataset_map[period][model_name][dataset_name].append(full_path) def get_available_periods(self): return sorted(list(self.period_file_map.keys())) def get_file_choices(self, period): """返回 [(Display Name, Full Path), ...]""" if period not in self.period_file_map: return [] return [(k, v) for k, v in self.period_file_map[period].items()] def get_model_choices(self, period): """返回 [(Model Name, Model Name), ...]""" if period not in self.period_model_map: return [] return [(k, k) for k in self.period_model_map[period].keys()] def get_paths_for_model(self, period, model_name): return self.period_model_map.get(period, {}).get(model_name, []) def get_dataset_choices(self, period): """返回某个period下的所有数据集列表 [(Dataset Name, Dataset Name), ...]""" if period not in self.period_dataset_map: return [] return [(k, k) for k in sorted(self.period_dataset_map[period].keys())] def get_paths_for_model_and_datasets(self, period, model_name, dataset_names): """根据模型名称和数据集名称列表,返回对应的文件路径列表""" if period not in self.period_model_dataset_map: return [] if model_name not in self.period_model_dataset_map[period]: return [] paths = [] for dataset_name in dataset_names: if dataset_name in self.period_model_dataset_map[period][model_name]: paths.extend(self.period_model_dataset_map[period][model_name][dataset_name]) return paths if __name__ == "__main__": # dm = DataManager("data") # periods = dm.get_available_periods() # print(f"Total records loaded: {len(dm.master_df)}") # print(f"Available periods: {periods}") # print(f"Available columns: {dm.get_available_columns('2025-11')}") # result = dm.query( # period="2025-11", # metric_code="cr", # param_range=(0, 20), # model_groups=["7b"], # visible_columns=["wikipedia_english"], # ) # print(result.head(20)) lcm = LongContextDataManager("longctx_data")