Skip to content

[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype)#1805

Open
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:feat/optimizer-state-dtype-coercion
Open

[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype)#1805
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:feat/optimizer-state-dtype-coercion

Conversation

@dyurk-lila

Copy link
Copy Markdown

Summary

Megatron's precision-aware OptimizerConfig types its *_dtype fields (exp_avg_dtype, exp_avg_sq_dtype, main_params_dtype, params_dtype) as real torch.dtype, but optimizer_config_kwargs is forwarded verbatim from YAML/Hydra, which delivers plain strings (e.g. "bf16"). Such a string would reach TransformerEngine FusedAdam and crash. This adds a central str -> torch.dtype coercion at the single optimizer-construction choke point so low-precision optimizer state can be configured from YAML.

What changed

  • New torch-only module skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py holding coerce_optimizer_dtype_kwargs(dict) -> dict plus two mapping tables:
    • _DTYPE_NAME_TO_TORCH: canonical dtype-name -> torch.dtype. The short forms fp32/bf16/fp16/fp8 follow Megatron-LM's own dtype_map; common alias spellings (bfloat16, float16/half, float32/float, float8/uint8) are also accepted. fp8 -> torch.uint8 is a TransformerEngine convention (TE represents FP8 optimizer state as uint8) and is added here, not sourced from Megatron's dtype_map.
    • _LEGAL_FIELD_DTYPES: per-field legal sets for the fields actually forwarded to TE FusedAdam — main_params_dtype (master weights) restricted to {fp32, fp16}; exp_avg_dtype/exp_avg_sq_dtype allow {fp32, fp16, bf16, fp8}. Illegal values raise a clear ValueError before reaching FusedAdam.
  • optimizer.py: init_megatron_optim_config now calls coerce_optimizer_dtype_kwargs(optimizer_config_kwargs) in place of the raw .update(...). The coercion sits at the single shared construction point (sole caller serves SFT and the RL policy).
  • Docs: documented the new string-dtype support for optimizer_config_kwargs in docs/content/docs/configuration/config.mdx (next to the existing use_precision_aware_optimizer callout, with the full name/alias table) and in docs/content/docs/examples/megatron.mdx (concise note cross-linking to the table) — accepted names/aliases, the per-field legal sets, and the fp8 -> uint8 convention — noting these short forms differ from the full bfloat16/float16/float32 spellings accepted by str_to_torch_dtype elsewhere.

The helper is deliberately kept free of any megatron.core import so it can be unit-tested on the CPU CI lane (torch only). Coercion lives at the optimizer-construction choke point rather than MegatronConfig.__post_init__, which would replace the YAML strings with torch.dtype objects in the dataclass and break the serializable config path (asdict/yaml.dump).

Numerical equivalence / safety

Byte-identical to current behavior unless a *_dtype key is explicitly set in optimizer_config_kwargs. With no *_dtype overrides the coercion is a pure pass-through copy, so OptimizerConfig keeps its existing fp32 defaults for exp_avg_dtype/exp_avg_sq_dtype/main_params_dtype and the hardcoded params_dtype=torch.bfloat16 seed is unchanged. The default optimizer kwargs contain no *_dtype keys, so the default path is unchanged. The only intentional behavior change: a *_dtype string that previously would have reached FusedAdam and crashed now becomes the correct torch.dtype (enabling low-precision optimizer state), and an illegal main_params_dtype now fails fast with a clear message instead of a cryptic TE error. Values already torch.dtype and non-dtype kwargs pass through untouched; non-string/non-dtype *_dtype values (e.g. None) pass through so Megatron's own validation surfaces them. main_grads_dtype is coerced str->dtype but intentionally has no legal-set row: at the pinned megatron-core rev it is not forwarded to TE FusedAdam, so there is no TE-backed legal set to enforce; it is left for OptimizerConfig.__post_init__ to validate (mirroring how params_dtype is handled).

Generality & follow-ups

  • Covers all Megatron optimizer construction reachable via init_megatron_optim_config (SFT and RL policy). The critic worker does not construct an optimizer through this path.
  • The FSDP optimizer-state path is intentionally out of scope — it is a separate code path with its own mixed-precision config; no change made there.
  • A separate str -> torch.dtype helper already exists (str_to_torch_dtype / PrecisionType.to_dtype), but neither knows the fp8 -> uint8 mapping nor does per-field legal-set validation, both of which the precision-aware optimizer-state feature requires; consolidating the canonical name table is a possible follow-up.

