diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4c7a48d7..d748a85a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2071,7 +2071,7 @@ def _multiscale_to_spatial_image( # use scale with highest resolution optimal_scale = scales[np.argmax(x_dims)] else: - # ensure that lists are sorted + # sort scales ascending by x resolution order = np.argsort(x_dims) scales = [scales[i] for i in order] x_dims = [x_dims[i] for i in order] @@ -2080,17 +2080,13 @@ def _multiscale_to_spatial_image( optimal_x = width * dpi optimal_y = height * dpi - # get scale where the dimensions are close to the optimal values - # when possible, pick higher resolution (worst case: downscaled afterwards) - optimal_index_y = np.searchsorted(y_dims, optimal_y) - if optimal_index_y == len(y_dims): - optimal_index_y -= 1 - optimal_index_x = np.searchsorted(x_dims, optimal_x) - if optimal_index_x == len(x_dims): - optimal_index_x -= 1 - - # pick the scale with higher resolution (worst case: downscaled afterwards) - optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))] + # Pick the lowest-resolution scale where both x and y are >= the + # target pixel count. Falls back to highest available resolution. + optimal_scale = scales[-1] + for i, (xd, yd) in enumerate(zip(x_dims, y_dims, strict=True)): + if xd >= optimal_x and yd >= optimal_y: + optimal_scale = scales[i] + break # NOTE: problematic if there are cases with > 1 data variable data_var_keys = list(multiscale_image[optimal_scale].data_vars) diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 699104f7..a456d765 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -283,3 +283,82 @@ def test_utils_get_subplots_produces_correct_axs_layout(input_output): assert len_axs == len(axs.flatten()) assert axs_visible == [ax.axison for ax in axs.flatten()] + + +class TestMultiscaleToSpatialImage: + """Regression tests for #589: multiscale resolution selection.""" + + @staticmethod + def _make_multiscale(shape, scale_factors): + from spatialdata.models import Image2DModel + + rng = np.random.default_rng(42) + return Image2DModel.parse( + rng.normal(size=shape), + scale_factors=scale_factors, + dims=("c", "y", "x"), + c_coords=["r", "g", "b"], + ) + + def test_larger_figure_never_picks_lower_resolution(self): + """Increasing figure size must select equal or higher resolution.""" + from spatialdata_plot.pl.utils import _multiscale_to_spatial_image + + multiscale = self._make_multiscale((3, 1024, 1024), [2, 2]) + dpi = 100.0 + prev_x = 0 + for size in [3, 4, 5, 6, 7, 8, 10, 12]: + result = _multiscale_to_spatial_image(multiscale, dpi, float(size), float(size)) + cur_x = result.sizes["x"] + assert cur_x >= prev_x, ( + f"figsize {size} selected x={cur_x} which is lower than x={prev_x} from a smaller figure" + ) + prev_x = cur_x + + def test_asymmetric_image_picks_sufficient_resolution(self): + """When image aspect ratio differs from figure, both axes must be covered.""" + from spatialdata_plot.pl.utils import _multiscale_to_spatial_image + + multiscale = self._make_multiscale((3, 400, 1200), [2, 2]) + scales_info = { + leaf.name: (multiscale[leaf.name].dims["x"], multiscale[leaf.name].dims["y"]) for leaf in multiscale.leaves + } + max_x = max(x for x, _ in scales_info.values()) + max_y = max(y for _, y in scales_info.values()) + + dpi = 100.0 + for w, h in [(5, 5), (3, 10), (10, 3), (7, 4)]: + result = _multiscale_to_spatial_image(multiscale, dpi, float(w), float(h)) + sel_x, sel_y = result.sizes["x"], result.sizes["y"] + opt_x, opt_y = w * dpi, h * dpi + assert sel_x >= opt_x or sel_x == max_x, ( + f"figsize {w}x{h}: x={sel_x} < optimal {opt_x} and not the maximum available" + ) + assert sel_y >= opt_y or sel_y == max_y, ( + f"figsize {w}x{h}: y={sel_y} < optimal {opt_y} and not the maximum available" + ) + + def test_all_scales_too_small_picks_highest_resolution(self): + """When no scale is large enough, the highest resolution is selected.""" + from spatialdata_plot.pl.utils import _multiscale_to_spatial_image + + multiscale = self._make_multiscale((3, 64, 64), [2, 2]) + result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=20.0, height=20.0) + assert result.sizes["x"] == 64 + + def test_single_scale_level(self): + """A single-level multiscale image always returns that level.""" + from spatialdata_plot.pl.utils import _multiscale_to_spatial_image + + multiscale = self._make_multiscale((3, 512, 512), [2]) + for size in [2, 5, 10]: + result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=float(size), height=float(size)) + assert result.sizes["x"] in (512, 256) + + def test_exact_match_selects_that_scale(self): + """When optimal pixels exactly match a scale's dimensions, that scale is selected.""" + from spatialdata_plot.pl.utils import _multiscale_to_spatial_image + + multiscale = self._make_multiscale((3, 500, 500), [2, 2]) + result = _multiscale_to_spatial_image(multiscale, dpi=100.0, width=2.5, height=2.5) + assert result.sizes["x"] == 250