Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| def grpo_step( | |
| policy_model, | |
| gen_model, | |
| tokenizer, | |
| prompt, | |
| reward_fn, | |
| beta: float = 1e-3, | |
| eps_clip: float = 0.2, | |
| group_size: int = 4, | |
| ): | |
| device = policy_model.device | |
| # 1) Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs_gpu = {k: v.to(device) for k, v in inputs.items()} | |
| input_ids_gpu = inputs_gpu["input_ids"] | |
| attn_gpu = inputs_gpu.get("attention_mask", None) | |
| input_ids_gpu = input_ids_gpu.repeat_interleave(group_size, dim=0) | |
| if attn_gpu is not None: | |
| attn_gpu = attn_gpu.repeat_interleave(group_size, dim=0) | |
| # CPU copy for gen_model | |
| input_ids_cpu = input_ids_gpu.cpu() | |
| attn_cpu = attn_gpu.cpu() if attn_gpu is not None else None | |
| gen_inputs = {"input_ids": input_ids_cpu} | |
| if attn_cpu is not None: | |
| gen_inputs["attention_mask"] = attn_cpu | |
| # 2) Generate on CPU | |
| with torch.no_grad(): | |
| gen_output = gen_model.generate( | |
| **gen_inputs, | |
| max_new_tokens=64, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=50, | |
| temperature=1.0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_dict_in_generate=True, | |
| output_scores=False, | |
| ) | |
| sequences_cpu = gen_output.sequences | |
| sequences = sequences_cpu.to(device) | |
| texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in sequences_cpu] | |
| rewards = torch.tensor( | |
| [reward_fn(text) for text in texts], | |
| device=device, | |
| dtype=torch.float32, | |
| ).clamp(-2.0, 2.0) | |
| # 3) Group-normalized advantages | |
| group_mean = rewards.mean() | |
| group_std = rewards.std(unbiased=False) + 1e-8 | |
| advantages = (rewards - group_mean) / group_std | |
| advantages = advantages.clamp(-5.0, 5.0) | |
| orig_len = inputs["input_ids"].shape[1] | |
| # 4) Ref logprobs (no grad) | |
| with torch.no_grad(): | |
| ref_out = policy_model(sequences) | |
| ref_logits = ref_out.logits[:, :-1, :] | |
| ref_logprobs = F.log_softmax(ref_logits, dim=-1) | |
| ref_lp_all = ref_logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1) | |
| ref_lp_gen = ref_lp_all[:, orig_len - 1 :] | |
| # 5) Current policy logprobs (with grad) | |
| out = policy_model(sequences) | |
| logits = out.logits[:, :-1, :] | |
| logprobs = F.log_softmax(logits, dim=-1) | |
| lp_all = logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1) | |
| lp_gen = lp_all[:, orig_len - 1 :] | |
| if ( | |
| not torch.isfinite(lp_gen).all() | |
| or not torch.isfinite(ref_lp_gen).all() | |
| ): | |
| best_idx = int(torch.argmax(rewards).item()) | |
| return { | |
| "text": texts[best_idx], | |
| "reward": float(rewards.mean().item()), | |
| "kl": 0.0, | |
| "loss": 0.0, | |
| } | |
| # 6) Ratios, KL, loss (no in-place ops) | |
| log_ratio = (lp_gen - ref_lp_gen).mean(dim=1) | |
| log_ratio = log_ratio.clamp(-10.0, 10.0) | |
| ratio = torch.exp(log_ratio) | |
| ratio = ratio.clamp(0.0, 10.0) | |
| kl_per_sample = (lp_gen - ref_lp_gen).mean(dim=1) | |
| kl_per_sample = kl_per_sample.clamp(-10.0, 10.0) | |
| kl_scalar = kl_per_sample.abs().mean() | |
| surr1 = ratio * advantages | |
| surr2 = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) * advantages | |
| policy_loss = -torch.min(surr1, surr2).mean() | |
| kl_loss = beta * kl_scalar | |
| loss = policy_loss + kl_loss | |
| if hasattr(policy_model, "optimizer") and policy_model.optimizer is not None: | |
| if not torch.isfinite(loss): | |
| best_idx = int(torch.argmax(rewards).item()) | |
| return { | |
| "text": texts[best_idx], | |
| "reward": float(rewards.mean().item()), | |
| "kl": float(kl_scalar.item()), | |
| "loss": float(loss.detach().cpu().item()), | |
| } | |
| policy_model.optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 0.5) | |
| policy_model.optimizer.step() | |
| best_idx = int(torch.argmax(rewards).item()) | |
| best_text = texts[best_idx] | |
| return { | |
| "text": best_text, | |
| "reward": float(rewards.mean().item()), | |
| "kl": float(kl_scalar.item()), | |
| "loss": float(loss.detach().cpu().item()), | |
| } | |