Skip to content

Commit ea7d76e

Browse files
authored
Wire scalebar_dx/scalebar_units into pl.show()
1 parent af10e7e commit ea7d76e

10 files changed

Lines changed: 293 additions & 41 deletions

src/spatialdata_plot/pl/basic.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from spatialdata_plot.pl.utils import (
5656
_RENDER_CMD_TO_CS_FLAG,
57+
_draw_scalebar,
5758
_get_cs_contents,
5859
_get_elements_to_be_rendered,
5960
_get_valid_cs,
@@ -890,6 +891,10 @@ def show(
890891
return_ax: bool = False,
891892
save: str | Path | None = None,
892893
show: bool | None = None,
894+
scalebar_dx: float | None = None,
895+
scalebar_units: str = "um",
896+
scalebar_params: dict[str, Any] | None = None,
897+
legend_params: dict[str, Any] | None = None,
893898
) -> Axes | list[Axes] | None:
894899
"""
895900
Execute the plotting tree and display the final figure.
@@ -950,6 +955,22 @@ def show(
950955
automatically when running in non-interactive mode (scripts) and suppressed in
951956
interactive sessions (e.g. Jupyter). When ``ax`` is provided by the user, defaults
952957
to ``False`` to allow further modifications.
958+
scalebar_dx : float | None
959+
Physical size of one axes-unit in ``scalebar_units``. If ``None``, no scalebar is drawn.
960+
SpatialData coordinate systems carry no unit metadata, so this value must be supplied
961+
explicitly (e.g. ``1.0`` when axes are already in micrometers; the microns-per-pixel
962+
value when axes are in image pixels).
963+
scalebar_units : str, default "um"
964+
Unit string for the scalebar (passed to :class:`matplotlib_scalebar.scalebar.ScaleBar`).
965+
Only takes effect when ``scalebar_dx`` is set.
966+
scalebar_params : dict[str, Any] | None
967+
Extra keyword arguments forwarded to :class:`matplotlib_scalebar.scalebar.ScaleBar`,
968+
e.g. ``{"location": "lower right", "color": "white", "length_fraction": 0.25}``.
969+
See the matplotlib-scalebar documentation for the full list of options.
970+
legend_params : dict[str, Any] | None
971+
Bundled legend options; overrides the matching ``legend_*`` flat kwargs. Accepted keys:
972+
``location`` (or ``loc``), ``fontsize``, ``fontweight``, ``fontoutline``,
973+
``na_in_legend``. Unknown keys raise ``ValueError``.
953974
954975
Returns
955976
-------
@@ -987,6 +1008,10 @@ def show(
9871008
return_ax,
9881009
save,
9891010
show,
1011+
scalebar_dx,
1012+
scalebar_units,
1013+
scalebar_params,
1014+
legend_params,
9901015
)
9911016

9921017
if fig is not None and not isinstance(ax, Sequence):
@@ -1100,7 +1125,7 @@ def show(
11001125
raise ValueError(msg)
11011126

11021127
# set up canvas
1103-
fig_params, scalebar_params = _prepare_params_plot(
1128+
fig_params, scalebar_params_obj = _prepare_params_plot(
11041129
num_panels=len(coordinate_systems),
11051130
figsize=figsize,
11061131
dpi=dpi,
@@ -1110,15 +1135,25 @@ def show(
11101135
hspace=hspace,
11111136
ncols=ncols,
11121137
frameon=frameon,
1138+
scalebar_dx=scalebar_dx,
1139+
scalebar_units=scalebar_units,
1140+
scalebar_kwargs=scalebar_params,
11131141
)
1114-
legend_colorbar = colorbar
1115-
legend_params = LegendParams(
1142+
if legend_params:
1143+
legend_fontsize = legend_params.get("fontsize", legend_fontsize)
1144+
legend_fontweight = legend_params.get("fontweight", legend_fontweight)
1145+
# `loc` is matplotlib.Legend's native key; `location` aligns with colorbar/scalebar.
1146+
legend_loc = legend_params.get("location", legend_params.get("loc", legend_loc))
1147+
legend_fontoutline = legend_params.get("fontoutline", legend_fontoutline)
1148+
na_in_legend = legend_params.get("na_in_legend", na_in_legend)
1149+
1150+
legend_params_obj = LegendParams(
11161151
legend_fontsize=legend_fontsize,
11171152
legend_fontweight=legend_fontweight,
11181153
legend_loc=legend_loc,
11191154
legend_fontoutline=legend_fontoutline,
11201155
na_in_legend=na_in_legend,
1121-
colorbar=legend_colorbar,
1156+
colorbar=colorbar,
11221157
)
11231158

11241159
def _draw_colorbar(
@@ -1210,7 +1245,7 @@ def _draw_colorbar(
12101245
has_shapes = cs_row["has_shapes"]
12111246
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
12121247
assert isinstance(ax, Axes)
1213-
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
1248+
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params_obj.colorbar else None
12141249
axis_channel_legend_entries: list[ChannelLegendEntry] = []
12151250

12161251
wants_images = False
@@ -1239,8 +1274,7 @@ def _draw_colorbar(
12391274
coordinate_system=cs,
12401275
ax=ax,
12411276
fig_params=fig_params,
1242-
scalebar_params=scalebar_params,
1243-
legend_params=legend_params,
1277+
legend_params=legend_params_obj,
12441278
colorbar_requests=axis_colorbar_requests,
12451279
channel_legend_entries=axis_channel_legend_entries,
12461280
rasterize=rasterize,
@@ -1258,8 +1292,7 @@ def _draw_colorbar(
12581292
coordinate_system=cs,
12591293
ax=ax,
12601294
fig_params=fig_params,
1261-
scalebar_params=scalebar_params,
1262-
legend_params=legend_params,
1295+
legend_params=legend_params_obj,
12631296
colorbar_requests=axis_colorbar_requests,
12641297
)
12651298

@@ -1275,8 +1308,7 @@ def _draw_colorbar(
12751308
coordinate_system=cs,
12761309
ax=ax,
12771310
fig_params=fig_params,
1278-
scalebar_params=scalebar_params,
1279-
legend_params=legend_params,
1311+
legend_params=legend_params_obj,
12801312
colorbar_requests=axis_colorbar_requests,
12811313
)
12821314

@@ -1311,8 +1343,7 @@ def _draw_colorbar(
13111343
coordinate_system=cs,
13121344
ax=ax,
13131345
fig_params=fig_params,
1314-
scalebar_params=scalebar_params,
1315-
legend_params=legend_params,
1346+
legend_params=legend_params_obj,
13161347
colorbar_requests=axis_colorbar_requests,
13171348
rasterize=rasterize,
13181349
)
@@ -1352,11 +1383,13 @@ def _draw_colorbar(
13521383
ax.set_xlim(x_min, x_max)
13531384
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
13541385

1355-
if legend_params.colorbar and axis_colorbar_requests:
1386+
if legend_params_obj.colorbar and axis_colorbar_requests:
13561387
pending_colorbars.append((ax, axis_colorbar_requests))
13571388

13581389
if axis_channel_legend_entries:
1359-
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params, fig_params)
1390+
_draw_channel_legend(ax, axis_channel_legend_entries, legend_params_obj, fig_params)
1391+
1392+
_draw_scalebar(ax, scalebar_params_obj, panel_idx=i)
13601393

13611394
if pending_colorbars and fig_params.fig is not None:
13621395
fig = fig_params.fig

src/spatialdata_plot/pl/render.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
LabelsRenderParams,
5454
LegendParams,
5555
PointsRenderParams,
56-
ScalebarParams,
5756
ShapesRenderParams,
5857
)
5958
from spatialdata_plot.pl.utils import (
@@ -268,9 +267,8 @@ def _add_legend_and_colorbar(
268267
colorbar: bool | str | None,
269268
colorbar_params: dict[str, object] | None,
270269
colorbar_requests: list[ColorbarSpec] | None,
271-
scalebar_params: ScalebarParams,
272270
) -> None:
273-
"""Add legend, colorbar, and scalebar decorations if the color vector warrants them."""
271+
"""Add legend and colorbar decorations if the color vector warrants them."""
274272
if not _want_decorations(color_vector, na_color):
275273
return
276274

@@ -309,8 +307,6 @@ def _add_legend_and_colorbar(
309307
colorbar_params,
310308
col_for_color if isinstance(col_for_color, str) else None,
311309
),
312-
scalebar_dx=scalebar_params.scalebar_dx,
313-
scalebar_units=scalebar_params.scalebar_units,
314310
)
315311

316312

@@ -320,7 +316,6 @@ def _render_shapes(
320316
coordinate_system: str,
321317
ax: matplotlib.axes.SubplotBase,
322318
fig_params: FigParams,
323-
scalebar_params: ScalebarParams,
324319
legend_params: LegendParams,
325320
colorbar_requests: list[ColorbarSpec] | None = None,
326321
) -> None:
@@ -696,7 +691,6 @@ def _render_shapes(
696691
colorbar=render_params.colorbar,
697692
colorbar_params=render_params.colorbar_params,
698693
colorbar_requests=colorbar_requests,
699-
scalebar_params=scalebar_params,
700694
)
701695

702696

@@ -706,7 +700,6 @@ def _render_points(
706700
coordinate_system: str,
707701
ax: matplotlib.axes.SubplotBase,
708702
fig_params: FigParams,
709-
scalebar_params: ScalebarParams,
710703
legend_params: LegendParams,
711704
colorbar_requests: list[ColorbarSpec] | None = None,
712705
) -> None:
@@ -1054,7 +1047,6 @@ def _render_points(
10541047
colorbar=render_params.colorbar,
10551048
colorbar_params=render_params.colorbar_params,
10561049
colorbar_requests=colorbar_requests,
1057-
scalebar_params=scalebar_params,
10581050
)
10591051

10601052

@@ -1201,7 +1193,6 @@ def _render_images(
12011193
coordinate_system: str,
12021194
ax: matplotlib.axes.SubplotBase,
12031195
fig_params: FigParams,
1204-
scalebar_params: ScalebarParams,
12051196
legend_params: LegendParams,
12061197
rasterize: bool,
12071198
colorbar_requests: list[ColorbarSpec] | None = None,
@@ -1604,7 +1595,6 @@ def _render_labels(
16041595
coordinate_system: str,
16051596
ax: matplotlib.axes.SubplotBase,
16061597
fig_params: FigParams,
1607-
scalebar_params: ScalebarParams,
16081598
legend_params: LegendParams,
16091599
rasterize: bool,
16101600
colorbar_requests: list[ColorbarSpec] | None = None,
@@ -1838,7 +1828,4 @@ def _draw_labels(
18381828
render_params.colorbar_params,
18391829
col_for_color if isinstance(col_for_color, str) else None,
18401830
),
1841-
scalebar_dx=scalebar_params.scalebar_dx,
1842-
scalebar_units=scalebar_params.scalebar_units,
1843-
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
18441831
)

src/spatialdata_plot/pl/render_params.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable, Sequence
4-
from dataclasses import dataclass
5-
from typing import Literal
3+
from collections.abc import Callable, Mapping, Sequence
4+
from dataclasses import dataclass, field
5+
from typing import Any, Literal
66

77
import numpy as np
88
from matplotlib.axes import Axes
@@ -220,6 +220,7 @@ class ScalebarParams:
220220

221221
scalebar_dx: Sequence[float] | None = None
222222
scalebar_units: Sequence[str] | None = None
223+
scalebar_kwargs: Mapping[str, Any] = field(default_factory=dict)
223224

224225

225226
@dataclass

src/spatialdata_plot/pl/utils.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from copy import copy
88
from functools import partial
99
from pathlib import Path
10-
from types import MappingProxyType
1110
from typing import Any, Literal
1211

1312
import dask
@@ -273,6 +272,7 @@ def _prepare_params_plot(
273272
# this args will be inferred from coordinate system
274273
scalebar_dx: float | Sequence[float] | None = None,
275274
scalebar_units: str | Sequence[str] | None = None,
275+
scalebar_kwargs: Mapping[str, Any] | None = None,
276276
) -> tuple[FigParams, ScalebarParams]:
277277
# handle axes and size
278278
wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02 if wspace is None else wspace
@@ -325,11 +325,29 @@ def _prepare_params_plot(
325325
num_panels=num_panels,
326326
frameon=frameon,
327327
)
328-
scalebar_params = ScalebarParams(scalebar_dx=scalebar_dx, scalebar_units=scalebar_units)
328+
scalebar_params = ScalebarParams(
329+
scalebar_dx=scalebar_dx,
330+
scalebar_units=scalebar_units,
331+
scalebar_kwargs=dict(scalebar_kwargs) if scalebar_kwargs else {},
332+
)
329333

330334
return fig_params, scalebar_params
331335

332336

337+
def _draw_scalebar(ax: Axes, scalebar_params: ScalebarParams, panel_idx: int) -> None:
338+
"""Attach a single :class:`matplotlib_scalebar.scalebar.ScaleBar` to ``ax``.
339+
340+
No-op when ``scalebar_dx`` is ``None``. ``scalebar_dx`` and ``scalebar_units`` are
341+
broadcast lists indexed by the panel position; ``scalebar_kwargs`` is forwarded
342+
verbatim to :class:`~matplotlib_scalebar.scalebar.ScaleBar`.
343+
"""
344+
if scalebar_params.scalebar_dx is None or scalebar_params.scalebar_units is None:
345+
return
346+
dx = scalebar_params.scalebar_dx[panel_idx]
347+
units = scalebar_params.scalebar_units[panel_idx]
348+
ax.add_artist(ScaleBar(dx, units=units, **scalebar_params.scalebar_kwargs))
349+
350+
333351
def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
334352
"""Check which coordinate systems contain which elements and return that info."""
335353
cs_mapping = _get_coordinate_system_mapping(sdata)
@@ -1680,9 +1698,6 @@ def _decorate_axs(
16801698
colorbar_params: dict[str, object] | None = None,
16811699
colorbar_requests: list[ColorbarSpec] | None = None,
16821700
colorbar_label: str | None = None,
1683-
scalebar_dx: Sequence[float] | None = None,
1684-
scalebar_units: Sequence[str] | None = None,
1685-
scalebar_kwargs: Mapping[str, Any] = MappingProxyType({}),
16861701
) -> Axes:
16871702
if value_to_plot is not None:
16881703
# if only dots were plotted without an associated value
@@ -1729,10 +1744,6 @@ def _decorate_axs(
17291744
)
17301745
)
17311746

1732-
if isinstance(scalebar_dx, list) and isinstance(scalebar_units, list):
1733-
scalebar = ScaleBar(scalebar_dx, units=scalebar_units, **scalebar_kwargs)
1734-
ax.add_artist(scalebar)
1735-
17361747
return ax
17371748

17381749

@@ -2141,6 +2152,10 @@ def _validate_show_parameters(
21412152
return_ax: bool,
21422153
save: str | Path | None,
21432154
show: bool | None,
2155+
scalebar_dx: float | None,
2156+
scalebar_units: str,
2157+
scalebar_params: dict[str, Any] | None,
2158+
legend_params: dict[str, Any] | None,
21442159
) -> None:
21452160
if coordinate_systems is not None and not isinstance(coordinate_systems, list | str):
21462161
raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.")
@@ -2234,6 +2249,28 @@ def _validate_show_parameters(
22342249
if show is not None and not isinstance(show, bool):
22352250
raise TypeError("Parameter 'show' must be a boolean or None.")
22362251

2252+
if scalebar_dx is not None:
2253+
if not isinstance(scalebar_dx, int | float) or isinstance(scalebar_dx, bool):
2254+
raise TypeError("Parameter 'scalebar_dx' must be a number or None.")
2255+
if scalebar_dx <= 0:
2256+
raise ValueError("Parameter 'scalebar_dx' must be > 0.")
2257+
if not isinstance(scalebar_units, str):
2258+
raise TypeError("Parameter 'scalebar_units' must be a string.")
2259+
2260+
if scalebar_params is not None and not isinstance(scalebar_params, dict):
2261+
raise TypeError("Parameter 'scalebar_params' must be a dictionary or None.")
2262+
2263+
if legend_params is not None:
2264+
if not isinstance(legend_params, dict):
2265+
raise TypeError("Parameter 'legend_params' must be a dictionary or None.")
2266+
# `loc` is matplotlib.Legend's native key; `location` aligns with colorbar_params / scalebar_params.
2267+
allowed_legend_keys = {"loc", "location", "fontsize", "fontweight", "fontoutline", "na_in_legend"}
2268+
unknown = set(legend_params) - allowed_legend_keys
2269+
if unknown:
2270+
raise ValueError(
2271+
f"Unknown legend_params key(s): {sorted(unknown)}. Allowed keys: {sorted(allowed_legend_keys)}."
2272+
)
2273+
22372274

22382275
def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]:
22392276
colorbar = param_dict.get("colorbar", "auto")
90.5 KB
Loading
90.3 KB
Loading
88.6 KB
Loading
92 KB
Loading
91.2 KB
Loading

0 commit comments

Comments
 (0)