Skip to content

[JAX] Size autotuned Triton grids per config #2975

Merged
tdophung merged 5 commits into
NVIDIA:mainfrom
tdophung:jax-triton-grid-autotune
May 14, 2026
Merged

[JAX] Size autotuned Triton grids per config #2975
tdophung merged 5 commits into
NVIDIA:mainfrom
tdophung:jax-triton-grid-autotune

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 11, 2026

Description

The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pass the grid callable format to triton kernel call
  • Make grid passed into the triton lowering call for all triton permutation kernels to be a a callable(meta)->tuple

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs
but dispatched every one with the same fixed grid sized for the smallest
BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make
grid accept a callable(meta)->tuple evaluated per config, matching the
jax-triton API. Update _permute_kernel, _unpermute_kernel, and
_sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on
GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung changed the title [JAX] Size autotuned Triton grids per config (3x perm-kernel speedup on JAX side) [JAX] Size autotuned Triton grids per config May 11, 2026
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 11, 2026 21:11
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes a performance bug where autotuned Triton kernels were launched with a grid sized for the smallest BLOCK_SIZE config regardless of which config was actually selected, causing the larger-block configs to over-launch. The fix replaces all three fixed-tuple grids in permutation.py with grid(meta) callables that compute the correct dimensions per config, and updates triton_call_lowering in utils.py to evaluate the callable per config at lowering time.

  • utils.py: Adds callable grid support (evaluated per autotune config), changes the default num_warps from 32 → 4 and num_stages from 1 → 3 to match Triton/jax-triton defaults, and adds an assertion on the grid argument type.
  • permutation.py: Replaces three fixed-tuple grid = (...) expressions with def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) in _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings.
  • The _get_min_block_size call and the "BLOCK_SIZE": block_size entry in constexprs are still present in all three sites; they are harmless on the autotuned path but represent dead code there.

Confidence Score: 5/5

The core fix is correct and well-contained — callable grids are evaluated per config, matching jax-triton's contract. The only discrepancy is a type-assertion gap that does not affect any current caller.

All three lowering sites consistently adopt the callable grid pattern. The dispatch logic correctly routes callable vs. tuple grids on both the autotuned and fallback paths. No existing caller passes a plain-int grid.

No files require special attention; the minor assertion inconsistency in utils.py is straightforward to address.

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/utils.py Adds callable grid support to triton_call_lowering, changes default num_warps from 32 to 4 and num_stages from 1 to 3; assertion omits int even though docstring and _normalize_grid still treat int as valid.
transformer_engine/jax/triton_extensions/permutation.py Replaces fixed-tuple grids with per-config callables in _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings; block_size variable and BLOCK_SIZE in constexprs remain but are overridden by config.kwargs on the autotuned path.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["triton_call_lowering(grid=...)"] --> B{grid type?}
    B -->|"callable"| C[grid_callable = grid]
    B -->|"tuple"| D["grid_tuple = _normalize_grid(grid)"]
    B -->|"int (assert fails!)"| E["AssertionError"]
    C --> F{is_autotuned?}
    D --> F
    F -->|"yes - for each config"| G["config_constexprs = {**constexprs, **config.kwargs}"]
    G --> H["config_grid = _normalize_grid(grid_callable(config_constexprs))"]
    H --> I["TritonKernelCall per config"]
    I --> J["TritonAutotunedKernelCall"]
    F -->|"no"| K["kernel_constexprs merged"]
    K --> L{grid_callable?}
    L -->|"yes"| M["single_grid = _normalize_grid(grid_callable(kernel_constexprs))"]
    L -->|"no"| N["single_grid = grid_tuple"]
    M --> O["TritonKernelCall"]
    N --> O
Loading

Reviews (4): Last reviewed commit: "[JAX] Triton wrapper defaults match jax-..." | Re-trigger Greptile

Comment thread transformer_engine/jax/triton_extensions/permutation.py Outdated
Clarify that constexprs values override config.kwargs in the non-autotune
fallback path (utils.py merges {**first_cfg.kwargs, **constexprs}). Three
sites: _permute_kernel, _unpermute_kernel, _sort_chunks_by_map_kernel.

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Left a few comments. Thanks!

Comment thread transformer_engine/jax/triton_extensions/permutation.py Outdated
Comment thread transformer_engine/jax/triton_extensions/utils.py
num_warps default 32->4 and num_stages 1->3 in triton_call_lowering match
Triton's own triton.Config defaults. Non-autotuned kernels (e.g.
_make_chunk_sort_map_kernel) were running with 1024 threads/block, an 8x
kernel slowdown. Also: tuple/callable grid assertion + comment trims.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te_ci jax

@tdophung
Copy link
Copy Markdown
Collaborator Author

2 notes:

  1. The num_warp arg is now recovered to default to Triton default num warp =4 num_stages=3. I deleted the TODO to expose them as I do not think hand tuning the kernels is sustainable, but want to note that we could also try to do autotuning on num_warps and num_stages as metrics. Let me know what you think @jberchtold-nvidia
  2. this recover most of the speed lost from using our own jax triton binding. However, there is still the dummy buffer issue that causes 1 specific kernel sort_chunks_by_map to slow down, but only about 1.172ms slow down from 3.02ms to 4.19 ms. This I will not disable in this PR because it compromise the correctness of the kernel in CI if done so. But another direction is to introduce yet another env var to control whether the dummy buffer is there or not, for benchmarking/pef purposes while waiting for the root cause to be fixed from JAX side

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

jberchtold-nvidia commented May 14, 2026

The num_warp arg is now recovered to default to Triton default num warp =4 num_stages=3. I deleted the TODO to expose them as I do not think hand tuning the kernels is sustainable, but want to note that we could also try to do autotuning on num_warps and num_stages as metrics. Let me know what you think @jberchtold-nvidia

I think the current approach gives a big speedup for now and we can leave it as you have it in this PR. We can try out the autotuning num_warps and num_stages in the future if needed

this recover most of the speed lost from using our own jax triton binding. However, there is still the dummy buffer issue that causes 1 specific kernel sort_chunks_by_map to slow down, but only about 1.172ms slow down from 3.02ms to 4.19 ms. This I will not disable in this PR because it compromise the correctness of the kernel in CI if done so. But another direction is to introduce yet another env var to control whether the dummy buffer is there or not, for benchmarking/pef purposes while waiting for the root cause to be fixed from JAX side

Sounds good! Makes sense we can't remove due to the correctness issue. I think the current speedups in this PR are big enough that we should move forward with it as is. If we end up needing to optimize further in the future we can explore the env var to remove the dummy buffer.

@tdophung

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI, thanks!

@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@tdophung tdophung merged commit c40398c into NVIDIA:main May 14, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants