diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 760736c6..9f8f9d3a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -14,7 +14,7 @@ from anndata import AnnData from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame -from dask.dataframe import Scalar, read_parquet +from dask.dataframe import Scalar from geopandas import GeoDataFrame from shapely import MultiPolygon, Polygon from upath import UPath @@ -1979,21 +1979,14 @@ def h(s: str) -> str: if attr == "shapes": descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} shape: {v.shape} (2D shapes)" elif attr == "points": + import pyarrow.parquet as pq + + from spatialdata._io._utils import get_dask_backing_files + length: int | None = None - if len(v.dask) == 1: - name, layer = v.dask.items().__iter__().__next__() - if "read-parquet" in name: - t = layer.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - table = read_parquet(parquet_file) - length = len(table) - else: - # length = len(v) - length = None - else: - length = None + backing_files = get_dask_backing_files(v) + if backing_files: + length = sum(pq.read_metadata(f).num_rows for f in backing_files) n = len(get_axes_names(v)) dim_string = f"({n}D points)" @@ -2084,8 +2077,8 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: - backing_files = ", ".join(get_dask_backing_files(element)) - descr += f"\n ▸ {element_name}: {backing_files}" + backing_files_str = ", ".join(get_dask_backing_files(element)) + descr += f"\n ▸ {element_name}: {backing_files_str}" if self.path is not None: elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 1f7be358..00ca6494 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -5,11 +5,14 @@ from contextlib import nullcontext import dask.dataframe as dd +import numpy as np +import pandas as pd import pytest from upath import UPath -from spatialdata import read_zarr +from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import get_dask_backing_files, handle_read_errors +from spatialdata.models import PointsModel def test_backing_files_points(points): @@ -141,3 +144,23 @@ def test_handle_read_errors(on_bad_files: str, actual_error: Exception, expectat with handle_read_errors(on_bad_files=on_bad_files, location="location", exc_types=KeyError): if actual_error is not None: raise actual_error + + +def test_repr_points_shows_row_count(): + """repr() must show the concrete row count, not , for backed points.""" + with tempfile.TemporaryDirectory() as tmp: + parquet_path = os.path.join(tmp, "points.parquet") + n_rows = 400 + rng = np.random.default_rng(0) + df = pd.DataFrame({"x": rng.random(n_rows), "y": rng.random(n_rows)}) + # aggregate_files=True produces a list-of-piece-dicts graph, the case reported in #1084 + dd.from_pandas(df, npartitions=4).to_parquet(parquet_path, write_index=False) + ddf = dd.read_parquet(parquet_path, aggregate_files=True) + + points = PointsModel.parse(ddf) + sdata = SpatialData(points={"pts": points}) + sdata.write(os.path.join(tmp, "example.zarr")) + + r = repr(sdata) + assert f"({n_rows}," in r, f"expected row count {n_rows} in repr, got: {r}" + assert "" not in r, f"repr still contains : {r}"