Spaces:
Sleeping
Sleeping
Commit
·
a2f004f
1
Parent(s):
c67b794
fix
Browse files- src/backend.py +15 -18
src/backend.py
CHANGED
|
@@ -133,33 +133,30 @@ def pattern_match(patterns, source_list):
|
|
| 133 |
|
| 134 |
def _backend_routine():
|
| 135 |
# List only the text classification models
|
| 136 |
-
rl_models =
|
| 137 |
logger.info(f"Found {len(rl_models)} RL models")
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
pending_models = list(set(rl_models) - set(evaluated_models))
|
| 143 |
-
pending_and_compatible_models = []
|
| 144 |
-
for repo_id, sha in pending_models:
|
| 145 |
-
try:
|
| 146 |
-
siblings = API.model_info(repo_id, revision="main").siblings
|
| 147 |
-
except Exception:
|
| 148 |
-
continue
|
| 149 |
-
filenames = [sib.rfilename for sib in siblings]
|
| 150 |
if "agent.pt" in filenames:
|
| 151 |
-
|
| 152 |
|
| 153 |
-
logger.info(f"Found {len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
if len(
|
| 156 |
return None
|
| 157 |
|
| 158 |
# Shuffle the dataset
|
| 159 |
-
random.shuffle(
|
| 160 |
|
| 161 |
# Select a random model
|
| 162 |
-
repo_id, sha =
|
| 163 |
user_id, model_id = repo_id.split("/")
|
| 164 |
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
|
| 165 |
|
|
|
|
| 133 |
|
| 134 |
def _backend_routine():
|
| 135 |
# List only the text classification models
|
| 136 |
+
rl_models = API.list_models(filter=["reinforcement-learning"])
|
| 137 |
logger.info(f"Found {len(rl_models)} RL models")
|
| 138 |
+
|
| 139 |
+
compatible_models = []
|
| 140 |
+
for model in rl_models:
|
| 141 |
+
filenames = [sib.rfilename for sib in model.siblings]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
if "agent.pt" in filenames:
|
| 143 |
+
compatible_models.append((model.modelId, model.sha))
|
| 144 |
|
| 145 |
+
logger.info(f"Found {len(compatible_models)} compatible models")
|
| 146 |
+
|
| 147 |
+
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks")
|
| 148 |
+
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
|
| 149 |
+
pending_models = list(set(compatible_models) - set(evaluated_models))
|
| 150 |
+
logger.info(f"Found {len(pending_models)} pending models")
|
| 151 |
|
| 152 |
+
if len(pending_models) == 0:
|
| 153 |
return None
|
| 154 |
|
| 155 |
# Shuffle the dataset
|
| 156 |
+
random.shuffle(pending_models)
|
| 157 |
|
| 158 |
# Select a random model
|
| 159 |
+
repo_id, sha = pending_models.pop()
|
| 160 |
user_id, model_id = repo_id.split("/")
|
| 161 |
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
|
| 162 |
|