Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from spatialdata import get_extent
from spatialdata._utils import _deprecation_alias
from spatialdata.transformations.operations import get_transformation
from xarray import DataArray, DataTree

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_graph,
_render_images,
_render_labels,
_render_points,
Expand All @@ -44,6 +46,7 @@
ChannelLegendEntry,
CmapParams,
ColorbarSpec,
GraphRenderParams,
ImageRenderParams,
LabelsRenderParams,
LegendParams,
Expand All @@ -63,6 +66,7 @@
_prepare_cmap_norm,
_prepare_params_plot,
_set_outline,
_validate_graph_render_params,
_validate_image_render_params,
_validate_label_render_params,
_validate_points_render_params,
Expand Down Expand Up @@ -861,6 +865,79 @@ def render_labels(
n_steps += 1
return sdata

def render_graph(
self,
element: str | None = None,
color: ColorLike | None = "grey",
*,
connectivity_key: str = "spatial",
groups: list[str] | str | None = None,
group_key: str | None = None,
edge_width: float = 1.0,
edge_alpha: float = 1.0,
table_name: str | None = None,
) -> sd.SpatialData:
"""Render spatial graph edges between observations.

Draws edges from a connectivity matrix stored in a table's ``obsp``,
using centroid coordinates of the linked spatial element.

Parameters
----------
element : str | None, optional
Name of the spatial element (shapes, points, or labels) whose
observations the graph connects. Auto-resolved from the table
if not given.
color : ColorLike | None, default "grey"
Edge color as a color-like value (e.g. ``"red"``, ``"#aabbcc"``).
connectivity_key : str, default "spatial"
Key prefix in ``table.obsp``. Tries ``obsp[key]`` first, then
``obsp[f"{key}_connectivities"]``.
groups : list[str] | str | None, optional
Show only edges where **both** endpoints belong to the specified
groups. Requires ``group_key``.
group_key : str | None, optional
Column in ``table.obs`` used for group filtering.
edge_width : float, default 1.0
Line width for edges.
edge_alpha : float, default 1.0
Transparency for edges (0 = invisible, 1 = opaque).
table_name : str | None, optional
Table containing the graph. Auto-discovered if not given.

Returns
-------
sd.SpatialData
Copy with rendering parameters stored in the plotting tree.
"""
params = _validate_graph_render_params(
self._sdata,
element=element,
connectivity_key=connectivity_key,
table_name=table_name,
color=color,
edge_width=edge_width,
edge_alpha=edge_alpha,
groups=groups,
group_key=group_key,
)

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams(
element=params["element"],
connectivity_key=params["obsp_key"],
table_name=params["table_name"],
color=params["color"],
groups=params["groups"],
group_key=params["group_key"],
edge_width=params["edge_width"],
edge_alpha=params["edge_alpha"],
zorder=n_steps,
)
return sdata

def show(
self,
coordinate_systems: list[str] | str | None = None,
Expand Down Expand Up @@ -1001,6 +1078,7 @@ def show(
"render_shapes",
"render_labels",
"render_points",
"render_graph",
]

# prepare rendering params
Expand Down Expand Up @@ -1311,6 +1389,19 @@ def _draw_colorbar(
rasterize=rasterize,
)

elif cmd == "render_graph":
graph_element = params_copy.element
element_in_cs = graph_element in sdata and cs in set(
get_transformation(sdata[graph_element], get_all=True).keys()
)
if element_in_cs:
_render_graph(
sdata=sdata,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
)

if title is None:
t = cs
elif len(title) == 1:
Expand Down
95 changes: 95 additions & 0 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Color,
ColorbarSpec,
FigParams,
GraphRenderParams,
ImageRenderParams,
LabelsRenderParams,
LegendParams,
Expand Down Expand Up @@ -1815,3 +1816,97 @@ def _draw_labels(
scalebar_units=scalebar_params.scalebar_units,
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
)


def _render_graph(
sdata: sd.SpatialData,
render_params: GraphRenderParams,
coordinate_system: str,
ax: matplotlib.axes.SubplotBase,
) -> None:
"""Render spatial graph edges as a LineCollection on the given axes."""
from matplotlib.collections import LineCollection
from scipy.sparse import triu

_log_context.set("render_graph")
element_name = render_params.element
table_name = render_params.table_name

# Get table and adjacency matrix
table = sdata[table_name]
obsp_key = render_params.connectivity_key
if obsp_key not in table.obsp:
logger.warning(f"Connectivity key '{obsp_key}' not found in table obsp. Skipping graph rendering.")
return

adjacency = table.obsp[obsp_key]

# Get the spatial element
if element_name in sdata.shapes:
element = sdata.shapes[element_name]
elif element_name in sdata.points:
element = sdata.points[element_name]
elif element_name in sdata.labels:
element = sdata.labels[element_name]
else:
logger.warning(f"Element '{element_name}' not found in sdata. Skipping graph rendering.")
return

# Get centroids in the target coordinate system
centroids_df = sd.get_centroids(element, coordinate_system=coordinate_system)
if hasattr(centroids_df, "compute"):
centroids_df = centroids_df.compute()

centroid_coords = np.column_stack([centroids_df["x"].values, centroids_df["y"].values])

# Align table observations to centroid positions via instance_key.
# Build a coordinate array indexed by full-table row so edge lookups are O(1).
_, region_key, instance_key = get_table_keys(table)

element_mask = table.obs[region_key] == element_name if region_key is not None else np.ones(table.n_obs, dtype=bool)
instance_ids = table.obs[instance_key].values[element_mask]
table_subset_indices = np.where(element_mask)[0]

centroid_ids = centroids_df.index.values if hasattr(centroids_df, "index") else np.arange(len(centroids_df))
id_to_centroid_row = {cid: row for row, cid in enumerate(centroid_ids)}

# has_coord[i] is True if table row i has a valid centroid
has_coord = np.zeros(table.n_obs, dtype=bool)
coord_lookup = np.full((table.n_obs, 2), np.nan)
for table_row, iid in zip(table_subset_indices, instance_ids, strict=True):
if iid in id_to_centroid_row:
has_coord[table_row] = True
coord_lookup[table_row] = centroid_coords[id_to_centroid_row[iid]]

# Apply group filtering: narrow has_coord to only rows in requested groups
groups = render_params.groups
group_key = render_params.group_key
if groups is not None and group_key is not None:
group_values = table.obs[group_key].values
in_groups = np.isin(group_values, groups)
has_coord &= in_groups

# Extract edges from upper triangle (undirected — draw each edge once, skip self-loops)
adj_upper = triu(adjacency, k=1)
rows, cols = adj_upper.nonzero()

# Vectorized filter: keep edges where both endpoints are valid
edge_mask = has_coord[rows] & has_coord[cols]
if not edge_mask.any():
return

valid_rows = rows[edge_mask]
valid_cols = cols[edge_mask]
segments = np.stack([coord_lookup[valid_rows], coord_lookup[valid_cols]], axis=1)

edge_color = render_params.color.get_hex() if render_params.color is not None else "#808080"

lc = LineCollection(
segments,
linewidths=render_params.edge_width,
colors=edge_color,
alpha=render_params.edge_alpha,
zorder=render_params.zorder,
)
lc.set_rasterized(True)
ax.add_collection(lc)
15 changes: 15 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,18 @@ class LabelsRenderParams:
zorder: int = 0
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None


@dataclass
class GraphRenderParams:
"""Graph render parameters."""

element: str
connectivity_key: str = "spatial"
table_name: str | None = None
color: Color | None = None
groups: list[str] | str | None = None
group_key: str | None = None
edge_width: float = 1.0
edge_alpha: float = 1.0
zorder: int = 0
Loading
Loading