Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,7 +2071,7 @@ def _multiscale_to_spatial_image(
# use scale with highest resolution
optimal_scale = scales[np.argmax(x_dims)]
else:
# ensure that lists are sorted
# sort scales ascending by x resolution
order = np.argsort(x_dims)
scales = [scales[i] for i in order]
x_dims = [x_dims[i] for i in order]
Expand All @@ -2080,17 +2080,13 @@ def _multiscale_to_spatial_image(
optimal_x = width * dpi
optimal_y = height * dpi

# get scale where the dimensions are close to the optimal values
# when possible, pick higher resolution (worst case: downscaled afterwards)
optimal_index_y = np.searchsorted(y_dims, optimal_y)
if optimal_index_y == len(y_dims):
optimal_index_y -= 1
optimal_index_x = np.searchsorted(x_dims, optimal_x)
if optimal_index_x == len(x_dims):
optimal_index_x -= 1

# pick the scale with higher resolution (worst case: downscaled afterwards)
optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))]
# Pick the lowest-resolution scale where both x and y are >= the
# target pixel count. Falls back to highest available resolution.
optimal_scale = scales[-1]
for i, (xd, yd) in enumerate(zip(x_dims, y_dims, strict=True)):
if xd >= optimal_x and yd >= optimal_y:
optimal_scale = scales[i]
break

# NOTE: problematic if there are cases with > 1 data variable
data_var_keys = list(multiscale_image[optimal_scale].data_vars)
Expand Down
79 changes: 79 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,82 @@ def test_utils_get_subplots_produces_correct_axs_layout(input_output):

assert len_axs == len(axs.flatten())
assert axs_visible == [ax.axison for ax in axs.flatten()]


class TestMultiscaleToSpatialImage:
"""Regression tests for #589: multiscale resolution selection."""

@staticmethod
def _make_multiscale(shape, scale_factors):
from spatialdata.models import Image2DModel

rng = np.random.default_rng(42)
return Image2DModel.parse(
rng.normal(size=shape),
scale_factors=scale_factors,
dims=("c", "y", "x"),
c_coords=["r", "g", "b"],
)

def test_larger_figure_never_picks_lower_resolution(self):
"""Increasing figure size must select equal or higher resolution."""
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image

multiscale = self._make_multiscale((3, 1024, 1024), [2, 2])
dpi = 100.0
prev_x = 0
for size in [3, 4, 5, 6, 7, 8, 10, 12]:
result = _multiscale_to_spatial_image(multiscale, dpi, float(size), float(size))
cur_x = result.sizes["x"]
assert cur_x >= prev_x, (
f"figsize {size} selected x={cur_x} which is lower than x={prev_x} from a smaller figure"
)
prev_x = cur_x

def test_asymmetric_image_picks_sufficient_resolution(self):
"""When image aspect ratio differs from figure, both axes must be covered."""
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image

multiscale = self._make_multiscale((3, 400, 1200), [2, 2])
scales_info = {
leaf.name: (multiscale[leaf.name].dims["x"], multiscale[leaf.name].dims["y"]) for leaf in multiscale.leaves
}
max_x = max(x for x, _ in scales_info.values())
max_y = max(y for _, y in scales_info.values())

dpi = 100.0
for w, h in [(5, 5), (3, 10), (10, 3), (7, 4)]:
result = _multiscale_to_spatial_image(multiscale, dpi, float(w), float(h))
sel_x, sel_y = result.sizes["x"], result.sizes["y"]
opt_x, opt_y = w * dpi, h * dpi
assert sel_x >= opt_x or sel_x == max_x, (
f"figsize {w}x{h}: x={sel_x} < optimal {opt_x} and not the maximum available"
)
assert sel_y >= opt_y or sel_y == max_y, (
f"figsize {w}x{h}: y={sel_y} < optimal {opt_y} and not the maximum available"
)

def test_all_scales_too_small_picks_highest_resolution(self):
"""When no scale is large enough, the highest resolution is selected."""
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image

multiscale = self._make_multiscale((3, 64, 64), [2, 2])
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=20.0, height=20.0)
assert result.sizes["x"] == 64

def test_single_scale_level(self):
"""A single-level multiscale image always returns that level."""
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image

multiscale = self._make_multiscale((3, 512, 512), [2])
for size in [2, 5, 10]:
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=float(size), height=float(size))
assert result.sizes["x"] in (512, 256)

def test_exact_match_selects_that_scale(self):
"""When optimal pixels exactly match a scale's dimensions, that scale is selected."""
from spatialdata_plot.pl.utils import _multiscale_to_spatial_image

multiscale = self._make_multiscale((3, 500, 500), [2, 2])
result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=2.5, height=2.5)
assert result.sizes["x"] == 250
Loading