[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 11 commits into
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 3/5The no-EP single-device path is safe; the A2A-EP path has a latent receive-buffer overflow when alignment padding and expert parallelism are active simultaneously. When align_size > 0 and EP are both enabled, recv_buffer_rows is sized for unpadded tokens but the A2A transmits aligned group sizes. In the worst case, actual received tokens exceed the buffer by num_experts*(align_size-1) rows, causing incorrect results or a crash. This combination is not currently tested, so the bug is latent but real. transformer_engine/jax/flax/moe.py — recv_buffer_rows calculation in _forward_a2a_ep. transformer_engine/jax/permutation.py — NaN-sentinel zeroing in _token_combine_bwd_rule. Important Files Changed
Sequence DiagramsequenceDiagram
participant C as __call__
participant G as _gate
participant R as _route_topk
participant GP as _global_permute
participant SM as shard_map (_a2a_body)
participant A2A as ragged_all_to_all
participant LP as local_permute_after_a2a
participant FFN as _expert_ffn
participant LU as local_unpermute_before_a2a
participant GC as _global_combine
C->>G: inputs to gate_logits
C->>R: gate_logits to sparse_probs and routing_map
C->>GP: inputs_2d, sparse_probs, routing_map to GlobalPermuteResult
alt expert_parallelism_axis is None
C->>FFN: sorted_inputs, group_sizes to expert_outputs
C->>GC: expert_outputs, perm to output
else A2A EP path via shard_map
SM->>A2A: fwd ragged_all_to_all sends per-expert chunks to owner shard
SM->>LP: reorder source_shard,expert to expert,source_shard
SM->>FFN: local grouped GEMM x3 plus activation
SM->>LU: reorder back to source-shard-major
SM->>A2A: rev ragged_all_to_all returns outputs to source shards
SM->>GC: global_combine weighted sum
end
C-->>C: return output and aux_loss
Reviews (5): Last reviewed commit: "revert C++ changes and will put in a new..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
| def _compute_aux_loss( | ||
| self, | ||
| logits_2d: jnp.ndarray, | ||
| ) -> Optional[jnp.ndarray]: | ||
| """Compute the MoE auxiliary load-balancing loss. | ||
|
|
||
| The score-for-aux kernel has no data dependency on the main | ||
| routing kernel, so XLA can overlap them on the GPU. | ||
|
|
||
| ``logits_2d`` should be the *full* logits tensor over the global | ||
| token batch -- under EP the caller is responsible for | ||
| :func:`jax.lax.all_gather` ing the logits before calling this so | ||
| the aux_loss formula | ||
| ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` | ||
| sees the global ``T`` and the global ``tokens_per_expert``. | ||
| """ | ||
| if self.aux_loss_coeff <= 0.0: | ||
| return None | ||
| aux_scores, aux_routing_map = fused_topk_with_score_function( | ||
| logits_2d.astype(jnp.float32), | ||
| topk=self.num_experts_per_tok, | ||
| score_function=self.score_function, | ||
| compute_aux_scores=True, | ||
| ) | ||
| aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) | ||
| return fused_moe_aux_loss( | ||
| aux_scores.astype(jnp.float32), | ||
| aux_tokens_per_expert, | ||
| topk=self.num_experts_per_tok, | ||
| coeff=self.aux_loss_coeff, | ||
| ) |
There was a problem hiding this comment.
Aux loss
tokens_per_expert is inconsistent with actual grouped-topk routing
When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.
| for name in ("gate_kernel", "wi_0", "wi_1", "wo"): | ||
| g_pj = _unwrap_partitioned(grads_pj["params"][name]) | ||
| g_tr = _unwrap_partitioned(grads_tr["params"][name]) | ||
| assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( |
There was a problem hiding this comment.
1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?
…int in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
|
|
||
| namespace transformer_engine::detail { | ||
|
|
||
| namespace { |
There was a problem hiding this comment.
NOTE: I added this as I run into cudaErrorInvalidResourceHandle (at cast.cu:112 in nvte_multi_tensor_quantize) when trying to launch with 1 process in an independent script (that imports TE moe block) to test MoEBlock with data type FP8. This was because the global cudaStream or Event pool was created lazily via std::call_once, which leaves the resources bound to whichever device arrive first.
I fixed this with caching per cudaGetDevice() in an unordered map. Let me know if there is any reason why we should not do this. @jberchtold-nvidia
There was a problem hiding this comment.
@tdophung Let's move this into a separate PR. This is an issue in the V1 codepath but not V2. It's a good fix and the code seems reasonable to me but it could impact PyTorch so I'd like to decouple from this PR to unblock MoEBlock+V2 and then we can have PyTorch and core review and test this fix separately.
| } else { | ||
| NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, | ||
| ", got sum(group_sizes)=", sum_group_sizes); | ||
| NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, |
There was a problem hiding this comment.
Just pointing at this so reviewers also pay attention to this change (that is different from the initial version previous to the other comments). When I tested with FP8, which is when the padding to align_size took effect, I start to see these checks firing, to which I then relaxed the checks because I think it should allow for garbage data on dim m to exist when there is worst case padding.
There was a problem hiding this comment.
This is for the V1 grouped GEMM FFI, not the one we are using now that binds to cuBLASLt. The V2 does support k >= group sizes and I have tested the V2 with it thoroughly. In theory, I think relaxing this constraint for V1 shouldn't cause issues, but I have not tested it so I am not sure.
In your message above, is FP8 = tensor-scaled FP8? If so, the reason that triggers this assertion is we don't support tensor-scaled FP8 for the V2 grouped quant + GEMM so it falls through to the old V1 codepath.
There was a problem hiding this comment.
Yes I think I tested with current scaling that's why it hit this V1 implementation. I have changedto MXFP8_1D_SCALING in my test script now. Good catch
There was a problem hiding this comment.
@tdophung can we revert this assertion relaxing since the test above was modified to only run if the V2 grouped GEMM was available? Unless you've already tried this and it works okay, then we can update the test above to run on both V1/V2 grouped GEMMs. But if it doesn't' work or you haven't tried yet, let's revert this change so we can focus on V2.
| assert jnp.abs(aux_loss) < 1e2 | ||
|
|
||
| def test_aux_loss_uses_real_routing_under_group_topk(self): | ||
| """Regression test for PR #2912 review (greptile P1). |
There was a problem hiding this comment.
SR: Remove this "Regression test for PR #2912 review (greptile P1).", I don't think we need those specifics. The other comments are useful though.
| } else { | ||
| NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, | ||
| ", got sum(group_sizes)=", sum_group_sizes); | ||
| NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, |
There was a problem hiding this comment.
@tdophung can we revert this assertion relaxing since the test above was modified to only run if the V2 grouped GEMM was available? Unless you've already tried this and it works okay, then we can update the test above to run on both V1/V2 grouped GEMMs. But if it doesn't' work or you haven't tried yet, let's revert this change so we can focus on V2.
| for name in ("gate_kernel", "wi_0", "wi_1", "wo"): | ||
| g_pj = _unwrap_partitioned(grads_pj["params"][name]) | ||
| g_tr = _unwrap_partitioned(grads_tr["params"][name]) | ||
| assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( |
|
|
||
| namespace transformer_engine::detail { | ||
|
|
||
| namespace { |
There was a problem hiding this comment.
@tdophung Let's move this into a separate PR. This is an issue in the V1 codepath but not V2. It's a good fix and the code seems reasonable to me but it could impact PyTorch so I'd like to decouple from this PR to unblock MoEBlock+V2 and then we can have PyTorch and core review and test this fix separately.
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class Signed-off-by: tdophung <tdophung@nvidia.com>
| batch_divisor = num_ep * dp_size | ||
| if global_batch_size % batch_divisor != 0: | ||
| raise ValueError( | ||
| f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" | ||
| ) | ||
| recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk |
There was a problem hiding this comment.
Receive buffer undersized when
align_size > 0 + EP are combined
recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)
This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: