Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions src/xarray_einstats/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
example usage.

The functions that are not available via the accessor are ``einsum``, ``einsum_path``,
``matmul`` and ``get_default_dims``.
``matmul``, ``get_default_dims`` and ``default_dims``.

"""

import sys
from collections.abc import Iterable
from contextlib import contextmanager

import numpy as np
import xarray as xr

Expand All @@ -35,6 +39,7 @@
"solve",
"inv",
"pinv",
"default_dims",
]


Expand Down Expand Up @@ -109,6 +114,10 @@ def get_default_dims(dims1, dims2):

You can still use ``dims`` explicitly to override those defaults.

.. note::
Monkeypatching ``get_default_dims`` directly works but is error-prone.
Consider using the :func:`default_dims` context manager instead.

"""
raise MissingMonkeypatchError()

Expand All @@ -119,12 +128,60 @@ def _attempt_default_dims(func, da1_dims, da2_dims=None):
aux = get_default_dims(da1_dims, da2_dims)
except MissingMonkeypatchError:
raise TypeError(
f"{func} missing required argument dims. You must monkeypatch "
"xarray_einstats.linalg.get_default_dims for dims=None to be supported"
f"{func} missing required argument dims. Use "
"xarray_einstats.linalg.default_dims context manager or pass dims explicitly"
) from None
return aux


@contextmanager
def default_dims(func_or_dims):
"""Context manager to temporarily set the default dimensions for linalg functions.

Safer alternative to monkey patching :func:`get_default_dims`,
as it ensures that the original function is restored even if an error occurs
within the context.

Parameters
----------
func_or_dims : callable or iterable
If a callable is provided, it should take the same arguments as :func:`get_default_dims`
and return the default dimensions based on those arguments.
If an iterable is provided, it will be used as the default dimensions
regardless of the input arguments.

See Also
--------
get_default_dims

Examples
--------
Set the default dims to ``("dim", "dim2")`` for the duration of the ``with`` block:

.. code-block:: python

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)

with linalg.default_dims(("dim", "dim2")):
linalg.inv(da)

"""
_linalg = sys.modules[__name__]
original_get_default_dims = _linalg.get_default_dims

def func(*args):
if isinstance(func_or_dims, Iterable):
return func_or_dims
return func_or_dims(*args)

_linalg.get_default_dims = func
try:
yield
finally:
_linalg.get_default_dims = original_get_default_dims


class PairHandler:
def __init__(self, all_dims, keep_dims):
self.potential_out_dims = keep_dims.union(all_dims)
Expand Down
11 changes: 8 additions & 3 deletions src/xarray_einstats/linalg.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# File generated with docstub

import numbers
from collections.abc import Hashable, Sequence
from typing import Literal
from collections.abc import Hashable, Iterable, Sequence
from contextlib import contextmanager
from typing import Callable, Generator, Literal

import numpy as np
import xarray
import xarray as xr
from _typeshed import Incomplete
from numpy.typing import NDArray

Expand All @@ -30,6 +30,7 @@ __all__ = [
"solve",
"inv",
"pinv",
"default_dims",
]

class MissingMonkeypatchError(Exception):
Expand Down Expand Up @@ -195,3 +196,7 @@ def pinv(
hermitian: bool = ...,
**kwargs: Incomplete,
) -> xarray.DataArray: ...
@contextmanager
def default_dims(
func_or_dims: Callable | Iterable,
) -> Generator[None, None, None]: ...
13 changes: 13 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def default_dims(dims1, dims2): # pylint: disable=unused-argument
assert out.dims == matrices.dims


def test_default_dims_context_manager(matrices):
with pytest.raises(TypeError, match="missing required argument dims"):
inv(matrices)

with linalg.default_dims(("dim", "dim2")):
out = inv(matrices)
assert out.dims == matrices.dims

# outside the context, it should raise again
with pytest.raises(TypeError, match="missing required argument dims"):
inv(matrices)


class TestEinsumFamily:
# raw_einsum calls einsum, so the tests on raw_einsum also cover einsum, then
# there are some specific ones for various reasons,
Expand Down