Skip to content

[megatron] Stream ChunkedDistributedLogprob.backward into a preallocated buffer (lower peak memory)#1806

Open
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:perf/streaming-chunked-logprob-backward
Open

[megatron] Stream ChunkedDistributedLogprob.backward into a preallocated buffer (lower peak memory)#1806
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:perf/streaming-chunked-logprob-backward

Conversation

@dyurk-lila

Copy link
Copy Markdown

Summary

Refactor ChunkedDistributedLogprob.backward (the vocab-parallel chunked-logprob autograd function used by the Megatron worker for SFT cross-entropy and RL policy/ref losses) to stream each chunk's gradient into a single preallocated fp32 buffer instead of appending to a Python list and concatenating with torch.cat at the end. This lowers peak activation memory on the chunked backward path with no change to numerics.

What changed

skyrl/backends/skyrl_train/distributed/megatron/model_utils.py, ChunkedDistributedLogprob.backward:

  • Removed all_grad_input = [] and the final grad_input = torch.cat(all_grad_input, dim=1).
  • Preallocate once before the loop:
    grad_input = torch.empty((batch_size, seq_size, partition_vocab_size), dtype=torch.float32, device=vocab_parallel_logits.device).
  • Renamed the per-chunk grad to chunk_grad_input and, after the scatter_add_, write it into its sequence slice: grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input.

The change is unconditional (no flag) because it is numerically byte-identical. It only engages on the chunked dispatch path (i.e. when chunk_size < seq_len_local). The non-chunked DistributedLogprob, forward(), and the vendored Triton fused-LCE path are untouched. forward() retains the list+cat form intentionally: its accumulator holds per-chunk [batch_size, chunk_len] log-prob tensors (tiny), not the [batch_size, chunk_len, V//TP] fp32 grads that make the backward cat expensive, so streaming it would not meaningfully lower peak.

The old list-then-cat form kept every per-chunk [B, chunk_len, V//TP] fp32 grad alive and allocated the full concatenated output at the cat moment, so peak was ~2x the full [B, seq, V//TP] fp32 grad. Streaming drops peak to full-buffer + one live chunk = ~(1 + 1/num_chunks)x of the full grad. The win scales with chunk count; it is not a flat halving.

Numerical equivalence / safety

Byte-identical by construction:

  • The per-chunk math is unchanged: same _compute_distributed_log_softmax on the same fp32 slice, same .exp(), same neg_/mul_/scatter_add_ formulation. chunk_grad_input is a fresh fp32 tensor, so the slice write is a same-dtype copy with no cast.
  • The chunks tile [0, seq_size) exactly and contiguously: chunk i covers [i*chunk_size, min(seq_size, (i+1)*chunk_size)), consecutive chunks meet with no gap/overlap, and num_chunks = ceil(seq_size/chunk_size). torch.cat(dim=1) placed chunk i's columns at those same [chunk_start:chunk_end] offsets, so values land at identical positions and each slice is written exactly once.
  • The new buffer is contiguous fp32 of shape [B, seq_size, V//TP] — identical shape/dtype/contiguity to the previous torch.cat output — and is fully overwritten by the full-coverage tiling, so no uninitialized torch.empty memory survives. Full-coverage tiling is a load-bearing invariant for the torch.empty buffer (a partial write would leak garbage silently); the prime-length coverage test below is the regression guard.

A deliberate choice of a separate fp32 buffer (rather than writing in place into the autograd-saved vocab_parallel_logits): the separate buffer keeps the gradient in fp32 regardless of the saved logits' dtype, does not mutate an autograd-saved tensor (avoiding version-counter / double-backward hazards), and trades only ~(1/num_chunks)x extra memory for that safety.

Test plan

Written to SkyRL CI conventions; python -m py_compile passes and lint is clean (ruff: all checks passed; black --line-length 120: unchanged) on all three files.

  • tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py (CPU lane): stubs megatron.core.parallel_state into sys.modules via a module-scoped autouse save/restore fixture (mirroring test_preprocess_packed_seqs_cp.py), uses a gloo world_size=1 TP group (every all_reduce is the identity), and asserts torch.equal between the chunked-streamed grad and the single-shot DistributedLogprob grad across chunk_size in {1,3,7,16,32,64} x with/without OOV targets, plus edge cases (seq_len=1, all-in/all-out mask, a prime/ragged tiny-vocab config). This is the byte-identity gate for the storage refactor: the log-softmax + scatter-add backward math is purely per-position (reductions only over the vocab dim), so chunk boundaries cannot change any value, and the world_size=1 torch.equal fully validates the device/dtype-agnostic slice-write. The prime-length (seq_len=17, chunk_size=5) coverage test additionally pins that every sequence slice of the preallocated buffer is written (an unwritten torch.empty slice cannot coincidentally match the reference). Collected by the SkyRL-Train-CPU lane (pytest tests/backends/skyrl_train/ --ignore=.../gpu).

    Run with:

    uv run --isolated --extra dev -- pytest -s \
      tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py
    
  • tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py (@pytest.mark.megatron): spawns TP NCCL ranks (parametrized tp_size in {2, 4}, guarded by a device_count() skip), runs mpu.initialize_model_parallel, shards the vocab across ranks, runs fwd+bwd on each rank's slice, and asserts (a) the rank vocab slices tile [0, vocab) exactly once via torch.equal on a coverage counter, and (b) each rank's local streamed grad matches the full-vocab single-process autograd reference columns within fp32 tolerance via torch.testing.assert_close(atol=1e-5, rtol=1e-4) — a tolerance is correct here because the cross-rank all-reduce reorders the fp32 reduction vs. the single-tensor reference (matching the existing GPU test's grad tolerance), while the no-overlap tiling check uses exact equality. Spawned ranks set the conftest-mandated NCCL env (NCCL_CUMEM_ENABLE=0, etc.) before init_process_group, since mp.spawn children do not inherit the runtime env set by the CI conftest.

    Run with (>=2 free GPUs; the tp_size=4 case is skipped unless 4 are present):

    uv run --isolated --extra dev --extra megatron -- \
      pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py
    

Scope & follow-ups

  • Backends covered: Megatron only, by nature — ChunkedDistributedLogprob is the vocab-parallel (TP/CP) chunked-logprob used on the Megatron worker, reached through the single forward_backward_mini_batch train-step chokepoint, so all Megatron-backed training pathways (SFT, sync/async/fully-async RL, full-context) inherit the optimization with no per-pathway edit. FSDP has no vocab-parallel chunked-logprob equivalent, so it is intentionally out of scope.
  • Scope: Limited to the single backward method; DistributedLogprob, forward(), and the fused-LCE path are unchanged.
  • Deferred: A direct streaming-vs-old-cat comparison at TP>1 was intentionally not added, since it would require shipping the pre-refactor cat-based backward as dead reference code. The byte-identity gate for the refactor itself lives on the CPU world_size=1 lane (torch.equal); the GPU lane validates the distributed math against the reference (with tolerance) plus exact no-overlap tiling.

Related PRs (merge ordering)

…ted buffer (lower peak memory)

## Summary

Refactor `ChunkedDistributedLogprob.backward` (the vocab-parallel chunked-logprob autograd function used by the Megatron worker for SFT cross-entropy and RL policy/ref losses) to stream each chunk's gradient into a single preallocated fp32 buffer instead of appending to a Python list and concatenating with `torch.cat` at the end. This lowers peak activation memory on the chunked backward path with **no change to numerics**.

## What changed

`skyrl/backends/skyrl_train/distributed/megatron/model_utils.py`, `ChunkedDistributedLogprob.backward`:

- Removed `all_grad_input = []` and the final `grad_input = torch.cat(all_grad_input, dim=1)`.
- Preallocate once before the loop:
  `grad_input = torch.empty((batch_size, seq_size, partition_vocab_size), dtype=torch.float32, device=vocab_parallel_logits.device)`.
- Renamed the per-chunk grad to `chunk_grad_input` and, after the `scatter_add_`, write it into its sequence slice: `grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input`.

The change is **unconditional** (no flag) because it is numerically byte-identical. It only engages on the chunked dispatch path (i.e. when `chunk_size < seq_len_local`). The non-chunked `DistributedLogprob`, `forward()`, and the vendored Triton fused-LCE path are untouched. `forward()` retains the list+`cat` form intentionally: its accumulator holds per-chunk `[batch_size, chunk_len]` log-prob tensors (tiny), not the `[batch_size, chunk_len, V//TP]` fp32 grads that make the backward `cat` expensive, so streaming it would not meaningfully lower peak.

The old list-then-`cat` form kept every per-chunk `[B, chunk_len, V//TP]` fp32 grad alive **and** allocated the full concatenated output at the cat moment, so peak was ~2x the full `[B, seq, V//TP]` fp32 grad. Streaming drops peak to full-buffer + one live chunk = ~`(1 + 1/num_chunks)`x of the full grad. The win scales with chunk count; it is **not** a flat halving.

## Numerical equivalence / safety

Byte-identical by construction:

- The per-chunk math is unchanged: same `_compute_distributed_log_softmax` on the same fp32 slice, same `.exp()`, same `neg_`/`mul_`/`scatter_add_` formulation. `chunk_grad_input` is a fresh fp32 tensor, so the slice write is a same-dtype copy with no cast.
- The chunks tile `[0, seq_size)` exactly and contiguously: chunk `i` covers `[i*chunk_size, min(seq_size, (i+1)*chunk_size))`, consecutive chunks meet with no gap/overlap, and `num_chunks = ceil(seq_size/chunk_size)`. `torch.cat(dim=1)` placed chunk `i`'s columns at those same `[chunk_start:chunk_end]` offsets, so values land at identical positions and each slice is written exactly once.
- The new buffer is contiguous fp32 of shape `[B, seq_size, V//TP]` — identical shape/dtype/contiguity to the previous `torch.cat` output — and is fully overwritten by the full-coverage tiling, so no uninitialized `torch.empty` memory survives. Full-coverage tiling is a load-bearing invariant for the `torch.empty` buffer (a partial write would leak garbage silently); the prime-length coverage test below is the regression guard.

A deliberate choice of a **separate fp32 buffer** (rather than writing in place into the autograd-saved `vocab_parallel_logits`): the separate buffer keeps the gradient in fp32 regardless of the saved logits' dtype, does not mutate an autograd-saved tensor (avoiding version-counter / double-backward hazards), and trades only ~`(1/num_chunks)`x extra memory for that safety.

## Test plan

> Written to SkyRL CI conventions; `python -m py_compile` passes and lint is clean (`ruff`: all checks passed; `black --line-length 120`: unchanged) on all three files.

- `tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py` (CPU lane): stubs `megatron.core.parallel_state` into `sys.modules` via a module-scoped autouse save/restore fixture (mirroring `test_preprocess_packed_seqs_cp.py`), uses a gloo `world_size=1` TP group (every `all_reduce` is the identity), and asserts `torch.equal` between the chunked-streamed grad and the single-shot `DistributedLogprob` grad across `chunk_size in {1,3,7,16,32,64}` x with/without OOV targets, plus edge cases (seq_len=1, all-in/all-out mask, a prime/ragged tiny-vocab config). This is the **byte-identity gate for the storage refactor**: the log-softmax + scatter-add backward math is purely per-position (reductions only over the vocab dim), so chunk boundaries cannot change any value, and the world_size=1 `torch.equal` fully validates the device/dtype-agnostic slice-write. The prime-length (`seq_len=17`, `chunk_size=5`) coverage test additionally pins that every sequence slice of the preallocated buffer is written (an unwritten `torch.empty` slice cannot coincidentally match the reference). Collected by the SkyRL-Train-CPU lane (`pytest tests/backends/skyrl_train/ --ignore=.../gpu`).

  Run with:
  ```
  uv run --isolated --extra dev -- pytest -s \
    tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py
  ```

- `tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py` (`@pytest.mark.megatron`): spawns TP NCCL ranks (parametrized `tp_size in {2, 4}`, guarded by a `device_count()` skip), runs `mpu.initialize_model_parallel`, shards the vocab across ranks, runs fwd+bwd on each rank's slice, and asserts (a) the rank vocab slices tile `[0, vocab)` exactly once via `torch.equal` on a coverage counter, and (b) each rank's local streamed grad matches the full-vocab single-process autograd reference columns **within fp32 tolerance** via `torch.testing.assert_close(atol=1e-5, rtol=1e-4)` — a tolerance is correct here because the cross-rank all-reduce reorders the fp32 reduction vs. the single-tensor reference (matching the existing GPU test's grad tolerance), while the no-overlap tiling check uses exact equality. Spawned ranks set the conftest-mandated NCCL env (`NCCL_CUMEM_ENABLE=0`, etc.) before `init_process_group`, since `mp.spawn` children do not inherit the runtime env set by the CI conftest.

  Run with (>=2 free GPUs; the `tp_size=4` case is skipped unless 4 are present):
  ```
  uv run --isolated --extra dev --extra megatron -- \
    pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py
  ```

## Scope & follow-ups

- **Backends covered:** Megatron only, by nature — `ChunkedDistributedLogprob` is the vocab-parallel (TP/CP) chunked-logprob used on the Megatron worker, reached through the single `forward_backward_mini_batch` train-step chokepoint, so all Megatron-backed training pathways (SFT, sync/async/fully-async RL, full-context) inherit the optimization with no per-pathway edit. FSDP has no vocab-parallel chunked-logprob equivalent, so it is intentionally out of scope.
- **Scope:** Limited to the single `backward` method; `DistributedLogprob`, `forward()`, and the fused-LCE path are unchanged.
- **Deferred:** A direct streaming-vs-old-`cat` comparison at TP>1 was intentionally not added, since it would require shipping the pre-refactor `cat`-based backward as dead reference code. The byte-identity gate for the refactor itself lives on the CPU `world_size=1` lane (`torch.equal`); the GPU lane validates the distributed math against the reference (with tolerance) plus exact no-overlap tiling.

## Related PRs (merge ordering)

- **NovaSky-AI#1765** (fused LM-head log-prob + entropy) edits the same `ChunkedDistributedLogprob.backward` and **keeps** the `all_grad_input = []` / `torch.cat` form. It refactors the per-chunk chosen-token scatter-add into a shared `_add_chosen_token_grad` helper — touching the exact lines this PR renames to `chunk_grad_input` and streams — so the two **will textually collide on this hunk**. NovaSky-AI#1765 does **not** subsume this PR's peak-memory win (it preserves the list+`cat`), and NovaSky-AI#1765 is otherwise complementary in intent. Whichever lands first forces a rebase of the other. **Recommended ordering:** land NovaSky-AI#1765 first, then rebase this streaming change on top of it (the preallocated-buffer write replaces the `append`+`cat` NovaSky-AI#1765 keeps); if the two are reviewed together they can be folded into one change.
- **NovaSky-AI#1543** (WIP) independently removes the same `torch.cat` 2x-peak, but does so **in place** into the autograd-saved logits with no extra buffer. This PR deliberately does **not** take that approach: the in-place variant downcasts the fp32 grad to the saved logits' dtype (breaking byte-identity) and mutates an autograd-saved tensor (the version-counter / double-backward hazard). This PR's separate fp32 buffer preserves numeric fidelity and autograd-safety at the cost of ~`(1/num_chunks)`x extra memory. NovaSky-AI#1543 is also built on pre-merge code and currently conflicts with main, so it is not a viable supersession.
@dyurk-lila dyurk-lila marked this pull request as ready for review June 18, 2026 16:34

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors ChunkedDistributedLogprob.backward to stream each chunk's gradient directly into a single preallocated buffer instead of accumulating them in a list and concatenating them, which significantly reduces peak memory usage. It also adds comprehensive CPU and GPU TP>1 parity tests to verify the correctness of the streamed-buffer implementation. The reviewer suggested a safer teardown mechanism in the tp_group test fixture to avoid unconditionally destroying the distributed process group if it was initialized outside of the fixture.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +80 to +96
@pytest.fixture(scope="module")
def tp_group():
"""Single-rank TP process group shared by both autograd functions.

Uses the gloo backend because the world size is 1, so every ``all_reduce``
inside ``_compute_distributed_log_softmax`` is the identity. This isolates
the streamed-buffer refactor from the (separately tested) TP reduction.
"""
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
yield dist.group.WORLD
if dist.is_initialized():
dist.destroy_process_group()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tp_group fixture unconditionally calls dist.destroy_process_group() during teardown. If the default process group was already initialized before this test module ran (e.g., by another test in the same session), destroying it here will interfere with subsequent tests that rely on it. It is safer to only destroy the process group if it was initialized by this fixture.

Suggested change
@pytest.fixture(scope="module")
def tp_group():
"""Single-rank TP process group shared by both autograd functions.
Uses the gloo backend because the world size is 1, so every ``all_reduce``
inside ``_compute_distributed_log_softmax`` is the identity. This isolates
the streamed-buffer refactor from the (separately tested) TP reduction.
"""
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
yield dist.group.WORLD
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture(scope="module")
def tp_group():
"""Single-rank TP process group shared by both autograd functions.
Uses the gloo backend because the world size is 1, so every ``all_reduce``
inside ``_compute_distributed_log_softmax`` is the identity. This isolates
the streamed-buffer refactor from the (separately tested) TP reduction.
"""
initialized_here = False
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
initialized_here = True
yield dist.group.WORLD
if initialized_here and dist.is_initialized():
dist.destroy_process_group()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant