Sneha7 commited on
Commit
55bd1b0
Β·
verified Β·
1 Parent(s): c966873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -5,12 +5,10 @@ from policy import load_policy_model, save_checkpoint
5
  from reward_fn import reward_fn
6
  from grpo_train import grpo_step
7
 
8
-
9
- # Load models
10
- policy_model, gen_model, ref_model, tokenizer = load_policy_model()
11
 
12
  reward_history = []
13
- global_step = 0 # simple in‑memory step counter
14
 
15
 
16
  def plot_rewards(history):
@@ -29,7 +27,6 @@ def run_step(prompt):
29
  result = grpo_step(
30
  policy_model=policy_model,
31
  gen_model=gen_model,
32
- ref_model=ref_model,
33
  tokenizer=tokenizer,
34
  prompt=prompt,
35
  reward_fn=reward_fn,
@@ -38,8 +35,8 @@ def run_step(prompt):
38
  reward_history.append(float(result["reward"]))
39
  reward_plot = plot_rewards(reward_history)
40
 
41
- # Save checkpoint
42
- save_checkpoint(policy_model, global_step)
43
 
44
  return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
45
 
 
5
  from reward_fn import reward_fn
6
  from grpo_train import grpo_step
7
 
8
+ policy_model, gen_model, tokenizer = load_policy_model()
 
 
9
 
10
  reward_history = []
11
+ global_step = 0
12
 
13
 
14
  def plot_rewards(history):
 
27
  result = grpo_step(
28
  policy_model=policy_model,
29
  gen_model=gen_model,
 
30
  tokenizer=tokenizer,
31
  prompt=prompt,
32
  reward_fn=reward_fn,
 
35
  reward_history.append(float(result["reward"]))
36
  reward_plot = plot_rewards(reward_history)
37
 
38
+ if global_step % 10 == 0:
39
+ save_checkpoint(policy_model, global_step)
40
 
41
  return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
42