diff --git a/.gitignore b/.gitignore index 379f16be..35db05af 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ temp/ # Compiled files __pycache__/ +.ipynb_checkpoints/ # Distribution / packaging /build/ @@ -42,3 +43,4 @@ data data/ tests/data uv.lock +.asv/ diff --git a/asv.conf.json b/asv.conf.json index 516f0f1f..55f61b90 100644 --- a/asv.conf.json +++ b/asv.conf.json @@ -3,7 +3,7 @@ "project": "spatialdata-io", "project_url": "https://github.com/scverse/spatialdata-io", "repo": ".", - "branches": ["main", "xenium-labels-dask", "xenium-labels-dask-zipstore"], + "branches": ["image-reader-chunkwise"], "dvcs": "git", "environment_type": "virtualenv", "pythons": ["3.12"], diff --git a/benchmarks/benchmark_image.py b/benchmarks/benchmark_image.py new file mode 100644 index 00000000..3096fee7 --- /dev/null +++ b/benchmarks/benchmark_image.py @@ -0,0 +1,180 @@ +"""Benchmarks for SpatialData IO operations for large images. + +Instructions: + See benchmark_xenium.py for instructions. +""" + +import logging +import logging.handlers +import tempfile +from pathlib import Path +from typing import Any + +import numpy as np +import tifffile +from spatialdata import SpatialData +from spatialdata._logging import logger +from xarray import DataArray + +from spatialdata_io import image # type: ignore[attr-defined] + +# ============================================================================= +# CONFIGURATION - Edit these values to match your setup +# ============================================================================= +# Image dimensions: (channels, height, width) +IMAGE_SHAPE = (3, 30000, 30000) +# ============================================================================= + + +class IOBenchmarkImage: + """Benchmark IO read operations with different parameter combinations.""" + + timeout = 3600 + repeat = 3 + number = 1 + warmup_time = 0 + processes = 1 + + # Parameter combinations: scale_factors, (use_tiff_memmap, compressed), chunks + # Combinations: (memmap=False, compressed=True), (memmap=False, compressed=False), (memmap=True, compressed=False) + params = [ + [None, [2, 2]], # scale_factors + [(False, True), (False, False), (True, False)], # (use_tiff_memmap, compressed) + [(1, 250, 250), (3, 250, 250)], # chunks + ] + param_names = ["scale_factors", "memmap_compressed", "chunks"] + + # Class-level temp directory for image files (persists across all benchmarks) + _images_temp_dir: tempfile.TemporaryDirectory[str] | None = None + _path_read_uncompressed: Path | None = None + _path_read_compressed: Path | None = None + + @classmethod + def _setup_images(cls) -> None: + """Create fake image data once for all benchmarks.""" + if cls._images_temp_dir is not None: + return + + cls._images_temp_dir = tempfile.TemporaryDirectory() + images_dir = Path(cls._images_temp_dir.name) + cls._path_read_uncompressed = images_dir / "image_uncompressed.tif" + cls._path_read_compressed = images_dir / "image_compressed.tif" + + # Generate fake image data + rng = np.random.default_rng(42) + data = rng.integers(0, 255, size=IMAGE_SHAPE, dtype=np.uint8) + + # Write uncompressed TIFF (memmappable) + tifffile.imwrite(cls._path_read_uncompressed, data, compression=None) + # Write compressed TIFF (not memmappable) + tifffile.imwrite(cls._path_read_compressed, data, compression="zlib") + + def setup(self, *_: Any) -> None: + """Set up paths for benchmarking.""" + # Create images once (shared across all benchmark runs) + self._setup_images() + self.path_read_uncompressed = self._path_read_uncompressed + self.path_read_compressed = self._path_read_compressed + + # Create a separate temp directory for output (cleaned up after each run) + self._output_temp_dir = tempfile.TemporaryDirectory() + self.path_write = Path(self._output_temp_dir.name) / "data_benchmark.zarr" + + def teardown(self, *_: Any) -> None: + """Clean up output directory after each benchmark run.""" + if hasattr(self, "_output_temp_dir"): + self._output_temp_dir.cleanup() + + def _convert_image( + self, scale_factors: list[int] | None, memmap_compressed: tuple[bool, bool], chunks: tuple[int, ...] + ) -> SpatialData: + """Read image data with specified parameters.""" + use_tiff_memmap, compressed = memmap_compressed + # Select file based on compression setting + path_read = self.path_read_compressed if compressed else self.path_read_uncompressed + assert path_read is not None + + # Capture log messages to verify memmappable warning behavior + log_capture = logging.handlers.MemoryHandler(capacity=100) + log_capture.setLevel(logging.WARNING) + logger.addHandler(log_capture) + original_propagate = logger.propagate + logger.propagate = True + + try: + im = image( + input=path_read, + data_axes=("c", "y", "x"), + coordinate_system="global", + use_tiff_memmap=use_tiff_memmap, + chunks=chunks, + scale_factors=scale_factors, + ) + finally: + logger.removeHandler(log_capture) + logger.propagate = original_propagate + + # Check warning behavior: when use_tiff_memmap=True with uncompressed file, no warning should be raised + log_messages = [record.getMessage() for record in log_capture.buffer] + has_memmap_warning = any("image data is not memory-mappable" in msg for msg in log_messages) + if use_tiff_memmap and not compressed: + assert not has_memmap_warning, ( + "Uncompressed TIFF with memmap=True should not trigger memory-mappable warning" + ) + + sdata = SpatialData.init_from_elements({"image": im}) + # sanity check: chunks is (c, y, x) + if scale_factors is None: + assert isinstance(sdata["image"], DataArray) + if chunks is not None: + assert ( + sdata["image"].chunksizes["x"][0] == chunks[2] + or sdata["image"].chunksizes["x"][0] == sdata["image"].shape[2] + ) + assert ( + sdata["image"].chunksizes["y"][0] == chunks[1] + or sdata["image"].chunksizes["y"][0] == sdata["image"].shape[1] + ) + else: + assert len(sdata["image"].keys()) == len(scale_factors) + 1 + if chunks is not None: + assert ( + sdata["image"]["scale0"]["image"].chunksizes["x"][0] == chunks[2] + or sdata["image"]["scale0"]["image"].chunksizes["x"][0] + == sdata["image"]["scale0"]["image"].shape[2] + ) + assert ( + sdata["image"]["scale0"]["image"].chunksizes["y"][0] == chunks[1] + or sdata["image"]["scale0"]["image"].chunksizes["y"][0] + == sdata["image"]["scale0"]["image"].shape[1] + ) + + return sdata + + def time_io( + self, scale_factors: list[int] | None, memmap_compressed: tuple[bool, bool], chunks: tuple[int, ...] + ) -> None: + """Walltime for data parsing.""" + sdata = self._convert_image(scale_factors, memmap_compressed, chunks) + sdata.write(self.path_write) + + def peakmem_io( + self, scale_factors: list[int] | None, memmap_compressed: tuple[bool, bool], chunks: tuple[int, ...] + ) -> None: + """Peak memory for data parsing.""" + sdata = self._convert_image(scale_factors, memmap_compressed, chunks) + sdata.write(self.path_write) + + +# if __name__ == "__main__": +# # Run a single test case for quick verification +# bench = IOBenchmarkImage() +# +# bench.setup() +# bench.time_io(None, (True, False), (1, 5000, 5000)) +# bench.teardown() +# +# # Clean up the shared images temp directory at the end +# if IOBenchmarkImage._images_temp_dir is not None: +# IOBenchmarkImage._images_temp_dir.cleanup() +# IOBenchmarkImage._images_temp_dir = None diff --git a/benchmarks/bench_xenium.py b/benchmarks/benchmark_xenium.py similarity index 85% rename from benchmarks/bench_xenium.py rename to benchmarks/benchmark_xenium.py index 685e5094..b1f09afe 100644 --- a/benchmarks/bench_xenium.py +++ b/benchmarks/benchmark_xenium.py @@ -11,18 +11,18 @@ cd /path/to/spatialdata-io # Quick benchmark (single run, for testing): - asv run --python=same -b IOBenchmark --quick --show-stderr -v + asv run --python=same -b IOBenchmarkXenium --quick --show-stderr -v # Full benchmark (multiple runs, for accurate results): - asv run --python=same -b IOBenchmark --show-stderr -v + asv run --python=same -b IOBenchmarkXenium --show-stderr -v Comparing branches: # Run on specific commits: - asv run main^! -b IOBenchmark --show-stderr -v - asv run xenium-labels-dask^! -b IOBenchmark --show-stderr -v + asv run main^! -b IOBenchmarkXenium --show-stderr -v + asv run xenium-labels-dask^! -b IOBenchmarkXenium --show-stderr -v # Or compare two branches directly: - asv continuous main xenium-labels-dask -b IOBenchmark --show-stderr -v + asv continuous main xenium-labels-dask -b IOBenchmarkXenium --show-stderr -v # View comparison: asv compare main xenium-labels-dask @@ -36,7 +36,6 @@ import inspect import shutil from pathlib import Path -from typing import TYPE_CHECKING from spatialdata import SpatialData @@ -62,9 +61,7 @@ def get_paths() -> tuple[Path, Path]: return path_read, path_write -class IOBenchmark: - """Benchmark IO read operations.""" - +class IOBenchmarkXenium: timeout = 3600 repeat = 3 number = 1 @@ -106,4 +103,6 @@ def peakmem_io(self) -> None: if __name__ == "__main__": - IOBenchmark().time_io() + benchmark = IOBenchmarkXenium() + benchmark.setup() + benchmark.time_io() diff --git a/src/spatialdata_io/readers/_utils/_image.py b/src/spatialdata_io/readers/_utils/_image.py new file mode 100644 index 00000000..3784ccca --- /dev/null +++ b/src/spatialdata_io/readers/_utils/_image.py @@ -0,0 +1,202 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import dask.array as da +import numpy as np +from dask import delayed +from numpy.typing import NDArray +from spatialdata.models.models import Chunks_t + +__all__ = ["Chunks_t", "_compute_chunks", "_read_chunks", "normalize_chunks"] + +_Y_IDX = 0 +"""Index of y coordinate in in chunk coordinate array format: (y, x, height, width)""" + +_X_IDX = 1 +"""Index of x coordinate in chunk coordinate array format: (y, x, height, width)""" + +_HEIGHT_IDX = 2 +"""Index of height specification in chunk coordinate array format: (y, x, height, width)""" + +_WIDTH_IDX = 3 +"""Index of width specification in chunk coordinate array format: (y, x, height, width)""" + +DEFAULT_CHUNK_SIZE = 1000 + + +def _compute_chunk_sizes_positions(size: int, chunk: int) -> tuple[NDArray[np.int_], NDArray[np.int_]]: + """Calculate chunk sizes and positions for a given dimension and chunk size.""" + # All chunks have the same size except for the last one + positions = np.arange(0, size, chunk) + lengths = np.minimum(chunk, size - positions) + + return positions, lengths + + +def _compute_chunks( + shape: tuple[int, int], + chunk_size: tuple[int, int], +) -> NDArray[np.int_]: + """Create all chunk specs for a given image and chunk size. + + Creates chunk specifications for tiling an image. Returns an array where position [i, j] + contains the spec for the chunk at block row i, block column j. + Each chunk specification consists of (y, x, height, width) with (y, x) being the upper left + corner of chunks of size chunk_size. + Chunks at the edges correspond to the remainder of chunk size and dimensions + + Parameters + ---------- + shape : tuple[int, int] + Size of the image in (image height, image width). + chunk_size : tuple[int, int] + Size of individual tiles in (height, width). + + Returns + ------- + np.ndarray + Array of shape (n_tiles_y, n_tiles_x, 4). Each entry defines a tile + as (y, x, height, width). + """ + y_positions, heights = _compute_chunk_sizes_positions(shape[0], chunk_size[0]) + x_positions, widths = _compute_chunk_sizes_positions(shape[1], chunk_size[1]) + + # Generate the tiles + # Each entry defines the chunk dimensions for a tile + # The order of the chunk definitions (chunk_index_y=outer, chunk_index_x=inner) follows the dask.block convention + tiles = np.array( + [ + [[y, x, h, w] for x, w in zip(x_positions, widths, strict=True)] + for y, h in zip(y_positions, heights, strict=True) + ], + dtype=int, + ) + return tiles + + +def _read_chunks( + func: Callable[..., NDArray[np.number]], + slide: Any, + coords: NDArray[np.int_], + n_channel: int, + dtype: np.dtype[Any], + **func_kwargs: Any, +) -> list[list[da.Array]]: + """Abstract method to tile a large microscopy image. + + Parameters + ---------- + func + Function to retrieve a single rectangular tile from the slide image. Must take the + arguments: + + - slide: Full slide image + - y0: y (row) coordinate of upper left corner of chunk + - x0: x (col) coordinate of upper left corner of chunk + - height: Height of chunk + - width: Width of chunk + + and should return the chunk as numpy array of shape (c, y, x) + slide + Slide image in lazily loaded format compatible with `func` + coords + Coordinates of the upper left corner of each chunk image in format (n_row_y, n_row_x, 4) + where the last dimension defines the rectangular tile in format (y, x, height, width), as returned + by :func:`_compute_chunks`. + n_row_y represents the number of chunks in y dimension (block rows) and n_row_x the number of chunks + in x dimension (block columns). + n_channel + Number of channels in array + dtype + Data type of image + func_kwargs + Additional keyword arguments passed to func + + Returns + ------- + list[list[da.array]] + (Outer) list (length: n_row_y) of (inner) lists (length: n_row_x) of chunks with axes + (c, y, x). Represents all chunks of the full image. + """ + func_kwargs = func_kwargs if func_kwargs else {} + + # Collect each delayed chunk (c, y, x) as item in list of list + # Inner list becomes dim=-1 (chunk columns/x) + # Outer list becomes dim=-2 (chunk rows/y) + # see dask.array.block + chunks = [ + [ + da.from_delayed( + delayed(func)( + slide, + y0=coords[chunk_y, chunk_x, _Y_IDX], + x0=coords[chunk_y, chunk_x, _X_IDX], + height=coords[chunk_y, chunk_x, _HEIGHT_IDX], + width=coords[chunk_y, chunk_x, _WIDTH_IDX], + **func_kwargs, + ), + dtype=dtype, + shape=(n_channel, coords[chunk_y, chunk_x, _HEIGHT_IDX], coords[chunk_y, chunk_x, _WIDTH_IDX]), + ) + for chunk_x in range(coords.shape[1]) + ] + for chunk_y in range(coords.shape[0]) + ] + return chunks + + +def normalize_chunks( + chunks: Chunks_t | None, + axes: Sequence[str], +) -> dict[str, int]: + """Normalize chunk specification to dict format. + + This function converts various chunk formats to a dict mapping dimension names + to chunk sizes. The dict format is preferred because it's explicit about which + dimension gets which chunk size and is compatible with spatialdata. + + Parameters + ---------- + chunks + Chunk specification. Can be: + - None: Uses DEFAULT_CHUNK_SIZE for all axes + - int: Applied to all axes + - tuple[int, ...]: Chunk sizes in order corresponding to axes + - dict: Mapping of axis names to chunk sizes (validated against axes) + axes + Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')). + + Returns + ------- + dict[str, int] + Dict mapping axis names to chunk sizes. + + Raises + ------ + ValueError + If chunks format is not supported or incompatible with axes. + """ + if chunks is None: + return dict.fromkeys(axes, DEFAULT_CHUNK_SIZE) + + if isinstance(chunks, int): + return dict.fromkeys(axes, chunks) + + if isinstance(chunks, Mapping): + chunks_dict = dict(chunks) + missing = set(axes) - set(chunks_dict.keys()) + if missing: + raise ValueError(f"chunks dict missing keys for axes {missing}, got: {list(chunks_dict.keys())}") + return {ax: chunks_dict[ax] for ax in axes} + + if isinstance(chunks, tuple): + if len(chunks) != len(axes): + raise ValueError(f"chunks tuple length {len(chunks)} doesn't match axes {axes} (length {len(axes)})") + if not all(isinstance(c, int) for c in chunks): + raise ValueError(f"All elements in chunks tuple must be int, got: {chunks}") + return dict(zip(axes, chunks, strict=True)) + + raise ValueError(f"Unsupported chunks type: {type(chunks)}. Expected int, tuple, dict, or None.") + + +## diff --git a/src/spatialdata_io/readers/generic.py b/src/spatialdata_io/readers/generic.py index 904b94dc..462989e0 100644 --- a/src/spatialdata_io/readers/generic.py +++ b/src/spatialdata_io/readers/generic.py @@ -1,12 +1,15 @@ from __future__ import annotations -import warnings from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, TypeVar +import dask.array as da import numpy as np +import tifffile from dask_image.imread import imread +from geopandas import GeoDataFrame from spatialdata._docs import docstring_parameter +from spatialdata._logging import logger from spatialdata.models import Image2DModel, ShapesModel from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM from spatialdata.transformations import Identity @@ -15,13 +18,28 @@ from collections.abc import Sequence from geopandas import GeoDataFrame + from numpy.typing import NDArray + from spatialdata.models.models import Chunks_t from xarray import DataArray + +from spatialdata_io.readers._utils._image import ( + _compute_chunks, + _read_chunks, + normalize_chunks, +) + VALID_IMAGE_TYPES = [".tif", ".tiff", ".png", ".jpg", ".jpeg"] VALID_SHAPE_TYPES = [".geojson"] __all__ = ["generic", "geojson", "image", "VALID_IMAGE_TYPES", "VALID_SHAPE_TYPES"] +T = TypeVar("T", bound=np.generic) # Restrict to NumPy scalar types + + +class DaskArray(Protocol[T]): + dtype: np.dtype[T] + @docstring_parameter( valid_image_types=", ".join(VALID_IMAGE_TYPES), @@ -58,7 +76,7 @@ def generic( coordinate_system = DEFAULT_COORDINATE_SYSTEM if input.suffix in VALID_SHAPE_TYPES: if data_axes is not None: - warnings.warn("data_axes is not used for geojson files", UserWarning, stacklevel=2) + logger.warning("data_axes is not used for geojson files") return geojson(input, coordinate_system=coordinate_system) elif input.suffix in VALID_IMAGE_TYPES: if data_axes is None: @@ -73,11 +91,150 @@ def geojson(input: Path, coordinate_system: str) -> GeoDataFrame: return ShapesModel.parse(input, transformations={coordinate_system: Identity()}) -def image(input: Path, data_axes: Sequence[str], coordinate_system: str) -> DataArray: - """Reads an image file and returns a parsed Image2D spatial element.""" - # this function is just a draft, the more general one will be available when - # https://github.com/scverse/spatialdata-io/pull/234 is merged - image = imread(input) - if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1: - image = np.squeeze(image, axis=0) - return Image2DModel.parse(image, dims=data_axes, transformations={coordinate_system: Identity()}) +def _tiff_to_chunks( + input: Path, + axes_dim_mapping: dict[str, int], + chunks_cyx: dict[str, int], +) -> list[list[DaskArray[np.number]]]: + """Chunkwise reader for tiff files. + + Creates spatial tiles from a TIFF file. Each tile contains all channels. + Channel chunking is handled downstream by Image2DModel.parse(). + + Parameters + ---------- + input + Path to image + axes_dim_mapping + Mapping between dimension name (c, y, x) and index + chunks_cyx + Chunk size dict with 'c', 'y', and 'x' keys. The 'y' and 'x' values + are used for spatial tiling. The 'c' value is passed through for + downstream rechunking. + + Returns + ------- + list[list[DaskArray]] + 2D list of dask arrays representing spatial tiles, each with shape (n_channels, height, width). + """ + # Lazy file reader + slide = tifffile.memmap(input) + + # Transpose to cyx order + slide = np.transpose(slide, (axes_dim_mapping["c"], axes_dim_mapping["y"], axes_dim_mapping["x"])) + + # Get dimensions in (y, x) + slide_dimensions = slide.shape[1], slide.shape[2] + + # Get number of channels (all channels are included in each spatial tile) + n_channel = slide.shape[0] + + # Compute chunk coords using (y, x) tuple + chunk_coords = _compute_chunks(slide_dimensions, chunk_size=(chunks_cyx["y"], chunks_cyx["x"])) + + # Define reader func - reads all channels for each spatial tile + def _reader_func(slide: np.memmap, y0: int, x0: int, height: int, width: int) -> NDArray[np.number]: + return np.array(slide[:, y0 : y0 + height, x0 : x0 + width]) + + return _read_chunks(_reader_func, slide, coords=chunk_coords, n_channel=n_channel, dtype=slide.dtype) + + +def _dask_image_imread(input: Path, data_axes: Sequence[str], chunks_cyx: dict[str, int]) -> da.Array: + """Read image using dask-image and rechunk. + + Parameters + ---------- + input + Path to image file. + data_axes + Axes of the input data. + chunks_cyx + Chunk size dict with 'c', 'y', 'x' keys. + + Returns + ------- + Dask array with (c, y, x) axes order. + """ + if set(data_axes) != {"c", "y", "x"}: + raise NotImplementedError(f"Only 'c', 'y', 'x' axes are supported, got {data_axes}") + im = imread(input) + + # dask_image.imread may add an extra leading dimension for frames/pages + # If image has one extra dimension with size 1, squeeze it out + if im.ndim == len(data_axes) + 1 and im.shape[0] == 1: + im = im[0] + + if im.ndim != len(data_axes): + raise ValueError(f"Expected image with {len(data_axes)} dimensions, got {im.ndim}") + + im = im.transpose(*[data_axes.index(ax) for ax in ["c", "y", "x"]]) + return im.rechunk((chunks_cyx["c"], chunks_cyx["y"], chunks_cyx["x"])) + + +def image( + input: Path, + data_axes: Sequence[str], + coordinate_system: str, + use_tiff_memmap: bool = True, + chunks: Chunks_t | None = None, + scale_factors: Sequence[int] | None = None, +) -> DataArray: + """Reads an image file and returns a parsed Image2D spatial element. + + Parameters + ---------- + input + Path to the image file. + data_axes + Axes of the data (e.g., ('c', 'y', 'x') or ('y', 'x', 'c')). + coordinate_system + Coordinate system of the spatial element. + use_tiff_memmap + Whether to use memory-mapped reading for TIFF files. + chunks + Chunk size specification. Can be: + - int: Applied to all dimensions + - tuple: Chunk sizes matching the order of output axes (c, y, x) + - dict: Mapping of axis names to chunk sizes (e.g., {'c': 1, 'y': 1000, 'x': 1000}) + If None, uses a default (DEFAULT_CHUNK_SIZE) for all axes. + scale_factors + Scale factors for building a multiscale image pyramid. Passed to Image2DModel.parse(). + + Returns + ------- + Parsed Image2D spatial element. + """ + # Map passed data axes to position of dimension + axes_dim_mapping = {axes: ndim for ndim, axes in enumerate(data_axes)} + + chunks_dict = normalize_chunks(chunks, axes=data_axes) + + im = None + if input.suffix in [".tiff", ".tif"] and use_tiff_memmap: + try: + im_chunks = _tiff_to_chunks(input, axes_dim_mapping=axes_dim_mapping, chunks_cyx=chunks_dict) + im = da.block(im_chunks, allow_unknown_chunksizes=True) + + # Edge case: Compressed images are not memory-mappable + except ValueError as e: + logger.warning( + f"Exception occurred: {str(e)}\nPossible troubleshooting: image data " + "is not memory-mappable, potentially due to compression. Trying to " + "load the image into memory at once", + ) + use_tiff_memmap = False + + if input.suffix in [".tiff", ".tif"] and not use_tiff_memmap or input.suffix in [".png", ".jpg", ".jpeg"]: + im = _dask_image_imread(input=input, data_axes=data_axes, chunks_cyx=chunks_dict) + + if im is None: + raise NotImplementedError(f"File format {input.suffix} not implemented") + + # the output axes are always cyx + return Image2DModel.parse( + im, + dims=("c", "y", "x"), + transformations={coordinate_system: Identity()}, + scale_factors=scale_factors, + chunks=chunks_dict, + ) diff --git a/tests/readers/test_utils_image.py b/tests/readers/test_utils_image.py new file mode 100644 index 00000000..c2e08230 --- /dev/null +++ b/tests/readers/test_utils_image.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest +from numpy.typing import NDArray + +from spatialdata_io.readers._utils._image import ( + DEFAULT_CHUNK_SIZE, + Chunks_t, + _compute_chunk_sizes_positions, + _compute_chunks, + normalize_chunks, +) + + +@pytest.mark.parametrize( + ("size", "chunk", "expected_positions", "expected_lengths"), + [ + (300, 100, np.array([0, 100, 200]), np.array([100, 100, 100])), + (300, 200, np.array([0, 200]), np.array([200, 100])), + ], +) +def test_compute_chunk_sizes_positions( + size: int, + chunk: int, + expected_positions: NDArray[np.number], + expected_lengths: NDArray[np.number], +) -> None: + computed_positions, computed_lengths = _compute_chunk_sizes_positions(size, chunk) + assert (expected_positions == computed_positions).all() + assert (expected_lengths == computed_lengths).all() + + +@pytest.mark.parametrize( + ("dimensions", "chunk_size", "result"), + [ + # Regular grid 2x2 + ( + (2, 2), + (1, 1), + np.array( + [ + [[0, 0, 1, 1], [0, 1, 1, 1]], + [[1, 0, 1, 1], [1, 1, 1, 1]], + ] + ), + ), + # Different tile sizes + ( + (300, 300), + (100, 200), + np.array( + [ + [[0, 0, 100, 200], [0, 200, 100, 100]], + [[100, 0, 100, 200], [100, 200, 100, 100]], + [[200, 0, 100, 200], [200, 200, 100, 100]], + ] + ), + ), + ], +) +def test_compute_chunks( + dimensions: tuple[int, int], + chunk_size: tuple[int, int], + result: NDArray[np.number], +) -> None: + tiles = _compute_chunks(dimensions, chunk_size) + + assert (tiles == result).all() + + +@pytest.mark.parametrize( + "chunks, axes, expected", + [ + # 2D (y, x) + (None, ("y", "x"), {"y": DEFAULT_CHUNK_SIZE, "x": DEFAULT_CHUNK_SIZE}), + (256, ("y", "x"), {"y": 256, "x": 256}), + ((200, 100), ("x", "y"), {"y": 100, "x": 200}), + ({"y": 300, "x": 400}, ("x", "y"), {"y": 300, "x": 400}), + # 2D with channel (c, y, x) + (None, ("c", "y", "x"), {"c": DEFAULT_CHUNK_SIZE, "y": DEFAULT_CHUNK_SIZE, "x": DEFAULT_CHUNK_SIZE}), + (256, ("c", "y", "x"), {"c": 256, "y": 256, "x": 256}), + ((1, 100, 200), ("c", "y", "x"), {"c": 1, "y": 100, "x": 200}), + ({"c": 1, "y": 300, "x": 400}, ("c", "y", "x"), {"c": 1, "y": 300, "x": 400}), + # 3D (z, y, x) + ((10, 100, 200), ("z", "y", "x"), {"z": 10, "y": 100, "x": 200}), + ({"z": 10, "y": 300, "x": 400}, ("z", "y", "x"), {"z": 10, "y": 300, "x": 400}), + ], +) +def test_normalize_chunks_valid(chunks: Chunks_t, axes: tuple[str, ...], expected: dict[str, int]) -> None: + assert normalize_chunks(chunks, axes=axes) == expected + + +@pytest.mark.parametrize( + "chunks, axes, match", + [ + ({"y": 100}, ("y", "x"), "missing keys for axes"), + ((1, 2, 3), ("y", "x"), "doesn't match axes"), + ((1.5, 2), ("y", "x"), "must be int"), + ("invalid", ("y", "x"), "Unsupported chunks type"), + ], +) +def test_normalize_chunks_errors(chunks: Chunks_t, axes: tuple[str, ...], match: str) -> None: + with pytest.raises(ValueError, match=match): + normalize_chunks(chunks, axes=axes) diff --git a/tests/test_generic.py b/tests/test_generic.py index 466b7a3f..dd46e15d 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,3 +1,4 @@ +import logging import tempfile from collections.abc import Generator from contextlib import contextmanager @@ -8,10 +9,14 @@ from click.testing import CliRunner from PIL import Image from spatialdata import SpatialData +from spatialdata._logging import logger from spatialdata.datasets import blobs +from tifffile import imread as tiffread +from tifffile import imwrite as tiffwrite from spatialdata_io.__main__ import read_generic_wrapper from spatialdata_io.converters.generic_to_zarr import generic_to_zarr +from spatialdata_io.readers.generic import image @contextmanager @@ -33,6 +38,70 @@ def save_temp_files() -> Generator[tuple[Path, Path, Path], None, None]: yield jpg_path, geojson_path, Path(tmpdir) +@pytest.fixture( + scope="module", + params=[ + {"axes": ("c", "y", "x"), "compression": None}, + {"axes": ("x", "y", "c"), "compression": None}, + {"axes": ("c", "y", "x"), "compression": "lzw"}, + {"axes": ("x", "y", "c"), "compression": "lzw"}, + ], +) +def save_tiff_files( + request: pytest.FixtureRequest, +) -> Generator[tuple[Path, tuple[str], Path], None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + axes = request.param["axes"] + compression = request.param["compression"] + + sdata = blobs() + # save the image as tiff + x = sdata["blobs_image"].transpose(*axes).data.compute() + + tiff_path = Path(tmpdir) / "blobs_image.tiff" + tiffwrite(tiff_path, x, compression=compression) + + yield tiff_path, axes, compression + + +@pytest.mark.parametrize("scale_factors", [None, [2]]) +def test_read_tiff( + save_tiff_files: tuple[Path, tuple[str, ...], str | None], + caplog: pytest.LogCaptureFixture, + scale_factors: list[int] | None, +) -> None: + tiff_path, axes, compression = save_tiff_files + # Use asymmetric chunk sizes to catch errors with the ordering of chunk dimensions and the assembly of the individual chunks + CHUNKS = {"c": 2, "y": 29, "x": 71} + + logger.propagate = True + with caplog.at_level(logging.WARNING): + obj = image( + tiff_path, + data_axes=axes, + coordinate_system="global", + chunks=CHUNKS, + scale_factors=scale_factors, + use_tiff_memmap=True, + ) + logger.propagate = False + assert ("image data is not memory-mappable" in caplog.text) == (compression is not None) + + target = obj if scale_factors is None else obj["scale0"]["image"] + + # check chunks are correct + for i, ax in enumerate(("c", "y", "x")): + assert target.chunksizes[ax][0] in (CHUNKS[ax], target.shape[i]) + + # check multiscale is correct + if scale_factors is not None: + assert "scale0" in obj and len(obj.keys()) == len(scale_factors) + 1 + + # check pixel data + ref = tiffread(tiff_path).transpose(*[axes.index(ax) for ax in ("c", "y", "x")]) + assert (target.compute() == ref).all() + + @pytest.mark.parametrize("cli", [True, False]) @pytest.mark.parametrize("element_name", [None, "test_element"]) def test_read_generic_image(runner: CliRunner, cli: bool, element_name: str | None) -> None: