[JAX] Size autotuned Triton grids per config #2975
Conversation
The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32. Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR fixes a performance bug where autotuned Triton kernels were launched with a grid sized for the smallest
Confidence Score: 5/5The core fix is correct and well-contained — callable grids are evaluated per config, matching jax-triton's contract. The only discrepancy is a type-assertion gap that does not affect any current caller. All three lowering sites consistently adopt the callable grid pattern. The dispatch logic correctly routes callable vs. tuple grids on both the autotuned and fallback paths. No existing caller passes a plain-int grid. No files require special attention; the minor assertion inconsistency in utils.py is straightforward to address. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["triton_call_lowering(grid=...)"] --> B{grid type?}
B -->|"callable"| C[grid_callable = grid]
B -->|"tuple"| D["grid_tuple = _normalize_grid(grid)"]
B -->|"int (assert fails!)"| E["AssertionError"]
C --> F{is_autotuned?}
D --> F
F -->|"yes - for each config"| G["config_constexprs = {**constexprs, **config.kwargs}"]
G --> H["config_grid = _normalize_grid(grid_callable(config_constexprs))"]
H --> I["TritonKernelCall per config"]
I --> J["TritonAutotunedKernelCall"]
F -->|"no"| K["kernel_constexprs merged"]
K --> L{grid_callable?}
L -->|"yes"| M["single_grid = _normalize_grid(grid_callable(kernel_constexprs))"]
L -->|"no"| N["single_grid = grid_tuple"]
M --> O["TritonKernelCall"]
N --> O
Reviews (4): Last reviewed commit: "[JAX] Triton wrapper defaults match jax-..." | Re-trigger Greptile |
Clarify that constexprs values override config.kwargs in the non-autotune
fallback path (utils.py merges {**first_cfg.kwargs, **constexprs}). Three
sites: _permute_kernel, _unpermute_kernel, _sort_chunks_by_map_kernel.
Signed-off-by: tdophung <tdophung@nvidia.com>
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks good! Left a few comments. Thanks!
num_warps default 32->4 and num_stages 1->3 in triton_call_lowering match Triton's own triton.Config defaults. Non-autotuned kernels (e.g. _make_chunk_sort_map_kernel) were running with 1024 threads/block, an 8x kernel slowdown. Also: tuple/callable grid assertion + comment trims. Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te_ci jax |
|
2 notes:
|
I think the current approach gives a big speedup for now and we can leave it as you have it in this PR. We can try out the autotuning num_warps and num_stages in the future if needed
Sounds good! Makes sense we can't remove due to the correctness issue. I think the current speedups in this PR are big enough that we should move forward with it as is. If we end up needing to optimize further in the future we can explore the env var to remove the dummy buffer. |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI, thanks!
|
/te-ci jax |
Description
The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: