diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 9adf5d19..38183fa8 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -4,7 +4,7 @@ import sys import warnings from collections import OrderedDict -from collections.abc import Callable +from collections.abc import Callable, Sequence from copy import deepcopy from pathlib import Path from typing import Any, Literal, cast @@ -25,12 +25,14 @@ from mpl_toolkits.axes_grid1.inset_locator import inset_axes from spatialdata import get_extent from spatialdata._utils import _deprecation_alias +from spatialdata.transformations.operations import get_transformation from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot._logging import _log_context from spatialdata_plot.pl.render import ( _draw_channel_legend, + _render_graph, _render_images, _render_labels, _render_points, @@ -44,6 +46,7 @@ ChannelLegendEntry, CmapParams, ColorbarSpec, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, LegendParams, @@ -64,6 +67,7 @@ _prepare_cmap_norm, _prepare_params_plot, _set_outline, + _validate_graph_render_params, _validate_image_render_params, _validate_label_render_params, _validate_points_render_params, @@ -856,6 +860,143 @@ def render_labels( n_steps += 1 return sdata + def render_graph( + self, + element: str | None = None, + color: ColorLike | None = None, + *, + connectivity_key: str = "spatial", + obsp_key: str | None = None, + palette: dict[str, str] | list[str] | str | None = None, + na_color: ColorLike | None = "default", + cmap: Colormap | str | None = None, + norm: Normalize | None = None, + groups: list[str] | str | None = None, + group_key: str | None = None, + edge_width: float | Literal["weight"] = 1.0, + edge_alpha: float | Literal["weight"] = 1.0, + weight_key: str | None = None, + linestyle: str | Sequence[str] = "solid", + rasterize: bool = True, + include_self_loops: bool = False, + colorbar: bool | str | None = "auto", + colorbar_params: dict[str, object] | None = None, + table_name: str | None = None, + ) -> sd.SpatialData: + """Render spatial graph edges between observations. + + Draws edges from a connectivity matrix in ``table.obsp`` using + centroid coordinates of the linked spatial element. + + Parameters + ---------- + element : str | None + Name of the shapes/points/labels element the graph connects. + Auto-resolved from the table if omitted. + color : ColorLike | None + A color-like value applied to every edge, or the name of a + ``table.obs`` column. Categorical columns colour same-category + edges by the shared value and cross-category edges by + ``na_color``. Continuous columns colour edges by the mean of + their endpoint values. Defaults to grey when unset. + connectivity_key : str, default "spatial" + ``table.obsp`` key. Tries ``key`` first, then ``f"{key}_connectivities"``. + obsp_key : str | None + ``table.obsp`` matrix used as per-edge scalar; coloured via + ``cmap``/``norm``. Mutually exclusive with ``color``. + palette : dict[str, str] | list[str] | str | None + Palette for categorical obs coloring. Same as :meth:`render_shapes`. + na_color : ColorLike | None, default "default" + Colour for cross-category edges. ``None`` makes them transparent. + cmap : Colormap | str | None + Colormap for continuous edge coloring. + norm : Normalize | None + Pass ``Normalize(vmin=..., vmax=...)`` to clamp the colormap range. + groups : list[str] | str | None + Show only edges where **both** endpoints fall in these groups. + Requires ``group_key``. + group_key : str | None + ``table.obs`` column used for group filtering. + edge_width : float | Literal["weight"], default 1.0 + Line width. Pass ``"weight"`` to scale by ``weight_key`` values + into ``[0.5, 3.0]``. + edge_alpha : float | Literal["weight"], default 1.0 + Transparency. Pass ``"weight"`` to scale into ``[0.2, 1.0]``. + weight_key : str | None + ``table.obsp`` matrix providing per-edge weights. Defaults to + ``connectivity_key`` when omitted. + linestyle : str | Sequence[str], default "solid" + ``LineCollection`` linestyle (scalar or per-edge). + rasterize : bool, default True + Rasterize the edge collection. Set ``False`` for vector output. + include_self_loops : bool, default False + Render diagonal entries of the connectivity matrix as circles. + colorbar : bool | str | None, default "auto" + Whether to draw a colorbar for continuous edge coloring + (``obsp_key`` or a continuous obs column). ``"auto"`` draws it + when a mappable is present; ``True``/``False`` force it on/off. + colorbar_params : dict[str, object] | None + Optional matplotlib colorbar kwargs and layout hints + (e.g. ``{"loc": "right", "fraction": 0.05, "label": "..."}``). + table_name : str | None + Table containing the graph. Auto-discovered if omitted. + + Returns + ------- + sd.SpatialData + Copy with rendering parameters stored in the plotting tree. + + Notes + ----- + Chaining with ``render_shapes``/``render_points`` on the same + categorical column shares the legend; no dedicated edge legend is drawn. + """ + params = _validate_graph_render_params( + self._sdata, + element=element, + connectivity_key=connectivity_key, + obsp_key=obsp_key, + weight_key=weight_key, + palette=palette, + na_color=na_color, + cmap=cmap, + norm=norm, + table_name=table_name, + color=color, + edge_width=edge_width, + edge_alpha=edge_alpha, + groups=groups, + group_key=group_key, + ) + + sdata = self._copy() + sdata = _verify_plotting_tree(sdata) + n_steps = len(sdata.plotting_tree.keys()) + sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams( + element=params["element"], + connectivity_obsp_key=params["connectivity_obsp_key"], + table_name=params["table_name"], + color=params["color"], + obs_col=params["obs_col"], + obsp_key=params["obsp_key"], + cmap_params=params["cmap_params"], + palette_map=params["palette_map"], + na_color=params["na_color"], + color_source=params["color_source"], + groups=params["groups"], + group_key=params["group_key"], + edge_width=params["edge_width"], + edge_alpha=params["edge_alpha"], + weight_key=params["weight_key"], + linestyle=linestyle, + rasterize=rasterize, + include_self_loops=include_self_loops, + zorder=n_steps, + colorbar=colorbar, + colorbar_params=colorbar_params, + ) + return sdata + def show( self, coordinate_systems: list[str] | str | None = None, @@ -1020,6 +1161,7 @@ def show( "render_shapes", "render_labels", "render_points", + "render_graph", ] # prepare rendering params @@ -1340,6 +1482,21 @@ def _draw_colorbar( rasterize=rasterize, ) + elif cmd == "render_graph": + graph_element = params_copy.element + element_in_cs = graph_element in sdata and cs in set( + get_transformation(sdata[graph_element], get_all=True).keys() + ) + if element_in_cs: + _render_graph( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + legend_params=legend_params_obj, + colorbar_requests=axis_colorbar_requests, + ) + if title is None: t = cs elif len(title) == 1: diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 4da2cfa6..d97bb474 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -49,6 +49,7 @@ Color, ColorbarSpec, FigParams, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, LegendParams, @@ -1834,3 +1835,222 @@ def _draw_labels( col_for_color if isinstance(col_for_color, str) else None, ), ) + + +def _normalise_to_range(values: np.ndarray, lo: float, hi: float) -> np.ndarray: + """Min-max normalise a 1-D array into ``[lo, hi]``. Constant input → midpoint.""" + if len(values) == 0: + return values + vmin = float(np.nanmin(values)) + vmax = float(np.nanmax(values)) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax - vmin == 0.0: + return np.full_like(values, (lo + hi) / 2.0, dtype=float) + return lo + (values - vmin) * (hi - lo) / (vmax - vmin) + + +def _render_graph( + sdata: sd.SpatialData, + render_params: GraphRenderParams, + coordinate_system: str, + ax: matplotlib.axes.SubplotBase, + legend_params: LegendParams | None = None, + colorbar_requests: list[ColorbarSpec] | None = None, +) -> None: + """Render spatial graph edges as a LineCollection on the given axes.""" + from matplotlib.collections import CircleCollection, LineCollection + from scipy.sparse import triu + + _log_context.set("render_graph") + element_name = render_params.element + table_name = render_params.table_name + + table = sdata[table_name] + adjacency_key = render_params.connectivity_obsp_key + if adjacency_key not in table.obsp: + logger.warning(f"Connectivity key '{adjacency_key}' not found in table obsp. Skipping graph rendering.") + return + + adjacency = table.obsp[adjacency_key] + element = sdata[element_name] + centroids_df = sd.get_centroids(element, coordinate_system=coordinate_system) + if hasattr(centroids_df, "compute"): + centroids_df = centroids_df.compute() + + centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values]) + + _, region_key, instance_key = get_table_keys(table) + element_mask = table.obs[region_key].values == element_name + instance_ids = table.obs[instance_key].values[element_mask] + table_subset_indices = np.where(element_mask)[0] + + centroid_ids = np.asarray(centroids_df.index.values) + # Vectorised join: for each instance_id in the table subset, locate the + # matching row in centroid_ids. searchsorted requires a sorted index, which + # we can't assume, so fall back on isin + argsort for correctness. + order = np.argsort(centroid_ids) + sorted_ids = centroid_ids[order] + positions = np.searchsorted(sorted_ids, instance_ids) + positions = np.clip(positions, 0, len(sorted_ids) - 1) + found = sorted_ids[positions] == instance_ids + centroid_rows = order[positions] + + has_coord = np.zeros(table.n_obs, dtype=bool) + coords = np.full((table.n_obs, 2), np.nan) + matched_table_rows = table_subset_indices[found] + has_coord[matched_table_rows] = True + coords[matched_table_rows] = centroid_coords[centroid_rows[found]] + + groups = render_params.groups + group_key = render_params.group_key + if groups is not None and group_key is not None: + in_groups = np.isin(table.obs[group_key].values, groups) + has_coord &= in_groups + + coords[~has_coord] = np.nan + + # Per-edge attribute arrays are built in triu(adj, k=1).nonzero() order so + # the NaN-coord mask below subsets them consistently. + adj_upper = triu(adjacency, k=1) + all_rows, all_cols = adj_upper.nonzero() + + edge_color_arg: Any = "grey" + cmap_for_render = None + norm_for_render = None + cmap_params = render_params.cmap_params + + if render_params.color_source == "obsp": + value_matrix = table.obsp[render_params.obsp_key] + edge_color_arg = value_matrix[all_rows, all_cols].A1 + elif render_params.color_source in {"obs_continuous", "obs_categorical"}: + obs_series = table.obs[render_params.obs_col] + na_hex = render_params.na_color.get_hex_with_alpha() if render_params.na_color is not None else "#00000000" + if obs_series.isna().all(): + logger.warning(f"Column '{render_params.obs_col}' contains only NaN values; rendering edges with na_color.") + edge_color_arg = np.full(len(all_rows), na_hex, dtype=object) + elif render_params.color_source == "obs_continuous": + obs_values = np.asarray(obs_series.values, dtype=float) + edge_color_arg = 0.5 * (obs_values[all_rows] + obs_values[all_cols]) + else: + obs_values = obs_series.values + row_vals = obs_values[all_rows] + col_vals = obs_values[all_cols] + # Pre-fill with na_hex, then look up palette colours only for shared-endpoint edges. + palette_map = render_params.palette_map or {} + same = row_vals == col_vals + per_edge_colors = np.full(len(row_vals), na_hex, dtype=object) + if same.any(): + per_edge_colors[same] = [palette_map.get(v, na_hex) for v in row_vals[same]] + edge_color_arg = per_edge_colors + else: + edge_color_arg = (render_params.color or Color("grey")).get_hex() + + if render_params.color_source in {"obsp", "obs_continuous"} and cmap_params is not None: + cmap_for_render = cmap_params.cmap + norm_for_render = cmap_params.norm + + edge_width_arg: Any = render_params.edge_width + edge_alpha_arg: Any = render_params.edge_alpha + if render_params.edge_width == "weight" or render_params.edge_alpha == "weight": + weight_matrix = table.obsp[render_params.weight_key] + weights = weight_matrix[all_rows, all_cols].A1.astype(float) + if render_params.edge_width == "weight": + edge_width_arg = _normalise_to_range(weights, 0.5, 3.0) + if render_params.edge_alpha == "weight": + edge_alpha_arg = _normalise_to_range(weights, 0.2, 1.0) + + # Drop edges touching nodes without valid coords, and align per-edge arrays. + edge_mask = has_coord[all_rows] & has_coord[all_cols] + rows = all_rows[edge_mask] + cols = all_cols[edge_mask] + + def _maybe_subset(value: Any) -> Any: + if isinstance(value, np.ndarray) and value.ndim == 1 and len(value) == len(edge_mask): + return value[edge_mask] + return value + + edge_color_arg = _maybe_subset(edge_color_arg) + edge_width_arg = _maybe_subset(edge_width_arg) + edge_alpha_arg = _maybe_subset(edge_alpha_arg) + + if len(rows) == 0: + lc = LineCollection([]) + ax.add_collection(lc) + return + + segments = np.stack([coords[rows], coords[cols]], axis=1) + + lc_kwargs: dict[str, Any] = { + "linewidths": edge_width_arg, + "alpha": edge_alpha_arg, + "linestyles": render_params.linestyle, + "zorder": render_params.zorder, + } + + is_numeric_array = ( + isinstance(edge_color_arg, np.ndarray) + and edge_color_arg.ndim == 1 + and np.issubdtype(edge_color_arg.dtype, np.number) + ) + lc = LineCollection(segments, **lc_kwargs) + if is_numeric_array: + lc.set_array(edge_color_arg) + if cmap_for_render is not None: + lc.set_cmap(cmap_for_render) + if norm_for_render is not None: + lc.set_norm(norm_for_render) + else: + lc.set_color(edge_color_arg) + lc.set_rasterized(render_params.rasterize) + ax.add_collection(lc) + + if render_params.include_self_loops: + diag = np.asarray(adjacency.diagonal()).ravel() + sl_rows = np.where(diag != 0)[0] + sl_rows = sl_rows[has_coord[sl_rows]] + if len(sl_rows) > 0: + edge_lengths = np.linalg.norm(segments[:, 1] - segments[:, 0], axis=1) + median_len = float(np.median(edge_lengths)) if len(edge_lengths) else 1.0 + sl_color = edge_color_arg if isinstance(edge_color_arg, str) else "grey" + sl_alpha = edge_alpha_arg if isinstance(edge_alpha_arg, int | float) else 1.0 + cc = CircleCollection( + sizes=[max(median_len * 2.0, 4.0)] * len(sl_rows), + offsets=coords[sl_rows], + transOffset=ax.transData, + facecolors=sl_color, + edgecolors="none", + alpha=sl_alpha, + zorder=render_params.zorder, + ) + cc.set_rasterized(render_params.rasterize) + ax.add_collection(cc) + + is_continuous = render_params.color_source in {"obsp", "obs_continuous"} + should_request = _should_request_colorbar( + render_params.colorbar, + has_mappable=render_params.cmap_params is not None, + is_continuous=is_continuous, + ) + if ( + should_request + and colorbar_requests is not None + and legend_params is not None + and legend_params.colorbar + and render_params.cmap_params is not None + ): + sm = plt.cm.ScalarMappable( + cmap=render_params.cmap_params.cmap, + norm=render_params.cmap_params.norm, + ) + sm.set_array(lc.get_array()) + label = _resolve_colorbar_label( + render_params.colorbar_params, + fallback=render_params.obs_col or render_params.obsp_key, + ) + colorbar_requests.append( + ColorbarSpec( + ax=ax, + mappable=sm, + params=render_params.colorbar_params, + label=label, + ) + ) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index bc69f16a..e7232ec7 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -309,3 +309,30 @@ class LabelsRenderParams: zorder: int = 0 colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None + + +@dataclass +class GraphRenderParams: + """Graph render parameters.""" + + element: str + connectivity_obsp_key: str = "spatial_connectivities" + table_name: str | None = None + color: Color | None = None + obs_col: str | None = None + obsp_key: str | None = None + cmap_params: CmapParams | None = None + palette_map: dict[str, str] | None = None + na_color: Color | None = None + color_source: Literal["scalar", "obsp", "obs_categorical", "obs_continuous"] = "scalar" + groups: list[str] | str | None = None + group_key: str | None = None + edge_width: float | Literal["weight"] = 1.0 + edge_alpha: float | Literal["weight"] = 1.0 + weight_key: str | None = None + linestyle: str | Sequence[str] = "solid" + rasterize: bool = True + include_self_loops: bool = False + zorder: int = 0 + colorbar: bool | str | None = "auto" + colorbar_params: dict[str, object] | None = None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index bcdcb0dd..a675e00e 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -75,6 +75,7 @@ Color, ColorbarSpec, FigParams, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, OutlineParams, @@ -2098,7 +2099,7 @@ def _get_elements_to_be_rendered( render_cmds: list[ tuple[ str, - ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, ] ], cs_index: pd.DataFrame, @@ -2125,9 +2126,14 @@ def _get_elements_to_be_rendered( cs_row = cs_index.loc[cs] if cs in cs_index.index else None for cmd, params in render_cmds: - key = _RENDER_CMD_TO_CS_FLAG.get(cmd) - if key and cs_row is not None and cs_row[key]: + if cmd == "render_graph": + # Graph doesn't have its own CS flag; include its element so + # _get_valid_cs keeps the coordinate system alive. elements_to_be_rendered.append(params.element) + else: + key = _RENDER_CMD_TO_CS_FLAG.get(cmd) + if key and cs_row is not None and cs_row[key]: + elements_to_be_rendered.append(params.element) return elements_to_be_rendered @@ -2319,29 +2325,29 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st "shapes", "points", "labels", + "graph", }: if not isinstance(color, str | tuple | list): raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") - if element_type in {"shapes", "points", "labels"}: - if _is_color_like(color): - logger.info("Value for parameter 'color' appears to be a color, using it as such.") - param_dict["col_for_color"] = None - param_dict["color"] = Color(color) - if param_dict["color"].alpha_is_user_defined(): - if element_type == "points" and param_dict.get("alpha") is None: - param_dict["alpha"] = param_dict["color"].get_alpha_as_float() - elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: - param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() - else: - logger.info( - f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' " - "is set and its value takes precedence." - ) - elif isinstance(color, str): - param_dict["col_for_color"] = color - param_dict["color"] = None - else: - raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") + if _is_color_like(color): + logger.info("Value for parameter 'color' appears to be a color, using it as such.") + param_dict["col_for_color"] = None + param_dict["color"] = Color(color) + if param_dict["color"].alpha_is_user_defined(): + if element_type == "points" and param_dict.get("alpha") is None: + param_dict["alpha"] = param_dict["color"].get_alpha_as_float() + elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: + param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() + else: + logger.info( + f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' " + "is set and its value takes precedence." + ) + elif isinstance(color, str): + param_dict["col_for_color"] = color + param_dict["color"] = None + else: + raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") elif "color" in param_dict and element_type != "images": param_dict["col_for_color"] = None @@ -2467,7 +2473,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif isinstance(palette, list): if not all(isinstance(p, str) for p in palette): raise ValueError("If specified, parameter 'palette' must contain only strings.") - elif isinstance(palette, str | type(None)) and "palette" in param_dict: + elif isinstance(palette, str | type(None)) and "palette" in param_dict and element_type != "graph": param_dict["palette"] = [palette] if palette is not None else None palette_group = param_dict.get("palette") @@ -2485,7 +2491,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if not all(isinstance(c, Colormap | str) for c in cmap): raise TypeError("Each item in 'cmap' list must be a string or a Colormap.") elif isinstance(cmap, Colormap | str | type(None)): - if "cmap" in param_dict: + if "cmap" in param_dict and element_type != "graph": param_dict["cmap"] = [cmap] if cmap is not None else None else: raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.") @@ -2508,6 +2514,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise TypeError("Parameter 'norm' must be of type Normalize.") if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") + if element_type == "graph" and not isinstance(norm, Normalize): + raise TypeError("Parameter 'norm' must be a Normalize instance.") scale = param_dict.get("scale") if scale is not None: @@ -2601,6 +2609,35 @@ def _ensure_table_and_layer_exist_in_sdata( if ds_reduction and (ds_reduction not in valid_ds_reduction_methods): raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.") + if element_type == "graph": + for key in ("connectivity_key",): + val = param_dict.get(key) + if val is not None and not isinstance(val, str): + raise TypeError(f"Parameter '{key}' must be a string.") + + for key in ("obsp_key", "weight_key", "group_key"): + val = param_dict.get(key) + if val is not None and not isinstance(val, str): + raise TypeError(f"Parameter '{key}' must be a string or None.") + + for key in ("edge_width", "edge_alpha"): + val = param_dict.get(key) + if val == "weight": + continue + if not isinstance(val, float | int): + raise TypeError(f"Parameter '{key}' must be numeric or the literal string 'weight'.") + if val < 0: + raise ValueError(f"Parameter '{key}' cannot be negative.") + + linestyle = param_dict.get("linestyle") + if linestyle is not None and not isinstance(linestyle, str | list | tuple): + raise TypeError("Parameter 'linestyle' must be a string or a sequence of strings.") + + for key in ("include_self_loops", "rasterize"): + val = param_dict.get(key) + if val is not None and not isinstance(val, bool): + raise TypeError(f"Parameter '{key}' must be a boolean.") + return param_dict @@ -2857,6 +2894,237 @@ def _resolve_gene_symbols( return str(adata.var.index[mask][0]) +def _validate_graph_render_params( + sdata: SpatialData, + element: str | None, + connectivity_key: str, + table_name: str | None, + color: ColorLike | None, + edge_width: float | Literal["weight"], + edge_alpha: float | Literal["weight"], + groups: list[str] | str | None, + group_key: str | None, + obsp_key: str | None = None, + weight_key: str | None = None, + palette: dict[str, str] | list[str] | str | None = None, + na_color: ColorLike | None = "default", + cmap: Colormap | str | None = None, + norm: Normalize | None = None, + linestyle: str | Sequence[str] = "solid", + include_self_loops: bool = False, + rasterize: bool = True, +) -> dict[str, Any]: + """Validate and resolve parameters for render_graph.""" + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "color": color, + "groups": groups, + "palette": palette, + "na_color": na_color, + "cmap": cmap, + "norm": norm if norm is not None else Normalize(clip=False), + "table_name": table_name, + "connectivity_key": connectivity_key, + "obsp_key": obsp_key, + "weight_key": weight_key, + "group_key": group_key, + "edge_width": edge_width, + "edge_alpha": edge_alpha, + "linestyle": linestyle, + "include_self_loops": include_self_loops, + "rasterize": rasterize, + } + param_dict = _type_check_params(param_dict, "graph") + + if param_dict["table_name"] is None: + candidates = [tname for tname in sdata.tables if _resolve_obsp_key(sdata[tname], connectivity_key) is not None] + if len(candidates) == 0: + raise ValueError( + f"No table found with connectivity key '{connectivity_key}' in obsp. " + f"Available tables: {list(sdata.tables.keys())}." + ) + if len(candidates) > 1: + raise ValueError( + f"Multiple tables contain connectivity key '{connectivity_key}': {candidates}. " + "Please specify `table_name` explicitly." + ) + param_dict["table_name"] = candidates[0] + + if param_dict["table_name"] not in sdata.tables: + raise KeyError(f"Table '{param_dict['table_name']}' not found. Available: {list(sdata.tables.keys())}.") + + table = sdata[param_dict["table_name"]] + connectivity_obsp_key = _require_obsp_key(table, connectivity_key, param_name="connectivity_key") + + _, region_key, _ = get_table_keys(table) + if region_key is None: + raise ValueError( + f"Table '{param_dict['table_name']}' has no `region_key`; cannot associate its observations " + "with a spatial element. Re-parse the table with `TableModel.parse(..., region_key=...)`." + ) + + if param_dict["element"] is None: + regions = table.obs[region_key].unique().tolist() + spatial_regions = [r for r in regions if r in sdata.shapes or r in sdata.points or r in sdata.labels] + if len(spatial_regions) == 0: + raise ValueError( + f"Table '{param_dict['table_name']}' does not annotate any spatial element. Region values: {regions}." + ) + if len(spatial_regions) > 1: + raise ValueError( + f"Table '{param_dict['table_name']}' annotates multiple spatial elements: {spatial_regions}. " + "Please specify `element` explicitly." + ) + param_dict["element"] = spatial_regions[0] + elif not ( + param_dict["element"] in sdata.shapes + or param_dict["element"] in sdata.points + or param_dict["element"] in sdata.labels + ): + raise KeyError( + f"Element '{param_dict['element']}' not found in shapes, points, or labels. " + f"Available: shapes={list(sdata.shapes.keys())}, " + f"points={list(sdata.points.keys())}, labels={list(sdata.labels.keys())}." + ) + + # _type_check_params normalised string groups → list; renormalise the working set here. + if param_dict["groups"] is not None and param_dict["group_key"] is None: + raise ValueError("`groups` requires `group_key` to be specified.") + if param_dict["group_key"] is not None and param_dict["group_key"] not in table.obs.columns: + raise KeyError( + f"`group_key='{param_dict['group_key']}'` not found in table obs columns. " + f"Available: {list(table.obs.columns)}." + ) + if param_dict["groups"] is not None and param_dict["group_key"] is not None: + groups_set: set[Any] = set(param_dict["groups"]) + available_groups = set(table.obs[param_dict["group_key"]].dropna().unique()) + missing_groups = groups_set - available_groups + if missing_groups: + try: + missing_str = str(sorted(missing_groups)) + except TypeError: + missing_str = str(list(missing_groups)) + if missing_groups == groups_set: + logger.warning( + f"None of the requested groups {missing_str} were found in column " + f"'{param_dict['group_key']}'. Resulting plot will contain no edges." + ) + else: + logger.warning( + f"Groups {missing_str} not found in column '{param_dict['group_key']}' and will be ignored." + ) + + # After _type_check_params: col_for_color is the non-color string user passed via `color=`; + # color is either a Color (user gave a real color) or None (user gave a column name or nothing). + col_for_color = param_dict.get("col_for_color") + if col_for_color is not None and col_for_color not in table.obs.columns: + raise ValueError( + f"`color='{col_for_color}'` is not a matplotlib color and was not found in " + f"`table.obs` columns. Available obs columns: {list(table.obs.columns)}." + ) + + color_is_obs_col = col_for_color is not None + if obsp_key is not None and color_is_obs_col: + raise ValueError( + "Cannot set both `color` (as an obs column) and `obsp_key` for edge coloring. " + "Pick one source: scalar color, obs-column color, or obsp-matrix color." + ) + if obsp_key is not None and param_dict["color"] is not None: + raise ValueError( + "Cannot set both `color` and `obsp_key` for edge coloring. " + "Use `obsp_key` for matrix-driven coloring with `cmap`/`norm`, " + "or `color` for a scalar / obs-column-driven coloring." + ) + + color_obsp_key: str | None = None + obs_col: str | None = None + color_source: str = "scalar" + cmap_params: CmapParams | None = None + palette_map: dict[str, str] | None = None + + if obsp_key is not None: + color_obsp_key = _require_obsp_key(table, obsp_key, param_name="obsp_key") + color_source = "obsp" + cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) + elif color_is_obs_col: + obs_col = col_for_color + obs_values = table.obs[obs_col] + if isinstance(obs_values.dtype, pd.CategoricalDtype) or obs_values.dtype == object: + color_source = "obs_categorical" + categories = ( + obs_values.cat.categories.tolist() + if isinstance(obs_values.dtype, pd.CategoricalDtype) + else sorted(obs_values.dropna().unique().tolist()) + ) + if isinstance(palette, dict): + missing = [c for c in categories if c not in palette] + if missing: + raise KeyError( + f"Palette dict is missing entries for categories: {missing}. " + f"Available categories: {categories}." + ) + palette_map = {c: palette[c] for c in categories} + else: + cat_colors = _get_colors_for_categorical_obs(categories=categories, palette=palette) + palette_map = dict(zip(categories, cat_colors, strict=True)) + else: + color_source = "obs_continuous" + cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) + + # When edge_width/edge_alpha="weight" but weight_key isn't given, fall back to the + # connectivity matrix so binary graphs still produce a per-edge array. + resolved_weight_key: str | None = None + if edge_width == "weight" or edge_alpha == "weight": + resolved_weight_key = _require_obsp_key( + table, weight_key if weight_key is not None else connectivity_key, param_name="weight_key" + ) + + edge_color = param_dict["color"] if param_dict["color"] is not None else Color("grey") + parsed_na_color = param_dict["na_color"] + + return { + "element": param_dict["element"], + "connectivity_key": connectivity_key, + "connectivity_obsp_key": connectivity_obsp_key, + "obsp_key": color_obsp_key, + "obs_col": obs_col, + "cmap_params": cmap_params, + "palette_map": palette_map, + "na_color": parsed_na_color, + "color_source": color_source, + "table_name": param_dict["table_name"], + "weight_key": resolved_weight_key, + "color": edge_color, + "edge_width": edge_width, + "edge_alpha": edge_alpha, + "groups": param_dict["groups"], + "group_key": param_dict["group_key"], + } + + +def _resolve_obsp_key(table: AnnData, connectivity_key: str) -> str | None: + """Resolve connectivity_key to an actual obsp key. Accepts full key or prefix.""" + if connectivity_key in table.obsp: + return connectivity_key + suffixed = f"{connectivity_key}_connectivities" + if suffixed in table.obsp: + return suffixed + return None + + +def _require_obsp_key(table: AnnData, key: str, *, param_name: str) -> str: + """Resolve key (with prefix fallback) or raise KeyError.""" + resolved = _resolve_obsp_key(table, key) + if resolved is None: + raise KeyError( + f"`{param_name}='{key}'` not found in `table.obsp`. " + f"Tried '{key}' and '{key}_connectivities'. " + f"Available obsp keys: {list(table.obsp.keys())}." + ) + return resolved + + def _validate_col_for_column_table( sdata: SpatialData, element_name: str, diff --git a/tests/_images/Graph_can_render_graph_on_labels.png b/tests/_images/Graph_can_render_graph_on_labels.png new file mode 100644 index 00000000..7a60beea Binary files /dev/null and b/tests/_images/Graph_can_render_graph_on_labels.png differ diff --git a/tests/_images/Graph_can_render_graph_on_shapes.png b/tests/_images/Graph_can_render_graph_on_shapes.png new file mode 100644 index 00000000..001ddbab Binary files /dev/null and b/tests/_images/Graph_can_render_graph_on_shapes.png differ diff --git a/tests/_images/Graph_can_render_graph_with_groups_filter.png b/tests/_images/Graph_can_render_graph_with_groups_filter.png new file mode 100644 index 00000000..f7aa55bd Binary files /dev/null and b/tests/_images/Graph_can_render_graph_with_groups_filter.png differ diff --git a/tests/pl/test_render_graph.py b/tests/pl/test_render_graph.py new file mode 100644 index 00000000..96dca213 --- /dev/null +++ b/tests/pl/test_render_graph.py @@ -0,0 +1,258 @@ +import geopandas as gpd +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest +import scanpy as sc +import spatialdata as sd +from anndata import AnnData +from matplotlib.collections import CircleCollection, LineCollection +from scipy.sparse import csr_matrix, lil_matrix, triu +from scipy.spatial import KDTree +from shapely.geometry import Point +from spatialdata import SpatialData +from spatialdata.datasets import blobs +from spatialdata.models import ShapesModel, TableModel + +import spatialdata_plot # noqa: F401 +from spatialdata_plot._logging import logger, logger_warns +from tests.conftest import DPI, PlotTester, PlotTesterMeta, get_standard_RNG + +sc.pl.set_rcParams_defaults() +sc.set_figure_params(dpi=DPI, color_map="viridis") +matplotlib.use("agg") +_ = spatialdata_plot + + +def _knn_adjacency(coords: np.ndarray, k: int = 3) -> csr_matrix: + n = len(coords) + adj = lil_matrix((n, n)) + tree = KDTree(coords) + for i in range(n): + _, neighbors = tree.query(coords[i], k=min(k + 1, n)) + for j in neighbors[1:]: + adj[i, j] = adj[j, i] = 1.0 + return adj.tocsr() + + +def _sdata_with_graph_on_shapes() -> SpatialData: + rng = get_standard_RNG() + n = 20 + coords = rng.uniform(10, 90, size=(n, 2)) + shapes = ShapesModel.parse( + gpd.GeoDataFrame(geometry=[Point(x, y) for x, y in coords], data={"radius": np.ones(n) * 2.5}) + ) + adata = AnnData(rng.normal(size=(n, 5))) + adata.obs["instance_id"] = np.arange(n) + adata.obs["region"] = "my_shapes" + adata.obs["cell_type"] = pd.Categorical(rng.choice(["tumor", "immune", "stroma"], size=n)) + adata.obsp["spatial_connectivities"] = _knn_adjacency(coords, k=3) + table = TableModel.parse(adata, region="my_shapes", region_key="region", instance_key="instance_id") + return SpatialData(shapes={"my_shapes": shapes}, tables={"table": table}) + + +def _sdata_with_graph_on_labels() -> SpatialData: + sdata = blobs() + table = sdata["table"] + centroids = sd.get_centroids(sdata["blobs_labels"]).compute() + coords = centroids.loc[table.obs["instance_id"].values, ["x", "y"]].to_numpy() + table.obsp["spatial_connectivities"] = _knn_adjacency(coords, k=3) + return sdata + + +def _sdata_with_weighted_graph() -> SpatialData: + sdata = _sdata_with_graph_on_shapes() + adj = sdata["table"].obsp["spatial_connectivities"].copy().astype(float).tolil() + rng = get_standard_RNG() + for r, c in zip(*adj.nonzero(), strict=True): + adj[r, c] = float(rng.uniform(0.1, 5.0)) + sdata["table"].obsp["spatial_distances"] = adj.tocsr() + return sdata + + +class TestGraph(PlotTester, metaclass=PlotTesterMeta): + def test_plot_can_render_graph_on_shapes(self): + sdata = _sdata_with_graph_on_shapes() + sdata.pl.render_graph("my_shapes", table_name="table").pl.render_shapes("my_shapes").pl.show() + + def test_plot_can_render_graph_on_labels(self): + sdata = _sdata_with_graph_on_labels() + ( + sdata.pl.render_images("blobs_image") + .pl.render_graph("blobs_labels", table_name="table", edge_alpha=0.5) + .pl.render_labels("blobs_labels") + .pl.show() + ) + + def test_plot_can_render_graph_with_groups_filter(self): + sdata = _sdata_with_graph_on_shapes() + ( + sdata.pl.render_graph("my_shapes", table_name="table", group_key="cell_type", groups=["tumor"]) + .pl.render_shapes("my_shapes", color="cell_type") + .pl.show() + ) + + +def test_render_graph_empty_graph_does_not_error(): + sdata = _sdata_with_graph_on_shapes() + sdata["table"].obsp["spatial_connectivities"] = csr_matrix((20, 20)) + sdata.pl.render_graph("my_shapes", table_name="table").pl.render_shapes("my_shapes").pl.show() + + +def test_render_graph_auto_discovers_element_and_table(): + sdata = _sdata_with_graph_on_shapes() + step_key, params = next(iter(sdata.pl.render_graph().plotting_tree.items())) + assert step_key.endswith("_render_graph") + assert params.element == "my_shapes" and params.table_name == "table" + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"connectivity_key": "nonexistent"}, "not found in `table.obsp`"), + ({"element": "no_such_element"}, "not found in shapes, points, or labels"), + ({"groups": ["tumor"]}, "`groups` requires `group_key`"), + ({"color": None, "obsp_key": "no_such"}, "`obsp_key='no_such'` not found"), + ({"edge_width": "weight", "weight_key": "no_such"}, "`weight_key='no_such'` not found"), + ], +) +def test_render_graph_error_paths(kwargs, match): + sdata = _sdata_with_graph_on_shapes() + element = kwargs.pop("element", "my_shapes") + with pytest.raises((KeyError, ValueError), match=match): + sdata.pl.render_graph(element, table_name="table", **kwargs) + + +def test_render_graph_rejects_color_and_obsp_key_together(): + sdata = _sdata_with_weighted_graph() + with pytest.raises(ValueError, match="Cannot set both `color` and `obsp_key`"): + sdata.pl.render_graph("my_shapes", table_name="table", color="red", obsp_key="spatial_distances") + + +def test_render_graph_warns_on_groups_not_in_column(caplog): + sdata = _sdata_with_graph_on_shapes() + with logger_warns(caplog, logger, match="not_a_real_group"): + sdata.pl.render_graph("my_shapes", table_name="table", group_key="cell_type", groups=["not_a_real_group"]) + + +def test_render_graph_raises_on_table_without_region_key(): + sdata = _sdata_with_graph_on_shapes() + sdata["table"].uns["spatialdata_attrs"]["region_key"] = None + with pytest.raises(ValueError, match="has no `region_key`"): + sdata.pl.render_graph("my_shapes", table_name="table") + + +def test_render_graph_obsp_key_populates_edge_values_from_matrix(): + """Edge color array must equal the obsp matrix entries for the rendered edges.""" + sdata = _sdata_with_weighted_graph() + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", color=None, obsp_key="spatial_distances", cmap="viridis") + .pl.show(ax=ax) + ) + lc = next(c for c in ax.collections if isinstance(c, LineCollection)) + distances = sdata["table"].obsp["spatial_distances"] + rows, cols = triu(distances, k=1).nonzero() + np.testing.assert_allclose(np.asarray(lc.get_array()), distances[rows, cols].A1) + plt.close(fig) + + +def test_render_graph_color_by_obs_categorical_with_palette_dict(): + """Same-category edges get the palette colour; cross-category edges get na_color.""" + sdata = _sdata_with_graph_on_shapes() + palette = {"tumor": "#ff0000", "immune": "#00ff00", "stroma": "#0000ff"} + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", color="cell_type", palette=palette, na_color="#888888") + .pl.show(ax=ax) + ) + lc = next(c for c in ax.collections if isinstance(c, LineCollection)) + allowed = {tuple(matplotlib.colors.to_rgba(v)) for v in palette.values()} + allowed.add(tuple(matplotlib.colors.to_rgba("#888888"))) + for c in lc.get_colors(): + assert tuple(c) in allowed + plt.close(fig) + + +def test_render_graph_color_by_obs_continuous_uses_endpoint_mean(): + sdata = _sdata_with_graph_on_shapes() + rng = get_standard_RNG() + scores = rng.uniform(0.0, 1.0, size=sdata["table"].n_obs) + sdata["table"].obs["score"] = scores + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", color="score", cmap="magma") + .pl.show(ax=ax) + ) + lc = next(c for c in ax.collections if isinstance(c, LineCollection)) + rows, cols = triu(sdata["table"].obsp["spatial_connectivities"], k=1).nonzero() + np.testing.assert_allclose(np.asarray(lc.get_array()), 0.5 * (scores[rows] + scores[cols])) + plt.close(fig) + + +def test_render_graph_draws_colorbar_for_continuous_coloring(): + sdata = _sdata_with_weighted_graph() + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", color=None, obsp_key="spatial_distances", cmap="viridis") + .pl.show(ax=ax) + ) + cbars = [c for c in fig.get_children() if isinstance(c, matplotlib.axes.Axes) and c is not ax] + assert cbars + plt.close(fig) + + +def test_render_graph_colorbar_can_be_disabled(): + sdata = _sdata_with_weighted_graph() + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", color=None, obsp_key="spatial_distances", colorbar=False) + .pl.show(ax=ax) + ) + cbars = [ + c + for c in fig.get_children() + if isinstance(c, matplotlib.axes.Axes) and c is not ax and c.get_ylim() != (0.0, 1.0) + ] + assert not cbars + plt.close(fig) + + +def test_render_graph_edge_width_by_weight_produces_normalised_array(): + sdata = _sdata_with_weighted_graph() + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", edge_width="weight", weight_key="spatial_distances") + .pl.show(ax=ax) + ) + lc = next(c for c in ax.collections if isinstance(c, LineCollection)) + widths = lc.get_linewidths() + assert len(widths) > 1 + assert 0.5 - 1e-6 <= float(np.min(widths)) < float(np.max(widths)) <= 3.0 + 1e-6 + plt.close(fig) + + +@pytest.mark.parametrize("include_self_loops", [True, False]) +def test_render_graph_include_self_loops(include_self_loops): + sdata = _sdata_with_graph_on_shapes() + adj = sdata["table"].obsp["spatial_connectivities"].tolil() + for i in range(sdata["table"].n_obs): + adj[i, i] = 1.0 + sdata["table"].obsp["spatial_connectivities"] = adj.tocsr() + + fig, ax = plt.subplots() + ( + sdata.pl.render_shapes("my_shapes") + .pl.render_graph("my_shapes", table_name="table", include_self_loops=include_self_loops) + .pl.show(ax=ax) + ) + has_circles = any(isinstance(c, CircleCollection) for c in ax.collections) + assert has_circles is include_self_loops + plt.close(fig)