[algorithm] add rollout KL loss + fix padding-microbatch field drops#1811
[algorithm] add rollout KL loss + fix padding-microbatch field drops#1811erictang000 wants to merge 1 commit into
Conversation
Adds an optional rollout KL loss: a soft trust region that penalizes the train<->rollout logprob drift (the off-policy "mismatch" gap) on every token, in addition to the policy loss. Unlike the binary DPPO/clip masks (which only act on tokens crossing a threshold), it constrains the aggregate gap, which is the lever that controls drift in fully-async training. prime-rl uses the squared log importance ratio (log_probs - rollout_logprobs)**2; "k2" gives that up to a 1/2 factor folded into the coefficient. Feature: - AlgorithmConfig: use_rollout_kl_loss / rollout_kl_loss_coef / rollout_kl_estimator_type (default "k2"). - FSDP (worker.py) and Megatron (megatron_model_wrapper.py) loss paths compute the term against rollout_logprobs (mirroring the existing use_kl_loss term, which is KL vs the frozen reference model) and report a `rollout_kl` metric. - validate_cfg auto-enables generator.sampling_params.logprobs, mirroring off_policy_correction. Bug fixes (both surface as traceback-less multi-rank aborts: an exception raised inside the distributed forward_backward desyncs NCCL): - worker_utils._create_padding_microbatch dropped optional per-token fields. Token-based batching appends fully-padding microbatches built from a hardcoded key set, so rollout_logprobs was absent on them -> the loss saw None and raised. Now mirrors rollout_logprobs when the real data has it (loss_mask=0 keeps the value irrelevant). Regression test added. - The same drop left router-replay state stale on padding microbatches for MoE models: RouterReplay stays in REPLAY_FORWARD with the previous microbatch's target indices, so get_replay_topk() gathers mis-sized indices and raises. Both forward_step closures now clear the replay action on padding microbatches (gated on moe_enable_routing_replay). - trainer.convert_to_training_input now fails fast on the driver (clear traceback) when rollout_logprobs are required but missing, instead of deferring to the per-rank loss. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a rollout KL loss feature to penalize the train-rollout logprob drift in off-policy/async training. It adds configuration options, computes the rollout KL loss in both Megatron and standard workers, ensures padding microbatches carry rollout_logprobs to prevent desyncs, and adds validation checks. The feedback suggests initializing the fallback rollout KL loss tensor on the correct device in worker.py to avoid device mismatch warnings, and validating the rollout KL estimator type early in validate_cfg to fail fast.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| ) | ||
| rollout_kl_loss = masked_mean(rollout_kl_loss, loss_mask, dim=-1).mean() | ||
| else: | ||
| rollout_kl_loss = torch.tensor(0.0) |
There was a problem hiding this comment.
Initializing rollout_kl_loss on the CPU (via torch.tensor(0.0)) when use_rollout_kl_loss is False can lead to device mismatch warnings or unnecessary host-to-device copies when added to the GPU-based policy_loss. It is safer and more efficient to initialize it directly on the active GPU device.
| rollout_kl_loss = torch.tensor(0.0) | |
| rollout_kl_loss = torch.tensor(0.0, device=action_log_probs.device) |
| if cfg.trainer.algorithm.use_rollout_kl_loss and cfg.generator.sampling_params.logprobs is None: | ||
| logger.warning( | ||
| "`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled." | ||
| " Setting `logprobs` to `1`." | ||
| ) | ||
| cfg.generator.sampling_params.logprobs = 1 |
There was a problem hiding this comment.
It is recommended to validate rollout_kl_estimator_type early in validate_cfg to fail fast with a clear error message if an unsupported estimator is configured, rather than raising a runtime error deep inside the training loop.
| if cfg.trainer.algorithm.use_rollout_kl_loss and cfg.generator.sampling_params.logprobs is None: | |
| logger.warning( | |
| "`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled." | |
| " Setting `logprobs` to `1`." | |
| ) | |
| cfg.generator.sampling_params.logprobs = 1 | |
| if cfg.trainer.algorithm.use_rollout_kl_loss: | |
| if cfg.generator.sampling_params.logprobs is None: | |
| logger.warning( | |
| "`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled." | |
| " Setting `logprobs` to `1`." | |
| ) | |
| cfg.generator.sampling_params.logprobs = 1 | |
| supported_estimators = {"k1", "abs", "k2", "k3"} | |
| if cfg.trainer.algorithm.rollout_kl_estimator_type not in supported_estimators: | |
| raise ValueError( | |
| f"Invalid rollout_kl_estimator_type: {cfg.trainer.algorithm.rollout_kl_estimator_type}. " | |
| f"Must be one of {supported_estimators}" | |
| ) |
Summary
Adds an optional rollout KL loss: a soft trust region that penalizes the train↔rollout logprob drift (the off-policy "mismatch" gap) on every token, in addition to the policy loss. Unlike the binary DPPO/clip masks — which only act on tokens crossing a threshold — it constrains the aggregate gap, which is the lever that controls drift in fully-async training. Motivated by growing logprob diff observed with DPPO. Follows prime-rl, which uses the squared log importance ratio
(log_probs - rollout_logprobs)**2; the"k2"estimator gives that up to a 1/2 factor folded into the coefficient.This composes with any
policy_loss_typeand is independent ofuse_kl_loss(which is KL vs the frozen reference model — this is KL vs the rollout/behavior policy).Feature
AlgorithmConfig:use_rollout_kl_loss(defaultFalse),rollout_kl_loss_coef(default0.0),rollout_kl_estimator_type(default"k2").worker.py) and Megatron (megatron_model_wrapper.py) loss paths compute the term againstrollout_logprobs, mirroring the existinguse_kl_lossterm, and report arollout_klmetric.validate_cfgauto-enablesgenerator.sampling_params.logprobs, mirroring the existingoff_policy_correctionbehavior.Usage:
Bug fixes (included)
While validating the feature, two latent bugs surfaced. Both manifest as traceback-less multi-rank aborts: an exception raised inside the distributed
forward_backwarddesyncs NCCL and kills all ranks with no Python traceback (in sync training the same path raises a cleanRayTaskError)._create_padding_microbatchdropped optional per-token fields. Token-based batching appends fully-padding microbatches built from a hardcoded key set, sorollout_logprobswas absent on them → the loss sawNoneand raised. Now mirrorsrollout_logprobswhen the real data has it (loss_mask=0keeps the value irrelevant). Regression test added.The same drop left router-replay state stale on padding microbatches for MoE models.
RouterReplaystays inREPLAY_FORWARDwith the previous microbatch's target indices, soget_replay_topk()gathers mis-sized indices against the padding microbatch's scores and raises. Bothforward_stepclosures now clear the replay action on padding microbatches (gated onmoe_enable_routing_replay).trainer.convert_to_training_inputnow fails fast on the driver (clear traceback) whenrollout_logprobsare required but missing, instead of deferring to the per-rank loss inside the collective.Testing
tests/backends/skyrl_train/test_token_based_batching_utils.py— addedtest_padding_microbatch_carries_rollout_logprobs(both directions). Full file passes (17/17) viauv run --isolated --extra dev --extra fsdp pytest.RouterReplaysource, not yet validated on an MoE run. Recommend a smoke test on the 30B-A3B recipe (token batching +moe_enable_routing_replay=true) past the first uneven-DP step before relying on it.🤖 Generated with Claude Code