diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index d29af709..49b9ef72 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -6,10 +6,11 @@ from functools import singledispatch from typing import TYPE_CHECKING, Any, cast +import dask import dask.array as da +import dask.dataframe as dd import dask_image.ndinterp import numpy as np -import pandas as pd from dask.array.core import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame @@ -442,20 +443,26 @@ def _( axes = get_axes_names(data) arrays = [] - # Workaround to prevent partition collaps and missing dependency problem for now. + # Dask's expression optimizer can collapse partitions at compute time, making the partition + # structure inside vs. outside a disable_dask_tune_optimization() context inconsistent. To avoid + # index-alignment failures (e.g. "cannot reindex on an axis with duplicate labels" from parquet + # files that each start their index at 0) and length-mismatch errors, we materialise the non-axis + # columns and compute the axis arrays inside a single context where the partition structure is + # stable, then do plain pandas operations and re-wrap with dd.from_delayed (not dd.from_pandas, + # which sorts by index and would scramble rows for non-monotonic or duplicate indices). with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): + lengths = [len(part) for part in data.partitions] for ax in axes: # TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization - # leads to collaps of the partitions. However this causes a missing dependency problem, which for now is + # leads to collapse of the partitions. However this causes a missing dependency problem, which for now is # prevented by setting the optimization to False when performing this operation. - arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1)) + arrays.append(data[ax].to_dask_array(lengths=lengths).reshape(-1, 1)) - xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)}) - xtransformed = transformation._transform_coordinates(xdata) - transformed = data.drop(columns=list(axes)).copy() - # dummy transformation that will be replaced by _adjust_transformation() - default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()} - transformed.attrs[TRANSFORM_KEY] = default_cs + xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(sum(lengths)), "dim": list(axes)}) + xtransformed = transformation._transform_coordinates(xdata) + + # Compute non-axis columns while the partition structure is still stable; preserves original index. + transformed_pd = data.drop(columns=list(axes)).compute() for ax in axes: indices = xtransformed["dim"] == ax @@ -463,8 +470,24 @@ def _( # TODO: discuss with dask team # This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing # a getattr missing dependency of dependent from_dask_array. - new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index) - transformed[ax] = new_col + # Assigning a numpy array is positional (no index alignment), so the original index is preserved. + transformed_pd[ax] = new_ax.data.flatten().compute() + + # Reconstruct as a dask DataFrame via delayed partitions so that: + # (a) row order matches the original (dd.from_pandas sorts by index, which scrambles rows for + # non-monotonic or duplicate indices such as those produced by multi-file parquet reads), and + # (b) the original index is preserved exactly. + offsets = np.cumsum([0] + lengths) + delayed_parts = [dask.delayed(transformed_pd.iloc[offsets[i] : offsets[i + 1]]) for i in range(len(lengths))] + transformed = dd.from_delayed(delayed_parts, meta=transformed_pd.iloc[:0]) + # Preserve spatialdata_attrs (feature_key, instance_key, …) from the original element; + # dd.from_delayed starts with empty attrs so we must copy them explicitly. + for k, v in data.attrs.items(): + if k != TRANSFORM_KEY: + transformed.attrs[k] = v + # dummy transformation that will be replaced by _adjust_transformation() + default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()} + transformed.attrs[TRANSFORM_KEY] = default_cs old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True)) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 0a251d5d..ef307ac7 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np +import pandas as pd import pytest from dask import config from geopandas.testing import geom_almost_equals @@ -590,6 +591,53 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa _ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning) +def test_transform_points_duplicate_index_gh1105(tmp_path: str): + """Regression test for https://github.com/scverse/spatialdata/issues/1105. + + Points loaded from multiple parquet files (e.g. Xenium transcripts) have a per-file 0-based + index, so the global dask DataFrame index has duplicate labels. The old implementation passed + ``index=transformed.index`` to ``pd.Series``, which materialised the duplicate dask Index and + caused ``ValueError: cannot reindex on an axis with duplicate labels`` when assigning back. + """ + import dask.dataframe as dd + + n_per_partition = 50 + n_partitions = 4 + rng = np.random.default_rng(0) + + # Simulate multi-file parquet: each partition's index starts at 0 + parts = [ + pd.DataFrame( + { + "x": rng.random(n_per_partition).astype("float32"), + "y": rng.random(n_per_partition).astype("float32"), + "gene": [f"gene_{j}" for j in range(n_per_partition)], + } + ) + for _ in range(n_partitions) + ] + # test also the case of non-contiguous indices + for part in parts: + part.index = part.index.to_list()[:-1] + [100] + ddf = dd.from_map(lambda df: df, parts) + assert not ddf.index.compute().is_unique, "test setup: index must have duplicates" + + scale_factor = 4 + points = PointsModel.parse(ddf) + set_transformation(points, Scale([scale_factor, scale_factor], axes=("x", "y")), to_coordinate_system="global") + + result = transform(points, to_coordinate_system="global") + result_pd = result.compute() + + # Index must be preserved as-is (duplicate [0..49] × 4) + assert list(result_pd.index) == list(ddf.compute().index) + # Non-axis column must survive unchanged + assert list(result_pd["gene"]) == list(ddf.compute()["gene"]) + # Axis values must be correctly scaled + expected_x = ddf.compute()["x"].values * scale_factor + np.testing.assert_allclose(result_pd["x"].values, expected_x, rtol=1e-5) + + def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" points_memory = full_sdata["points_0"].compute()