Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,12 +5,10 @@ from policy import load_policy_model, save_checkpoint
|
|
| 5 |
from reward_fn import reward_fn
|
| 6 |
from grpo_train import grpo_step
|
| 7 |
|
| 8 |
-
|
| 9 |
-
# Load models
|
| 10 |
-
policy_model, gen_model, ref_model, tokenizer = load_policy_model()
|
| 11 |
|
| 12 |
reward_history = []
|
| 13 |
-
global_step = 0
|
| 14 |
|
| 15 |
|
| 16 |
def plot_rewards(history):
|
|
@@ -29,7 +27,6 @@ def run_step(prompt):
|
|
| 29 |
result = grpo_step(
|
| 30 |
policy_model=policy_model,
|
| 31 |
gen_model=gen_model,
|
| 32 |
-
ref_model=ref_model,
|
| 33 |
tokenizer=tokenizer,
|
| 34 |
prompt=prompt,
|
| 35 |
reward_fn=reward_fn,
|
|
@@ -38,8 +35,8 @@ def run_step(prompt):
|
|
| 38 |
reward_history.append(float(result["reward"]))
|
| 39 |
reward_plot = plot_rewards(reward_history)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
|
| 45 |
|
|
|
|
| 5 |
from reward_fn import reward_fn
|
| 6 |
from grpo_train import grpo_step
|
| 7 |
|
| 8 |
+
policy_model, gen_model, tokenizer = load_policy_model()
|
|
|
|
|
|
|
| 9 |
|
| 10 |
reward_history = []
|
| 11 |
+
global_step = 0
|
| 12 |
|
| 13 |
|
| 14 |
def plot_rewards(history):
|
|
|
|
| 27 |
result = grpo_step(
|
| 28 |
policy_model=policy_model,
|
| 29 |
gen_model=gen_model,
|
|
|
|
| 30 |
tokenizer=tokenizer,
|
| 31 |
prompt=prompt,
|
| 32 |
reward_fn=reward_fn,
|
|
|
|
| 35 |
reward_history.append(float(result["reward"]))
|
| 36 |
reward_plot = plot_rewards(reward_history)
|
| 37 |
|
| 38 |
+
if global_step % 10 == 0:
|
| 39 |
+
save_checkpoint(policy_model, global_step)
|
| 40 |
|
| 41 |
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
|
| 42 |
|