Skip to content

fix(transform): restore JAX 0.10.2 compatibility (batching.not_mapped removal); release 0.5.1#222

Merged
chaoming0625 merged 1 commit into
mainfrom
worktree-fix-jax-0.10.2
Jun 18, 2026
Merged

fix(transform): restore JAX 0.10.2 compatibility (batching.not_mapped removal); release 0.5.1#222
chaoming0625 merged 1 commit into
mainfrom
worktree-fix-jax-0.10.2

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Summary

JAX 0.10 removed jax.interpreters.batching.not_mapped — the "not batched" sentinel returned from a primitive's batching rule, now collapsed to plain None (NotMapped = type(None)). The custom unvmap primitives in brainstate/transform/_unvmap.py still referenced the removed attribute, so under JAX 0.10.2 every vmap-traced path that crossed one of them raised:

AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'

This surfaced as 8 test failures across the affected public surface: ifelse, error_if / jit_error_if, bounded_while_loop's per-lane exit, and unvmap itself.

Fix

Resolve the sentinel once, version-agnostically:

not_mapped = getattr(batching, 'not_mapped', None)
  • On older JAX it yields the real not_mapped object — no behavior change.
  • On JAX 0.10+ it yields None, which is exactly the sentinel the new batching machinery expects.

This preserves the full supported jax>=0.7.0 range and relies only on stable, public-facing batching behavior. The 4 call sites in _unvmap.py now use the resolved local sentinel.

Release 0.5.1

  • Bump __version__ 0.5.0 -> 0.5.1.
  • Add a professional 0.5.1 changelog entry.

Verification

  • The 8 previously-failing tests now pass.
  • Full suite on JAX 0.10.2: 5312 passed, 0 failed, 23 skipped (no regressions).

Summary by Sourcery

Restore JAX 0.10.2 compatibility for brainstate.transform unvmap primitives and cut the 0.5.1 patch release.

Bug Fixes:

  • Fix vmap batching for brainstate.transform unvmap-related primitives by using a version-agnostic sentinel so code runs on JAX 0.10.2 and older releases.

Documentation:

  • Add 0.5.1 changelog entry documenting the JAX 0.10.2 compatibility fix and test coverage.

Chores:

  • Bump the library version from 0.5.0 to 0.5.1.

…rules

JAX 0.10 removed `jax.interpreters.batching.not_mapped` (the "not batched"
sentinel returned from a primitive's batching rule, now simply `None`). The
custom `unvmap` primitives still referenced the removed attribute, so every
`vmap`-traced path that crossed them — `ifelse`, `error_if` / `jit_error_if`,
`bounded_while_loop`'s per-lane exit, and `unvmap` itself — raised
`AttributeError` on JAX 0.10.2.

Resolve the sentinel once via `getattr(batching, 'not_mapped', None)`: the real
object on older JAX, `None` on 0.10+ (what the new batching machinery expects),
preserving the supported `jax>=0.7.0` range.

Release 0.5.1: bump version and add changelog entry.
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Reviewer's Guide

Adds a JAX-version-agnostic not_mapped sentinel in the transform unvmap primitives to restore compatibility with JAX 0.10.2’s batching changes, and bumps the package to version 0.5.1 with corresponding changelog updates.

Sequence diagram for vmap interaction with unvmap_all batching using version-agnostic not_mapped sentinel

sequenceDiagram
    actor User
    participant vmap
    participant unvmap_all_p as unvmap_all_primitive
    participant _unvmap_all_batch

    User ->> vmap: vmap(unvmap_all, x)
    vmap ->> unvmap_all_p: apply_primitive(x)
    unvmap_all_p ->> _unvmap_all_batch: _unvmap_all_batch(x, batch_axes)
    _unvmap_all_batch ->> _unvmap_all_batch: not_mapped = getattr(batching, not_mapped, None)
    _unvmap_all_batch -->> unvmap_all_p: return unvmap_all(x), not_mapped
    unvmap_all_p -->> vmap: unmapped_result
    vmap -->> User: result (no AttributeError on JAX 0.10.2)
Loading

File-Level Changes

Change Details Files
Make unvmap batching rules compatible with JAX 0.10+ while preserving older JAX support.
  • Introduce a module-level not_mapped sentinel resolved via getattr(batching, 'not_mapped', None) so it is batching.not_mapped on older JAX and None on JAX 0.10+.
  • Update the batching rules for unvmap_all, unvmap_any, unvmap_max, and the internal no_vmap primitive to return the new not_mapped sentinel instead of batching.not_mapped.
  • Document the JAX batching behavior change in comments to clarify why the sentinel indirection exists.
brainstate/transform/_unvmap.py
Prepare the 0.5.1 patch release metadata and documentation.
  • Bump the library version constant from 0.5.0 to 0.5.1.
  • Add a 0.5.1 changelog entry describing the JAX 0.10.2 compatibility fix, impacted APIs, and test status.
brainstate/_version.py
changelog.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625 chaoming0625 merged commit 094d018 into main Jun 18, 2026
6 checks passed
@chaoming0625 chaoming0625 deleted the worktree-fix-jax-0.10.2 branch June 18, 2026 07:39

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've reviewed your changes and they look great!


Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

@codecov

codecov Bot commented Jun 18, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant