Skip to content

[train] VLM SFT support on Megatron backend (Qwen3-VL)#1752

Open
s-chundi wants to merge 9 commits into
NovaSky-AI:mainfrom
s-chundi:vlm-sft-megatron
Open

[train] VLM SFT support on Megatron backend (Qwen3-VL)#1752
s-chundi wants to merge 9 commits into
NovaSky-AI:mainfrom
s-chundi:vlm-sft-megatron

Conversation

@s-chundi

@s-chundi s-chundi commented Jun 4, 2026

Copy link
Copy Markdown

What

Adds vision-language-model (VLM) support to the SFT trainer on the Megatron backend, starting with Qwen3-VL. The Megatron path previously supported text-only SFT; the FSDP path already had VLM support. This PR brings the Megatron side up to parity for SFT by reusing the existing VLM data infrastructure (TensorList, Experience.pixel_values/image_grid_thw, replay-buffer image fields).

The vision tower runs on the first pipeline stage; pixel_values / image_grid_thw flow through the same TrainingInputBatchExperience path the FSDP backend already uses, and are passed into GPTModel.forward as Qwen-style kwargs.

Changes

  • skyrl/utils/tok.pycheck_is_vlm() and get_processor() helpers.
  • skyrl/train/sft_trainer.py — detect VLMs in setup(), build an HF processor, thread it through tokenize_chat_example / _tokenize_chat_last_assistant (now return_dict=True with a small _unbatch helper), emit image tensors as a TensorList in collate_sft_batch, force sequential tokenization for VLMs (the processor does not round-trip through the spawn worker pool), and disable sequence packing / microbatch-padding removal.
  • megatron_model_wrapper.pyis_vlm flag, VLM constraint asserts, and pixel_values/image_grid_thw passthrough in both forward and forward_backward_mini_batch.
  • megatron_worker.py — VLM detection off the HF config, plumb is_vlm into the wrapper, and forward image tensors through the micro-batch builders (policy forward/train + ref forward).
  • teststest_megatron_vlm_init.py (GPU: Megatron VLM forward vs HF reference, and a 5-step SFTTrainer train loop asserting loss decreases) and VLM cases in test_sft_tokenization.py (CPU: tokenization + collation).

(a) Constraints carried over from the FSDP VLM path

3D RoPE and multimodal token positions mean the following are unsupported and asserted/forced off:

  • sample / microbatch packing (remove_microbatch_padding, use_sequence_packing)
  • context parallelism (CP)
  • sequence parallelism (SP)
  • mixed text+image microbatches (batches are assumed homogeneously imaged)

(b) Parity / correctness

test_megatron_vlm_forward checks the Megatron VLM forward against a plain HuggingFace AutoModelForImageTextToText forward over the same tokens + images (finite, aligned log-probs). test_vlm_train runs 5 SFT steps and asserts the loss decreases.

(c) Parallelism combos exercised in tests

  • TP=2, PP=1 (forward parity)
  • TP=1, PP=2 (5-step train loop)

(d) Not supported

CP, SP, sample/microbatch packing, mixed text+image batches, MoE+VLM.

(e) End-to-end example / smoke test

A runnable example exercises the full data -> train path on a small public dataset, so the changes can be smoke-tested without internal data:

  • examples/train/sft/prepare_cauldron_vlm.py — converts HuggingFaceM4/the_cauldron (e.g. the ai2d subset) into the OpenAI messages format SFTTrainer expects, encoding images as base64 data URIs.
  • examples/train/sft/run_sft_megatron_vlm.sh — runs Megatron VLM SFT of Qwen/Qwen3-VL-2B-Instruct on that subset (DP=4 on 4 GPUs by default; auto-generates the dataset on first run). This validates tokenization, image-tensor collation, the first-pipeline-stage vision tower, and a real training loop end-to-end.

Quick smoke test:

uv run examples/train/sft/prepare_cauldron_vlm.py --output-dir $HOME/data/cauldron-vlm --num-rows 512
bash examples/train/sft/run_sft_megatron_vlm.sh num_steps=20 batch_size=8

s-chundi and others added 4 commits June 4, 2026 21:09
Adds vision-language-model support to the Megatron SFT path, mirroring the
existing FSDP VLM data plumbing. The vision tower runs on the first pipeline
stage; image tensors (pixel_values / image_grid_thw) flow through the same
TrainingInputBatch -> Experience path the FSDP backend already uses.

