Automatic Sparse Differentiation in JAX.
asdex exploits sparsity structure to efficiently materialize Jacobians and Hessians.
It implements a custom Jaxpr interpreter
that uses abstract interpretation
to detect sparsity patterns from the computation graph,
then uses graph coloring to minimize the number of AD passes needed.
pip install asdexOr with uv:
uv add asdeximport asdex
import jax
import jax.numpy as jnp
def f(x):
return (x[1:] - x[:-1]) ** 2
x_sample = jnp.zeros(50) # sample input for sparsity pattern detection
jac_fn = jax.jit(asdex.jacobian(f, x_sample))
# ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
# 2 JVPs (instead of 49 VJPs or 50 JVPs)
# ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎡⣿⎤
# ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥ ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥ ⎢⣿⎥
# ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦ ⎣⠉⎦
for x in inputs:
J = jac_fn(x)Instead of 49 VJPs or 50 JVPs,
asdex computes the full sparse Jacobian with just 2 JVPs.
Since sparsity detection and coloring can be expensive on large problems, we recommend saving and reusing colored patterns:
import jax.numpy as jnp
from asdex import jacobian_coloring
from asdex import ColoredPattern, jacobian_from_coloring
# Compute coloring once...
x = jnp.zeros(1000)
coloring = jacobian_coloring(f, x)
coloring.save("colored.npz")
# ...load and reuse later
coloring = ColoredPattern.load("colored.npz")
jac_fn = jax.jit(jacobian_from_coloring(f, coloring))The full ASD pipeline:
- Sparse Jacobians and Hessians: one VJP/JVP/HVP per color, with automatic (or user-defined) mode selection.
- Sparsity detection: finds global sparsity patterns valid for all inputs.
- Graph coloring: row, column, and symmetric coloring minimize AD passes.
- Correctness verification against vanilla JAX.
You already know your sparsity pattern?
- Manually provide sparsity patterns: supply a known pattern from dense, COO, or BCOO formats.
- Precompute, save & load: reuse a colored pattern across inputs, or persist it by saving and loading.
An interface mirroring JAX:
- Multiple inputs and outputs: supports multi-argument functions via
argnums, as well as multiple return values. - PyTree inputs and outputs: sparse differentiation through arbitrary nested PyTrees.
- Auxiliary outputs: supports
has_aux=Truefor functions returning(output, aux). - Value and derivative:
value_and_jacobian/value_and_hessianreturn the primal valuef(x)without a redundant forward pass.
And more:
- Multiple output formats: decompression to BCOO, dense JAX arrays, NumPy, and SciPy (COO/CSR/CSC) arrays.
- Bounded memory:
chunk_sizecaps parallel AD passes for large color counts. - Visualizations:
spyplots and braille pattern previews.
- Getting Started: step-by-step tutorial
- How-To Guides: task-oriented recipes
- Explanation: how and why it works
- API Reference: full API documentation
- Contributing: guidelines for collaborating on asdex
- AI Policy: guidelines for LLM contributions
Prior work on ASD by asdex's authors Adrian Hill (@adrhill) and Guillaume Dalle (@gdalle), as well as Alexis Montoison (@amontoison):
- An Illustrated Guide to Automatic Sparse Differentiation, Hill, Dalle, Montoison (2025)
- Sparser, Better, Faster, Stronger: Efficient Automatic Differentiation for Sparse Jacobians and Hessians, Hill & Dalle (2025)
- Revisiting Sparse Matrix Coloring and Bicoloring, Montoison, Dalle, Gebremedhin (2025)
SparseConnectivityTracer.jl, Hill & DalleSparseMatrixColorings.jl, Dalle & MontoisonDifferentiationInterface.jl, Dalle & Hill
Prior and concurrent (partial) attempts at ASD in JAX:
sparsejac: coloring and decompressionsparsediffax: coloring and decompression (by asdex's@gdalle)jax-nansparse: sparsity detection using NaN propagationJAX-AMG: specialized ASD module for algebraic multigrid methodstatva: specialized ASD module for FEM- See discussion in JAX issue #1032
Adrian Hill gratefully acknowledges funding from the German Federal Ministry of Education and Research under the grant BIFOLD26B.
This package is built with Claude Code, based on previous, hand-written work by the same authors in the Julia programming language, as noted above. These works in turn stand on the shoulders of giants, notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.
The asdex logo was designed by @overripemango.
If you use asdex in your research, please cite:
@software{asdex2026,
author = {Hill, Adrian},
title = {asdex: Automatic Sparse Differentiation in JAX},
url = {https://github.com/adrhill/asdex},
doi = {10.5281/zenodo.18788242}
}