Sneha7 commited on
Commit
cdc84bc
Β·
verified Β·
1 Parent(s): e354192

Update policy.py

Browse files
Files changed (1) hide show
  1. policy.py +4 -19
policy.py CHANGED
@@ -10,12 +10,12 @@ CHECKPOINT_DIR = "checkpoints"
10
  def load_policy_model(lr: float = 1e-6):
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
 
13
- # Trainable policy model
14
  policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
15
  policy_model.to("cuda")
16
  policy_model.train()
17
 
18
- # Only train lm_head
19
  for name, param in policy_model.named_parameters():
20
  param.requires_grad = ("lm_head" in name)
21
 
@@ -25,20 +25,13 @@ def load_policy_model(lr: float = 1e-6):
25
  )
26
  policy_model.optimizer = optimizer
27
 
28
- # Frozen generation model
29
  gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
30
- gen_model.to("cuda")
31
  gen_model.eval()
32
  for p in gen_model.parameters():
33
  p.requires_grad_(False)
34
 
35
- # Frozen reference model (can just deepcopy gen_model)
36
- ref_model = copy.deepcopy(gen_model)
37
- ref_model.eval()
38
- for p in ref_model.parameters():
39
- p.requires_grad_(False)
40
-
41
- return policy_model, gen_model, ref_model, tokenizer
42
 
43
 
44
  def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
@@ -55,11 +48,3 @@ def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
55
  path,
56
  )
57
  print(f"[CKPT] Saved checkpoint at {path}")
58
-
59
-
60
- def load_checkpoint(policy_model, optimizer, ckpt_path: str):
61
- ckpt = torch.load(ckpt_path, map_location="cuda")
62
- policy_model.load_state_dict(ckpt["model_state_dict"])
63
- if optimizer is not None and ckpt.get("optimizer_state_dict") is not None:
64
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
65
- print(f"[CKPT] Loaded checkpoint from {ckpt_path} at step={ckpt.get('step')}")
 
10
  def load_policy_model(lr: float = 1e-6):
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
 
13
+ # Trainable policy model on GPU
14
  policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
15
  policy_model.to("cuda")
16
  policy_model.train()
17
 
18
+ # Train only lm_head
19
  for name, param in policy_model.named_parameters():
20
  param.requires_grad = ("lm_head" in name)
21
 
 
25
  )
26
  policy_model.optimizer = optimizer
27
 
28
+ # Frozen generation model on CPU (no .to("cuda"))
29
  gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
30
  gen_model.eval()
31
  for p in gen_model.parameters():
32
  p.requires_grad_(False)
33
 
34
+ return policy_model, gen_model, tokenizer
 
 
 
 
 
 
35
 
36
 
37
  def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
 
48
  path,
49
  )
50
  print(f"[CKPT] Saved checkpoint at {path}")