import torch from transformers import AutoModelForCausalLM, AutoTokenizer import copy import os MODEL_NAME = "microsoft/phi-2" CHECKPOINT_DIR = "checkpoints" def load_policy_model(lr: float = 1e-6): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Trainable policy model on GPU policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) policy_model.to("cuda") policy_model.train() # Train only lm_head for name, param in policy_model.named_parameters(): param.requires_grad = ("lm_head" in name) optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, policy_model.parameters()), lr=lr, ) policy_model.optimizer = optimizer # Frozen generation model on CPU (no .to("cuda")) gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) gen_model.eval() for p in gen_model.parameters(): p.requires_grad_(False) return policy_model, gen_model, tokenizer def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR): os.makedirs(ckpt_dir, exist_ok=True) path = os.path.join(ckpt_dir, f"step_{step}.pt") torch.save( { "step": step, "model_state_dict": policy_model.state_dict(), "optimizer_state_dict": policy_model.optimizer.state_dict() if hasattr(policy_model, "optimizer") else None, }, path, ) print(f"[CKPT] Saved checkpoint at {path}")