Test plan

tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py:

  • TestCoerceOptimizerDtypeKwargs (CPU, no skip-guard — runs on the CPU lane): parametrized name->dtype coercion for all aliases, fp8->uint8, case/whitespace insensitivity, torch.dtype pass-through, main_params_dtype accepts fp32/fp16 and rejects bf16/fp8, params_dtype coercion, main_grads_dtype coerced-but-not-field-validated, unrecognized-name ValueError, unrelated kwargs untouched, None pass-through, input not mutated.
  • TestInitMegatronOptimConfigDtypeCoercion (megatron-gated via _has_megatron skip-guard, no GPU): end-to-end that string kwargs reach a real OptimizerConfig with coerced dtypes; params_dtype string override replaces the seeded default; default (no override) keeps fp32 defaults; and use_precision_aware_optimizer=False + non-fp32 state fast-fails with megatron's own AssertionError.

Run:

uv run --isolated --extra megatron --extra dev pytest \
    tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v

The CPU class runs on the CPU lane (torch only, megatron-core not required); the megatron-gated class runs wherever megatron-core is installed. No GPU is required by either.

…g_dtype etc. to torch.dtype)

## Summary

Megatron's precision-aware `OptimizerConfig` types its `*_dtype` fields (`exp_avg_dtype`, `exp_avg_sq_dtype`, `main_params_dtype`, `params_dtype`) as real `torch.dtype`, but `optimizer_config_kwargs` is forwarded verbatim from YAML/Hydra, which delivers plain strings (e.g. `"bf16"`). Such a string would reach TransformerEngine FusedAdam and crash. This adds a central `str -> torch.dtype` coercion at the single optimizer-construction choke point so low-precision optimizer state can be configured from YAML.

## What changed

- New torch-only module `skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py` holding `coerce_optimizer_dtype_kwargs(dict) -> dict` plus two mapping tables:
  - `_DTYPE_NAME_TO_TORCH`: canonical dtype-name -> `torch.dtype`. The short forms `fp32`/`bf16`/`fp16`/`fp8` follow Megatron-LM's own `dtype_map`; common alias spellings (`bfloat16`, `float16`/`half`, `float32`/`float`, `float8`/`uint8`) are also accepted. `fp8 -> torch.uint8` is a TransformerEngine convention (TE represents FP8 optimizer state as uint8) and is added here, not sourced from Megatron's `dtype_map`.
  - `_LEGAL_FIELD_DTYPES`: per-field legal sets for the fields actually forwarded to TE FusedAdam — `main_params_dtype` (master weights) restricted to `{fp32, fp16}`; `exp_avg_dtype`/`exp_avg_sq_dtype` allow `{fp32, fp16, bf16, fp8}`. Illegal values raise a clear `ValueError` before reaching FusedAdam.
- `optimizer.py`: `init_megatron_optim_config` now calls `coerce_optimizer_dtype_kwargs(optimizer_config_kwargs)` in place of the raw `.update(...)`. The coercion sits at the single shared construction point (sole caller serves SFT and the RL policy).
- Docs: documented the new string-dtype support for `optimizer_config_kwargs` in `docs/content/docs/configuration/config.mdx` (next to the existing `use_precision_aware_optimizer` callout, with the full name/alias table) and in `docs/content/docs/examples/megatron.mdx` (concise note cross-linking to the table) — accepted names/aliases, the per-field legal sets, and the `fp8 -> uint8` convention — noting these short forms differ from the full `bfloat16`/`float16`/`float32` spellings accepted by `str_to_torch_dtype` elsewhere.

The helper is deliberately kept free of any `megatron.core` import so it can be unit-tested on the CPU CI lane (torch only). Coercion lives at the optimizer-construction choke point rather than `MegatronConfig.__post_init__`, which would replace the YAML strings with `torch.dtype` objects in the dataclass and break the serializable config path (`asdict`/`yaml.dump`).

## Numerical equivalence / safety

Byte-identical to current behavior unless a `*_dtype` key is explicitly set in `optimizer_config_kwargs`. With no `*_dtype` overrides the coercion is a pure pass-through copy, so `OptimizerConfig` keeps its existing fp32 defaults for `exp_avg_dtype`/`exp_avg_sq_dtype`/`main_params_dtype` and the hardcoded `params_dtype=torch.bfloat16` seed is unchanged. The default optimizer kwargs contain no `*_dtype` keys, so the default path is unchanged. The only intentional behavior change: a `*_dtype` string that previously would have reached FusedAdam and crashed now becomes the correct `torch.dtype` (enabling low-precision optimizer state), and an illegal `main_params_dtype` now fails fast with a clear message instead of a cryptic TE error. Values already `torch.dtype` and non-dtype kwargs pass through untouched; non-string/non-dtype `*_dtype` values (e.g. `None`) pass through so Megatron's own validation surfaces them. `main_grads_dtype` is coerced str->dtype but intentionally has no legal-set row: at the pinned megatron-core rev it is not forwarded to TE FusedAdam, so there is no TE-backed legal set to enforce; it is left for `OptimizerConfig.__post_init__` to validate (mirroring how `params_dtype` is handled).

## Generality & follow-ups

- Covers all Megatron optimizer construction reachable via `init_megatron_optim_config` (SFT and RL policy). The critic worker does not construct an optimizer through this path.
- The FSDP optimizer-state path is intentionally out of scope — it is a separate code path with its own mixed-precision config; no change made there.
- A separate `str -> torch.dtype` helper already exists (`str_to_torch_dtype` / `PrecisionType.to_dtype`), but neither knows the `fp8 -> uint8` mapping nor does per-field legal-set validation, both of which the precision-aware optimizer-state feature requires; consolidating the canonical name table is a possible follow-up.

## Test plan

`tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py`:
- `TestCoerceOptimizerDtypeKwargs` (CPU, no skip-guard — runs on the CPU lane): parametrized name->dtype coercion for all aliases, `fp8->uint8`, case/whitespace insensitivity, `torch.dtype` pass-through, `main_params_dtype` accepts fp32/fp16 and rejects bf16/fp8, `params_dtype` coercion, `main_grads_dtype` coerced-but-not-field-validated, unrecognized-name `ValueError`, unrelated kwargs untouched, `None` pass-through, input not mutated.
- `TestInitMegatronOptimConfigDtypeCoercion` (megatron-gated via `_has_megatron` skip-guard, no GPU): end-to-end that string kwargs reach a real `OptimizerConfig` with coerced dtypes; `params_dtype` string override replaces the seeded default; default (no override) keeps fp32 defaults; and `use_precision_aware_optimizer=False` + non-fp32 state fast-fails with megatron's own `AssertionError`.

Run:

```bash
uv run --isolated --extra megatron --extra dev pytest \
    tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v
```

The CPU class runs on the CPU lane (torch only, megatron-core not required); the megatron-gated class runs wherever megatron-core is installed. No GPU is required by either.
@dyurk-lila dyurk-lila marked this pull request as ready for review June 18, 2026 16:34

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to coerce Megatron optimizer *_dtype string arguments (such as 'bf16' or 'fp8' from YAML configurations) into real torch.dtype objects before constructing Megatron's OptimizerConfig. It adds a new helper module optimizer_dtype.py with validation logic, updates the optimizer initialization, adds comprehensive unit tests, and updates the documentation. The reviewer suggested adding a defensive check for None in coerce_optimizer_dtype_kwargs to prevent potential AttributeError crashes if the configuration is omitted.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

ValueError: if a ``*_dtype`` value is an unrecognized dtype name, or if a coerced
dtype is illegal for that specific field (e.g. bf16/fp8 for ``main_params_dtype``).
"""
coerced: Dict[str, Any] = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If optimizer_config_kwargs is None (e.g., if it is omitted or set to null in the YAML configuration), calling .items() on it will raise an AttributeError. Adding a defensive None check at the beginning of the function ensures robustness and prevents runtime crashes.

    if optimizer_config_kwargs is None:
        return {}
    coerced: Dict[str, Any] = {}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant