From 80923fde9672c37cdc85d5da40e8ffaa45239930 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 13 May 2026 18:04:29 +0200 Subject: [PATCH 1/2] fix _search_for_backing_files_recursively() to support multifile parquet --- src/spatialdata/_io/_utils.py | 51 +++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 6690d111..81f95d09 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -27,6 +27,7 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._io.format import RasterFormatType, RasterFormatV01, RasterFormatV02, RasterFormatV03 +from spatialdata._logging import logger from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -357,17 +358,45 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No # This occurs when for example points and images are mixed, the main task still starts with # read_parquet, but the execution happens through a subgraph which we iterate over to get the # actual read_parquet task. - for task in v.args[0].values(): - # Recursively go through tasks, this is required because differences between dask versions. - piece_dict = _find_piece_dict(task) - if isinstance(piece_dict, dict) and "piece" in piece_dict: - parquet_file, check0, check1 = piece_dict["piece"] # type: ignore[misc] - if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None: - raise ValueError( - f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " - f"report this bug." - ) - files.append(os.path.realpath(parquet_file)) + # + # v.args[0] has two known shapes: + # dict – keys are task keys, values are Task objects (classic subgraph case) + # list – list of piece dicts produced when aggregate_files=True aggregates multiple + # parquet files into one partition; check0/check1 are row-group selectors + # ([0], []) rather than None, so only the file extension is validated. + args0 = v.args[0] + if isinstance(args0, dict): + for task in args0.values(): + # Recursively go through tasks, this is required because differences between dask + # versions. + piece_dict = _find_piece_dict(task) + if isinstance(piece_dict, dict) and "piece" in piece_dict: + parquet_file, check0, check1 = piece_dict["piece"] # type: ignore[misc] + if ( + not parquet_file.endswith(".parquet") + or check0 is not None + or check1 is not None + ): + raise ValueError( + f"Unable to parse the parquet file from the dask subgraph {subgraph}. " + f"Please report this bug." + ) + files.append(os.path.realpath(parquet_file)) + elif isinstance(args0, list): + for item in args0: + if isinstance(item, dict) and "piece" in item: + parquet_file = item["piece"][0] + if not parquet_file.endswith(".parquet"): + raise ValueError( + f"Unable to parse the parquet file from the dask subgraph {subgraph}. " + f"Please report this bug." + ) + files.append(os.path.realpath(parquet_file)) + else: + logger.warning( + f"Unexpected type {type(args0)} for v.args[0] in the read_parquet task graph. " + f"Backing files may not be detected correctly. Please report this as a bug." + ) def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]: From c2e3d06ef3df3608b91075779c9b6ac998c91060 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 15 May 2026 11:40:40 +0200 Subject: [PATCH 2/2] Fix repr showing instead of row count for backed points (#1084) Replace broken dask graph introspection (which only worked for single-task graphs with a HighLevelGraph layer API that no longer exists) with get_dask_backing_files() + pyarrow footer metadata reads. This handles all graph shapes including the list-of-piece-dicts case produced by aggregate_files=True. Co-Authored-By: Claude Sonnet 4.6 --- src/spatialdata/_core/spatialdata.py | 27 ++++++++++----------------- tests/io/test_utils.py | 25 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 760736c6..9f8f9d3a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -14,7 +14,7 @@ from anndata import AnnData from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame -from dask.dataframe import Scalar, read_parquet +from dask.dataframe import Scalar from geopandas import GeoDataFrame from shapely import MultiPolygon, Polygon from upath import UPath @@ -1979,21 +1979,14 @@ def h(s: str) -> str: if attr == "shapes": descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} shape: {v.shape} (2D shapes)" elif attr == "points": + import pyarrow.parquet as pq + + from spatialdata._io._utils import get_dask_backing_files + length: int | None = None - if len(v.dask) == 1: - name, layer = v.dask.items().__iter__().__next__() - if "read-parquet" in name: - t = layer.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - table = read_parquet(parquet_file) - length = len(table) - else: - # length = len(v) - length = None - else: - length = None + backing_files = get_dask_backing_files(v) + if backing_files: + length = sum(pq.read_metadata(f).num_rows for f in backing_files) n = len(get_axes_names(v)) dim_string = f"({n}D points)" @@ -2084,8 +2077,8 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: - backing_files = ", ".join(get_dask_backing_files(element)) - descr += f"\n ▸ {element_name}: {backing_files}" + backing_files_str = ", ".join(get_dask_backing_files(element)) + descr += f"\n ▸ {element_name}: {backing_files_str}" if self.path is not None: elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 1f7be358..00ca6494 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -5,11 +5,14 @@ from contextlib import nullcontext import dask.dataframe as dd +import numpy as np +import pandas as pd import pytest from upath import UPath -from spatialdata import read_zarr +from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import get_dask_backing_files, handle_read_errors +from spatialdata.models import PointsModel def test_backing_files_points(points): @@ -141,3 +144,23 @@ def test_handle_read_errors(on_bad_files: str, actual_error: Exception, expectat with handle_read_errors(on_bad_files=on_bad_files, location="location", exc_types=KeyError): if actual_error is not None: raise actual_error + + +def test_repr_points_shows_row_count(): + """repr() must show the concrete row count, not , for backed points.""" + with tempfile.TemporaryDirectory() as tmp: + parquet_path = os.path.join(tmp, "points.parquet") + n_rows = 400 + rng = np.random.default_rng(0) + df = pd.DataFrame({"x": rng.random(n_rows), "y": rng.random(n_rows)}) + # aggregate_files=True produces a list-of-piece-dicts graph, the case reported in #1084 + dd.from_pandas(df, npartitions=4).to_parquet(parquet_path, write_index=False) + ddf = dd.read_parquet(parquet_path, aggregate_files=True) + + points = PointsModel.parse(ddf) + sdata = SpatialData(points={"pts": points}) + sdata.write(os.path.join(tmp, "example.zarr")) + + r = repr(sdata) + assert f"({n_rows}," in r, f"expected row count {n_rows} in repr, got: {r}" + assert "" not in r, f"repr still contains : {r}"