Sneha7's picture
Update grpo_train.py
c8e38f6 verified
import torch
import torch.nn.functional as F
def grpo_step(
policy_model,
gen_model,
tokenizer,
prompt,
reward_fn,
beta: float = 1e-3,
eps_clip: float = 0.2,
group_size: int = 4,
):
device = policy_model.device
# 1) Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
inputs_gpu = {k: v.to(device) for k, v in inputs.items()}
input_ids_gpu = inputs_gpu["input_ids"]
attn_gpu = inputs_gpu.get("attention_mask", None)
input_ids_gpu = input_ids_gpu.repeat_interleave(group_size, dim=0)
if attn_gpu is not None:
attn_gpu = attn_gpu.repeat_interleave(group_size, dim=0)
# CPU copy for gen_model
input_ids_cpu = input_ids_gpu.cpu()
attn_cpu = attn_gpu.cpu() if attn_gpu is not None else None
gen_inputs = {"input_ids": input_ids_cpu}
if attn_cpu is not None:
gen_inputs["attention_mask"] = attn_cpu
# 2) Generate on CPU
with torch.no_grad():
gen_output = gen_model.generate(
**gen_inputs,
max_new_tokens=64,
do_sample=True,
top_p=0.9,
top_k=50,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False,
)
sequences_cpu = gen_output.sequences
sequences = sequences_cpu.to(device)
texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in sequences_cpu]
rewards = torch.tensor(
[reward_fn(text) for text in texts],
device=device,
dtype=torch.float32,
).clamp(-2.0, 2.0)
# 3) Group-normalized advantages
group_mean = rewards.mean()
group_std = rewards.std(unbiased=False) + 1e-8
advantages = (rewards - group_mean) / group_std
advantages = advantages.clamp(-5.0, 5.0)
orig_len = inputs["input_ids"].shape[1]
# 4) Ref logprobs (no grad)
with torch.no_grad():
ref_out = policy_model(sequences)
ref_logits = ref_out.logits[:, :-1, :]
ref_logprobs = F.log_softmax(ref_logits, dim=-1)
ref_lp_all = ref_logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1)
ref_lp_gen = ref_lp_all[:, orig_len - 1 :]
# 5) Current policy logprobs (with grad)
out = policy_model(sequences)
logits = out.logits[:, :-1, :]
logprobs = F.log_softmax(logits, dim=-1)
lp_all = logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1)
lp_gen = lp_all[:, orig_len - 1 :]
if (
not torch.isfinite(lp_gen).all()
or not torch.isfinite(ref_lp_gen).all()
):
best_idx = int(torch.argmax(rewards).item())
return {
"text": texts[best_idx],
"reward": float(rewards.mean().item()),
"kl": 0.0,
"loss": 0.0,
}
# 6) Ratios, KL, loss (no in-place ops)
log_ratio = (lp_gen - ref_lp_gen).mean(dim=1)
log_ratio = log_ratio.clamp(-10.0, 10.0)
ratio = torch.exp(log_ratio)
ratio = ratio.clamp(0.0, 10.0)
kl_per_sample = (lp_gen - ref_lp_gen).mean(dim=1)
kl_per_sample = kl_per_sample.clamp(-10.0, 10.0)
kl_scalar = kl_per_sample.abs().mean()
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
kl_loss = beta * kl_scalar
loss = policy_loss + kl_loss
if hasattr(policy_model, "optimizer") and policy_model.optimizer is not None:
if not torch.isfinite(loss):
best_idx = int(torch.argmax(rewards).item())
return {
"text": texts[best_idx],
"reward": float(rewards.mean().item()),
"kl": float(kl_scalar.item()),
"loss": float(loss.detach().cpu().item()),
}
policy_model.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 0.5)
policy_model.optimizer.step()
best_idx = int(torch.argmax(rewards).item())
best_text = texts[best_idx]
return {
"text": best_text,
"reward": float(rewards.mean().item()),
"kl": float(kl_scalar.item()),
"loss": float(loss.detach().cpu().item()),
}