[megatron] Stream ChunkedDistributedLogprob.backward into a preallocated buffer (lower peak memory)#1806
Conversation
…ted buffer (lower peak memory)
## Summary
Refactor `ChunkedDistributedLogprob.backward` (the vocab-parallel chunked-logprob autograd function used by the Megatron worker for SFT cross-entropy and RL policy/ref losses) to stream each chunk's gradient into a single preallocated fp32 buffer instead of appending to a Python list and concatenating with `torch.cat` at the end. This lowers peak activation memory on the chunked backward path with **no change to numerics**.
## What changed
`skyrl/backends/skyrl_train/distributed/megatron/model_utils.py`, `ChunkedDistributedLogprob.backward`:
- Removed `all_grad_input = []` and the final `grad_input = torch.cat(all_grad_input, dim=1)`.
- Preallocate once before the loop:
`grad_input = torch.empty((batch_size, seq_size, partition_vocab_size), dtype=torch.float32, device=vocab_parallel_logits.device)`.
- Renamed the per-chunk grad to `chunk_grad_input` and, after the `scatter_add_`, write it into its sequence slice: `grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input`.
The change is **unconditional** (no flag) because it is numerically byte-identical. It only engages on the chunked dispatch path (i.e. when `chunk_size < seq_len_local`). The non-chunked `DistributedLogprob`, `forward()`, and the vendored Triton fused-LCE path are untouched. `forward()` retains the list+`cat` form intentionally: its accumulator holds per-chunk `[batch_size, chunk_len]` log-prob tensors (tiny), not the `[batch_size, chunk_len, V//TP]` fp32 grads that make the backward `cat` expensive, so streaming it would not meaningfully lower peak.
The old list-then-`cat` form kept every per-chunk `[B, chunk_len, V//TP]` fp32 grad alive **and** allocated the full concatenated output at the cat moment, so peak was ~2x the full `[B, seq, V//TP]` fp32 grad. Streaming drops peak to full-buffer + one live chunk = ~`(1 + 1/num_chunks)`x of the full grad. The win scales with chunk count; it is **not** a flat halving.
## Numerical equivalence / safety
Byte-identical by construction:
- The per-chunk math is unchanged: same `_compute_distributed_log_softmax` on the same fp32 slice, same `.exp()`, same `neg_`/`mul_`/`scatter_add_` formulation. `chunk_grad_input` is a fresh fp32 tensor, so the slice write is a same-dtype copy with no cast.
- The chunks tile `[0, seq_size)` exactly and contiguously: chunk `i` covers `[i*chunk_size, min(seq_size, (i+1)*chunk_size))`, consecutive chunks meet with no gap/overlap, and `num_chunks = ceil(seq_size/chunk_size)`. `torch.cat(dim=1)` placed chunk `i`'s columns at those same `[chunk_start:chunk_end]` offsets, so values land at identical positions and each slice is written exactly once.
- The new buffer is contiguous fp32 of shape `[B, seq_size, V//TP]` — identical shape/dtype/contiguity to the previous `torch.cat` output — and is fully overwritten by the full-coverage tiling, so no uninitialized `torch.empty` memory survives. Full-coverage tiling is a load-bearing invariant for the `torch.empty` buffer (a partial write would leak garbage silently); the prime-length coverage test below is the regression guard.
A deliberate choice of a **separate fp32 buffer** (rather than writing in place into the autograd-saved `vocab_parallel_logits`): the separate buffer keeps the gradient in fp32 regardless of the saved logits' dtype, does not mutate an autograd-saved tensor (avoiding version-counter / double-backward hazards), and trades only ~`(1/num_chunks)`x extra memory for that safety.
## Test plan
> Written to SkyRL CI conventions; `python -m py_compile` passes and lint is clean (`ruff`: all checks passed; `black --line-length 120`: unchanged) on all three files.
- `tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py` (CPU lane): stubs `megatron.core.parallel_state` into `sys.modules` via a module-scoped autouse save/restore fixture (mirroring `test_preprocess_packed_seqs_cp.py`), uses a gloo `world_size=1` TP group (every `all_reduce` is the identity), and asserts `torch.equal` between the chunked-streamed grad and the single-shot `DistributedLogprob` grad across `chunk_size in {1,3,7,16,32,64}` x with/without OOV targets, plus edge cases (seq_len=1, all-in/all-out mask, a prime/ragged tiny-vocab config). This is the **byte-identity gate for the storage refactor**: the log-softmax + scatter-add backward math is purely per-position (reductions only over the vocab dim), so chunk boundaries cannot change any value, and the world_size=1 `torch.equal` fully validates the device/dtype-agnostic slice-write. The prime-length (`seq_len=17`, `chunk_size=5`) coverage test additionally pins that every sequence slice of the preallocated buffer is written (an unwritten `torch.empty` slice cannot coincidentally match the reference). Collected by the SkyRL-Train-CPU lane (`pytest tests/backends/skyrl_train/ --ignore=.../gpu`).
Run with:
```
uv run --isolated --extra dev -- pytest -s \
tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py
```
- `tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py` (`@pytest.mark.megatron`): spawns TP NCCL ranks (parametrized `tp_size in {2, 4}`, guarded by a `device_count()` skip), runs `mpu.initialize_model_parallel`, shards the vocab across ranks, runs fwd+bwd on each rank's slice, and asserts (a) the rank vocab slices tile `[0, vocab)` exactly once via `torch.equal` on a coverage counter, and (b) each rank's local streamed grad matches the full-vocab single-process autograd reference columns **within fp32 tolerance** via `torch.testing.assert_close(atol=1e-5, rtol=1e-4)` — a tolerance is correct here because the cross-rank all-reduce reorders the fp32 reduction vs. the single-tensor reference (matching the existing GPU test's grad tolerance), while the no-overlap tiling check uses exact equality. Spawned ranks set the conftest-mandated NCCL env (`NCCL_CUMEM_ENABLE=0`, etc.) before `init_process_group`, since `mp.spawn` children do not inherit the runtime env set by the CI conftest.
Run with (>=2 free GPUs; the `tp_size=4` case is skipped unless 4 are present):
```
uv run --isolated --extra dev --extra megatron -- \
pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py
```
## Scope & follow-ups
- **Backends covered:** Megatron only, by nature — `ChunkedDistributedLogprob` is the vocab-parallel (TP/CP) chunked-logprob used on the Megatron worker, reached through the single `forward_backward_mini_batch` train-step chokepoint, so all Megatron-backed training pathways (SFT, sync/async/fully-async RL, full-context) inherit the optimization with no per-pathway edit. FSDP has no vocab-parallel chunked-logprob equivalent, so it is intentionally out of scope.
- **Scope:** Limited to the single `backward` method; `DistributedLogprob`, `forward()`, and the fused-LCE path are unchanged.
- **Deferred:** A direct streaming-vs-old-`cat` comparison at TP>1 was intentionally not added, since it would require shipping the pre-refactor `cat`-based backward as dead reference code. The byte-identity gate for the refactor itself lives on the CPU `world_size=1` lane (`torch.equal`); the GPU lane validates the distributed math against the reference (with tolerance) plus exact no-overlap tiling.
## Related PRs (merge ordering)
- **NovaSky-AI#1765** (fused LM-head log-prob + entropy) edits the same `ChunkedDistributedLogprob.backward` and **keeps** the `all_grad_input = []` / `torch.cat` form. It refactors the per-chunk chosen-token scatter-add into a shared `_add_chosen_token_grad` helper — touching the exact lines this PR renames to `chunk_grad_input` and streams — so the two **will textually collide on this hunk**. NovaSky-AI#1765 does **not** subsume this PR's peak-memory win (it preserves the list+`cat`), and NovaSky-AI#1765 is otherwise complementary in intent. Whichever lands first forces a rebase of the other. **Recommended ordering:** land NovaSky-AI#1765 first, then rebase this streaming change on top of it (the preallocated-buffer write replaces the `append`+`cat` NovaSky-AI#1765 keeps); if the two are reviewed together they can be folded into one change.
- **NovaSky-AI#1543** (WIP) independently removes the same `torch.cat` 2x-peak, but does so **in place** into the autograd-saved logits with no extra buffer. This PR deliberately does **not** take that approach: the in-place variant downcasts the fp32 grad to the saved logits' dtype (breaking byte-identity) and mutates an autograd-saved tensor (the version-counter / double-backward hazard). This PR's separate fp32 buffer preserves numeric fidelity and autograd-safety at the cost of ~`(1/num_chunks)`x extra memory. NovaSky-AI#1543 is also built on pre-merge code and currently conflicts with main, so it is not a viable supersession.
There was a problem hiding this comment.
Code Review
This pull request refactors ChunkedDistributedLogprob.backward to stream each chunk's gradient directly into a single preallocated buffer instead of accumulating them in a list and concatenating them, which significantly reduces peak memory usage. It also adds comprehensive CPU and GPU TP>1 parity tests to verify the correctness of the streamed-buffer implementation. The reviewer suggested a safer teardown mechanism in the tp_group test fixture to avoid unconditionally destroying the distributed process group if it was initialized outside of the fixture.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| @pytest.fixture(scope="module") | ||
| def tp_group(): | ||
| """Single-rank TP process group shared by both autograd functions. | ||
|
|
||
| Uses the gloo backend because the world size is 1, so every ``all_reduce`` | ||
| inside ``_compute_distributed_log_softmax`` is the identity. This isolates | ||
| the streamed-buffer refactor from the (separately tested) TP reduction. | ||
| """ | ||
| if not dist.is_initialized(): | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = str(get_free_port()) | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | ||
| yield dist.group.WORLD | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() |
There was a problem hiding this comment.
The tp_group fixture unconditionally calls dist.destroy_process_group() during teardown. If the default process group was already initialized before this test module ran (e.g., by another test in the same session), destroying it here will interfere with subsequent tests that rely on it. It is safer to only destroy the process group if it was initialized by this fixture.
| @pytest.fixture(scope="module") | |
| def tp_group(): | |
| """Single-rank TP process group shared by both autograd functions. | |
| Uses the gloo backend because the world size is 1, so every ``all_reduce`` | |
| inside ``_compute_distributed_log_softmax`` is the identity. This isolates | |
| the streamed-buffer refactor from the (separately tested) TP reduction. | |
| """ | |
| if not dist.is_initialized(): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = str(get_free_port()) | |
| os.environ["RANK"] = "0" | |
| os.environ["WORLD_SIZE"] = "1" | |
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | |
| yield dist.group.WORLD | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| @pytest.fixture(scope="module") | |
| def tp_group(): | |
| """Single-rank TP process group shared by both autograd functions. | |
| Uses the gloo backend because the world size is 1, so every ``all_reduce`` | |
| inside ``_compute_distributed_log_softmax`` is the identity. This isolates | |
| the streamed-buffer refactor from the (separately tested) TP reduction. | |
| """ | |
| initialized_here = False | |
| if not dist.is_initialized(): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = str(get_free_port()) | |
| os.environ["RANK"] = "0" | |
| os.environ["WORLD_SIZE"] = "1" | |
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | |
| initialized_here = True | |
| yield dist.group.WORLD | |
| if initialized_here and dist.is_initialized(): | |
| dist.destroy_process_group() |
Summary
Refactor
ChunkedDistributedLogprob.backward(the vocab-parallel chunked-logprob autograd function used by the Megatron worker for SFT cross-entropy and RL policy/ref losses) to stream each chunk's gradient into a single preallocated fp32 buffer instead of appending to a Python list and concatenating withtorch.catat the end. This lowers peak activation memory on the chunked backward path with no change to numerics.What changed
skyrl/backends/skyrl_train/distributed/megatron/model_utils.py,ChunkedDistributedLogprob.backward:all_grad_input = []and the finalgrad_input = torch.cat(all_grad_input, dim=1).grad_input = torch.empty((batch_size, seq_size, partition_vocab_size), dtype=torch.float32, device=vocab_parallel_logits.device).chunk_grad_inputand, after thescatter_add_, write it into its sequence slice:grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input.The change is unconditional (no flag) because it is numerically byte-identical. It only engages on the chunked dispatch path (i.e. when
chunk_size < seq_len_local). The non-chunkedDistributedLogprob,forward(), and the vendored Triton fused-LCE path are untouched.forward()retains the list+catform intentionally: its accumulator holds per-chunk[batch_size, chunk_len]log-prob tensors (tiny), not the[batch_size, chunk_len, V//TP]fp32 grads that make the backwardcatexpensive, so streaming it would not meaningfully lower peak.The old list-then-
catform kept every per-chunk[B, chunk_len, V//TP]fp32 grad alive and allocated the full concatenated output at the cat moment, so peak was ~2x the full[B, seq, V//TP]fp32 grad. Streaming drops peak to full-buffer + one live chunk = ~(1 + 1/num_chunks)x of the full grad. The win scales with chunk count; it is not a flat halving.Numerical equivalence / safety
Byte-identical by construction:
_compute_distributed_log_softmaxon the same fp32 slice, same.exp(), sameneg_/mul_/scatter_add_formulation.chunk_grad_inputis a fresh fp32 tensor, so the slice write is a same-dtype copy with no cast.[0, seq_size)exactly and contiguously: chunkicovers[i*chunk_size, min(seq_size, (i+1)*chunk_size)), consecutive chunks meet with no gap/overlap, andnum_chunks = ceil(seq_size/chunk_size).torch.cat(dim=1)placed chunki's columns at those same[chunk_start:chunk_end]offsets, so values land at identical positions and each slice is written exactly once.[B, seq_size, V//TP]— identical shape/dtype/contiguity to the previoustorch.catoutput — and is fully overwritten by the full-coverage tiling, so no uninitializedtorch.emptymemory survives. Full-coverage tiling is a load-bearing invariant for thetorch.emptybuffer (a partial write would leak garbage silently); the prime-length coverage test below is the regression guard.A deliberate choice of a separate fp32 buffer (rather than writing in place into the autograd-saved
vocab_parallel_logits): the separate buffer keeps the gradient in fp32 regardless of the saved logits' dtype, does not mutate an autograd-saved tensor (avoiding version-counter / double-backward hazards), and trades only ~(1/num_chunks)x extra memory for that safety.Test plan
tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py(CPU lane): stubsmegatron.core.parallel_stateintosys.modulesvia a module-scoped autouse save/restore fixture (mirroringtest_preprocess_packed_seqs_cp.py), uses a glooworld_size=1TP group (everyall_reduceis the identity), and assertstorch.equalbetween the chunked-streamed grad and the single-shotDistributedLogprobgrad acrosschunk_size in {1,3,7,16,32,64}x with/without OOV targets, plus edge cases (seq_len=1, all-in/all-out mask, a prime/ragged tiny-vocab config). This is the byte-identity gate for the storage refactor: the log-softmax + scatter-add backward math is purely per-position (reductions only over the vocab dim), so chunk boundaries cannot change any value, and the world_size=1torch.equalfully validates the device/dtype-agnostic slice-write. The prime-length (seq_len=17,chunk_size=5) coverage test additionally pins that every sequence slice of the preallocated buffer is written (an unwrittentorch.emptyslice cannot coincidentally match the reference). Collected by the SkyRL-Train-CPU lane (pytest tests/backends/skyrl_train/ --ignore=.../gpu).Run with:
tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py(@pytest.mark.megatron): spawns TP NCCL ranks (parametrizedtp_size in {2, 4}, guarded by adevice_count()skip), runsmpu.initialize_model_parallel, shards the vocab across ranks, runs fwd+bwd on each rank's slice, and asserts (a) the rank vocab slices tile[0, vocab)exactly once viatorch.equalon a coverage counter, and (b) each rank's local streamed grad matches the full-vocab single-process autograd reference columns within fp32 tolerance viatorch.testing.assert_close(atol=1e-5, rtol=1e-4)— a tolerance is correct here because the cross-rank all-reduce reorders the fp32 reduction vs. the single-tensor reference (matching the existing GPU test's grad tolerance), while the no-overlap tiling check uses exact equality. Spawned ranks set the conftest-mandated NCCL env (NCCL_CUMEM_ENABLE=0, etc.) beforeinit_process_group, sincemp.spawnchildren do not inherit the runtime env set by the CI conftest.Run with (>=2 free GPUs; the
tp_size=4case is skipped unless 4 are present):Scope & follow-ups
ChunkedDistributedLogprobis the vocab-parallel (TP/CP) chunked-logprob used on the Megatron worker, reached through the singleforward_backward_mini_batchtrain-step chokepoint, so all Megatron-backed training pathways (SFT, sync/async/fully-async RL, full-context) inherit the optimization with no per-pathway edit. FSDP has no vocab-parallel chunked-logprob equivalent, so it is intentionally out of scope.backwardmethod;DistributedLogprob,forward(), and the fused-LCE path are unchanged.catcomparison at TP>1 was intentionally not added, since it would require shipping the pre-refactorcat-based backward as dead reference code. The byte-identity gate for the refactor itself lives on the CPUworld_size=1lane (torch.equal); the GPU lane validates the distributed math against the reference (with tolerance) plus exact no-overlap tiling.Related PRs (merge ordering)
ChunkedDistributedLogprob.backwardand keeps theall_grad_input = []/torch.catform. It refactors the per-chunk chosen-token scatter-add into a shared_add_chosen_token_gradhelper — touching the exact lines this PR renames tochunk_grad_inputand streams — so the two will textually collide on this hunk. [megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization) #1765 does not subsume this PR's peak-memory win (it preserves the list+cat), and [megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization) #1765 is otherwise complementary in intent. Whichever lands first forces a rebase of the other. Recommended ordering: land [megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization) #1765 first, then rebase this streaming change on top of it (the preallocated-buffer write replaces theappend+cat[megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization) #1765 keeps); if the two are reviewed together they can be folded into one change.torch.cat2x-peak, but does so in place into the autograd-saved logits with no extra buffer. This PR deliberately does not take that approach: the in-place variant downcasts the fp32 grad to the saved logits' dtype (breaking byte-identity) and mutates an autograd-saved tensor (the version-counter / double-backward hazard). This PR's separate fp32 buffer preserves numeric fidelity and autograd-safety at the cost of ~(1/num_chunks)x extra memory. [WIP] Add changes needed for FP8 megatron training #1543 is also built on pre-merge code and currently conflicts with main, so it is not a viable supersession.