Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 11 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 11 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces MoEBlock, a self-contained Flax-Linen module that wires together TE's fused router, pluggable permutation backends (pure-JAX argsort or Triton), grouped_dense-based expert FFN, and ragged-all-to-all expert parallelism via shard_map. It also extends permutation.py with pure-JAX token dispatch/combine (custom VJPs), Triton-backed wrappers, and ragged-A2A EP parameter helpers.

  • MoEBlock (flax/moe.py): Full MoE forward/backward supporting no-EP and A2A-EP paths, DeepSeek-style grouped top-k, optional aux load-balancing loss, FSDP + TP axes, and two pluggable permutation backends. The EP path has a latent receive-buffer overflow when align_size > 0 is combined with EP.
  • permutation.py: Adds pure-JAX and Triton token dispatch/combine with correct custom VJPs, plus ragged-A2A EP helpers. The backward pass uses jnp.isnan as a sentinel to zero gradient slots left uninitialized by the Triton padding kernel.
  • gemm.py: Removes @cache from _should_enforce_v2_grouped_gemm so monkeypatch.setenv in tests takes effect immediately; negligible runtime cost.

Confidence Score: 3/5

The no-EP single-device path is safe; the A2A-EP path has a latent receive-buffer overflow when alignment padding and expert parallelism are active simultaneously.

When align_size > 0 and EP are both enabled, recv_buffer_rows is sized for unpadded tokens but the A2A transmits aligned group sizes. In the worst case, actual received tokens exceed the buffer by num_experts*(align_size-1) rows, causing incorrect results or a crash. This combination is not currently tested, so the bug is latent but real.

transformer_engine/jax/flax/moe.py — recv_buffer_rows calculation in _forward_a2a_ep. transformer_engine/jax/permutation.py — NaN-sentinel zeroing in _token_combine_bwd_rule.

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New MoEBlock Flax-Linen module; EP path has a latent buffer-overflow bug when align_size > 0 and EP are combined.
transformer_engine/jax/permutation.py Adds pure-JAX and Triton token dispatch/combine with custom VJPs and ragged-A2A EP helpers; NaN-sentinel zeroing in backward could mask legitimate NaN gradients.
transformer_engine/jax/cpp_extensions/gemm.py Removes @cache from _should_enforce_v2_grouped_gemm so tests can flip the env var via monkeypatch.setenv.
transformer_engine/jax/flax/init.py Exports MoEBlock from the flax subpackage.
tests/jax/test_moe_block.py Single-device tests covering forward shape, backward finiteness, cross-backend equivalence, aux-loss, group-topk, align_size equivalence, and JIT determinism.
tests/jax/test_distributed_moe_block.py EP=2 x FSDP=2 distributed test; no align_size > 0 EP test, leaving the receive-buffer overflow path untested.

Sequence Diagram

sequenceDiagram
    participant C as __call__
    participant G as _gate
    participant R as _route_topk
    participant GP as _global_permute
    participant SM as shard_map (_a2a_body)
    participant A2A as ragged_all_to_all
    participant LP as local_permute_after_a2a
    participant FFN as _expert_ffn
    participant LU as local_unpermute_before_a2a
    participant GC as _global_combine

    C->>G: inputs to gate_logits
    C->>R: gate_logits to sparse_probs and routing_map
    C->>GP: inputs_2d, sparse_probs, routing_map to GlobalPermuteResult
    alt expert_parallelism_axis is None
        C->>FFN: sorted_inputs, group_sizes to expert_outputs
        C->>GC: expert_outputs, perm to output
    else A2A EP path via shard_map
        SM->>A2A: fwd ragged_all_to_all sends per-expert chunks to owner shard
        SM->>LP: reorder source_shard,expert to expert,source_shard
        SM->>FFN: local grouped GEMM x3 plus activation
        SM->>LU: reorder back to source-shard-major
        SM->>A2A: rev ragged_all_to_all returns outputs to source shards
        SM->>GC: global_combine weighted sum
    end
    C-->>C: return output and aux_loss
Loading