Changes:
- skyrl/utils/tok.py: add check_is_vlm and get_processor helpers.
- skyrl/train/sft_trainer.py: detect VLMs in setup(), build an HF processor,
  thread it through tokenize_chat_example / _tokenize_chat_last_assistant,
  emit image tensors as a TensorList in collate_sft_batch, force sequential
  tokenization for VLMs, and disable sequence packing / microbatch padding
  removal (unsupported for VLMs).
- megatron_model_wrapper.py: is_vlm flag, VLM constraint asserts (no packing,
  CP, or SP), and pixel_values/image_grid_thw passthrough in both forward and
  forward_backward_mini_batch.
- megatron_worker.py: VLM detection off the HF config, plumb is_vlm into the
  wrapper, and forward image tensors through the micro-batch builders.
- tests: GPU init/forward + 5-step train test (test_megatron_vlm_init.py),
  CPU tokenization/collation cases (test_sft_tokenization.py).

Constraints carried over from the FSDP VLM path: no sample/microbatch packing,
no context parallelism, no sequence parallelism, and homogeneous (all-image)
batches. MoE+VLM is out of scope.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@s-chundi s-chundi marked this pull request as ready for review June 10, 2026 15:53
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

_load_and_tokenize now reads self.is_vlm and self.processor, but
_build_trainer constructs the trainer via object.__new__ (bypassing
__init__). Set both on the bare trainer so the non-VLM tokenization
tests exercise the same path as before.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
s-chundi added 2 commits June 13, 2026 06:27
# Conflicts:
#	skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 16, 2026
…r does not consume them

## Summary

Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.

SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).

## What changed

- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.

The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.

`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

## Numerical equivalence / safety

Byte-identical for all existing callers:

- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.

## Test plan

- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag).
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

```bash
uv run --isolated --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.

## Generality & follow-ups

Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).

Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

## Relationship to open PRs

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1769** (qwen3.5 sequence packing) — disjoint Megatron region.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.

The gate logic itself is composable with all of these.
dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 16, 2026
# What does this PR do?

The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. **Outputs are bit-identical** (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.

## Changes

- **`PackedDataCollator` (Megatron SFT FFD packing), `skyrl/train/dataset/collators.py`** — the per-bin packed row tensors (`sequences` / `attention_mask` / `loss_mask`) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and `total_nonpad` is a single vectorized reduction.
- **`collate_sft_batch` / `DefaultCollator` (unpacked SFT), `skyrl/train/sft_trainer.py`** — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.
- **`convert_prompts_responses_to_batch_tensors` (RL), `skyrl/train/dataset/preprocess.py`** — the left-padded `sequences` are built with two slice copies per row, and the fixed-width `attention_mask` / `action_mask` / `loss_mask` / `rewards` / `logprobs` tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.

This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all `RayPPOTrainer` subclasses (sync / async / full-context / agentic), since none override `convert_to_training_input`.

### Intentionally out of scope

- **MoE router-replay (`rollout_expert_indices`).** The optional `rollout_expert_indices` branch in `convert_prompts_responses_to_batch_tensors` is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).
- **Eval path.** Packing only fires on the training-step batch (`batch_size == self.batch_size`); on the eval path `PackedDataCollator` delegates to the un-packed `DefaultCollator`, so eval collation is unchanged by this PR.

Notes on the bit-identical claim:
- dtypes are preserved exactly: `int64` / `torch.long` for `sequences` / masks (incl. `action_mask`), `float32` / `torch.float` for `loss_mask` / `rewards` / `logprobs`. `dtype=np.int64` is pinned explicitly (NumPy's platform-default int is `int32` on Windows).
- The RL reward path accepts Python lists and `float32` reward tensors (what the reward postprocessing produces today). A `requires_grad`, CUDA, or `bfloat16` reward tensor is not accepted; no reward producer in the repo emits those.
- The `PackedDataCollator` loss-mask write window keeps the original `row_p < max_packed_len - 1` clamp. That `min()` is a defensive no-op — `max_packed_len` is `>=` every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.

## Benchmarks

Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):

| Path | Before | After | Speedup |
|------|--------|-------|---------|
| `PackedDataCollator` (FFD, dp=8) | 288.5 ms | 8.4 ms | ~34x |
| `convert_prompts_responses_to_batch_tensors` (RL) | 92.8 ms | 14.5 ms | ~6.4x |
| `collate_sft_batch` (unpacked SFT) | 98.4 ms | 16.2 ms | ~6x |

## Test plan

- [x] New `tests/train/test_collation_vectorization_equivalence.py`: pins a faithful reference of each *original* loop and fuzzes the vectorized output against it with `torch.equal` plus explicit per-tensor `dtype` assertions — RL (with/without logprobs, list and `float32`-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because `torch.equal` is dtype-insensitive on matching values, the integer/float dtypes (`action_mask`/`attention_mask` `int64`/`long`, `loss_mask`/`rewards` `float32`) are pinned with explicit `.dtype` assertions. The packed test re-derives the FFD / DP-shard / `max_packed_len` decision inline as its own oracle, so any production drift surfaces as a `torch.equal` mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable `loss_mask` clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)
- [x] Existing `tests/train/dataset/test_preprocess.py`, `tests/train/test_sft_packing_collate.py`, `tests/train/test_packing_round_trip.py`, `tests/train/test_sft_tokenization.py` pass unchanged.
- [x] `ruff` + `black` clean.

```bash
uv run --isolated --extra dev --extra megatron -- pytest \
  tests/train/test_collation_vectorization_equivalence.py \
  tests/train/dataset/test_preprocess.py \
  tests/train/test_sft_packing_collate.py \
  tests/train/test_packing_round_trip.py
```

> Heads-up for reviewers: this overlaps open PR NovaSky-AI#1752 ([train] VLM SFT on Megatron), which edits the same `collate_sft_batch` loop and `TrainingInputBatch` dict to collect `pixel_values` / `image_grid_thw`. Whichever lands second needs a small rebase; if NovaSky-AI#1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized `for i, ex in enumerate(examples):` loop and its two keys re-added to the `from_numpy` batch dict.

🤖 Generated with [Claude Code](https://claude.com/claude-code)
@s-chundi

Copy link
Copy Markdown
Author

@SumanthRH Could I get a review here if you get the chance? I see you've reviewed multimodal contributions in the past. I'm concerned the changes might get stale, had to update a few times alr.

Apologies if this is not the correct way to request a review

dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 18, 2026
…r does not consume them

## Summary

Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.

SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).

## What changed

- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.

The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.

`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

## Numerical equivalence / safety

Byte-identical for all existing callers:

- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.

## Test plan

- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag).
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

```bash
uv run --isolated --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.

## Generality & follow-ups

Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).

Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

## Relationship to open PRs

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1769** (qwen3.5 sequence packing) — disjoint Megatron region.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.

The gate logic itself is composable with all of these.
dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 18, 2026
…r does not consume them

## Summary

Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.

SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).

## What changed

- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.

The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.

`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

## Numerical equivalence / safety

Byte-identical for all existing callers:

- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.

## Test plan

- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag).
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

```bash
uv run --isolated --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.

## Generality & follow-ups

Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).

Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

## Relationship to open PRs

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1769** (qwen3.5 sequence packing) — disjoint Megatron region.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.

The gate logic itself is composable with all of these.
dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 18, 2026
…r does not consume them

## Summary

Threads a per-request `return_per_token_outputs` flag (default `True`) carried on the already-plumbed `loss_fn_config` dict to gate the per-token `loss_fn_outputs` build on the `cross_entropy` branch of **both** backends and **both** the train and eval/forward paths. When the flag is `False`, the per-token NLL, the two detached `[mb, seq]` D2H copies (logprobs + elementwise loss), and the `.tolist()` loop are skipped; each sequence gets an empty dict instead. The `loss` / `response_length` metrics and the `loss_fn_output_type` tag are unchanged.

SkyRL's own `SFTTrainer` reads only `output.metrics` (loss / response_length), never `output.loss_fn_outputs`, so it now opts out — eliminating dead work on the SFT train + eval hot path. RL and Tinker callers pass no flag and keep the existing contract (default `True`).

## What changed

- **`worker_utils.py`**: new module-level helper `pop_return_per_token_outputs(loss_fn_config) -> (Optional[dict], bool)` that copies the dict before popping the flag (callers' dicts are never mutated) and returns the default `True` for `None`. Its docstring is the single authoritative source for the flag's contract and the why-popped-before-merge rationale.
- **Megatron** (`megatron/megatron_model_wrapper.py`): pop the flag once in `forward_backward_mini_batch` and gate the per-token build inside the shared `loss_func` (covers train + `forward_only` eval).
- **FSDP** (`workers/worker.py`): pop + gate in `_forward_backward_micro` (train) and `_forward_micro_with_loss` (eval).
- **Caller wiring** (`train/sft_trainer.py`): `train_step` (forward_backward) and `run_eval` (forward) now pass `loss_fn_config={"return_per_token_outputs": False}`.

The flag is popped from a per-call **copy** of `loss_fn_config` **before** the `AlgorithmConfig` merge — it is not an `AlgorithmConfig` field, so leaving it in would trip the key validation in `build_nested_dataclass` (reached via `from_dict_config`). The merge guard was changed from `if loss_fn_config is not None:` to `if loss_fn_config:` so an empty dict after the pop skips the merge exactly as today.

`return_per_token_outputs` is documented in the `pop_return_per_token_outputs` helper docstring and at each pop/gate site; the `loss_fn_config` docstrings on the three consuming worker methods that read the flag — `_forward_backward_micro`, `_forward_micro_with_loss` (FSDP), and `forward_backward_mini_batch` (Megatron) — also note the reserved key and its default. The per-build-site rationale comments were collapsed to a single line each (pointing at the helper docstring) to avoid hand-sync drift across the three near-identical sites.

## Numerical equivalence / safety

Byte-identical for all existing callers:

- When the flag is absent/`None`, the pop never runs and the full prior per-token build executes unchanged. Default `True` reproduces the exact prior code paths.
- The merge-guard change (`is not None` → truthy) is a strict no-op: no existing caller ever passed `{}`, a non-empty override dict is still truthy and still merges, and `OmegaConf.merge(base, {})` is itself a no-op.
- `loss`, the backward pass, and all consumed scalar metrics (`loss`, `response_length`, `lr`) are computed before/independent of the gated block, so they are identical whether per-token outputs are kept or skipped.
- `loss_fn_output_type` is a `WorkerOutput` field that always defaults to `"scalar"` (never set explicitly), so the type tag survives automatically — only the arrays become empty. The empty-dict-with-`"scalar"`-tag combination is only reachable behind the explicit opt-out whose sole caller ignores the payload.
- RL's separate (non-`cross_entropy`) `loss_fn_outputs` else-branch is untouched; the RL trainer and Tinker backend pass no flag, so their contracts hold. The Tinker public API whitelists `loss_fn_config` keys (empty allowed-key set for `cross_entropy`), so the flag cannot be injected by users; and even if present it is popped before any merge.

Note on the eval path: `run_eval` already iterates eval batches serially and reads only `output.metrics["loss"]`, so opting out there removes per-token work without changing any reported eval metric.

## Test plan

- `tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py` (new, CPU): drives the real FSDP `_forward_backward_micro` / `_forward_micro_with_loss` `cross_entropy` builds on CPU; asserts default/explicit-True populate `logprobs` + `elementwise_loss`, `False` yields empty dicts, `loss`/`response_length`/`lr` are identical across the flag, the flag is popped before the `AlgorithmConfig` merge (a real `eps_clip_low` override alongside the flag still merges without raising), and the caller dict is not mutated. Adds an RL-path test confirming the non-`cross_entropy` else-branch is ungated (logprobs still built; outputs + loss identical across the flag); that test disables `use_kl_loss`/`use_entropy_loss` explicitly (rather than relying on a default) so it isolates the gate from the KL/entropy terms. CPU run: 12 passed.
- `tests/backends/skyrl_train/workers/test_worker_utils.py` (extended): unit tests for `pop_return_per_token_outputs` — `None`→`(None, True)`, absent-flag→`True` with config preserved, explicit `False`/`True` popped leaving legitimate overrides, and no caller-dict mutation.
- `tests/train/test_sft_callbacks.py` (extended): assert `SFTTrainer.train_step` and `run_eval` pass `loss_fn="cross_entropy"` and `loss_fn_config={"return_per_token_outputs": False}` to the dispatch.
- `tests/backends/skyrl_train/gpu/gpu_ci/test_training_step.py` (extended, GPU): parametrized over FSDP + Megatron (the Megatron leg carries `@pytest.mark.megatron` like the sibling tests), DP=2; runs the real worker `forward_backward`/`forward` twice on the same dummy batch (flag default-True vs explicit-False) and asserts `loss` + `response_length` identical and `loss_fn_output_type == "scalar"` in both, with per-token outputs populated when kept vs empty when skipped.

Run locally:

```bash
uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_sft_loss_fn_outputs_gate.py tests/backends/skyrl_train/workers/test_worker_utils.py tests/train/test_sft_callbacks.py
```

The CPU suite for the new/extended tests runs automatically under the standard CPU CI job (`tests/backends/skyrl_train/` + `tests/train/`); confirm that check is green before merge.

## Generality & follow-ups

Covered: both backends (Megatron `loss_func`; FSDP micro methods) and both the train (`forward_backward`) and eval/forward (`forward(loss_fn="cross_entropy")`) paths. RL and Tinker contracts preserved (default `True`).

Intentionally out of scope:
- The RL `loss_fn_outputs` build (separate else-branch, not `cross_entropy`) is untouched; the `forward(loss_fn=None)` pure-inference path is untouched.
- The JAX backend builds `loss_fn_outputs` in a jit-traced path driven by a structured `LossFnConfig` dataclass rather than the runtime `loss_fn_config` dict, so the dict-borne flag does not reach it — a possible follow-up, not a regression.
- No config schema field is added — the flag rides the runtime `loss_fn_config` dict only, preferring a function-arg/flag with a safe default over a global switch.

Doc follow-up: the caller-facing dispatch docstrings `WorkerDispatch.forward` / `WorkerDispatch.forward_backward` (the API the SFT trainer / Tinker call into) are outside this PR's touched files and so are left unchanged here; a one-line note about the reserved key could be added there in a follow-up so future callers can discover the opt-out from the dispatch layer.

## Relationship to open PRs

This gate is orthogonal/complementary to several in-flight efforts touching the same files:

- **NovaSky-AI#1513** (SFT loss-aggregation rewrite of the same FSDP `cross_entropy` branch) — note it renames the SFT status key `loss`→`sft_loss`, so merge ordering matters; only the test coupling to the literal `loss` metric key would need a touch-up if it lands first.
- **NovaSky-AI#1752** (VLM SFT on Megatron) — disjoint regions in the shared files.
- **NovaSky-AI#1534** (preserve staged `forward_backward` loss_fn_outputs across DP ranks; `worker_dispatch.py`) — no overlapping file.

The gate logic itself is composable with all of these.
dyurk-lila added a commit to dyurk-lila/SkyRL that referenced this pull request Jun 18, 2026
# What does this PR do?

The controller builds every training batch on the main process before dispatching it to the workers. All three collation paths did so with per-token / per-sample Python loops that dominate the controller-side collate wall-time and serially block the GPU at large batch sizes. This PR replaces those loops with NumPy slice-assignments and broadcast comparisons. **Outputs are bit-identical** (same dtypes, same layout) for all inputs produced in practice — this is a pure CPU-side latency optimization, not a behavior change.

## Changes

- **`PackedDataCollator` (Megatron SFT FFD packing), `skyrl/train/dataset/collators.py`** — the per-bin packed row tensors (`sequences` / `attention_mask` / `loss_mask`) were built with a per-token Python loop over the reconstructed full loss mask. Each sub-seq is now written with one C-level copy, and `total_nonpad` is a single vectorized reduction.
- **`collate_sft_batch` / `DefaultCollator` (unpacked SFT), `skyrl/train/sft_trainer.py`** — each left-padded row is written with a single slice assignment into a preallocated array instead of building a per-example padded Python list.
- **`convert_prompts_responses_to_batch_tensors` (RL), `skyrl/train/dataset/preprocess.py`** — the left-padded `sequences` are built with two slice copies per row, and the fixed-width `attention_mask` / `action_mask` / `loss_mask` / `rewards` / `logprobs` tensors are produced with broadcast comparisons / slice writes instead of per-token Python loops.

This covers the SFT (packed + unpacked) and RL training-batch construction paths; the RL and SFT data paths are separate functions, so each is vectorized independently. The RL change is inherited unchanged by all `RayPPOTrainer` subclasses (sync / async / full-context / agentic), since none override `convert_to_training_input`.

### Intentionally out of scope

- **MoE router-replay (`rollout_expert_indices`).** The optional `rollout_expert_indices` branch in `convert_prompts_responses_to_batch_tensors` is left exactly as-is — only the dense per-token batch tensors are vectorized. That branch is byte-identical to the prior implementation (zero correctness/regression risk); it is a narrow MoE-only path, so its residual per-sample loop is left for a follow-up rather than folded into this CPU-latency change. The new equivalence suite therefore does not exercise it (the oracle compares the six dense outputs; the 7th return value is intentionally discarded).
- **Eval path.** Packing only fires on the training-step batch (`batch_size == self.batch_size`); on the eval path `PackedDataCollator` delegates to the un-packed `DefaultCollator`, so eval collation is unchanged by this PR.

Notes on the bit-identical claim:
- dtypes are preserved exactly: `int64` / `torch.long` for `sequences` / masks (incl. `action_mask`), `float32` / `torch.float` for `loss_mask` / `rewards` / `logprobs`. `dtype=np.int64` is pinned explicitly (NumPy's platform-default int is `int32` on Windows).
- The RL reward path accepts Python lists and `float32` reward tensors (what the reward postprocessing produces today). A `requires_grad`, CUDA, or `bfloat16` reward tensor is not accepted; no reward producer in the repo emits those.
- The `PackedDataCollator` loss-mask write window keeps the original `row_p < max_packed_len - 1` clamp. That `min()` is a defensive no-op — `max_packed_len` is `>=` every bin's packed length by construction — so the clamp never bites today; it is retained (and now commented) to preserve the original behavior exactly.

## Benchmarks

Controller-side collate, single process, CPU, batch of 1024 (varying sequence lengths):

| Path | Before | After | Speedup |
|------|--------|-------|---------|
| `PackedDataCollator` (FFD, dp=8) | 288.5 ms | 8.4 ms | ~34x |
| `convert_prompts_responses_to_batch_tensors` (RL) | 92.8 ms | 14.5 ms | ~6.4x |
| `collate_sft_batch` (unpacked SFT) | 98.4 ms | 16.2 ms | ~6x |

## Test plan

- [x] New `tests/train/test_collation_vectorization_equivalence.py`: pins a faithful reference of each *original* loop and fuzzes the vectorized output against it with `torch.equal` plus explicit per-tensor `dtype` assertions — RL (with/without logprobs, list and `float32`-tensor rewards), unpacked SFT, and packed SFT across TP/PP/CP/DP configs. Because `torch.equal` is dtype-insensitive on matching values, the integer/float dtypes (`action_mask`/`attention_mask` `int64`/`long`, `loss_mask`/`rewards` `float32`) are pinned with explicit `.dtype` assertions. The packed test re-derives the FFD / DP-shard / `max_packed_len` decision inline as its own oracle, so any production drift surfaces as a `torch.equal` mismatch. (Mutation-checked: an injected off-by-one in any vectorized path fails the suite — note the unreachable `loss_mask` clamp is the one spot a localized off-by-one would not be caught, since it never fires under any in-practice input.)
- [x] Existing `tests/train/dataset/test_preprocess.py`, `tests/train/test_sft_packing_collate.py`, `tests/train/test_packing_round_trip.py`, `tests/train/test_sft_tokenization.py` pass unchanged.
- [x] `ruff` + `black` clean.

```bash
uv run --isolated --extra dev --extra megatron -- pytest \
  tests/train/test_collation_vectorization_equivalence.py \
  tests/train/dataset/test_preprocess.py \
  tests/train/test_sft_packing_collate.py \
  tests/train/test_packing_round_trip.py
```

> Heads-up for reviewers: this overlaps open PR NovaSky-AI#1752 ([train] VLM SFT on Megatron), which edits the same `collate_sft_batch` loop and `TrainingInputBatch` dict to collect `pixel_values` / `image_grid_thw`. Whichever lands second needs a small rebase; if NovaSky-AI#1752 lands first, its per-sample VLM tensor collection should be reinstated inside the vectorized `for i, ex in enumerate(examples):` loop and its two keys re-added to the `from_numpy` batch dict.

🤖 Generated with [Claude Code](https://claude.com/claude-code)
@s-chundi

Copy link
Copy Markdown
Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Vision-Language Model (VLM) Supervised Fine-Tuning (SFT) using the Megatron backend, including dataset preparation scripts, training configurations, and integration with Hugging Face processors. The changes enforce VLM-specific parallelism constraints and optimize memory usage by dropping vision inputs on non-first pipeline stages. The review feedback focuses on improving pipeline parallelism consistency by restricting image_grid_thw processing and transfer to the first pipeline stage, and making the image mode conversion in the dataset preparation script more robust by converting all non-RGB images to RGB.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Comment thread examples/train/sft/prepare_cauldron_vlm.py Outdated
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant