diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 9448dbb8..df7eebe7 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -25,12 +25,14 @@ from mpl_toolkits.axes_grid1.inset_locator import inset_axes from spatialdata import get_extent from spatialdata._utils import _deprecation_alias +from spatialdata.transformations.operations import get_transformation from xarray import DataArray, DataTree from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl.render import ( _draw_channel_legend, + _render_graph, _render_images, _render_labels, _render_points, @@ -44,6 +46,7 @@ ChannelLegendEntry, CmapParams, ColorbarSpec, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, LegendParams, @@ -63,6 +66,7 @@ _prepare_cmap_norm, _prepare_params_plot, _set_outline, + _validate_graph_render_params, _validate_image_render_params, _validate_label_render_params, _validate_points_render_params, @@ -861,6 +865,79 @@ def render_labels( n_steps += 1 return sdata + def render_graph( + self, + element: str | None = None, + color: ColorLike | None = "grey", + *, + connectivity_key: str = "spatial", + groups: list[str] | str | None = None, + group_key: str | None = None, + edge_width: float = 1.0, + edge_alpha: float = 1.0, + table_name: str | None = None, + ) -> sd.SpatialData: + """Render spatial graph edges between observations. + + Draws edges from a connectivity matrix stored in a table's ``obsp``, + using centroid coordinates of the linked spatial element. + + Parameters + ---------- + element : str | None, optional + Name of the spatial element (shapes, points, or labels) whose + observations the graph connects. Auto-resolved from the table + if not given. + color : ColorLike | None, default "grey" + Edge color as a color-like value (e.g. ``"red"``, ``"#aabbcc"``). + connectivity_key : str, default "spatial" + Key prefix in ``table.obsp``. Tries ``obsp[key]`` first, then + ``obsp[f"{key}_connectivities"]``. + groups : list[str] | str | None, optional + Show only edges where **both** endpoints belong to the specified + groups. Requires ``group_key``. + group_key : str | None, optional + Column in ``table.obs`` used for group filtering. + edge_width : float, default 1.0 + Line width for edges. + edge_alpha : float, default 1.0 + Transparency for edges (0 = invisible, 1 = opaque). + table_name : str | None, optional + Table containing the graph. Auto-discovered if not given. + + Returns + ------- + sd.SpatialData + Copy with rendering parameters stored in the plotting tree. + """ + params = _validate_graph_render_params( + self._sdata, + element=element, + connectivity_key=connectivity_key, + table_name=table_name, + color=color, + edge_width=edge_width, + edge_alpha=edge_alpha, + groups=groups, + group_key=group_key, + ) + + sdata = self._copy() + sdata = _verify_plotting_tree(sdata) + n_steps = len(sdata.plotting_tree.keys()) + sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams( + element=params["element"], + connectivity_key=params["obsp_key"], + table_name=params["table_name"], + color=params["color"], + groups=params["groups"], + group_key=params["group_key"], + edge_width=params["edge_width"], + edge_alpha=params["edge_alpha"], + zorder=n_steps, + ) + return sdata + def show( self, coordinate_systems: list[str] | str | None = None, @@ -1001,6 +1078,7 @@ def show( "render_shapes", "render_labels", "render_points", + "render_graph", ] # prepare rendering params @@ -1311,6 +1389,19 @@ def _draw_colorbar( rasterize=rasterize, ) + elif cmd == "render_graph": + graph_element = params_copy.element + element_in_cs = graph_element in sdata and cs in set( + get_transformation(sdata[graph_element], get_all=True).keys() + ) + if element_in_cs: + _render_graph( + sdata=sdata, + render_params=params_copy, + coordinate_system=cs, + ax=ax, + ) + if title is None: t = cs elif len(title) == 1: diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 892dbf6a..7f878cce 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -49,6 +49,7 @@ Color, ColorbarSpec, FigParams, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, LegendParams, @@ -1815,3 +1816,97 @@ def _draw_labels( scalebar_units=scalebar_params.scalebar_units, # scalebar_kwargs=scalebar_params.scalebar_kwargs, ) + + +def _render_graph( + sdata: sd.SpatialData, + render_params: GraphRenderParams, + coordinate_system: str, + ax: matplotlib.axes.SubplotBase, +) -> None: + """Render spatial graph edges as a LineCollection on the given axes.""" + from matplotlib.collections import LineCollection + from scipy.sparse import triu + + _log_context.set("render_graph") + element_name = render_params.element + table_name = render_params.table_name + + # Get table and adjacency matrix + table = sdata[table_name] + obsp_key = render_params.connectivity_key + if obsp_key not in table.obsp: + logger.warning(f"Connectivity key '{obsp_key}' not found in table obsp. Skipping graph rendering.") + return + + adjacency = table.obsp[obsp_key] + + # Get the spatial element + if element_name in sdata.shapes: + element = sdata.shapes[element_name] + elif element_name in sdata.points: + element = sdata.points[element_name] + elif element_name in sdata.labels: + element = sdata.labels[element_name] + else: + logger.warning(f"Element '{element_name}' not found in sdata. Skipping graph rendering.") + return + + # Get centroids in the target coordinate system + centroids_df = sd.get_centroids(element, coordinate_system=coordinate_system) + if hasattr(centroids_df, "compute"): + centroids_df = centroids_df.compute() + + centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values]) + + # Align table observations to centroid positions via instance_key. + # Build a coordinate array indexed by full-table row so edge lookups are O(1). + _, region_key, instance_key = get_table_keys(table) + + element_mask = table.obs[region_key] == element_name if region_key is not None else np.ones(table.n_obs, dtype=bool) + instance_ids = table.obs[instance_key].values[element_mask] + table_subset_indices = np.where(element_mask)[0] + + centroid_ids = centroids_df.index.values if hasattr(centroids_df, "index") else np.arange(len(centroids_df)) + id_to_centroid_row = {cid: row for row, cid in enumerate(centroid_ids)} + + # has_coord[i] is True if table row i has a valid centroid + has_coord = np.zeros(table.n_obs, dtype=bool) + coord_lookup = np.full((table.n_obs, 2), np.nan) + for table_row, iid in zip(table_subset_indices, instance_ids, strict=True): + if iid in id_to_centroid_row: + has_coord[table_row] = True + coord_lookup[table_row] = centroid_coords[id_to_centroid_row[iid]] + + # Apply group filtering: narrow has_coord to only rows in requested groups + groups = render_params.groups + group_key = render_params.group_key + if groups is not None and group_key is not None: + group_values = table.obs[group_key].values + in_groups = np.isin(group_values, groups) + has_coord &= in_groups + + # Extract edges from upper triangle (undirected — draw each edge once, skip self-loops) + adj_upper = triu(adjacency, k=1) + rows, cols = adj_upper.nonzero() + + # Vectorized filter: keep edges where both endpoints are valid + edge_mask = has_coord[rows] & has_coord[cols] + if not edge_mask.any(): + return + + valid_rows = rows[edge_mask] + valid_cols = cols[edge_mask] + segments = np.stack([coord_lookup[valid_rows], coord_lookup[valid_cols]], axis=1) + + edge_color = render_params.color.get_hex() if render_params.color is not None else "#808080" + + lc = LineCollection( + segments, + linewidths=render_params.edge_width, + colors=edge_color, + alpha=render_params.edge_alpha, + zorder=render_params.zorder, + ) + lc.set_rasterized(True) + ax.add_collection(lc) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 16f81578..444c8d99 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -307,3 +307,18 @@ class LabelsRenderParams: zorder: int = 0 colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None + + +@dataclass +class GraphRenderParams: + """Graph render parameters.""" + + element: str + connectivity_key: str = "spatial" + table_name: str | None = None + color: Color | None = None + groups: list[str] | str | None = None + group_key: str | None = None + edge_width: float = 1.0 + edge_alpha: float = 1.0 + zorder: int = 0 diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4c7a48d7..7d2ea343 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -76,6 +76,7 @@ Color, ColorbarSpec, FigParams, + GraphRenderParams, ImageRenderParams, LabelsRenderParams, OutlineParams, @@ -2103,7 +2104,7 @@ def _get_elements_to_be_rendered( render_cmds: list[ tuple[ str, - ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, ] ], cs_contents: pd.DataFrame, @@ -2130,9 +2131,14 @@ def _get_elements_to_be_rendered( cs_query = cs_contents.query(f"cs == '{cs}'") for cmd, params in render_cmds: - key = _RENDER_CMD_TO_CS_FLAG.get(cmd) - if key and cs_query[key][0]: + if cmd == "render_graph": + # Graph doesn't have its own CS flag; include its element so + # _get_valid_cs keeps the coordinate system alive. elements_to_be_rendered += [params.element] + else: + key = _RENDER_CMD_TO_CS_FLAG.get(cmd) + if key and cs_query[key][0]: + elements_to_be_rendered += [params.element] return elements_to_be_rendered @@ -2829,6 +2835,105 @@ def _resolve_gene_symbols( return str(adata.var.index[mask][0]) +def _validate_graph_render_params( + sdata: SpatialData, + element: str | None, + connectivity_key: str, + table_name: str | None, + color: ColorLike | None, + edge_width: float, + edge_alpha: float, + groups: list[str] | str | None, + group_key: str | None, +) -> dict[str, Any]: + """Validate and resolve parameters for render_graph.""" + # Resolve table_name: find a table with the connectivity key + if table_name is None: + candidates = [] + for tname in sdata.tables: + t = sdata[tname] + obsp_key = _resolve_obsp_key(t, connectivity_key) + if obsp_key is not None: + candidates.append(tname) + if len(candidates) == 0: + raise ValueError( + f"No table found with connectivity key '{connectivity_key}' in obsp. " + f"Available tables: {list(sdata.tables.keys())}." + ) + if len(candidates) > 1: + raise ValueError( + f"Multiple tables contain connectivity key '{connectivity_key}': {candidates}. " + "Please specify `table_name` explicitly." + ) + table_name = candidates[0] + + if table_name not in sdata.tables: + raise KeyError(f"Table '{table_name}' not found. Available: {list(sdata.tables.keys())}.") + + table = sdata[table_name] + obsp_key = _resolve_obsp_key(table, connectivity_key) + if obsp_key is None: + raise KeyError( + f"Connectivity key '{connectivity_key}' not found in `table.obsp`. " + f"Tried '{connectivity_key}' and '{connectivity_key}_connectivities'. " + f"Available obsp keys: {list(table.obsp.keys())}." + ) + + # Resolve element: find the spatial element this table annotates + if element is None: + _, region_key, _ = get_table_keys(table) + regions = table.obs[region_key].unique().tolist() if region_key else [] + spatial_regions = [r for r in regions if r in sdata.shapes or r in sdata.points or r in sdata.labels] + if len(spatial_regions) == 0: + raise ValueError(f"Table '{table_name}' does not annotate any spatial element. Region values: {regions}.") + if len(spatial_regions) > 1: + raise ValueError( + f"Table '{table_name}' annotates multiple spatial elements: {spatial_regions}. " + "Please specify `element` explicitly." + ) + element = spatial_regions[0] + else: + if not (element in sdata.shapes or element in sdata.points or element in sdata.labels): + raise KeyError( + f"Element '{element}' not found in shapes, points, or labels. " + f"Available: shapes={list(sdata.shapes.keys())}, " + f"points={list(sdata.points.keys())}, labels={list(sdata.labels.keys())}." + ) + + # Validate groups/group_key + if groups is not None and group_key is None: + raise ValueError("`groups` requires `group_key` to be specified.") + if group_key is not None and group_key not in table.obs.columns: + raise KeyError( + f"`group_key='{group_key}'` not found in table obs columns. Available: {list(table.obs.columns)}." + ) + + # Parse color + edge_color = Color(color) if color is not None else Color("grey") + + return { + "element": element, + "connectivity_key": connectivity_key, + "obsp_key": obsp_key, + "table_name": table_name, + "color": edge_color, + "edge_width": edge_width, + "edge_alpha": edge_alpha, + "groups": [groups] if isinstance(groups, str) else groups, + "group_key": group_key, + } + + +def _resolve_obsp_key(table: AnnData, connectivity_key: str) -> str | None: + """Resolve connectivity_key to an actual obsp key. Accepts full key or prefix.""" + if connectivity_key in table.obsp: + return connectivity_key + suffixed = f"{connectivity_key}_connectivities" + if suffixed in table.obsp: + return suffixed + return None + + def _validate_col_for_column_table( sdata: SpatialData, element_name: str, diff --git a/tests/_images/Graph_can_render_graph_on_labels.png b/tests/_images/Graph_can_render_graph_on_labels.png new file mode 100644 index 00000000..4d9c8cf2 Binary files /dev/null and b/tests/_images/Graph_can_render_graph_on_labels.png differ diff --git a/tests/_images/Graph_can_render_graph_on_shapes.png b/tests/_images/Graph_can_render_graph_on_shapes.png new file mode 100644 index 00000000..001ddbab Binary files /dev/null and b/tests/_images/Graph_can_render_graph_on_shapes.png differ diff --git a/tests/_images/Graph_can_render_graph_with_auto_discovery.png b/tests/_images/Graph_can_render_graph_with_auto_discovery.png new file mode 100644 index 00000000..001ddbab Binary files /dev/null and b/tests/_images/Graph_can_render_graph_with_auto_discovery.png differ diff --git a/tests/_images/Graph_can_render_graph_with_groups_filter.png b/tests/_images/Graph_can_render_graph_with_groups_filter.png new file mode 100644 index 00000000..f7aa55bd Binary files /dev/null and b/tests/_images/Graph_can_render_graph_with_groups_filter.png differ diff --git a/tests/pl/test_render_graph.py b/tests/pl/test_render_graph.py new file mode 100644 index 00000000..a325c5eb --- /dev/null +++ b/tests/pl/test_render_graph.py @@ -0,0 +1,177 @@ +import geopandas as gpd +import matplotlib +import numpy as np +import pandas as pd +import pytest +import scanpy as sc +import spatialdata as sd +from anndata import AnnData +from scipy.sparse import csr_matrix, lil_matrix +from scipy.spatial import KDTree +from shapely.geometry import Point +from spatialdata import SpatialData +from spatialdata.datasets import blobs +from spatialdata.models import ShapesModel, TableModel + +import spatialdata_plot # noqa: F401 +from tests.conftest import DPI, PlotTester, PlotTesterMeta, get_standard_RNG + +sc.pl.set_rcParams_defaults() +sc.set_figure_params(dpi=DPI, color_map="viridis") +matplotlib.use("agg") +_ = spatialdata_plot + + +def _make_sdata_with_graph_on_shapes() -> SpatialData: + """Create SpatialData with shapes, an annotating table, and a spatial connectivity graph in obsp.""" + rng = get_standard_RNG() + n = 20 + + # Shapes at reproducible positions + coords = rng.uniform(10, 90, size=(n, 2)) + gdf = gpd.GeoDataFrame( + geometry=[Point(x, y) for x, y in coords], + data={"radius": np.ones(n) * 2.5}, + ) + shapes = ShapesModel.parse(gdf) + + # Table annotating the shapes + adata = AnnData(rng.normal(size=(n, 5))) + adata.obs["instance_id"] = np.arange(n) + adata.obs["region"] = "my_shapes" + adata.obs["cell_type"] = pd.Categorical(rng.choice(["tumor", "immune", "stroma"], size=n)) + + # Build KNN spatial graph (k=3 neighbors) + tree = KDTree(coords) + adj = lil_matrix((n, n)) + for i in range(n): + _, indices = tree.query(coords[i], k=4) # self + 3 neighbors + for j in indices[1:]: + adj[i, j] = 1.0 + adj[j, i] = 1.0 + + adata.obsp["spatial_connectivities"] = adj.tocsr() + + table = TableModel.parse(adata, region="my_shapes", region_key="region", instance_key="instance_id") + return SpatialData(shapes={"my_shapes": shapes}, tables={"table": table}) + + +def _make_sdata_with_graph_on_labels() -> SpatialData: + """Create SpatialData based on blobs with a spatial graph connecting label regions.""" + blob = blobs() + table = blob["table"] + n = table.n_obs + + # Compute label centroids to build a spatially meaningful graph + centroids_df = sd.get_centroids(blob["blobs_labels"]).compute() + instance_ids = table.obs["instance_id"].values.astype(int) + + # Align centroids to table instance order + # centroids_df index corresponds to label IDs (excluding background 0) + centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values]) + + # Map table instance_ids to centroid positions + # centroids_df is indexed 0..len-1, label IDs are in the index implicitly + # We need to match table's instance_ids to the centroid rows + if hasattr(centroids_df.index, "values"): + label_ids_in_centroids = centroids_df.index.values + else: + label_ids_in_centroids = np.arange(len(centroids_df)) + + # Build lookup: label_id -> row index in centroids + id_to_row = {lid: row for row, lid in enumerate(label_ids_in_centroids)} + + # Only include table obs that have centroids + valid_mask = np.array([iid in id_to_row for iid in instance_ids]) + valid_indices = np.where(valid_mask)[0] + valid_coords = np.array([centroid_coords[id_to_row[instance_ids[i]]] for i in valid_indices]) + + # Build KNN graph over valid observations + adj = lil_matrix((n, n)) + if len(valid_coords) > 1: + tree = KDTree(valid_coords) + k = min(4, len(valid_coords)) + for idx_in_valid, i in enumerate(valid_indices): + _, neighbors = tree.query(valid_coords[idx_in_valid], k=k) + for nb in neighbors[1:]: + j = valid_indices[nb] + adj[i, j] = 1.0 + adj[j, i] = 1.0 + + table.obsp["spatial_connectivities"] = adj.tocsr() + + rng = get_standard_RNG() + table.obs["cell_type"] = pd.Categorical(rng.choice(["tumor", "immune", "stroma"], size=n)) + + return blob + + +class TestGraph(PlotTester, metaclass=PlotTesterMeta): + def test_plot_can_render_graph_on_shapes(self): + """Basic graph rendering: edges overlaid on shapes.""" + sdata = _make_sdata_with_graph_on_shapes() + ( + sdata.pl.render_graph( + "my_shapes", + connectivity_key="spatial", + table_name="table", + ) + .pl.render_shapes("my_shapes") + .pl.show() + ) + + def test_plot_can_render_graph_on_labels(self): + """Graph overlay on label segmentation with background image — most common real-world use case.""" + sdata = _make_sdata_with_graph_on_labels() + ( + sdata.pl.render_images("blobs_image") + .pl.render_graph( + "blobs_labels", + connectivity_key="spatial", + table_name="table", + edge_alpha=0.5, + ) + .pl.render_labels("blobs_labels") + .pl.show() + ) + + def test_plot_can_render_graph_with_groups_filter(self): + """Graph filtered to show only edges between 'tumor' cells.""" + sdata = _make_sdata_with_graph_on_shapes() + ( + sdata.pl.render_graph( + "my_shapes", + connectivity_key="spatial", + table_name="table", + group_key="cell_type", + groups=["tumor"], + ) + .pl.render_shapes("my_shapes", color="cell_type") + .pl.show() + ) + + def test_plot_can_render_graph_with_auto_discovery(self): + """Auto-discover element and table when unambiguous.""" + sdata = _make_sdata_with_graph_on_shapes() + sdata.pl.render_graph().pl.render_shapes("my_shapes").pl.show() + + def test_render_graph_empty_graph_does_not_error(self): + """An adjacency matrix with no edges should render without error.""" + sdata = _make_sdata_with_graph_on_shapes() + sdata["table"].obsp["spatial_connectivities"] = csr_matrix((20, 20)) + sdata.pl.render_graph("my_shapes", table_name="table").pl.render_shapes("my_shapes").pl.show() + + def test_render_graph_raises_on_missing_obsp_key(self): + sdata = _make_sdata_with_graph_on_shapes() + with pytest.raises(KeyError, match="not found in `table.obsp`"): + sdata.pl.render_graph("my_shapes", connectivity_key="nonexistent", table_name="table") + + def test_render_graph_raises_on_missing_element(self): + sdata = _make_sdata_with_graph_on_shapes() + with pytest.raises(KeyError, match="not found in shapes, points, or labels"): + sdata.pl.render_graph("no_such_element", table_name="table") + + def test_render_graph_raises_on_groups_without_group_key(self): + sdata = _make_sdata_with_graph_on_shapes() + with pytest.raises(ValueError, match="`groups` requires `group_key`"): + sdata.pl.render_graph("my_shapes", table_name="table", groups=["tumor"])