fix(transform): restore JAX 0.10.2 compatibility (batching.not_mapped removal); release 0.5.1#222
Merged
Merged
Conversation
…rules JAX 0.10 removed `jax.interpreters.batching.not_mapped` (the "not batched" sentinel returned from a primitive's batching rule, now simply `None`). The custom `unvmap` primitives still referenced the removed attribute, so every `vmap`-traced path that crossed them — `ifelse`, `error_if` / `jit_error_if`, `bounded_while_loop`'s per-lane exit, and `unvmap` itself — raised `AttributeError` on JAX 0.10.2. Resolve the sentinel once via `getattr(batching, 'not_mapped', None)`: the real object on older JAX, `None` on 0.10+ (what the new batching machinery expects), preserving the supported `jax>=0.7.0` range. Release 0.5.1: bump version and add changelog entry.
Contributor
Reviewer's GuideAdds a JAX-version-agnostic Sequence diagram for vmap interaction with unvmap_all batching using version-agnostic not_mapped sentinelsequenceDiagram
actor User
participant vmap
participant unvmap_all_p as unvmap_all_primitive
participant _unvmap_all_batch
User ->> vmap: vmap(unvmap_all, x)
vmap ->> unvmap_all_p: apply_primitive(x)
unvmap_all_p ->> _unvmap_all_batch: _unvmap_all_batch(x, batch_axes)
_unvmap_all_batch ->> _unvmap_all_batch: not_mapped = getattr(batching, not_mapped, None)
_unvmap_all_batch -->> unvmap_all_p: return unvmap_all(x), not_mapped
unvmap_all_p -->> vmap: unmapped_result
vmap -->> User: result (no AttributeError on JAX 0.10.2)
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
JAX 0.10 removed
jax.interpreters.batching.not_mapped— the "not batched" sentinel returned from a primitive's batching rule, now collapsed to plainNone(NotMapped = type(None)). The customunvmapprimitives inbrainstate/transform/_unvmap.pystill referenced the removed attribute, so under JAX 0.10.2 everyvmap-traced path that crossed one of them raised:This surfaced as 8 test failures across the affected public surface:
ifelse,error_if/jit_error_if,bounded_while_loop's per-lane exit, andunvmapitself.Fix
Resolve the sentinel once, version-agnostically:
not_mappedobject — no behavior change.None, which is exactly the sentinel the new batching machinery expects.This preserves the full supported
jax>=0.7.0range and relies only on stable, public-facing batching behavior. The 4 call sites in_unvmap.pynow use the resolved local sentinel.Release 0.5.1
__version__0.5.0->0.5.1.0.5.1changelog entry.Verification
Summary by Sourcery
Restore JAX 0.10.2 compatibility for brainstate.transform unvmap primitives and cut the 0.5.1 patch release.
Bug Fixes:
Documentation:
Chores: