From 1a5f808c1e431f4350d000a9321bae2f0dddbb8e Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:28:18 +0200 Subject: [PATCH 1/7] Add a context manager for default dims --- src/xarray_einstats/__init__.py | 12 ++++++++++++ src/xarray_einstats/linalg.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 1b6da71..08f096e 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -1,6 +1,7 @@ """Stats, linear algebra and einops for xarray.""" from __future__ import annotations +from contextlib import contextmanager import numpy as np import xarray as xr @@ -188,3 +189,14 @@ def ones_ref(*args, dims, dtype=None): empty_ref, zeros_ref """ return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype) + + +@contextmanager +def default_linalg_dims(func: callable): + original_get_default_dims = linalg.get_default_dims + + linalg.get_default_dims = func + try: + yield + finally: + linalg.get_default_dims = original_get_default_dims diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 0bead5b..a29823d 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -12,6 +12,8 @@ """ +from contextlib import contextmanager + import numpy as np import xarray as xr From d0e0fdf76d1e42e363ce74e017cc2d1a97055aef Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:34:21 +0200 Subject: [PATCH 2/7] Add a docstring and fix lint issues --- src/xarray_einstats/__init__.py | 26 +++++++++++++++++++++++++- src/xarray_einstats/linalg.py | 2 -- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 08f096e..3f953e5 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -192,9 +192,33 @@ def ones_ref(*args, dims, dtype=None): @contextmanager -def default_linalg_dims(func: callable): +def default_linalg_dims(func_or_dims: callable | list): + """Context manager to temporarily set the default dimensions for linalg functions. + + Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, + as it ensures that the original function is restored even if an error occurs within the context. + + Parameters + ---------- + func_or_dims : callable or list + If a callable is provided, it should take the same arguments as `get_default_dims` + and return the default dimensions based on those arguments. + If a list is provided, it will be used as the default dimensions + regardless of the input arguments. + + Yields + ------ + None + """ + from xarray_einstats import linalg + original_get_default_dims = linalg.get_default_dims + def func(*args): + if isinstance(func_or_dims, list): + return func_or_dims + return func_or_dims(*args) + linalg.get_default_dims = func try: yield diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index a29823d..0bead5b 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -12,8 +12,6 @@ """ -from contextlib import contextmanager - import numpy as np import xarray as xr From 3d338bb152a70940f79bb9e45f025328f789a691 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:48:20 +0200 Subject: [PATCH 3/7] Add default linalg dims to package exports --- src/xarray_einstats/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 3f953e5..955767f 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -10,6 +10,7 @@ from .accessors import LinAlgAccessor, EinopsAccessor __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", From 39b4bf96d13beea72e65d3ddce1453e1537cdbc9 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:54:10 +0200 Subject: [PATCH 4/7] add type hints --- src/xarray_einstats/__init__.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 870c3c1..55192d9 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -13,6 +13,7 @@ from .accessors import EinopsAccessor, LinAlgAccessor from .linalg import einsum, einsum_path, matmul __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -52,3 +53,4 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... +def default_linalg_dims(func_or_dims: callable | list[Unknown]) -> Generator[None, Any, None]: ... From 96fd80109f6db9e8bebeb4b53fd28c291a01f1f0 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 14:47:07 +0200 Subject: [PATCH 5/7] Handle all iterables --- src/xarray_einstats/__init__.py | 9 +++++---- src/xarray_einstats/__init__.pyi | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 955767f..09b18dc 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import contextmanager +from collections.abc import Iterable import numpy as np import xarray as xr @@ -193,7 +194,7 @@ def ones_ref(*args, dims, dtype=None): @contextmanager -def default_linalg_dims(func_or_dims: callable | list): +def default_linalg_dims(func_or_dims): """Context manager to temporarily set the default dimensions for linalg functions. Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, @@ -201,10 +202,10 @@ def default_linalg_dims(func_or_dims: callable | list): Parameters ---------- - func_or_dims : callable or list + func_or_dims : callable or iterable If a callable is provided, it should take the same arguments as `get_default_dims` and return the default dimensions based on those arguments. - If a list is provided, it will be used as the default dimensions + If an iterable is provided, it will be used as the default dimensions regardless of the input arguments. Yields @@ -216,7 +217,7 @@ def default_linalg_dims(func_or_dims: callable | list): original_get_default_dims = linalg.get_default_dims def func(*args): - if isinstance(func_or_dims, list): + if isinstance(func_or_dims, Iterable): return func_or_dims return func_or_dims(*args) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 55192d9..f6e6646 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence +from typing import Any, Callable, Generator import numpy as np import xarray @@ -53,4 +54,6 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... -def default_linalg_dims(func_or_dims: callable | list[Unknown]) -> Generator[None, Any, None]: ... +def default_linalg_dims( + func_or_dims: Callable | Iterable, +) -> Generator[None, Any, None]: ... From 21a607c02defec522173bee090415b7ded5bb9dd Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 14:54:20 +0200 Subject: [PATCH 6/7] Add missing contextmanager wrapper --- src/xarray_einstats/__init__.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index f6e6646..0645295 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence +from contextlib import contextmanager from typing import Any, Callable, Generator import numpy as np @@ -54,6 +55,7 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... +@contextmanager def default_linalg_dims( func_or_dims: Callable | Iterable, ) -> Generator[None, Any, None]: ... From 5712455ee72990743d50df456ed28c0a99e301e3 Mon Sep 17 00:00:00 2001 From: wkrokosz Date: Tue, 9 Jun 2026 12:45:21 +0200 Subject: [PATCH 7/7] Resolve PR comments --- src/xarray_einstats/__init__.py | 38 ------------------- src/xarray_einstats/__init__.pyi | 7 ---- src/xarray_einstats/linalg.py | 63 ++++++++++++++++++++++++++++++-- src/xarray_einstats/linalg.pyi | 11 ++++-- tests/test_linalg.py | 13 +++++++ 5 files changed, 81 insertions(+), 51 deletions(-) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 09b18dc..1b6da71 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -1,8 +1,6 @@ """Stats, linear algebra and einops for xarray.""" from __future__ import annotations -from contextlib import contextmanager -from collections.abc import Iterable import numpy as np import xarray as xr @@ -11,7 +9,6 @@ from .accessors import LinAlgAccessor, EinopsAccessor __all__ = [ - "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -191,38 +188,3 @@ def ones_ref(*args, dims, dtype=None): empty_ref, zeros_ref """ return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype) - - -@contextmanager -def default_linalg_dims(func_or_dims): - """Context manager to temporarily set the default dimensions for linalg functions. - - Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, - as it ensures that the original function is restored even if an error occurs within the context. - - Parameters - ---------- - func_or_dims : callable or iterable - If a callable is provided, it should take the same arguments as `get_default_dims` - and return the default dimensions based on those arguments. - If an iterable is provided, it will be used as the default dimensions - regardless of the input arguments. - - Yields - ------ - None - """ - from xarray_einstats import linalg - - original_get_default_dims = linalg.get_default_dims - - def func(*args): - if isinstance(func_or_dims, Iterable): - return func_or_dims - return func_or_dims(*args) - - linalg.get_default_dims = func - try: - yield - finally: - linalg.get_default_dims = original_get_default_dims diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 0645295..870c3c1 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,8 +3,6 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence -from contextlib import contextmanager -from typing import Any, Callable, Generator import numpy as np import xarray @@ -15,7 +13,6 @@ from .accessors import EinopsAccessor, LinAlgAccessor from .linalg import einsum, einsum_path, matmul __all__ = [ - "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -55,7 +52,3 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... -@contextmanager -def default_linalg_dims( - func_or_dims: Callable | Iterable, -) -> Generator[None, Any, None]: ... diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 0bead5b..aef88ce 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -8,10 +8,14 @@ example usage. The functions that are not available via the accessor are ``einsum``, ``einsum_path``, - ``matmul`` and ``get_default_dims``. + ``matmul``, ``get_default_dims`` and ``default_dims``. """ +import sys +from collections.abc import Iterable +from contextlib import contextmanager + import numpy as np import xarray as xr @@ -35,6 +39,7 @@ "solve", "inv", "pinv", + "default_dims", ] @@ -109,6 +114,10 @@ def get_default_dims(dims1, dims2): You can still use ``dims`` explicitly to override those defaults. + .. note:: + Monkeypatching ``get_default_dims`` directly works but is error-prone. + Consider using the :func:`default_dims` context manager instead. + """ raise MissingMonkeypatchError() @@ -119,12 +128,60 @@ def _attempt_default_dims(func, da1_dims, da2_dims=None): aux = get_default_dims(da1_dims, da2_dims) except MissingMonkeypatchError: raise TypeError( - f"{func} missing required argument dims. You must monkeypatch " - "xarray_einstats.linalg.get_default_dims for dims=None to be supported" + f"{func} missing required argument dims. Use " + "xarray_einstats.linalg.default_dims context manager or pass dims explicitly" ) from None return aux +@contextmanager +def default_dims(func_or_dims): + """Context manager to temporarily set the default dimensions for linalg functions. + + Safer alternative to monkey patching :func:`get_default_dims`, + as it ensures that the original function is restored even if an error occurs + within the context. + + Parameters + ---------- + func_or_dims : callable or iterable + If a callable is provided, it should take the same arguments as :func:`get_default_dims` + and return the default dimensions based on those arguments. + If an iterable is provided, it will be used as the default dimensions + regardless of the input arguments. + + See Also + -------- + get_default_dims + + Examples + -------- + Set the default dims to ``("dim", "dim2")`` for the duration of the ``with`` block: + + .. code-block:: python + + from xarray_einstats import linalg, tutorial + da = tutorial.generate_matrices_dataarray(5) + + with linalg.default_dims(("dim", "dim2")): + linalg.inv(da) + + """ + _linalg = sys.modules[__name__] + original_get_default_dims = _linalg.get_default_dims + + def func(*args): + if isinstance(func_or_dims, Iterable): + return func_or_dims + return func_or_dims(*args) + + _linalg.get_default_dims = func + try: + yield + finally: + _linalg.get_default_dims = original_get_default_dims + + class PairHandler: def __init__(self, all_dims, keep_dims): self.potential_out_dims = keep_dims.union(all_dims) diff --git a/src/xarray_einstats/linalg.pyi b/src/xarray_einstats/linalg.pyi index d7ef47c..45c2771 100644 --- a/src/xarray_einstats/linalg.pyi +++ b/src/xarray_einstats/linalg.pyi @@ -1,12 +1,12 @@ # File generated with docstub import numbers -from collections.abc import Hashable, Sequence -from typing import Literal +from collections.abc import Hashable, Iterable, Sequence +from contextlib import contextmanager +from typing import Callable, Generator, Literal import numpy as np import xarray -import xarray as xr from _typeshed import Incomplete from numpy.typing import NDArray @@ -30,6 +30,7 @@ __all__ = [ "solve", "inv", "pinv", + "default_dims", ] class MissingMonkeypatchError(Exception): @@ -195,3 +196,7 @@ def pinv( hermitian: bool = ..., **kwargs: Incomplete, ) -> xarray.DataArray: ... +@contextmanager +def default_dims( + func_or_dims: Callable | Iterable, +) -> Generator[None, None, None]: ... diff --git a/tests/test_linalg.py b/tests/test_linalg.py index f688f5b..d0e00dc 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -79,6 +79,19 @@ def default_dims(dims1, dims2): # pylint: disable=unused-argument assert out.dims == matrices.dims +def test_default_dims_context_manager(matrices): + with pytest.raises(TypeError, match="missing required argument dims"): + inv(matrices) + + with linalg.default_dims(("dim", "dim2")): + out = inv(matrices) + assert out.dims == matrices.dims + + # outside the context, it should raise again + with pytest.raises(TypeError, match="missing required argument dims"): + inv(matrices) + + class TestEinsumFamily: # raw_einsum calls einsum, so the tests on raw_einsum also cover einsum, then # there are some specific ones for various reasons,