Skip to content

[train] Vectorize controller-side training-batch collation (SFT + RL)#1808

Open
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:feat/vectorize-collation
Open

[train] Vectorize controller-side training-batch collation (SFT + RL)#1808
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:feat/vectorize-collation

Conversation

@dyurk-lila

Copy link
Copy Markdown

What does this PR do?

The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. Outputs are bit-identical (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.

Changes

  • PackedDataCollator (Megatron SFT FFD packing), skyrl/train/dataset/collators.py — the per-bin packed row tensors (sequences / attention_mask / loss_mask) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and total_nonpad is a single vectorized reduction.
  • collate_sft_batch / DefaultCollator (unpacked SFT), skyrl/train/sft_trainer.py — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.
  • convert_prompts_responses_to_batch_tensors (RL), skyrl/train/dataset/preprocess.py — the left-padded sequences are built with two slice copies per row, and the fixed-width attention_mask / action_mask / loss_mask / rewards / logprobs tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.

This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all RayPPOTrainer subclasses (sync / async / full-context / agentic), since none override convert_to_training_input.

Intentionally out of scope

  • MoE router-replay (rollout_expert_indices). The optional rollout_expert_indices branch in convert_prompts_responses_to_batch_tensors is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).
  • Eval path. Packing only fires on the training-step batch (batch_size == self.batch_size); on the eval path PackedDataCollator delegates to the un-packed DefaultCollator, so eval collation is unchanged by this PR.

Notes on the bit-identical claim:

  • dtypes are preserved exactly: int64 / torch.long for sequences / masks (incl. action_mask), float32 / torch.float for loss_mask / rewards / logprobs. dtype=np.int64 is pinned explicitly (NumPy's platform-default int is int32 on Windows).
  • The RL reward path accepts Python lists and float32 reward tensors (what the reward postprocessing produces today). A requires_grad, CUDA, or bfloat16 reward tensor is not accepted; no reward producer in the repo emits those.
  • The PackedDataCollator loss-mask write window keeps the original row_p < max_packed_len - 1 clamp. That min() is a defensive no-op — max_packed_len is >= every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.

Benchmarks

Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):

Path Before After Speedup
PackedDataCollator (FFD, dp=8) 288.5 ms 8.4 ms ~34x
convert_prompts_responses_to_batch_tensors (RL) 92.8 ms 14.5 ms ~6.4x
collate_sft_batch (unpacked SFT) 98.4 ms 16.2 ms ~6x

Test plan

  • New tests/train/test_collation_vectorization_equivalence.py: pins a faithful reference of each original loop and fuzzes the vectorized output against it with torch.equal plus explicit per-tensor dtype assertions — RL (with/without logprobs, list and float32-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because torch.equal is dtype-insensitive on matching values, the integer/float dtypes (action_mask/attention_mask int64/long, loss_mask/rewards float32) are pinned with explicit .dtype assertions. The packed test re-derives the FFD / DP-shard / max_packed_len decision inline as its own oracle, so any production drift surfaces as a torch.equal mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable loss_mask clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)
  • Existing tests/train/dataset/test_preprocess.py, tests/train/test_sft_packing_collate.py, tests/train/test_packing_round_trip.py, tests/train/test_sft_tokenization.py pass unchanged.
  • ruff + black clean.
uv run --isolated --extra dev --extra megatron -- pytest \
  tests/train/test_collation_vectorization_equivalence.py \
  tests/train/dataset/test_preprocess.py \
  tests/train/test_sft_packing_collate.py \
  tests/train/test_packing_round_trip.py

Heads-up for reviewers: this overlaps open PR #1752 ([train] VLM SFT on Megatron), which edits the same collate_sft_batch loop and TrainingInputBatch dict to collect pixel_values / image_grid_thw. Whichever lands second needs a small rebase; if #1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized for i, ex in enumerate(examples): loop and its two keys re-added to the from_numpy batch dict.

🤖 Generated with Claude Code

# What does this PR do?

The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. **Outputs are bit-identical** (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.

## Changes

- **`PackedDataCollator` (Megatron SFT FFD packing), `skyrl/train/dataset/collators.py`** — the per-bin packed row tensors (`sequences` / `attention_mask` / `loss_mask`) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and `total_nonpad` is a single vectorized reduction.
- **`collate_sft_batch` / `DefaultCollator` (unpacked SFT), `skyrl/train/sft_trainer.py`** — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.
- **`convert_prompts_responses_to_batch_tensors` (RL), `skyrl/train/dataset/preprocess.py`** — the left-padded `sequences` are built with two slice copies per row, and the fixed-width `attention_mask` / `action_mask` / `loss_mask` / `rewards` / `logprobs` tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.

This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all `RayPPOTrainer` subclasses (sync / async / full-context / agentic), since none override `convert_to_training_input`.

### Intentionally out of scope

- **MoE router-replay (`rollout_expert_indices`).** The optional `rollout_expert_indices` branch in `convert_prompts_responses_to_batch_tensors` is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).
- **Eval path.** Packing only fires on the training-step batch (`batch_size == self.batch_size`); on the eval path `PackedDataCollator` delegates to the un-packed `DefaultCollator`, so eval collation is unchanged by this PR.

Notes on the bit-identical claim:
- dtypes are preserved exactly: `int64` / `torch.long` for `sequences` / masks (incl. `action_mask`), `float32` / `torch.float` for `loss_mask` / `rewards` / `logprobs`. `dtype=np.int64` is pinned explicitly (NumPy's platform-default int is `int32` on Windows).
- The RL reward path accepts Python lists and `float32` reward tensors (what the reward postprocessing produces today). A `requires_grad`, CUDA, or `bfloat16` reward tensor is not accepted; no reward producer in the repo emits those.
- The `PackedDataCollator` loss-mask write window keeps the original `row_p < max_packed_len - 1` clamp. That `min()` is a defensive no-op — `max_packed_len` is `>=` every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.

## Benchmarks

Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):

| Path | Before | After | Speedup |
|------|--------|-------|---------|
| `PackedDataCollator` (FFD, dp=8) | 288.5 ms | 8.4 ms | ~34x |
| `convert_prompts_responses_to_batch_tensors` (RL) | 92.8 ms | 14.5 ms | ~6.4x |
| `collate_sft_batch` (unpacked SFT) | 98.4 ms | 16.2 ms | ~6x |

## Test plan

- [x] New `tests/train/test_collation_vectorization_equivalence.py`: pins a faithful reference of each *original* loop and fuzzes the vectorized output against it with `torch.equal` plus explicit per-tensor `dtype` assertions — RL (with/without logprobs, list and `float32`-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because `torch.equal` is dtype-insensitive on matching values, the integer/float dtypes (`action_mask`/`attention_mask` `int64`/`long`, `loss_mask`/`rewards` `float32`) are pinned with explicit `.dtype` assertions. The packed test re-derives the FFD / DP-shard / `max_packed_len` decision inline as its own oracle, so any production drift surfaces as a `torch.equal` mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable `loss_mask` clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)
- [x] Existing `tests/train/dataset/test_preprocess.py`, `tests/train/test_sft_packing_collate.py`, `tests/train/test_packing_round_trip.py`, `tests/train/test_sft_tokenization.py` pass unchanged.
- [x] `ruff` + `black` clean.

```bash
uv run --isolated --extra dev --extra megatron -- pytest \
  tests/train/test_collation_vectorization_equivalence.py \
  tests/train/dataset/test_preprocess.py \
  tests/train/test_sft_packing_collate.py \
  tests/train/test_packing_round_trip.py
```

> Heads-up for reviewers: this overlaps open PR NovaSky-AI#1752 ([train] VLM SFT on Megatron), which edits the same `collate_sft_batch` loop and `TrainingInputBatch` dict to collect `pixel_values` / `image_grid_thw`. Whichever lands second needs a small rebase; if NovaSky-AI#1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized `for i, ex in enumerate(examples):` loop and its two keys re-added to the `from_numpy` batch dict.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request vectorizes the collation and preprocessing pipelines in SkyRL by replacing per-token and per-sample Python loops with NumPy slice assignments and broadcast operations, followed by a single conversion to PyTorch tensors. It also adds a comprehensive test suite to verify bit-identical equivalence with the original implementations. The review feedback highlights potential runtime errors when inputs like rewards, logprobs, or loss_masks are PyTorch tensors (especially if they reside on GPU or require gradients), and suggests defensively detaching and moving them to CPU before converting to NumPy arrays.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 164 to +166
for i, custom_reward in enumerate(rewards):
if isinstance(custom_reward, list):
custom_reward = torch.tensor(custom_reward)
ret_rewards[i, max_response - len(custom_reward) :] = custom_reward
reward_arr = np.asarray(custom_reward, dtype=np.float32)
ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In convert_prompts_responses_to_batch_tensors, the rewards input can contain PyTorch tensors (e.g., from a reward model running on GPU or requiring gradients). Converting them directly via np.asarray will raise a TypeError if they are on GPU, or a RuntimeError if they require gradients.

To ensure robustness and prevent runtime crashes, we should defensively check if each reward is a PyTorch tensor, and if so, detach it and move it to CPU before converting to a NumPy array.

Suggested change
for i, custom_reward in enumerate(rewards):
if isinstance(custom_reward, list):
custom_reward = torch.tensor(custom_reward)
ret_rewards[i, max_response - len(custom_reward) :] = custom_reward
reward_arr = np.asarray(custom_reward, dtype=np.float32)
ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr
for i, custom_reward in enumerate(rewards):
if isinstance(custom_reward, torch.Tensor):
reward_arr = custom_reward.detach().cpu().numpy().astype(np.float32)
else:
reward_arr = np.asarray(custom_reward, dtype=np.float32)
ret_rewards_np[i, max_response - reward_arr.shape[0] :] = reward_arr

Comment on lines 175 to +176
for i, sample_logprobs in enumerate(logprobs):
lp = torch.tensor(sample_logprobs, dtype=torch.float)
logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp
logprobs_np[i, max_response - len(sample_logprobs) :] = sample_logprobs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly to the rewards handling, logprobs are typically generated by the policy model on GPU and are represented as PyTorch tensors. Assigning a PyTorch GPU tensor directly to a NumPy array slice will raise a TypeError.

We should defensively handle PyTorch tensors for logprobs by detaching them and moving them to CPU before converting to NumPy.

Suggested change
for i, sample_logprobs in enumerate(logprobs):
lp = torch.tensor(sample_logprobs, dtype=torch.float)
logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp
logprobs_np[i, max_response - len(sample_logprobs) :] = sample_logprobs
for i, sample_logprobs in enumerate(logprobs):
if isinstance(sample_logprobs, torch.Tensor):
lp_arr = sample_logprobs.detach().cpu().numpy().astype(np.float32)
else:
lp_arr = np.asarray(sample_logprobs, dtype=np.float32)
logprobs_np[i, max_response - lp_arr.shape[0] :] = lp_arr

Comment on lines 158 to +159
for i, lm in enumerate(loss_masks):
ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float)
ret_loss_masks_np[i, max_response - len(lm) :] = lm

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and robustness, we should defensively handle the case where loss_masks elements are PyTorch tensors (e.g., if they are loaded or processed as tensors). Detaching and moving them to CPU before converting to NumPy prevents potential runtime errors.

Suggested change
for i, lm in enumerate(loss_masks):
ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float)
ret_loss_masks_np[i, max_response - len(lm) :] = lm
for i, lm in enumerate(loss_masks):
if isinstance(lm, torch.Tensor):
lm_arr = lm.detach().cpu().numpy().astype(np.float32)
else:
lm_arr = np.asarray(lm, dtype=np.float32)
ret_loss_masks_np[i, max_response - lm_arr.shape[0] :] = lm_arr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant