Skip to content

adrhill/asdex

Repository files navigation

asdex logo

asdex

Automatic Sparse Differentiation in JAX.

CI codecov Ruff ty PyPI

Contributing AI Policy

Docs Benchmarks Changelog

DOI

asdex exploits sparsity structure to efficiently materialize Jacobians and Hessians. It implements a custom Jaxpr interpreter that uses abstract interpretation to detect sparsity patterns from the computation graph, then uses graph coloring to minimize the number of AD passes needed.

Installation

pip install asdex

Or with uv:

uv add asdex

Example

import asdex
import jax
import jax.numpy as jnp

def f(x):
    return (x[1:] - x[:-1]) ** 2

x_sample = jnp.zeros(50)  # sample input for sparsity pattern detection
jac_fn = jax.jit(asdex.jacobian(f, x_sample))
# ColoredPattern(49×50, nnz=98, sparsity=96.0%, JVP, 2 colors)
#   2 JVPs (instead of 49 VJPs or 50 JVPs)
# ⎡⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤   ⎡⣿⎤
# ⎢⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ → ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⠀⠀⎥   ⎢⣿⎥
# ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢦⡀⎥   ⎢⣿⎥
# ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⎦   ⎣⠉⎦

for x in inputs:
    J = jac_fn(x)

Instead of 49 VJPs or 50 JVPs, asdex computes the full sparse Jacobian with just 2 JVPs.

Since sparsity detection and coloring can be expensive on large problems, we recommend saving and reusing colored patterns:

import jax.numpy as jnp
from asdex import jacobian_coloring
from asdex import ColoredPattern, jacobian_from_coloring

# Compute coloring once...
x = jnp.zeros(1000)
coloring = jacobian_coloring(f, x)
coloring.save("colored.npz")

# ...load and reuse later
coloring = ColoredPattern.load("colored.npz")
jac_fn = jax.jit(jacobian_from_coloring(f, coloring))

Features

The full ASD pipeline:

You already know your sparsity pattern?

An interface mirroring JAX:

And more:

Documentation

Related work

Prior work on ASD by asdex's authors Adrian Hill (@adrhill) and Guillaume Dalle (@gdalle), as well as Alexis Montoison (@amontoison):

Prior and concurrent (partial) attempts at ASD in JAX:

Acknowledgements

Adrian Hill gratefully acknowledges funding from the German Federal Ministry of Education and Research under the grant BIFOLD26B.

This package is built with Claude Code, based on previous, hand-written work by the same authors in the Julia programming language, as noted above. These works in turn stand on the shoulders of giants, notably Andreas Griewank, Andrea Walther, and Assefaw Gebremedhin.

The asdex logo was designed by @overripemango.

Citation

If you use asdex in your research, please cite:

@software{asdex2026,
  author = {Hill, Adrian},
  title = {asdex: Automatic Sparse Differentiation in JAX},
  url = {https://github.com/adrhill/asdex},
  doi = {10.5281/zenodo.18788242}
}

About

Automatic Sparse Differentiation in JAX.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages