diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1635bdd2a..27aa16acc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: include: - - {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"} + - {os: windows-latest, python: "3.11", dask-version: "2026.3.0", name: "Dask 2026.3.0"} - {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"} - {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"} diff --git a/pyproject.toml b/pyproject.toml index 07ec8140b..04bbb2d44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,8 @@ dependencies = [ "annsel>=0.1.2", "click", "dask-image", - "dask>=2025.12.0,<2026.1.2", - "distributed<2026.1.2", + "dask>=2026.3.0", + "distributed>=2026.3.0", "datashader", "fsspec[s3,http]", "geopandas>=0.14", @@ -50,6 +50,7 @@ dependencies = [ "xarray>=2024.10.0", "xarray-spatial>=0.3.5", "zarr>=3.0.0", + "zarrs", ] [project.optional-dependencies] torch = [ @@ -62,6 +63,9 @@ extra = [ ] [dependency-groups] +sharding = [ + "zarrs", +] dev = [ "bump2version", ] diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 7ba66e710..1bb0483c9 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -131,6 +131,10 @@ "settings", ] +import zarr + +zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) + def __getattr__(name: str) -> Any: if name in _submodules: diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py index a55815655..396c792c9 100644 --- a/src/spatialdata/_core/_utils.py +++ b/src/spatialdata/_core/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Any from anndata import AnnData @@ -164,3 +165,26 @@ def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> setattr(sanitized, attr, new_dict) return None if inplace else sanitized + + +def create_raster_element_kwargs( + raster_write_kwargs: dict[str, dict[str, Any] | list[dict[str, Any]]] | list[dict[str, Any]], + element_name: str, +) -> dict[str, Any] | list[dict[str, Any]]: + + if isinstance(raster_write_kwargs, dict) and (kwargs := raster_write_kwargs.get(element_name)): + element_raster_write_kwargs = kwargs + elif isinstance(raster_write_kwargs, dict) and not all( + isinstance(x, (dict, list)) for x in raster_write_kwargs.values() + ): + element_raster_write_kwargs = raster_write_kwargs + elif isinstance(raster_write_kwargs, list): + if not all(isinstance(x, dict) for x in raster_write_kwargs): + raise ValueError( + "If passing raster_write_kwargs as list, it is assumed to be the storage " + "options for each scale of a multiscale raster as a dictionary." + ) + element_raster_write_kwargs = raster_write_kwargs + else: + raise ValueError(f"Type of raster_write_kwargs should be either dict or list, got {type(raster_write_kwargs)}.") + return element_raster_write_kwargs diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225fe..11e1ed52d 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1108,6 +1108,7 @@ def write( update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, dict[str, Any] | list[dict[str, Any]]] | list[dict[str, Any]] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1155,7 +1156,27 @@ def write( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. - """ + raster_write_kwargs + Storage options for raster elements.These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied globally. + 2. Dictionary per raster element + A dictionary where: + - Keys = names of raster elements + - Values = storage options for each element + - For single-scale data: a dictionary + - For multiscale data: a list of dictionaries (one per scale) + 3. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of a multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array + """ + from spatialdata._core._utils import create_raster_element_kwargs from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats @@ -1173,6 +1194,10 @@ def write( store.close() for element_type, element_name, element in self.gen_elements(): + element_raster_write_kwargs = None + if element_type in ("images", "labels") and raster_write_kwargs: + element_raster_write_kwargs = create_raster_element_kwargs(raster_write_kwargs, element_name) + self._write_element( element=element, zarr_container_path=file_path, @@ -1181,6 +1206,7 @@ def write( overwrite=False, parsed_formats=parsed, shapes_geometry_encoding=shapes_geometry_encoding, + element_raster_write_kwargs=element_raster_write_kwargs, ) if self.path != file_path and update_sdata_path: @@ -1198,6 +1224,7 @@ def _write_element( overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + element_raster_write_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1231,6 +1258,7 @@ def _write_element( group=element_group, name=element_name, element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, ) elif element_type == "labels": write_labels( @@ -1238,6 +1266,7 @@ def _write_element( group=root_group, name=element_name, element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, ) elif element_type == "points": write_points( @@ -1268,6 +1297,9 @@ def write_element( overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, dict[str, Any] | list[dict[str, Any]] | Any] + | list[dict[str, Any]] + | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1286,12 +1318,32 @@ def write_element( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + raster_write_kwargs + Storage options for raster elements.These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied globally. + 2. Dictionary per raster element + A dictionary where: + - Keys = names of raster elements + - Values = storage options for each element + - For single-scale data: a dictionary + - For multiscale data: a list of dictionaries (one per scale) + 3. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of a multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array Notes ----- If you pass a list of names, the elements will be written one by one. If an error occurs during the writing of an element, the writing of the remaining elements will not be attempted. """ + from spatialdata._core._utils import create_raster_element_kwargs from spatialdata._io.format import _parse_formats parsed_formats = _parse_formats(formats=sdata_formats) @@ -1331,6 +1383,10 @@ def write_element( self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) + element_raster_write_kwargs = None + if element_type in ("images", "labels") and raster_write_kwargs: + element_raster_write_kwargs = create_raster_element_kwargs(raster_write_kwargs, element_name) + self._write_element( element=element, zarr_container_path=self.path, @@ -1339,6 +1395,7 @@ def write_element( overwrite=overwrite, parsed_formats=parsed_formats, shapes_geometry_encoding=shapes_geometry_encoding, + element_raster_write_kwargs=element_raster_write_kwargs, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. if self.has_consolidated_metadata(): diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index a8b2ab2ce..4692dcfea 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -148,13 +148,13 @@ def _prepare_storage_options( return None if isinstance(storage_options, dict): prepared = dict(storage_options) - if "chunks" in prepared: + if "chunks" in prepared and prepared["chunks"] is not None: prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) return prepared prepared_options = [dict(options) for options in storage_options] for options in prepared_options: - if "chunks" in options: + if "chunks" in options and options["chunks"] is not None: options["chunks"] = _normalize_explicit_chunks(options["chunks"]) return prepared_options @@ -283,12 +283,27 @@ def _write_raster( raster_format The format used to write the raster data. storage_options - Additional options for writing the raster data, like chunks and compression. + Storage options for raster elements.These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied to the raster, either single or multiscale. + 2. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of the multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array label_metadata Label metadata which can only be defined when writing 'labels'. metadata Additional metadata for the raster element """ + from dataclasses import asdict + + from spatialdata import settings + if raster_type not in ["image", "labels"]: raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in @@ -305,6 +320,24 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + if isinstance(storage_options, dict): + storage_options = { + **{k.split("_")[1]: v for k, v in asdict(settings).items() if k in ("raster_chunks", "raster_shards")}, + **storage_options, + } + elif isinstance(storage_options, list): + storage_options = [ + { + **{k.split("_")[1]: v for k, v in asdict(settings).items() if k in ("raster_chunks", "raster_shards")}, + **x, + } + for x in storage_options + ] + elif not storage_options: + storage_options = { + k.split("_")[1]: v for k, v in asdict(settings).items() if k in ("raster_chunks", "raster_shards") + } + if isinstance(raster_data, DataArray): _write_raster_dataarray( raster_type, diff --git a/src/spatialdata/config.py b/src/spatialdata/config.py index 35b96e5f7..fff2b045f 100644 --- a/src/spatialdata/config.py +++ b/src/spatialdata/config.py @@ -1,8 +1,18 @@ from __future__ import annotations -from dataclasses import dataclass +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path from typing import Literal +from platformdirs import user_config_dir + + +def _config_path() -> Path: + """Return the platform-appropriate path to the user config file.""" + return Path(user_config_dir(appname="spatialdata")) / "settings.json" + @dataclass class Settings: @@ -10,6 +20,8 @@ class Settings: Attributes ---------- + custom_config_path + The path specified by the user of where to store the settings. shapes_geometry_encoding Default geometry encoding for GeoParquet files when writing shapes. Can be "WKB" (Well-Known Binary) or "geoarrow". @@ -18,13 +30,161 @@ class Settings: Chunk sizes bigger than this value (bytes) can trigger a compression error. See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 If detected during parsing/validation, a warning is raised. + raster_chunks + The chunksize to use for chunking an array. Length of the tuple must match + the number of dimensions. + raster_shards + The default shard size (zarr v3) to use when storing arrays. Length of the tuple + must match the number of dimensions. """ + custom_config_path: Path | None = None shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB" large_chunk_threshold_bytes: int = 2147483647 + raster_chunks: tuple[int, ...] | None = None + raster_shards: tuple[int, ...] | None = None + + def save(self, path: Path | str | None = None) -> None: + """Store current settings on disk. + + If Path is specified, it will store the config settings to this location. Otherwise, stores + the config in the default config directory for the given operating system. + + Parameters + ---------- + path + The path to use for storing settings if different from default. Must be + a json file. This will be stored in the global config as the custom_config_path. + + Returns + ------- + Path + The path the settings were written to. + """ + target = Path(path) if path else _config_path() + + if not str(target).endswith(".json"): + raise ValueError("Path must end with .json") + + if path is not None: + data = asdict(self) + data["custom_config_path"] = str(target) + with target.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + global_path = _config_path() + global_path.parent.mkdir(parents=True, exist_ok=True) + try: + with global_path.open(encoding="utf-8") as f: + global_data = json.load(f) + except (json.JSONDecodeError, OSError): + global_data = {} + global_data["custom_config_path"] = str(target) + with global_path.open("w", encoding="utf-8") as f: + json.dump(global_data, f, indent=2) + else: + target.parent.mkdir(parents=True, exist_ok=True) + data = asdict(self) + data["custom_config_path"] = str(data["custom_config_path"]) if data["custom_config_path"] else None + with target.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + @classmethod + def load(cls, path: Path | str | None = None) -> Settings: + """Load settings from disk. + + This method falls back to default settings if either there is no config at the + given path or there is a decoding error. Unknown or renamed keys in the file + are silently ignored, e.g. old config files will not cause errors. + + Parameters + ---------- + path + The path to the config file if different from default. If not specified, + the default location is used. + + Returns + ------- + Settings + A populated Settings instance. + """ + target = Path(path) if path else _config_path() + + if not target.exists(): + instance = cls() + instance.apply_env() + return instance + + try: + with target.open(encoding="utf-8") as f: + data = json.load(f) + except (json.JSONDecodeError, OSError): + instance = cls() + instance.apply_env() + return instance + + # This prevents fields from old config files to be used. + known_fields = {k: v for k, v in data.items() if k in cls.__dataclass_fields__} + instance = cls(**known_fields) + instance.apply_env() + return instance + + def reset(self) -> None: + """Inplace reset all settings to their built-in defaults (in memory only). + + Call 'save' method afterwards if you want the reset to be persisted. + """ + defaults = Settings() + for field_name in self.__dataclass_fields__: + setattr(self, field_name, getattr(defaults, field_name)) + + def apply_env(self) -> None: + """Apply environment variable overrides on top of the current state. + + Env vars take precedence over both the config file and any + in-session assignments. Useful in CI pipelines or HPC clusters + where you cannot edit the config file. + + Supported variables + ------------------- + SPATIALDATA_CUSTOM_CONFIG_PATH -> custom_config_path + SPATIALDATA_SHAPES_GEOMETRY_ENCODING → shapes_geometry_encoding + SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES → large_chunk_threshold_bytes + SPATIALDATA_RASTER_CHUNKS → chunks + SPATIALDATA_RASTER_SHARDS → shards (integer or "none") + """ + _ENV: dict[str, tuple[str, type]] = { + "SPATIALDATA_CUSTOM_CONFIG_PATH": ("custom_config_path", Path), + "SPATIALDATA_SHAPES_GEOMETRY_ENCODING": ("shapes_geometry_encoding", str), + "SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES": ("large_chunk_threshold_bytes", int), + "SPATIALDATA_RASTER_CHUNKS": ("raster_chunks", str), + "SPATIALDATA_RASTER_SHARDS": ("raster_shards", str), # handled specially below + } + for env_key, (field_name, cast) in _ENV.items(): + raw = os.environ.get(env_key) + if raw is None: + continue + if field_name in ("raster_chunks", "raster_shards"): + setattr( + self, + field_name, + None if raw.lower() in ("none", "") else tuple(int(v) for v in raw.split(",")), + ) + else: + setattr(self, field_name, cast(raw)) + + def __repr__(self) -> str: + fields = ", ".join(f"{k}={v!r}" for k, v in asdict(self).items()) + return f"Settings({fields})" + + @staticmethod + def config_path() -> Path: + """Return platform-specific path where settings are stored.""" + return _config_path() + -settings = Settings() +settings = Settings.load() # Backwards compatibility alias LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes diff --git a/tests/conftest.py b/tests/conftest.py index c97939129..5a73b5b35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -642,3 +642,14 @@ def complex_sdata() -> SpatialData: sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) return sdata + + +@pytest.fixture() +def settings_cls(tmp_path, monkeypatch): + """ + Provide setting class with default path redirected. + """ + from spatialdata.config import Settings + + monkeypatch.setattr("spatialdata.config._config_path", lambda: tmp_path / "default_settings.json") + return Settings diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 209a43046..d16ba2648 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -34,7 +34,7 @@ ) from spatialdata._io.io_raster import write_image from spatialdata.datasets import blobs -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Labels2DModel from spatialdata.models._utils import get_channel_names from spatialdata.testing import assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import ( @@ -53,6 +53,27 @@ RNG = default_rng(0) SDATA_FORMATS = list(SpatialDataContainerFormats.values()) +RASTER_CASES = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 800, 1000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (800, 1000), "zarr_subpath": "labels"}, + id="label", + ), +] + +RASTER_CASES_MULTISCALE = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 1600, 2000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (1600, 2000), "zarr_subpath": "labels"}, + id="label", + ), +] @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) @@ -743,6 +764,192 @@ def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: assert list(image_group.keys()) == ["s0"] +@pytest.mark.parametrize("raster_case", RASTER_CASES) +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_raster_sharding( + tmp_path: Path, + raster_case: dict, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +def test_write_raster_sharding_with_settings(tmp_path: Path) -> None: + from dataclasses import replace + + from spatialdata import settings + + old_settings = replace(settings) + settings.raster_chunks = (1, 100, 100) + settings.save() + + data = da.from_array(RNG.random((1, 1000, 1000)), chunks=(1, 200, 200)) + element = Image2DModel.parse(data, dims=("c", "y", "x")) + name = "element" + sdata = SpatialData(images={name: element}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + ) + arr = zarr.open_group(path / "images" / name, mode="r")["s0"] + assert arr.chunks == (1, 100, 100) + old_settings.save() + s = settings.load() + assert s.raster_chunks == old_settings.raster_chunks + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write(path, raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + for scale in ("s0", "s1"): + arr = group[scale] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_scale_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks_s0 = (1, 50, 100) if len(dims) == 3 else (50, 100) + shards_s0 = (1, 100, 200) if len(dims) == 3 else (100, 200) + chunks_s1 = (1, 25, 50) if len(dims) == 3 else (25, 50) + shards_s1 = (1, 50, 100) if len(dims) == 3 else (50, 100) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs=[ + {"chunks": chunks_s0, "shards": shards_s0}, + {"chunks": chunks_s1, "shards": shards_s1}, + ], + ) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + assert group["s0"].chunks == chunks_s0 + assert group["s0"].shards == shards_s0 + assert group["s1"].chunks == chunks_s1 + assert group["s1"].shards == shards_s1 + + +@pytest.mark.parametrize("raster_case", RASTER_CASES) +def test_write_raster_sharding_keyword(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims) + other = model.parse(data.copy(), dims=dims) + name, other_name = "element", "other_element" + sdata = SpatialData(**{zarr_subpath: {name: element, other_name: other}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs={name: {"chunks": write_chunks, "shards": write_shards}}, + ) + + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + other_arr = zarr.open_group(path / zarr_subpath / other_name, mode="r")["s0"] + assert other_arr.chunks == base_chunks + + +def test_write_raster_elements_sharding_chunking(tmp_path: Path) -> None: + write_chunks = (1, 50, 100) + write_shards = (1, 100, 200) + + data = da.from_array(RNG.random((1, 500, 600))) + element = Image2DModel.parse(data, dims=("c", "y", "x")) + + sdata = SpatialData() + path = tmp_path / "data.zarr" + + sdata.write(path) + sdata["image"] = element + sdata["other_image"] = element + + sdata.write_element( + element_name=["image", "other_image"], raster_write_kwargs={"chunks": write_chunks, "shards": write_shards} + ) + + arr = zarr.open_group(path / "images" / "image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + arr = zarr.open_group(path / "images" / "other_image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py new file mode 100644 index 000000000..07c4ab76d --- /dev/null +++ b/tests/utils/test_config.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import os +from pathlib import Path + + +def _config_path_for(tmp_path: Path) -> Path: + return tmp_path / "settings.json" + + +class TestDefaults: + def test_default_settings(self, settings_cls): + s = settings_cls() + assert s.shapes_geometry_encoding == "WKB" + assert s.large_chunk_threshold_bytes == 2_147_483_647 + assert s.raster_chunks is None + assert s.raster_shards is None + assert s.custom_config_path is None + + def test_change_settings_default_path(self, settings_cls): + s = settings_cls() + s.shapes_geometry_encoding = "geoarrow" + s.large_chunk_threshold_bytes = 1_000_000_000 + s.raster_chunks = (512, 512) + s.raster_shards = (1024, 1024) + s.save() + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + assert s.large_chunk_threshold_bytes == 1_000_000_000 + assert s.raster_chunks == [512, 512] + assert s.raster_shards == [1024, 1024] + assert s.custom_config_path is None + + def test_change_settings_custom_path(self, settings_cls, tmp_path): + os.environ["SPATIALDATA_SHAPES_GEOMETRY_ENCODING"] = "geoarrow" + os.environ["SPATIALDATA_RASTER_CHUNKS"] = "40,40,40" + + target_path = tmp_path / "custom_settings.json" + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + assert s.raster_chunks == (40, 40, 40) + + # We set the value also using environment variables to test whether these properly overwrite + s.large_chunk_threshold_bytes = 1_000_000_000 + os.environ["SPATIALDATA_LARGE_CHUNK_THRESHOLD_BYTES"] = "1_111_111_111" + + s.raster_chunks = (512, 512) + s.raster_shards = (1024, 1024) + s.save(path=target_path) + s = settings_cls().load() + assert s.shapes_geometry_encoding == "geoarrow" + assert s.large_chunk_threshold_bytes == 1_111_111_111 + assert s.raster_chunks == (40, 40, 40) + assert s.raster_shards is None + assert s.custom_config_path == str(target_path) + + s.reset() + s.save() + assert s.custom_config_path is None # This returns False + s = settings_cls().load() + assert s.custom_config_path is None