From e9ea783b668d418b0cd288e3e2afb1f5c1a0595f Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 12 Mar 2026 09:17:09 +0100 Subject: [PATCH 01/11] ome zarr chunks --- src/spatialdata/_io/io_raster.py | 93 +++++++++++++++++++++++++++++--- tests/io/test_readwrite.py | 32 +++++++++++ 2 files changed, 117 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index df7e1cb8f..a75670232 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path from typing import Any, Literal @@ -38,6 +39,88 @@ ) +def _is_flat_int_sequence(value: object) -> bool: + if isinstance(value, str | bytes): + return False + if not isinstance(value, Sequence): + return False + return all(isinstance(v, int) for v in value) + + +def _is_dask_chunk_grid(value: object) -> bool: + if isinstance(value, str | bytes): + return False + if not isinstance(value, Sequence): + return False + return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) + + +def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: + # Match Dask's private _check_regular_chunks() logic without depending on its internal API. + for axis_chunks in chunk_grid: + if len(axis_chunks) <= 1: + continue + if len(set(axis_chunks[:-1])) > 1: + return False + if axis_chunks[-1] > axis_chunks[0]: + return False + return True + + +def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None: + if isinstance(chunks, int): + return chunks + if _is_flat_int_sequence(chunks): + return tuple(chunks) + if _is_dask_chunk_grid(chunks): + chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks) + if _is_regular_dask_chunk_grid(chunk_grid): + return tuple(axis_chunks[0] for axis_chunks in chunk_grid) + return None + return None + + +def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: + normalized = _chunks_to_zarr_chunks(chunks) + if normalized is None: + raise ValueError( + "storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. " + "Irregular Dask chunk grids must be rechunked before writing or omitted." + ) + return normalized + + +def _prepare_single_scale_storage_options( + storage_options: JSONDict | list[JSONDict] | None, +) -> JSONDict | list[JSONDict] | None: + if storage_options is None: + return None + if isinstance(storage_options, dict): + prepared = dict(storage_options) + if "chunks" in prepared: + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) + return prepared + return [dict(options) for options in storage_options] + + +def _prepare_multiscale_storage_options( + storage_options: JSONDict | list[JSONDict] | None, +) -> JSONDict | list[JSONDict] | None: + if storage_options is None: + return None + if isinstance(storage_options, dict): + prepared = dict(storage_options) + if "chunks" in prepared: + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) + return prepared + + prepared_options = [dict(options) for options in storage_options] + for options in prepared_options: + if "chunks" in options: + options["chunks"] = _normalize_explicit_chunks(options["chunks"]) + return prepared_options + + def _read_multiscale( store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: @@ -251,13 +334,8 @@ def _write_raster_dataarray( if transformations is None: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} + storage_options = _prepare_single_scale_storage_options(storage_options) # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. @@ -322,10 +400,9 @@ def _write_raster_datatree( transformations = _get_transformations_xarray(xdata) if transformations is None: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") - chunks = get_pyramid_levels(raster_data, "chunks") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = [{"chunks": chunk} for chunk in chunks] + storage_options = _prepare_multiscale_storage_options(storage_options) ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index be07d8be8..3e9371ea9 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Literal +import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd @@ -30,6 +31,7 @@ SpatialDataContainerFormatType, SpatialDataContainerFormatV01, ) +from spatialdata._io.io_raster import write_image from spatialdata.datasets import blobs from spatialdata.models import Image2DModel from spatialdata.models._utils import get_channel_names @@ -623,6 +625,36 @@ def test_bug_rechunking_after_queried_raster(): queried.write(f) +def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image}) + + sdata.write(tmp_path / "data.zarr") + + +def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) + + assert group["s0"].chunks == (3, 300, 512) + + +def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid", + ): + write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) + + @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained From c4e8608a498d97f990a6d71cf3f4772450f0a8a8 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 12 Mar 2026 09:45:21 +0100 Subject: [PATCH 02/11] set scale factors to emtpy list + fix unit tests --- src/spatialdata/_io/io_raster.py | 7 +++++-- tests/io/test_partial_read.py | 20 ++++++++++---------- tests/io/test_readwrite.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index a75670232..66bf6b6a0 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -336,13 +336,16 @@ def _write_raster_dataarray( input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = _prepare_single_scale_storage_options(storage_options) - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of + # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default + # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... + # even when the input is a plain DataArray. + # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data ome_zarr_format = get_ome_zarr_format(raster_format) write_single_scale_ngff( group=group, + scale_factors=[], scaler=None, fmt=ome_zarr_format, axes=parsed_axes, diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 7c5d47841..9f51b4e17 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") - (sdata_path / "images" / corrupted / "0").touch() + os.unlink(sdata_path / "images" / corrupted / "S0" / "zarr.json") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") + (sdata_path / "images" / corrupted / "S0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") - (sdata_path / "images" / corrupted / "0").touch() + os.unlink(sdata_path / "images" / corrupted / "S0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") + (sdata_path / "images" / corrupted / "S0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3( sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "S0" / "zarr.json") + os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2( sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "S0" / ".zarray") + os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 3e9371ea9..42f484782 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -19,6 +19,7 @@ from packaging.version import Version from shapely import MultiPolygon, Polygon from upath import UPath +from xarray import DataArray from zarr.errors import GroupNotFoundError import spatialdata.config @@ -655,6 +656,19 @@ def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) +def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: + image = Image2DModel.parse(RNG.random((3, 64, 64)), dims=("c", "y", "x")) + sdata = SpatialData(images={"image": image}) + path = tmp_path / "data.zarr" + + sdata.write(path) + sdata_back = read_zarr(path) + + assert isinstance(sdata_back["image"], DataArray) + image_group = zarr.open_group(path / "images" / "image", mode="r") + assert list(image_group.keys()) == ["s0"] + + @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained From da9eef3b1a568ed25dd35b85a799cea7f8ae3bc8 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 12 Mar 2026 10:04:21 +0100 Subject: [PATCH 03/11] mypy --- src/spatialdata/_io/io_raster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 66bf6b6a0..55e03abed 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, TypeGuard import dask.array as da import numpy as np @@ -39,7 +39,7 @@ ) -def _is_flat_int_sequence(value: object) -> bool: +def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: if isinstance(value, str | bytes): return False if not isinstance(value, Sequence): @@ -47,7 +47,7 @@ def _is_flat_int_sequence(value: object) -> bool: return all(isinstance(v, int) for v in value) -def _is_dask_chunk_grid(value: object) -> bool: +def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: if isinstance(value, str | bytes): return False if not isinstance(value, Sequence): From 1c060b38a6c4101b84f0cdc460a78b05768bddc5 Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 12 Mar 2026 10:54:06 +0100 Subject: [PATCH 04/11] lowercase to fix unit test linux --- tests/io/test_partial_read.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 9f51b4e17..28460046e 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "S0" / "zarr.json") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") - (sdata_path / "images" / corrupted / "S0").touch() + os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") + (sdata_path / "images" / corrupted / "s0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "S0" / ".zarray") # it will hide the "0" array from the Zarr reader - os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") - (sdata_path / "images" / corrupted / "S0").touch() + os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") + (sdata_path / "images" / corrupted / "s0").touch() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3( sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "S0" / "zarr.json") - os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2( sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "S0" / ".zarray") - os.rename(sdata_path / "images" / corrupted / "S0", sdata_path / "images" / corrupted / "S0_corrupted") + os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") + os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] From 85148d7f2823dd268831e726195f5f5a2874841d Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Thu, 12 Mar 2026 11:28:22 +0100 Subject: [PATCH 05/11] bump ome zarr in pyproject toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5f3134aa..07ec8140b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "networkx", "numba>=0.55.0", "numpy", - "ome_zarr>=0.12.2", + "ome_zarr>=0.14.0", "pandas", "pooch", "pyarrow", From 930922c2ad96c3cf2d399b154b5d871c7248e4af Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 19 Mar 2026 23:16:51 +0100 Subject: [PATCH 06/11] dask accessor is now always loaded --- src/spatialdata/__init__.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 19be99d56..7ba66e710 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -4,6 +4,8 @@ from importlib.metadata import version from typing import TYPE_CHECKING, Any +import spatialdata.models._accessor # noqa: F401 + __version__ = version("spatialdata") _submodules = { @@ -129,15 +131,8 @@ "settings", ] -_accessor_loaded = False - def __getattr__(name: str) -> Any: - global _accessor_loaded - if not _accessor_loaded: - _accessor_loaded = True - import spatialdata.models._accessor # noqa: F401 - if name in _submodules: return importlib.import_module(f"spatialdata.{name}") if name in _LAZY_IMPORTS: From 20696b558589b42daf6fe227ebe2250ad77aab0d Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 19 Mar 2026 23:19:02 +0100 Subject: [PATCH 07/11] deduplicate storage option util; use chunks from data when not specified in storage options --- src/spatialdata/_io/io_raster.py | 45 ++++++++++++++------------------ tests/io/test_readwrite.py | 6 ++++- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 55e03abed..2eea98162 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -90,34 +90,27 @@ def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: return normalized -def _prepare_single_scale_storage_options( +def _prepare_storage_options( storage_options: JSONDict | list[JSONDict] | None, -) -> JSONDict | list[JSONDict] | None: + data: list[da.Array], +) -> list[JSONDict]: if storage_options is None: - return None - if isinstance(storage_options, dict): - prepared = dict(storage_options) - if "chunks" in prepared: - prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) - return prepared - return [dict(options) for options in storage_options] - - -def _prepare_multiscale_storage_options( - storage_options: JSONDict | list[JSONDict] | None, -) -> JSONDict | list[JSONDict] | None: - if storage_options is None: - return None + return [{"chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data] if isinstance(storage_options, dict): + if "chunks" not in storage_options: + return [{**storage_options, "chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data] prepared = dict(storage_options) - if "chunks" in prepared: - prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) - return prepared - - prepared_options = [dict(options) for options in storage_options] - for options in prepared_options: - if "chunks" in options: - options["chunks"] = _normalize_explicit_chunks(options["chunks"]) + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) + return prepared # type: ignore[return-value] + + prepared_options = [] + for i, options in enumerate(storage_options): + opts = dict(options) + if "chunks" not in opts: + opts["chunks"] = _normalize_explicit_chunks(data[i].chunks) + else: + opts["chunks"] = _normalize_explicit_chunks(opts["chunks"]) + prepared_options.append(opts) return prepared_options @@ -335,7 +328,7 @@ def _write_raster_dataarray( raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = _prepare_single_scale_storage_options(storage_options) + storage_options = _prepare_storage_options(storage_options, [data]) # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... # even when the input is a plain DataArray. @@ -405,7 +398,7 @@ def _write_raster_datatree( raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = _prepare_multiscale_storage_options(storage_options) + storage_options = _prepare_storage_options(storage_options, data) ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 42f484782..463ea2f21 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -631,7 +631,11 @@ def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: image = Image2DModel.parse(data, dims=("c", "y", "x")) sdata = SpatialData(images={"image": image}) - sdata.write(tmp_path / "data.zarr") + with pytest.raises( + ValueError, + match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid", + ): + sdata.write(tmp_path / "data.zarr") def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None: From 2450bd437914004cc940279482f2dc60293e0575 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 19 Mar 2026 23:27:08 +0100 Subject: [PATCH 08/11] simplify, document and test the chunk helper functions --- src/spatialdata/_io/io_raster.py | 52 +++++++++++++++++++++++++++++--- tests/io/test_readwrite.py | 23 ++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 2eea98162..e291fa9b8 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -40,22 +40,66 @@ def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: - if isinstance(value, str | bytes): - return False if not isinstance(value, Sequence): return False return all(isinstance(v, int) for v in value) def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: - if isinstance(value, str | bytes): - return False if not isinstance(value, Sequence): return False return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: + """Check whether a Dask chunk grid is regular (zarr-compatible). + + A grid is regular when every axis has at most one unique chunk size among all but the last + chunk, and the last chunk is not larger than the first. + + Parameters + ---------- + chunk_grid + Per-axis tuple of chunk sizes, for instance as returned by ``dask_array.chunks``. + + Examples + -------- + Triggers ``continue`` on the first ``if`` (single or empty axis): + + >>> _is_regular_dask_chunk_grid([(4,)]) # single chunk → True + True + >>> _is_regular_dask_chunk_grid([()]) # empty axis → True + True + + Triggers the first ``return False`` (non-uniform interior chunks): + + >>> _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) # interior sizes differ → False + False + + Triggers the second ``return False`` (last chunk larger than the first): + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) # last > first → False + False + + Exits with ``return True``: + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) # all equal → True + True + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) # last < first → True + True + + Empty grid (loop never executes) → True: + + >>> _is_regular_dask_chunk_grid([]) + True + + Multi-axis: all axes regular → True; one axis irregular → False: + + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)]) + True + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)]) + False + """ # Match Dask's private _check_regular_chunks() logic without depending on its internal API. for axis_chunks in chunk_grid: if len(axis_chunks) <= 1: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 463ea2f21..26b8cbfc4 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -626,6 +626,29 @@ def test_bug_rechunking_after_queried_raster(): queried.write(f) +def test_is_regular_dask_chunk_grid() -> None: + from spatialdata._io.io_raster import _is_regular_dask_chunk_grid + + # Single chunk per axis → continue branch, overall True + assert _is_regular_dask_chunk_grid([(4,)]) is True + # Empty axis → continue branch, overall True + assert _is_regular_dask_chunk_grid([()]) is True + # Non-uniform interior chunks → first return False + assert _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) is False + # Last chunk larger than first → second return False + assert _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) is False + # All chunks equal → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) is True + # Last chunk smaller than first → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) is True + # Empty grid (no axes) → True + assert _is_regular_dask_chunk_grid([]) is True + # Multi-axis: all axes regular → True + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)]) is True + # Multi-axis: one axis irregular → False + assert _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)]) is False + + def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None: data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) image = Image2DModel.parse(data, dims=("c", "y", "x")) From 3dd6121fb12fa8a2de7ed435da5d1380c6cc36cd Mon Sep 17 00:00:00 2001 From: ArneDefauw Date: Fri, 20 Mar 2026 11:20:21 +0100 Subject: [PATCH 09/11] guard against storage_options["chunks"]="" + Change ValueError --- src/spatialdata/_io/io_raster.py | 10 +++++-- tests/io/test_readwrite.py | 51 ++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index e291fa9b8..afcf2f9aa 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -40,12 +40,16 @@ def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: + if isinstance(value, str | bytes): + return False if not isinstance(value, Sequence): return False return all(isinstance(v, int) for v in value) def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: + if isinstance(value, str | bytes): + return False if not isinstance(value, Sequence): return False return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) @@ -128,8 +132,10 @@ def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: normalized = _chunks_to_zarr_chunks(chunks) if normalized is None: raise ValueError( - "storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. " - "Irregular Dask chunk grids must be rechunked before writing or omitted." + 'storage_options["chunks"] must resolve to a Zarr chunk shape or a regular Dask chunk grid. ' + "The current raster has irregular Dask chunks, which cannot be written to Zarr. " + "To fix this, rechunk before writing, for example by passing regular chunks=... " + "to Image2DModel.parse(...) / Labels2DModel.parse(...)." ) return normalized diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 26b8cbfc4..0de42604d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -656,7 +656,7 @@ def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: with pytest.raises( ValueError, - match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid", + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', ): sdata.write(tmp_path / "data.zarr") @@ -678,11 +678,58 @@ def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) with pytest.raises( ValueError, - match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid", + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', ): write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) +def test_write_image_normalizes_explicit_zarr_chunk_grid(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + zarr_chunks = (3, 100, 512) # ome zarr rechunks when writing + write_image(image, group, "image", storage_options={"chunks": zarr_chunks}) + + assert group["s0"].chunks == (3, 100, 512) + + +def test_write_image_rejects_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": "auto"}) + + +def test_write_image_rejects_empty_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": ""}) + + +def test_write_image_rejects_byte_string(tmp_path: Path) -> None: + data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) + image = Image2DModel.parse(data, dims=("c", "y", "x")) + group = zarr.open_group(tmp_path / "image.zarr", mode="w") + + with pytest.raises( + ValueError, + match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', + ): + write_image(image, group, "image", storage_options={"chunks": b"auto"}) + + def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: image = Image2DModel.parse(RNG.random((3, 64, 64)), dims=("c", "y", "x")) sdata = SpatialData(images={"image": image}) From 8c85587404d5a9379fe103b27dbbeafd1363b9ef Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 20 Mar 2026 14:47:51 +0100 Subject: [PATCH 10/11] remove data argument from _prepare_storage_options() --- src/spatialdata/_io/io_raster.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index afcf2f9aa..a8b2ab2ce 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -40,6 +40,7 @@ def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: + # e.g. "", "auto" or b"auto" if isinstance(value, str | bytes): return False if not isinstance(value, Sequence): @@ -142,25 +143,19 @@ def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: def _prepare_storage_options( storage_options: JSONDict | list[JSONDict] | None, - data: list[da.Array], -) -> list[JSONDict]: +) -> JSONDict | list[JSONDict] | None: if storage_options is None: - return [{"chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data] + return None if isinstance(storage_options, dict): - if "chunks" not in storage_options: - return [{**storage_options, "chunks": _normalize_explicit_chunks(arr.chunks)} for arr in data] prepared = dict(storage_options) - prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) - return prepared # type: ignore[return-value] - - prepared_options = [] - for i, options in enumerate(storage_options): - opts = dict(options) - if "chunks" not in opts: - opts["chunks"] = _normalize_explicit_chunks(data[i].chunks) - else: - opts["chunks"] = _normalize_explicit_chunks(opts["chunks"]) - prepared_options.append(opts) + if "chunks" in prepared: + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) + return prepared + + prepared_options = [dict(options) for options in storage_options] + for options in prepared_options: + if "chunks" in options: + options["chunks"] = _normalize_explicit_chunks(options["chunks"]) return prepared_options @@ -378,7 +373,7 @@ def _write_raster_dataarray( raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = _prepare_storage_options(storage_options, [data]) + storage_options = _prepare_storage_options(storage_options) # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... # even when the input is a plain DataArray. @@ -448,7 +443,7 @@ def _write_raster_datatree( raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = _prepare_storage_options(storage_options, data) + storage_options = _prepare_storage_options(storage_options) ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, From 385dd2ec86d62ca9bbe16f67c2c2f0fe9d87faa4 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 20 Mar 2026 14:48:06 +0100 Subject: [PATCH 11/11] remove data argument from _prepare_storage_options() --- tests/io/test_readwrite.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 0de42604d..209a43046 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -654,11 +654,11 @@ def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: image = Image2DModel.parse(data, dims=("c", "y", "x")) sdata = SpatialData(images={"image": image}) - with pytest.raises( - ValueError, - match='storage_options\\["chunks"\\] must resolve to a Zarr chunk shape or a regular Dask chunk grid', - ): - sdata.write(tmp_path / "data.zarr") + path = tmp_path / "data.zarr" + with pytest.warns(UserWarning, match="irregular chunk sizes"): + sdata.write(path) + sdata_back = read_zarr(path) + assert sdata_back["image"].chunks == ((3,), (300, 300, 200), (512, 488)) def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None: