import torch import torch.nn.functional as F def grpo_step( policy_model, gen_model, tokenizer, prompt, reward_fn, beta: float = 1e-3, eps_clip: float = 0.2, group_size: int = 4, ): device = policy_model.device # 1) Tokenize inputs = tokenizer(prompt, return_tensors="pt") inputs_gpu = {k: v.to(device) for k, v in inputs.items()} input_ids_gpu = inputs_gpu["input_ids"] attn_gpu = inputs_gpu.get("attention_mask", None) input_ids_gpu = input_ids_gpu.repeat_interleave(group_size, dim=0) if attn_gpu is not None: attn_gpu = attn_gpu.repeat_interleave(group_size, dim=0) # CPU copy for gen_model input_ids_cpu = input_ids_gpu.cpu() attn_cpu = attn_gpu.cpu() if attn_gpu is not None else None gen_inputs = {"input_ids": input_ids_cpu} if attn_cpu is not None: gen_inputs["attention_mask"] = attn_cpu # 2) Generate on CPU with torch.no_grad(): gen_output = gen_model.generate( **gen_inputs, max_new_tokens=64, do_sample=True, top_p=0.9, top_k=50, temperature=1.0, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=False, ) sequences_cpu = gen_output.sequences sequences = sequences_cpu.to(device) texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in sequences_cpu] rewards = torch.tensor( [reward_fn(text) for text in texts], device=device, dtype=torch.float32, ).clamp(-2.0, 2.0) # 3) Group-normalized advantages group_mean = rewards.mean() group_std = rewards.std(unbiased=False) + 1e-8 advantages = (rewards - group_mean) / group_std advantages = advantages.clamp(-5.0, 5.0) orig_len = inputs["input_ids"].shape[1] # 4) Ref logprobs (no grad) with torch.no_grad(): ref_out = policy_model(sequences) ref_logits = ref_out.logits[:, :-1, :] ref_logprobs = F.log_softmax(ref_logits, dim=-1) ref_lp_all = ref_logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1) ref_lp_gen = ref_lp_all[:, orig_len - 1 :] # 5) Current policy logprobs (with grad) out = policy_model(sequences) logits = out.logits[:, :-1, :] logprobs = F.log_softmax(logits, dim=-1) lp_all = logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1) lp_gen = lp_all[:, orig_len - 1 :] if ( not torch.isfinite(lp_gen).all() or not torch.isfinite(ref_lp_gen).all() ): best_idx = int(torch.argmax(rewards).item()) return { "text": texts[best_idx], "reward": float(rewards.mean().item()), "kl": 0.0, "loss": 0.0, } # 6) Ratios, KL, loss (no in-place ops) log_ratio = (lp_gen - ref_lp_gen).mean(dim=1) log_ratio = log_ratio.clamp(-10.0, 10.0) ratio = torch.exp(log_ratio) ratio = ratio.clamp(0.0, 10.0) kl_per_sample = (lp_gen - ref_lp_gen).mean(dim=1) kl_per_sample = kl_per_sample.clamp(-10.0, 10.0) kl_scalar = kl_per_sample.abs().mean() surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) * advantages policy_loss = -torch.min(surr1, surr2).mean() kl_loss = beta * kl_scalar loss = policy_loss + kl_loss if hasattr(policy_model, "optimizer") and policy_model.optimizer is not None: if not torch.isfinite(loss): best_idx = int(torch.argmax(rewards).item()) return { "text": texts[best_idx], "reward": float(rewards.mean().item()), "kl": float(kl_scalar.item()), "loss": float(loss.detach().cpu().item()), } policy_model.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 0.5) policy_model.optimizer.step() best_idx = int(torch.argmax(rewards).item()) best_text = texts[best_idx] return { "text": best_text, "reward": float(rewards.mean().item()), "kl": float(kl_scalar.item()), "loss": float(loss.detach().cpu().item()), }