[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype)#1805
Conversation
…g_dtype etc. to torch.dtype)
## Summary
Megatron's precision-aware `OptimizerConfig` types its `*_dtype` fields (`exp_avg_dtype`, `exp_avg_sq_dtype`, `main_params_dtype`, `params_dtype`) as real `torch.dtype`, but `optimizer_config_kwargs` is forwarded verbatim from YAML/Hydra, which delivers plain strings (e.g. `"bf16"`). Such a string would reach TransformerEngine FusedAdam and crash. This adds a central `str -> torch.dtype` coercion at the single optimizer-construction choke point so low-precision optimizer state can be configured from YAML.
## What changed
- New torch-only module `skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py` holding `coerce_optimizer_dtype_kwargs(dict) -> dict` plus two mapping tables:
- `_DTYPE_NAME_TO_TORCH`: canonical dtype-name -> `torch.dtype`. The short forms `fp32`/`bf16`/`fp16`/`fp8` follow Megatron-LM's own `dtype_map`; common alias spellings (`bfloat16`, `float16`/`half`, `float32`/`float`, `float8`/`uint8`) are also accepted. `fp8 -> torch.uint8` is a TransformerEngine convention (TE represents FP8 optimizer state as uint8) and is added here, not sourced from Megatron's `dtype_map`.
- `_LEGAL_FIELD_DTYPES`: per-field legal sets for the fields actually forwarded to TE FusedAdam — `main_params_dtype` (master weights) restricted to `{fp32, fp16}`; `exp_avg_dtype`/`exp_avg_sq_dtype` allow `{fp32, fp16, bf16, fp8}`. Illegal values raise a clear `ValueError` before reaching FusedAdam.
- `optimizer.py`: `init_megatron_optim_config` now calls `coerce_optimizer_dtype_kwargs(optimizer_config_kwargs)` in place of the raw `.update(...)`. The coercion sits at the single shared construction point (sole caller serves SFT and the RL policy).
- Docs: documented the new string-dtype support for `optimizer_config_kwargs` in `docs/content/docs/configuration/config.mdx` (next to the existing `use_precision_aware_optimizer` callout, with the full name/alias table) and in `docs/content/docs/examples/megatron.mdx` (concise note cross-linking to the table) — accepted names/aliases, the per-field legal sets, and the `fp8 -> uint8` convention — noting these short forms differ from the full `bfloat16`/`float16`/`float32` spellings accepted by `str_to_torch_dtype` elsewhere.
The helper is deliberately kept free of any `megatron.core` import so it can be unit-tested on the CPU CI lane (torch only). Coercion lives at the optimizer-construction choke point rather than `MegatronConfig.__post_init__`, which would replace the YAML strings with `torch.dtype` objects in the dataclass and break the serializable config path (`asdict`/`yaml.dump`).
## Numerical equivalence / safety
Byte-identical to current behavior unless a `*_dtype` key is explicitly set in `optimizer_config_kwargs`. With no `*_dtype` overrides the coercion is a pure pass-through copy, so `OptimizerConfig` keeps its existing fp32 defaults for `exp_avg_dtype`/`exp_avg_sq_dtype`/`main_params_dtype` and the hardcoded `params_dtype=torch.bfloat16` seed is unchanged. The default optimizer kwargs contain no `*_dtype` keys, so the default path is unchanged. The only intentional behavior change: a `*_dtype` string that previously would have reached FusedAdam and crashed now becomes the correct `torch.dtype` (enabling low-precision optimizer state), and an illegal `main_params_dtype` now fails fast with a clear message instead of a cryptic TE error. Values already `torch.dtype` and non-dtype kwargs pass through untouched; non-string/non-dtype `*_dtype` values (e.g. `None`) pass through so Megatron's own validation surfaces them. `main_grads_dtype` is coerced str->dtype but intentionally has no legal-set row: at the pinned megatron-core rev it is not forwarded to TE FusedAdam, so there is no TE-backed legal set to enforce; it is left for `OptimizerConfig.__post_init__` to validate (mirroring how `params_dtype` is handled).
## Generality & follow-ups
- Covers all Megatron optimizer construction reachable via `init_megatron_optim_config` (SFT and RL policy). The critic worker does not construct an optimizer through this path.
- The FSDP optimizer-state path is intentionally out of scope — it is a separate code path with its own mixed-precision config; no change made there.
- A separate `str -> torch.dtype` helper already exists (`str_to_torch_dtype` / `PrecisionType.to_dtype`), but neither knows the `fp8 -> uint8` mapping nor does per-field legal-set validation, both of which the precision-aware optimizer-state feature requires; consolidating the canonical name table is a possible follow-up.
## Test plan
`tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py`:
- `TestCoerceOptimizerDtypeKwargs` (CPU, no skip-guard — runs on the CPU lane): parametrized name->dtype coercion for all aliases, `fp8->uint8`, case/whitespace insensitivity, `torch.dtype` pass-through, `main_params_dtype` accepts fp32/fp16 and rejects bf16/fp8, `params_dtype` coercion, `main_grads_dtype` coerced-but-not-field-validated, unrecognized-name `ValueError`, unrelated kwargs untouched, `None` pass-through, input not mutated.
- `TestInitMegatronOptimConfigDtypeCoercion` (megatron-gated via `_has_megatron` skip-guard, no GPU): end-to-end that string kwargs reach a real `OptimizerConfig` with coerced dtypes; `params_dtype` string override replaces the seeded default; default (no override) keeps fp32 defaults; and `use_precision_aware_optimizer=False` + non-fp32 state fast-fails with megatron's own `AssertionError`.
Run:
```bash
uv run --isolated --extra megatron --extra dev pytest \
tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v
```
The CPU class runs on the CPU lane (torch only, megatron-core not required); the megatron-gated class runs wherever megatron-core is installed. No GPU is required by either.
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to coerce Megatron optimizer *_dtype string arguments (such as 'bf16' or 'fp8' from YAML configurations) into real torch.dtype objects before constructing Megatron's OptimizerConfig. It adds a new helper module optimizer_dtype.py with validation logic, updates the optimizer initialization, adds comprehensive unit tests, and updates the documentation. The reviewer suggested adding a defensive check for None in coerce_optimizer_dtype_kwargs to prevent potential AttributeError crashes if the configuration is omitted.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| ValueError: if a ``*_dtype`` value is an unrecognized dtype name, or if a coerced | ||
| dtype is illegal for that specific field (e.g. bf16/fp8 for ``main_params_dtype``). | ||
| """ | ||
| coerced: Dict[str, Any] = {} |
There was a problem hiding this comment.
If optimizer_config_kwargs is None (e.g., if it is omitted or set to null in the YAML configuration), calling .items() on it will raise an AttributeError. Adding a defensive None check at the beginning of the function ensures robustness and prevents runtime crashes.
if optimizer_config_kwargs is None:
return {}
coerced: Dict[str, Any] = {}
Summary
Megatron's precision-aware
OptimizerConfigtypes its*_dtypefields (exp_avg_dtype,exp_avg_sq_dtype,main_params_dtype,params_dtype) as realtorch.dtype, butoptimizer_config_kwargsis forwarded verbatim from YAML/Hydra, which delivers plain strings (e.g."bf16"). Such a string would reach TransformerEngine FusedAdam and crash. This adds a centralstr -> torch.dtypecoercion at the single optimizer-construction choke point so low-precision optimizer state can be configured from YAML.What changed
skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.pyholdingcoerce_optimizer_dtype_kwargs(dict) -> dictplus two mapping tables:_DTYPE_NAME_TO_TORCH: canonical dtype-name ->torch.dtype. The short formsfp32/bf16/fp16/fp8follow Megatron-LM's owndtype_map; common alias spellings (bfloat16,float16/half,float32/float,float8/uint8) are also accepted.fp8 -> torch.uint8is a TransformerEngine convention (TE represents FP8 optimizer state as uint8) and is added here, not sourced from Megatron'sdtype_map._LEGAL_FIELD_DTYPES: per-field legal sets for the fields actually forwarded to TE FusedAdam —main_params_dtype(master weights) restricted to{fp32, fp16};exp_avg_dtype/exp_avg_sq_dtypeallow{fp32, fp16, bf16, fp8}. Illegal values raise a clearValueErrorbefore reaching FusedAdam.optimizer.py:init_megatron_optim_confignow callscoerce_optimizer_dtype_kwargs(optimizer_config_kwargs)in place of the raw.update(...). The coercion sits at the single shared construction point (sole caller serves SFT and the RL policy).optimizer_config_kwargsindocs/content/docs/configuration/config.mdx(next to the existinguse_precision_aware_optimizercallout, with the full name/alias table) and indocs/content/docs/examples/megatron.mdx(concise note cross-linking to the table) — accepted names/aliases, the per-field legal sets, and thefp8 -> uint8convention — noting these short forms differ from the fullbfloat16/float16/float32spellings accepted bystr_to_torch_dtypeelsewhere.The helper is deliberately kept free of any
megatron.coreimport so it can be unit-tested on the CPU CI lane (torch only). Coercion lives at the optimizer-construction choke point rather thanMegatronConfig.__post_init__, which would replace the YAML strings withtorch.dtypeobjects in the dataclass and break the serializable config path (asdict/yaml.dump).Numerical equivalence / safety
Byte-identical to current behavior unless a
*_dtypekey is explicitly set inoptimizer_config_kwargs. With no*_dtypeoverrides the coercion is a pure pass-through copy, soOptimizerConfigkeeps its existing fp32 defaults forexp_avg_dtype/exp_avg_sq_dtype/main_params_dtypeand the hardcodedparams_dtype=torch.bfloat16seed is unchanged. The default optimizer kwargs contain no*_dtypekeys, so the default path is unchanged. The only intentional behavior change: a*_dtypestring that previously would have reached FusedAdam and crashed now becomes the correcttorch.dtype(enabling low-precision optimizer state), and an illegalmain_params_dtypenow fails fast with a clear message instead of a cryptic TE error. Values alreadytorch.dtypeand non-dtype kwargs pass through untouched; non-string/non-dtype*_dtypevalues (e.g.None) pass through so Megatron's own validation surfaces them.main_grads_dtypeis coerced str->dtype but intentionally has no legal-set row: at the pinned megatron-core rev it is not forwarded to TE FusedAdam, so there is no TE-backed legal set to enforce; it is left forOptimizerConfig.__post_init__to validate (mirroring howparams_dtypeis handled).Generality & follow-ups
init_megatron_optim_config(SFT and RL policy). The critic worker does not construct an optimizer through this path.str -> torch.dtypehelper already exists (str_to_torch_dtype/PrecisionType.to_dtype), but neither knows thefp8 -> uint8mapping nor does per-field legal-set validation, both of which the precision-aware optimizer-state feature requires; consolidating the canonical name table is a possible follow-up.Test plan
tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py:TestCoerceOptimizerDtypeKwargs(CPU, no skip-guard — runs on the CPU lane): parametrized name->dtype coercion for all aliases,fp8->uint8, case/whitespace insensitivity,torch.dtypepass-through,main_params_dtypeaccepts fp32/fp16 and rejects bf16/fp8,params_dtypecoercion,main_grads_dtypecoerced-but-not-field-validated, unrecognized-nameValueError, unrelated kwargs untouched,Nonepass-through, input not mutated.TestInitMegatronOptimConfigDtypeCoercion(megatron-gated via_has_megatronskip-guard, no GPU): end-to-end that string kwargs reach a realOptimizerConfigwith coerced dtypes;params_dtypestring override replaces the seeded default; default (no override) keeps fp32 defaults; anduse_precision_aware_optimizer=False+ non-fp32 state fast-fails with megatron's ownAssertionError.Run:
uv run --isolated --extra megatron --extra dev pytest \ tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -vThe CPU class runs on the CPU lane (torch only, megatron-core not required); the megatron-gated class runs wherever megatron-core is installed. No GPU is required by either.