diff --git a/changes/4001.misc.md b/changes/4001.misc.md new file mode 100644 index 0000000000..adbba988d9 --- /dev/null +++ b/changes/4001.misc.md @@ -0,0 +1,6 @@ +Restore sharding write performance for shards with many inner chunks. The +`subchunk_write_order` feature inadvertently rebuilt the per-shard chunk +coordinate grid (up to tens of thousands of coordinate tuples) on every shard +write. These coordinates are now computed once per shard shape and cached, so +repeated writes to same-shaped shards reuse them, restoring write throughput to +its previous level. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 33c8602ecb..d555f0d26d 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -46,9 +46,11 @@ BasicIndexer, ChunkProjection, SelectorTuple, - c_order_iter, + _lexicographic_order, + colexicographic_order_coords, get_indexer, - morton_order_iter, + lexicographic_order_coords, + morton_order_coords, ) from zarr.core.metadata.v3 import ( ChunkGridMetadata, @@ -261,31 +263,42 @@ def __len__(self) -> int: return int(self.index.offsets_and_lengths.size / 2) def __iter__(self) -> Iterator[tuple[int, ...]]: - return c_order_iter(self.index.chunks_per_shard) + return iter(lexicographic_order_coords(self.index.chunks_per_shard)) - def to_dict_vectorized( - self, - chunk_coords_array: npt.NDArray[np.integer[Any]], - ) -> dict[tuple[int, ...], Buffer | None]: + def to_dict_vectorized(self) -> dict[tuple[int, ...], Buffer | None]: """Build a dict of chunk coordinates to buffers using vectorized lookup. - Parameters - ---------- - chunk_coords_array : ndarray of shape (n_chunks, n_dims) - Array of chunk coordinates for vectorized index lookup. + The full per-shard chunk coordinate grid (both the array used for the + vectorized index lookup and the plain tuples used as dict keys) is + cached on `chunks_per_shard`, so neither is rebuilt on every call. For a + shard with tens of thousands of chunks this avoids reconstructing that + many tuples on every partial write. Returns ------- dict mapping chunk coordinate tuples to Buffer or None """ + chunks_per_shard = self.index.chunks_per_shard + # The same chunk-grid coordinates are needed in two forms, and neither can + # stand in for the other: + # - `chunk_coords_array`: an (n_chunks, n_dims) numpy array, fed to the + # vectorized index lookup, which does modulo + advanced indexing on it. + # A list of tuples can't be used for that without first being arrayified. + # - `chunk_coords_keys`: the same coordinates as hashable Python tuples, + # used as the result dict's keys. numpy array rows are unhashable + # (mutable), so they can't key a dict. + # Both are cached per shape (see indexing.py), so neither is rebuilt here; + # row i of the array and key i refer to the same chunk. + chunk_coords_array = _lexicographic_order(chunks_per_shard) + chunk_coords_keys = lexicographic_order_coords(chunks_per_shard) starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array) result: dict[tuple[int, ...], Buffer | None] = {} - for i, coords in enumerate(chunk_coords_array): + for i, coords in enumerate(chunk_coords_keys): if valid[i]: - result[tuple(coords.ravel())] = self.buf[int(starts[i]) : int(ends[i])] + result[coords] = self.buf[int(starts[i]) : int(ends[i])] else: - result[tuple(coords.ravel())] = None + result[coords] = None return result @@ -529,13 +542,14 @@ async def _decode_partial_single( def _subchunk_order_iter( self, chunks_per_shard: tuple[int, ...], subchunk_write_order: SubchunkWriteOrder ) -> Iterable[tuple[int, ...]]: + subchunk_iter: Iterable[tuple[int, ...]] match subchunk_write_order: case "morton": - subchunk_iter = morton_order_iter(chunks_per_shard) + subchunk_iter = morton_order_coords(chunks_per_shard) case "lexicographic": - subchunk_iter = np.ndindex(chunks_per_shard) + subchunk_iter = lexicographic_order_coords(chunks_per_shard) case "colexicographic": - subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1])) + subchunk_iter = colexicographic_order_coords(chunks_per_shard) case "unordered": subchunk_list = list(np.ndindex(chunks_per_shard)) (self.rng if self.rng is not None else np.random.default_rng()).shuffle( @@ -561,7 +575,7 @@ async def _encode_single( chunk_grid=ChunkGrid.from_sizes(shard_shape, chunk_shape), ) ) - shard_builder = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard, "lexicographic")) + shard_builder = dict.fromkeys(lexicographic_order_coords(chunks_per_shard)) await self.codec_pipeline.write( [ @@ -604,7 +618,7 @@ async def _encode_partial_single( ) if self._is_complete_shard_write(indexer, chunks_per_shard): - shard_dict = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard, "lexicographic")) + shard_dict = dict.fromkeys(lexicographic_order_coords(chunks_per_shard)) else: shard_reader = await self._load_full_shard_maybe( byte_getter=byte_setter, @@ -612,10 +626,10 @@ async def _encode_partial_single( chunks_per_shard=chunks_per_shard, ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized( - np.array(list(self._subchunk_order_iter(chunks_per_shard, "lexicographic"))) - ) + # Use vectorized lookup for better performance. The lexicographic + # coordinate array and keys are cached, so neither is rebuilt on + # every write. + shard_dict = shard_reader.to_dict_vectorized() await self.codec_pipeline.write( [ @@ -685,9 +699,13 @@ async def _encode_shard_dict( def _is_total_shard( self, all_chunk_coords: set[tuple[int, ...]], chunks_per_shard: tuple[int, ...] ) -> bool: - return len(all_chunk_coords) == product(chunks_per_shard) and all( - chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) - ) + # `all_chunk_coords` comes from an indexer over this shard's chunk grid, so + # it is always a subset of that grid (`validate` requires the shard shape to + # be divisible by the inner chunk shape, so the indexer cannot produce an + # out-of-grid coordinate). A subset whose size equals the grid's is the + # whole grid, so the count check alone proves totality — no need to build + # and membership-test the full coordinate set on this hot path. + return len(all_chunk_coords) == product(chunks_per_shard) def _is_complete_shard_write( self, diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index cb81164209..d205d49a11 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1521,19 +1521,19 @@ def decode_morton_vectorized( @lru_cache(maxsize=16) -def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: - n_total = product(chunk_shape) - n_dims = len(chunk_shape) +def _morton_order(shape: tuple[int, ...]) -> npt.NDArray[np.intp]: + n_total = product(shape) + n_dims = len(shape) if n_total == 0: out = np.empty((0, n_dims), dtype=np.intp) out.flags.writeable = False return out # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span - # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number + # all valid coordinates in shape. (c-1).bit_length() gives the number # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits # is the size of this hypercube. - total_bits = sum((c - 1).bit_length() for c in chunk_shape) + total_bits = sum((c - 1).bit_length() for c in shape) n_z = 1 << total_bits if total_bits > 0 else 1 # Decode all Morton codes in the ceiling hypercube, then filter to valid coords. @@ -1544,8 +1544,8 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds. # Works well when the overgeneration ratio n_z/n_total is small (≤4). z_values = np.arange(n_z, dtype=np.intp) - all_coords = decode_morton_vectorized(z_values, chunk_shape) - shape_arr = np.array(chunk_shape, dtype=np.intp) + all_coords = decode_morton_vectorized(z_values, shape) + shape_arr = np.array(shape, dtype=np.intp) valid_mask = np.all(all_coords < shape_arr, axis=1) order = all_coords[valid_mask] else: @@ -1554,11 +1554,11 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: # larger overgeneration penalty for near-miss shapes like (33,33,33). # Cost: O(n_total * bits) encode + O(n_total log n_total) sort, # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling. - grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij") + grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in shape], indexing="ij") all_coords = np.stack([g.ravel() for g in grids], axis=1) # Encode all coordinates to Morton codes (vectorized). - bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape) + bits_per_dim = tuple((c - 1).bit_length() for c in shape) max_coord_bits = max(bits_per_dim) z_codes = np.zeros(n_total, dtype=np.intp) output_bit = 0 @@ -1576,16 +1576,56 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: @lru_cache(maxsize=16) -def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: - return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape)) +def morton_order_coords(shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + # The grid coordinates in Morton (Z) order, as a cached sequence. The + # coordinate set of a finite grid has a known length and is reused in full on + # every shard write, so it is built once (vectorized, via `_morton_order`) and + # cached per shape rather than recomputed. Indexable and `len`-able; iterate it + # directly where an iterator is needed. + # + # `.tolist()` converts the whole array to native Python ints in one C-level + # call; building the tuples row-by-row with `int(x)` is ~9x slower. + return tuple(map(tuple, _morton_order(shape).tolist())) -def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - return iter(_morton_order_keys(tuple(chunk_shape))) +@lru_cache(maxsize=16) +def _lexicographic_order(shape: tuple[int, ...]) -> npt.NDArray[np.intp]: + # Lexicographic (C-order) coordinates, computed vectorized and cached so that + # the sharding codec's per-shard chunk grid is not rebuilt on every call. + # Equivalent to `np.array(list(np.ndindex(shape)))` but without the + # Python-level iteration over every coordinate. + n_dims = len(shape) + if n_dims == 0: + # A 0-d shard holds a single chunk addressed by the empty coordinate, so + # the coordinate array has one row and zero columns. np.indices(()) cannot + # express this, so build it directly. Matches list(np.ndindex(())) == [()]. + order = np.empty((1, 0), dtype=np.intp) + else: + order = np.indices(shape, dtype=np.intp).reshape(n_dims, -1).T + order.flags.writeable = False + return order -def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: - return itertools.product(*(range(x) for x in chunks_per_shard)) +@lru_cache(maxsize=16) +def lexicographic_order_coords(shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + # The grid coordinates in lexicographic (row-major / C) order, as a cached + # sequence. The coordinate set of a finite grid has a known length and is + # reused in full on every shard write, so it is built once (vectorized, via + # `_lexicographic_order`) and cached per shape. Indexable and `len`-able; + # iterate it directly where an iterator is needed. + # + # `.tolist()` converts the whole array to native Python ints in one C-level + # call; building the tuples row-by-row with `int(x)` is ~9x slower. + return tuple(map(tuple, _lexicographic_order(shape).tolist())) + + +@lru_cache(maxsize=16) +def colexicographic_order_coords(shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + # The grid coordinates in colexicographic (column-major / F) order, as a cached + # sequence: the first axis varies fastest. Equivalent to reversing each axis, + # taking lexicographic order, and reversing the coordinates back. Cached per + # shape like its siblings so shard writes don't rebuild it. + return tuple(c[::-1] for c in lexicographic_order_coords(shape[::-1])) def get_indexer( diff --git a/tests/benchmarks/test_indexing.py b/tests/benchmarks/test_indexing.py index 385a85b5b5..c9b80f9ff6 100644 --- a/tests/benchmarks/test_indexing.py +++ b/tests/benchmarks/test_indexing.py @@ -74,7 +74,7 @@ def test_sharded_morton_indexing( The Morton order cache is cleared before each iteration to measure the full computation cost. """ - from zarr.core.indexing import _morton_order, _morton_order_keys + from zarr.core.indexing import _morton_order, morton_order_coords # Create array where each shard contains many small chunks # e.g., shards=(32,32,32) with chunks=(2,2,2) means 16x16x16 = 4096 chunks per shard @@ -98,7 +98,7 @@ def test_sharded_morton_indexing( def read_with_cache_clear() -> None: _morton_order.cache_clear() - _morton_order_keys.cache_clear() + morton_order_coords.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) @@ -126,7 +126,7 @@ def test_sharded_morton_indexing_large( the Morton order computation a more significant portion of total time. The Morton order cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order, _morton_order_keys + from zarr.core.indexing import _morton_order, morton_order_coords # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -149,7 +149,7 @@ def test_sharded_morton_indexing_large( def read_with_cache_clear() -> None: _morton_order.cache_clear() - _morton_order_keys.cache_clear() + morton_order_coords.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) @@ -169,7 +169,7 @@ def test_sharded_morton_single_chunk( computing the full Morton order, making the optimization impact clear. The Morton order cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order, _morton_order_keys + from zarr.core.indexing import _morton_order, morton_order_coords # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -192,13 +192,13 @@ def test_sharded_morton_single_chunk( def read_with_cache_clear() -> None: _morton_order.cache_clear() - _morton_order_keys.cache_clear() + morton_order_coords.cache_clear() getitem(data, indexer) benchmark(read_with_cache_clear) -# Benchmark for morton_order_iter directly (no I/O) +# Benchmark for morton_order_coords directly (no I/O) morton_iter_shapes = ( (8, 8, 8), # 512 elements (power-of-2) (10, 10, 10), # 1000 elements (non-power-of-2) @@ -211,23 +211,23 @@ def read_with_cache_clear() -> None: @pytest.mark.parametrize("shape", morton_iter_shapes, ids=str) -def test_morton_order_iter( +def test_morton_order( shape: tuple[int, ...], benchmark: BenchmarkFixture, ) -> None: - """Benchmark morton_order_iter directly without I/O. + """Benchmark morton_order_coords directly without I/O. This isolates the Morton order computation to measure the optimization impact without array read/write overhead. The cache is cleared before each iteration. """ - from zarr.core.indexing import _morton_order, _morton_order_keys, morton_order_iter + from zarr.core.indexing import _morton_order, morton_order_coords def compute_morton_order() -> None: _morton_order.cache_clear() - _morton_order_keys.cache_clear() - # Consume the iterator to force computation - list(morton_order_iter(shape)) + morton_order_coords.cache_clear() + # Build the full sequence to force computation + list(morton_order_coords(shape)) benchmark(compute_morton_order) @@ -250,7 +250,12 @@ def test_sharded_morton_write_single_chunk( """ import numpy as np - from zarr.core.indexing import _morton_order, _morton_order_keys + from zarr.core.indexing import ( + _lexicographic_order, + _morton_order, + lexicographic_order_coords, + morton_order_coords, + ) # 1x1x1 chunks means chunks_per_shard equals shard shape shape = tuple(s * 2 for s in shards) # 2 shards per dimension @@ -272,8 +277,67 @@ def test_sharded_morton_write_single_chunk( indexer = (slice(1), slice(1), slice(1)) def write_with_cache_clear() -> None: + # Clear every coordinate cache the write path touches, not just morton: + # the sharded write also builds the lexicographic grid (dict.fromkeys / + # to_dict_vectorized), so a partial clear would leave that path warm and + # under-report the cold build cost. _morton_order.cache_clear() - _morton_order_keys.cache_clear() + morton_order_coords.cache_clear() + _lexicographic_order.cache_clear() + lexicographic_order_coords.cache_clear() data[indexer] = write_data benchmark(write_with_cache_clear) + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +@pytest.mark.parametrize("shards", large_morton_shards, ids=str) +def test_sharded_morton_write_single_chunk_warm_cache( + store: Store, + shards: tuple[int, ...], + benchmark: BenchmarkFixture, +) -> None: + """Benchmark a single-chunk shard write with the chunk-order cache warm. + + Unlike ``test_sharded_morton_write_single_chunk``, this does NOT clear the + order cache between iterations: it warms the cache once, then repeatedly + writes the same single chunk. This isolates the amortized per-write cost the + cache exists to optimize — the regime where the coordinate grid was already + built (by an earlier write to this shard, or to any same-shaped shard) and is + reused rather than rebuilt. Repeated writes to one shard and writes spread + across many same-shaped shards exercise that cache reuse identically. + + This is the regime the cold benchmark cannot see. A regression that rebuilds + the per-shard coordinate tuples on every write (rather than reusing the + cached sequence) is invisible to the cold benchmark but shows up here. + """ + import numpy as np + + from zarr.core.indexing import _morton_order, morton_order_coords + + shape = tuple(s * 2 for s in shards) # 2 shards per dimension + chunks = (1,) * 3 # 1x1x1 chunks: chunks_per_shard = shards + + data = create_array( + store=store, + shape=shape, + dtype="uint8", + chunks=chunks, + shards=shards, + compressors=None, + filters=None, + fill_value=0, + ) + + write_data = np.ones((1, 1, 1), dtype="uint8") + indexer = (slice(1), slice(1), slice(1)) + + # Warm the cache once up front; the timed writes then hit the warm path. + _morton_order.cache_clear() + morton_order_coords.cache_clear() + data[indexer] = write_data + + def write_warm() -> None: + data[indexer] = write_data + + benchmark(write_warm) diff --git a/tests/test_codecs/test_codecs.py b/tests/test_codecs/test_codecs.py index 6e3e3f6d28..b88aa6f507 100644 --- a/tests/test_codecs/test_codecs.py +++ b/tests/test_codecs/test_codecs.py @@ -18,7 +18,7 @@ TransposeCodec, ) from zarr.core.buffer import default_buffer_prototype -from zarr.core.indexing import BasicSelection, decode_morton, morton_order_iter +from zarr.core.indexing import BasicSelection, decode_morton, morton_order_coords from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.dtype import UInt8 from zarr.errors import ZarrUserWarning @@ -173,8 +173,8 @@ def test_open(store: Store) -> None: def test_morton_exact_order() -> None: """Test exact morton ordering for power-of-2 shapes.""" - assert list(morton_order_iter((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)] - assert list(morton_order_iter((2, 2, 2))) == [ + assert list(morton_order_coords((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)] + assert list(morton_order_coords((2, 2, 2))) == [ (0, 0, 0), (1, 0, 0), (0, 1, 0), @@ -184,7 +184,7 @@ def test_morton_exact_order() -> None: (0, 1, 1), (1, 1, 1), ] - assert list(morton_order_iter((2, 2, 2, 2))) == [ + assert list(morton_order_coords((2, 2, 2, 2))) == [ (0, 0, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0), @@ -223,12 +223,12 @@ def test_morton_exact_order() -> None: ], ) def test_morton_is_permutation(shape: tuple[int, ...]) -> None: - """Test that morton_order_iter produces every valid coordinate exactly once.""" + """Test that morton_order_coords produces every valid coordinate exactly once.""" import itertools from zarr.core.common import product - order = list(morton_order_iter(shape)) + order = list(morton_order_coords(shape)) expected_len = product(shape) # completeness: every valid coordinate is present assert len(order) == expected_len @@ -257,7 +257,7 @@ def test_morton_ordering(shape: tuple[int, ...]) -> None: so the ordering should be exactly decode_morton(0), decode_morton(1), ... """ - order = list(morton_order_iter(shape)) + order = list(morton_order_coords(shape)) for i, coord in enumerate(order): assert coord == decode_morton(i, shape) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 74e4a7e0d5..0319ac1845 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -21,7 +21,7 @@ ) from zarr.codecs.sharding import MAX_UINT_64, SubchunkWriteOrder, _ShardIndex, _ShardReader from zarr.core.buffer import NDArrayLike, default_buffer_prototype -from zarr.core.indexing import c_order_iter +from zarr.core.indexing import lexicographic_order_coords from zarr.storage import MemoryStore, StorePath, ZipStore from ..conftest import ArrayRequest @@ -978,7 +978,7 @@ def test_shard_index_get_chunk_slices_vectorized(chunks_per_shard: tuple[int, .. """get_chunk_slices_vectorized works uniformly across chunk grid ranks, including 0-D.""" index = _ShardIndex.create_empty(chunks_per_shard) # Write the first chunk; leave the rest (if any) empty. - all_coords = list(c_order_iter(chunks_per_shard)) + all_coords = list(lexicographic_order_coords(chunks_per_shard)) index.set_chunk_slice(all_coords[0], slice(10, 14)) coords_array = np.array(all_coords, dtype=np.uint64).reshape( @@ -992,3 +992,39 @@ def test_shard_index_get_chunk_slices_vectorized(chunks_per_shard: tuple[int, .. assert starts[0] == 10 assert ends[0] == 14 np.testing.assert_array_equal(starts[~expected_valid], MAX_UINT_64) + + +@pytest.mark.parametrize("chunks_per_shard", [(), (3,), (2, 3)]) +def test_shard_reader_to_dict_vectorized(chunks_per_shard: tuple[int, ...]) -> None: + """to_dict_vectorized derives its own coords and maps present chunks to buffers, empty to None. + + The reader is given the full per-shard chunk grid implicitly (it reads + ``chunks_per_shard`` off its own index), so the result must contain every + lexicographic coordinate as a key, with the stored bytes for present chunks + and ``None`` for empty ones. + """ + all_coords = list(lexicographic_order_coords(chunks_per_shard)) + # Lay two chunks back-to-back in the buffer; leave the rest (if any) empty. + payload = b"abcdXY" + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice(all_coords[0], slice(0, 4)) + present = {all_coords[0]: payload[0:4]} + if len(all_coords) > 1: + index.set_chunk_slice(all_coords[1], slice(4, 6)) + present[all_coords[1]] = payload[4:6] + + reader = _ShardReader() + reader.index = index + reader.buf = default_buffer_prototype().buffer.from_bytes(payload) + + result = reader.to_dict_vectorized() + + # Every lexicographic coordinate is present as a key, in order. + assert list(result.keys()) == all_coords + for coords in all_coords: + buf = result[coords] + if coords in present: + assert buf is not None + assert buf.to_bytes() == present[coords] + else: + assert buf is None