diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..9f26bc57b1 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from dataclasses import dataclass, replace from enum import Enum from functools import lru_cache @@ -45,6 +45,7 @@ from zarr.core.dtype.npy.int import UInt64 from zarr.core.indexing import ( BasicIndexer, + ChunkProjection, SelectorTuple, _morton_order, _morton_order_keys, @@ -574,21 +575,26 @@ async def _encode_partial_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - shard_reader = await self._load_full_shard_maybe( - byte_getter=byte_setter, - prototype=chunk_spec.prototype, - chunks_per_shard=chunks_per_shard, - ) - shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) - indexer = list( get_indexer( selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) ) ) + if self._is_complete_shard_write(indexer, chunks_per_shard): + shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard)) + else: + shard_reader = await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, + ) + shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) + # Use vectorized lookup for better performance + shard_dict = shard_reader.to_dict_vectorized( + np.asarray(_morton_order(chunks_per_shard)) + ) + await self.codec_pipeline.write( [ ( @@ -661,6 +667,16 @@ def _is_total_shard( chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) ) + def _is_complete_shard_write( + self, + indexed_chunks: Sequence[ChunkProjection], + chunks_per_shard: tuple[int, ...], + ) -> bool: + all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} + return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all( + is_complete_chunk for *_, is_complete_chunk in indexed_chunks + ) + async def _decode_shard_index( self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...] ) -> _ShardIndex: diff --git a/tests/test_array.py b/tests/test_array.py index 5b85c6ba1d..8ea79b5f10 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -2259,9 +2259,34 @@ def test_create_array_with_data_num_gets( data = zarr.zeros(shape, dtype="int64") zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type] - # one get for the metadata and one per shard. - # Note: we don't actually need one get per shard, but this is the current behavior - assert store.counter["get"] == 1 + num_shards + # One get for the metadata; full-shard writes should not read shard payloads. + assert store.counter["get"] == 1 + + +@pytest.mark.parametrize( + ("selection", "expected_gets"), + [(slice(None), 0), (slice(1, 9), 1)], +) +def test_shard_write_num_gets(selection: slice, expected_gets: int) -> None: + """ + Test that partial-shard writes read the existing data and full-shard writes don't. + """ + store = LoggingStore(store=MemoryStore()) + arr = zarr.create_array( + store, + shape=(10,), + chunks=(1,), + shards=(10,), + dtype="int64", + fill_value=-1, + ) + arr[:] = 0 + + store.counter.clear() + + arr[selection] = 1 + + assert store.counter["get"] == expected_gets @pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}])