[train] Vectorize controller-side training-batch collation (SFT + RL)#1808
[train] Vectorize controller-side training-batch collation (SFT + RL)#1808dyurk-lila wants to merge 1 commit into
Conversation
# What does this PR do? The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. **Outputs are bit-identical** (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change. ## Changes - **`PackedDataCollator` (Megatron SFT FFD packing), `skyrl/train/dataset/collators.py`** — the per-bin packed row tensors (`sequences` / `attention_mask` / `loss_mask`) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and `total_nonpad` is a single vectorized reduction. - **`collate_sft_batch` / `DefaultCollator` (unpacked SFT), `skyrl/train/sft_trainer.py`** — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list. - **`convert_prompts_responses_to_batch_tensors` (RL), `skyrl/train/dataset/preprocess.py`** — the left-padded `sequences` are built with two slice copies per row, and the fixed-width `attention_mask` / `action_mask` / `loss_mask` / `rewards` / `logprobs` tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops. This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all `RayPPOTrainer` subclasses (sync / async / full-context / agentic), since none override `convert_to_training_input`. ### Intentionally out of scope - **MoE router-replay (`rollout_expert_indices`).** The optional `rollout_expert_indices` branch in `convert_prompts_responses_to_batch_tensors` is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded). - **Eval path.** Packing only fires on the training-step batch (`batch_size == self.batch_size`); on the eval path `PackedDataCollator` delegates to the un-packed `DefaultCollator`, so eval collation is unchanged by this PR. Notes on the bit-identical claim: - dtypes are preserved exactly: `int64` / `torch.long` for `sequences` / masks (incl. `action_mask`), `float32` / `torch.float` for `loss_mask` / `rewards` / `logprobs`. `dtype=np.int64` is pinned explicitly (NumPy's platform-default int is `int32` on Windows). - The RL reward path accepts Python lists and `float32` reward tensors (what the reward postprocessing produces today). A `requires_grad`, CUDA, or `bfloat16` reward tensor is not accepted; no reward producer in the repo emits those. - The `PackedDataCollator` loss-mask write window keeps the original `row_p < max_packed_len - 1` clamp. That `min()` is a defensive no-op — `max_packed_len` is `>=` every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly. ## Benchmarks Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths): | Path | Before | After | Speedup | |------|--------|-------|---------| | `PackedDataCollator` (FFD, dp=8) | 288.5 ms | 8.4 ms | ~34x | | `convert_prompts_responses_to_batch_tensors` (RL) | 92.8 ms | 14.5 ms | ~6.4x | | `collate_sft_batch` (unpacked SFT) | 98.4 ms | 16.2 ms | ~6x | ## Test plan - [x] New `tests/train/test_collation_vectorization_equivalence.py`: pins a faithful reference of each *original* loop and fuzzes the vectorized output against it with `torch.equal` plus explicit per-tensor `dtype` assertions — RL (with/without logprobs, list and `float32`-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because `torch.equal` is dtype-insensitive on matching values, the integer/float dtypes (`action_mask`/`attention_mask` `int64`/`long`, `loss_mask`/`rewards` `float32`) are pinned with explicit `.dtype` assertions. The packed test re-derives the FFD / DP-shard / `max_packed_len` decision inline as its own oracle, so any production drift surfaces as a `torch.equal` mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable `loss_mask` clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.) - [x] Existing `tests/train/dataset/test_preprocess.py`, `tests/train/test_sft_packing_collate.py`, `tests/train/test_packing_round_trip.py`, `tests/train/test_sft_tokenization.py` pass unchanged. - [x] `ruff` + `black` clean. ```bash uv run --isolated --extra dev --extra megatron -- pytest \ tests/train/test_collation_vectorization_equivalence.py \ tests/train/dataset/test_preprocess.py \ tests/train/test_sft_packing_collate.py \ tests/train/test_packing_round_trip.py ``` > Heads-up for reviewers: this overlaps open PR NovaSky-AI#1752 ([train] VLM SFT on Megatron), which edits the same `collate_sft_batch` loop and `TrainingInputBatch` dict to collect `pixel_values` / `image_grid_thw`. Whichever lands second needs a small rebase; if NovaSky-AI#1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized `for i, ex in enumerate(examples):` loop and its two keys re-added to the `from_numpy` batch dict. 🤖 Generated with [Claude Code](https://claude.com/claude-code)
There was a problem hiding this comment.
Code Review
This pull request vectorizes the collation and preprocessing pipelines in SkyRL by replacing per-token and per-sample Python loops with NumPy slice assignments and broadcast operations, followed by a single conversion to PyTorch tensors. It also adds a comprehensive test suite to verify bit-identical equivalence with the original implementations. The review feedback highlights potential runtime errors when inputs like rewards, logprobs, or loss_masks are PyTorch tensors (especially if they reside on GPU or require gradients), and suggests defensively detaching and moving them to CPU before converting to NumPy arrays.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for i, custom_reward in enumerate(rewards): | ||
| if isinstance(custom_reward, list): | ||
| custom_reward = torch.tensor(custom_reward) | ||
| ret_rewards[i, max_response - len(custom_reward) :] = custom_reward | ||
| reward_arr = np.asarray(custom_reward, dtype=np.float32) | ||
| ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr |
There was a problem hiding this comment.
In convert_prompts_responses_to_batch_tensors, the rewards input can contain PyTorch tensors (e.g., from a reward model running on GPU or requiring gradients). Converting them directly via np.asarray will raise a TypeError if they are on GPU, or a RuntimeError if they require gradients.
To ensure robustness and prevent runtime crashes, we should defensively check if each reward is a PyTorch tensor, and if so, detach it and move it to CPU before converting to a NumPy array.
| for i, custom_reward in enumerate(rewards): | |
| if isinstance(custom_reward, list): | |
| custom_reward = torch.tensor(custom_reward) | |
| ret_rewards[i, max_response - len(custom_reward) :] = custom_reward | |
| reward_arr = np.asarray(custom_reward, dtype=np.float32) | |
| ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr | |
| for i, custom_reward in enumerate(rewards): | |
| if isinstance(custom_reward, torch.Tensor): | |
| reward_arr = custom_reward.detach().cpu().numpy().astype(np.float32) | |
| else: | |
| reward_arr = np.asarray(custom_reward, dtype=np.float32) | |
| ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr |
| for i, sample_logprobs in enumerate(logprobs): | ||
| lp = torch.tensor(sample_logprobs, dtype=torch.float) | ||
| logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp | ||
| logprobs_np[i, max_response - len(sample_logprobs) :] = sample_logprobs |
There was a problem hiding this comment.
Similarly to the rewards handling, logprobs are typically generated by the policy model on GPU and are represented as PyTorch tensors. Assigning a PyTorch GPU tensor directly to a NumPy array slice will raise a TypeError.
We should defensively handle PyTorch tensors for logprobs by detaching them and moving them to CPU before converting to NumPy.
| for i, sample_logprobs in enumerate(logprobs): | |
| lp = torch.tensor(sample_logprobs, dtype=torch.float) | |
| logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp | |
| logprobs_np[i, max_response - len(sample_logprobs) :] = sample_logprobs | |
| for i, sample_logprobs in enumerate(logprobs): | |
| if isinstance(sample_logprobs, torch.Tensor): | |
| lp_arr = sample_logprobs.detach().cpu().numpy().astype(np.float32) | |
| else: | |
| lp_arr = np.asarray(sample_logprobs, dtype=np.float32) | |
| logprobs_np[i, max_response - lp_arr.shape[0] :] = lp_arr |
| for i, lm in enumerate(loss_masks): | ||
| ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float) | ||
| ret_loss_masks_np[i, max_response - len(lm) :] = lm |
There was a problem hiding this comment.
For consistency and robustness, we should defensively handle the case where loss_masks elements are PyTorch tensors (e.g., if they are loaded or processed as tensors). Detaching and moving them to CPU before converting to NumPy prevents potential runtime errors.
| for i, lm in enumerate(loss_masks): | |
| ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float) | |
| ret_loss_masks_np[i, max_response - len(lm) :] = lm | |
| for i, lm in enumerate(loss_masks): | |
| if isinstance(lm, torch.Tensor): | |
| lm_arr = lm.detach().cpu().numpy().astype(np.float32) | |
| else: | |
| lm_arr = np.asarray(lm, dtype=np.float32) | |
| ret_loss_masks_np[i, max_response - lm_arr.shape[0] :] = lm_arr |
What does this PR do?
The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. Outputs are bit-identical (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.
Changes
PackedDataCollator(Megatron SFT FFD packing),skyrl/train/dataset/collators.py— the per-bin packed row tensors (sequences/attention_mask/loss_mask) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, andtotal_nonpadis a single vectorized reduction.collate_sft_batch/DefaultCollator(unpacked SFT),skyrl/train/sft_trainer.py— each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.convert_prompts_responses_to_batch_tensors(RL),skyrl/train/dataset/preprocess.py— the left-paddedsequencesare built with two slice copies per row, and the fixed-widthattention_mask/action_mask/loss_mask/rewards/logprobstensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all
RayPPOTrainersubclasses (sync / async / full-context / agentic), since none overrideconvert_to_training_input.Intentionally out of scope
rollout_expert_indices). The optionalrollout_expert_indicesbranch inconvert_prompts_responses_to_batch_tensorsis left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).batch_size == self.batch_size); on the eval pathPackedDataCollatordelegates to the un-packedDefaultCollator, so eval collation is unchanged by this PR.Notes on the bit-identical claim:
int64/torch.longforsequences/ masks (incl.action_mask),float32/torch.floatforloss_mask/rewards/logprobs.dtype=np.int64is pinned explicitly (NumPy's platform-default int isint32on Windows).float32reward tensors (what the reward postprocessing produces today). Arequires_grad, CUDA, orbfloat16reward tensor is not accepted; no reward producer in the repo emits those.PackedDataCollatorloss-mask write window keeps the originalrow_p < max_packed_len - 1clamp. Thatmin()is a defensive no-op —max_packed_lenis>=every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.Benchmarks
Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):
PackedDataCollator(FFD, dp=8)convert_prompts_responses_to_batch_tensors(RL)collate_sft_batch(unpacked SFT)Test plan
tests/train/test_collation_vectorization_equivalence.py: pins a faithful reference of each original loop and fuzzes the vectorized output against it withtorch.equalplus explicit per-tensordtypeassertions — RL (with/without logprobs, list andfloat32-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Becausetorch.equalis dtype-insensitive on matching values, the integer/float dtypes (action_mask/attention_maskint64/long,loss_mask/rewardsfloat32) are pinned with explicit.dtypeassertions. The packed test re-derives the FFD / DP-shard /max_packed_lendecision inline as its own oracle, so any production drift surfaces as atorch.equalmismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachableloss_maskclamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)tests/train/dataset/test_preprocess.py,tests/train/test_sft_packing_collate.py,tests/train/test_packing_round_trip.py,tests/train/test_sft_tokenization.pypass unchanged.ruff+blackclean.🤖 Generated with Claude Code