Reviews (5): Last reviewed commit: "revert C++ changes and will put in a new..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +427 to +457
def _compute_aux_loss(
self,
logits_2d: jnp.ndarray,
) -> Optional[jnp.ndarray]:
"""Compute the MoE auxiliary load-balancing loss.

The score-for-aux kernel has no data dependency on the main
routing kernel, so XLA can overlap them on the GPU.

``logits_2d`` should be the *full* logits tensor over the global
token batch -- under EP the caller is responsible for
:func:`jax.lax.all_gather` ing the logits before calling this so
the aux_loss formula
``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])``
sees the global ``T`` and the global ``tokens_per_expert``.
"""
if self.aux_loss_coeff <= 0.0:
return None
aux_scores, aux_routing_map = fused_topk_with_score_function(
logits_2d.astype(jnp.float32),
topk=self.num_experts_per_tok,
score_function=self.score_function,
compute_aux_scores=True,
)
aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0)
return fused_moe_aux_loss(
aux_scores.astype(jnp.float32),
aux_tokens_per_expert,
topk=self.num_experts_per_tok,
coeff=self.aux_loss_coeff,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Aux loss tokens_per_expert is inconsistent with actual grouped-topk routing

When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.

Comment thread tests/jax/test_distributed_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
g_pj = _unwrap_partitioned(grads_pj["params"][name])
g_tr = _unwrap_partitioned(grads_tr["params"][name])
assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung following up on this tolerance

Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/permutation.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py
Comment thread transformer_engine/jax/flax/moe.py Outdated
nvjax and others added 2 commits May 7, 2026 15:18
…int in C++ files, make FP8 works. Tested with current scaling

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 7, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.


namespace transformer_engine::detail {

namespace {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: I added this as I run into cudaErrorInvalidResourceHandle (at cast.cu:112 in nvte_multi_tensor_quantize) when trying to launch with 1 process in an independent script (that imports TE moe block) to test MoEBlock with data type FP8. This was because the global cudaStream or Event pool was created lazily via std::call_once, which leaves the resources bound to whichever device arrive first.

I fixed this with caching per cudaGetDevice() in an unordered map. Let me know if there is any reason why we should not do this. @jberchtold-nvidia

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung Let's move this into a separate PR. This is an issue in the V1 codepath but not V2. It's a good fix and the code seems reasonable to me but it could impact PyTorch so I'd like to decouple from this PR to unblock MoEBlock+V2 and then we can have PyTorch and core review and test this fix separately.

} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pointing at this so reviewers also pay attention to this change (that is different from the initial version previous to the other comments). When I tested with FP8, which is when the padding to align_size took effect, I start to see these checks firing, to which I then relaxed the checks because I think it should allow for garbage data on dim m to exist when there is worst case padding.

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the V1 grouped GEMM FFI, not the one we are using now that binds to cuBLASLt. The V2 does support k >= group sizes and I have tested the V2 with it thoroughly. In theory, I think relaxing this constraint for V1 shouldn't cause issues, but I have not tested it so I am not sure.

In your message above, is FP8 = tensor-scaled FP8? If so, the reason that triggers this assertion is we don't support tensor-scaled FP8 for the V2 grouped quant + GEMM so it falls through to the old V1 codepath.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think I tested with current scaling that's why it hit this V1 implementation. I have changedto MXFP8_1D_SCALING in my test script now. Good catch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung can we revert this assertion relaxing since the test above was modified to only run if the V2 grouped GEMM was available? Unless you've already tried this and it works okay, then we can update the test above to run on both V1/V2 grouped GEMMs. But if it doesn't' work or you haven't tried yet, let's revert this change so we can focus on V2.

Comment thread tests/jax/test_moe_block.py Outdated
assert jnp.abs(aux_loss) < 1e2

def test_aux_loss_uses_real_routing_under_group_topk(self):
"""Regression test for PR #2912 review (greptile P1).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: Remove this "Regression test for PR #2912 review (greptile P1).", I don't think we need those specifics. The other comments are useful though.

} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung can we revert this assertion relaxing since the test above was modified to only run if the V2 grouped GEMM was available? Unless you've already tried this and it works okay, then we can update the test above to run on both V1/V2 grouped GEMMs. But if it doesn't' work or you haven't tried yet, let's revert this change so we can focus on V2.

Comment thread transformer_engine/jax/flax/moe.py
Comment thread tests/jax/test_moe_block.py Outdated
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
g_pj = _unwrap_partitioned(grads_pj["params"][name])
g_tr = _unwrap_partitioned(grads_tr["params"][name])
assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung following up on this tolerance


namespace transformer_engine::detail {

namespace {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdophung Let's move this into a separate PR. This is an issue in the V1 codepath but not V2. It's a good fix and the code seems reasonable to me but it could impact PyTorch so I'd like to decouple from this PR to unblock MoEBlock+V2 and then we can have PyTorch and core review and test this fix separately.

… grad tol to 5e-2, move arch/align_size docs into MoEBlock class

Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +909 to +914
batch_divisor = num_ep * dp_size
if global_batch_size % batch_divisor != 0:
raise ValueError(
f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}"
)
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Receive buffer undersized when align_size > 0 + EP are combined

recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:

recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)

This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

3 participants