Skip to content

[algorithm] add rollout KL loss + fix padding-microbatch field drops#1811

Open
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:rollout-kl-loss
Open

[algorithm] add rollout KL loss + fix padding-microbatch field drops#1811
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:rollout-kl-loss

Conversation

@erictang000

Copy link
Copy Markdown
Collaborator

Summary

Adds an optional rollout KL loss: a soft trust region that penalizes the train↔rollout logprob drift (the off-policy "mismatch" gap) on every token, in addition to the policy loss. Unlike the binary DPPO/clip masks — which only act on tokens crossing a threshold — it constrains the aggregate gap, which is the lever that controls drift in fully-async training. Motivated by growing logprob diff observed with DPPO. Follows prime-rl, which uses the squared log importance ratio (log_probs - rollout_logprobs)**2; the "k2" estimator gives that up to a 1/2 factor folded into the coefficient.

This composes with any policy_loss_type and is independent of use_kl_loss (which is KL vs the frozen reference model — this is KL vs the rollout/behavior policy).

Feature

  • AlgorithmConfig: use_rollout_kl_loss (default False), rollout_kl_loss_coef (default 0.0), rollout_kl_estimator_type (default "k2").
  • FSDP (worker.py) and Megatron (megatron_model_wrapper.py) loss paths compute the term against rollout_logprobs, mirroring the existing use_kl_loss term, and report a rollout_kl metric.
  • validate_cfg auto-enables generator.sampling_params.logprobs, mirroring the existing off_policy_correction behavior.

Usage:

trainer.algorithm.use_rollout_kl_loss=true
trainer.algorithm.rollout_kl_loss_coef=<tune, e.g. 1e-3 .. 1.0>

Bug fixes (included)

While validating the feature, two latent bugs surfaced. Both manifest as traceback-less multi-rank aborts: an exception raised inside the distributed forward_backward desyncs NCCL and kills all ranks with no Python traceback (in sync training the same path raises a clean RayTaskError).

  1. _create_padding_microbatch dropped optional per-token fields. Token-based batching appends fully-padding microbatches built from a hardcoded key set, so rollout_logprobs was absent on them → the loss saw None and raised. Now mirrors rollout_logprobs when the real data has it (loss_mask=0 keeps the value irrelevant). Regression test added.

  2. The same drop left router-replay state stale on padding microbatches for MoE models. RouterReplay stays in REPLAY_FORWARD with the previous microbatch's target indices, so get_replay_topk() gathers mis-sized indices against the padding microbatch's scores and raises. Both forward_step closures now clear the replay action on padding microbatches (gated on moe_enable_routing_replay).

  3. trainer.convert_to_training_input now fails fast on the driver (clear traceback) when rollout_logprobs are required but missing, instead of deferring to the per-rank loss inside the collective.

Testing

  • tests/backends/skyrl_train/test_token_based_batching_utils.py — added test_padding_microbatch_carries_rollout_logprobs (both directions). Full file passes (17/17) via uv run --isolated --extra dev --extra fsdp pytest.
  • The Megatron loss-combination and MoE router-replay paths are GPU-only; the router-replay fix is reasoned from the megatron-core RouterReplay source, not yet validated on an MoE run. Recommend a smoke test on the 30B-A3B recipe (token batching + moe_enable_routing_replay=true) past the first uneven-DP step before relying on it.

🤖 Generated with Claude Code

Adds an optional rollout KL loss: a soft trust region that penalizes the
train<->rollout logprob drift (the off-policy "mismatch" gap) on every token,
in addition to the policy loss. Unlike the binary DPPO/clip masks (which only
act on tokens crossing a threshold), it constrains the aggregate gap, which is
the lever that controls drift in fully-async training. prime-rl uses the
squared log importance ratio (log_probs - rollout_logprobs)**2; "k2" gives
that up to a 1/2 factor folded into the coefficient.

Feature:
- AlgorithmConfig: use_rollout_kl_loss / rollout_kl_loss_coef /
  rollout_kl_estimator_type (default "k2").
- FSDP (worker.py) and Megatron (megatron_model_wrapper.py) loss paths compute
  the term against rollout_logprobs (mirroring the existing use_kl_loss term,
  which is KL vs the frozen reference model) and report a `rollout_kl` metric.
- validate_cfg auto-enables generator.sampling_params.logprobs, mirroring
  off_policy_correction.

Bug fixes (both surface as traceback-less multi-rank aborts: an exception
raised inside the distributed forward_backward desyncs NCCL):
- worker_utils._create_padding_microbatch dropped optional per-token fields.
  Token-based batching appends fully-padding microbatches built from a
  hardcoded key set, so rollout_logprobs was absent on them -> the loss saw
  None and raised. Now mirrors rollout_logprobs when the real data has it
  (loss_mask=0 keeps the value irrelevant). Regression test added.
- The same drop left router-replay state stale on padding microbatches for MoE
  models: RouterReplay stays in REPLAY_FORWARD with the previous microbatch's
  target indices, so get_replay_topk() gathers mis-sized indices and raises.
  Both forward_step closures now clear the replay action on padding
  microbatches (gated on moe_enable_routing_replay).

- trainer.convert_to_training_input now fails fast on the driver (clear
  traceback) when rollout_logprobs are required but missing, instead of
  deferring to the per-rank loss.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a rollout KL loss feature to penalize the train-rollout logprob drift in off-policy/async training. It adds configuration options, computes the rollout KL loss in both Megatron and standard workers, ensures padding microbatches carry rollout_logprobs to prevent desyncs, and adds validation checks. The feedback suggests initializing the fallback rollout KL loss tensor on the correct device in worker.py to avoid device mismatch warnings, and validating the rollout KL estimator type early in validate_cfg to fail fast.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

)
rollout_kl_loss = masked_mean(rollout_kl_loss, loss_mask, dim=-1).mean()
else:
rollout_kl_loss = torch.tensor(0.0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initializing rollout_kl_loss on the CPU (via torch.tensor(0.0)) when use_rollout_kl_loss is False can lead to device mismatch warnings or unnecessary host-to-device copies when added to the GPU-based policy_loss. It is safer and more efficient to initialize it directly on the active GPU device.

Suggested change
rollout_kl_loss = torch.tensor(0.0)
rollout_kl_loss = torch.tensor(0.0, device=action_log_probs.device)

Comment on lines +382 to +387
if cfg.trainer.algorithm.use_rollout_kl_loss and cfg.generator.sampling_params.logprobs is None:
logger.warning(
"`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled."
" Setting `logprobs` to `1`."
)
cfg.generator.sampling_params.logprobs = 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to validate rollout_kl_estimator_type early in validate_cfg to fail fast with a clear error message if an unsupported estimator is configured, rather than raising a runtime error deep inside the training loop.

Suggested change
if cfg.trainer.algorithm.use_rollout_kl_loss and cfg.generator.sampling_params.logprobs is None:
logger.warning(
"`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled."
" Setting `logprobs` to `1`."
)
cfg.generator.sampling_params.logprobs = 1
if cfg.trainer.algorithm.use_rollout_kl_loss:
if cfg.generator.sampling_params.logprobs is None:
logger.warning(
"`generator.sampling_params.logprobs` is `None` but `use_rollout_kl_loss` is enabled."
" Setting `logprobs` to `1`."
)
cfg.generator.sampling_params.logprobs = 1
supported_estimators = {"k1", "abs", "k2", "k3"}
if cfg.trainer.algorithm.rollout_kl_estimator_type not in supported_estimators:
raise ValueError(
f"Invalid rollout_kl_estimator_type: {cfg.trainer.algorithm.rollout_kl_estimator_type}. "
f"Must be one of {supported_estimators}"
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant