Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from functools import singledispatch
from typing import TYPE_CHECKING, Any, cast

import dask
import dask.array as da
import dask.dataframe as dd
import dask_image.ndinterp
import numpy as np
import pandas as pd
from dask.array.core import Array as DaskArray
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
Expand Down Expand Up @@ -442,29 +443,51 @@ def _(
axes = get_axes_names(data)
arrays = []

# Workaround to prevent partition collaps and missing dependency problem for now.
# Dask's expression optimizer can collapse partitions at compute time, making the partition
# structure inside vs. outside a disable_dask_tune_optimization() context inconsistent. To avoid
# index-alignment failures (e.g. "cannot reindex on an axis with duplicate labels" from parquet
# files that each start their index at 0) and length-mismatch errors, we materialise the non-axis
# columns and compute the axis arrays inside a single context where the partition structure is
# stable, then do plain pandas operations and re-wrap with dd.from_delayed (not dd.from_pandas,
# which sorts by index and would scramble rows for non-monotonic or duplicate indices).
with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
lengths = [len(part) for part in data.partitions]
for ax in axes:
# TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization
# leads to collaps of the partitions. However this causes a missing dependency problem, which for now is
# leads to collapse of the partitions. However this causes a missing dependency problem, which for now is
# prevented by setting the optimization to False when performing this operation.
arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1))
arrays.append(data[ax].to_dask_array(lengths=lengths).reshape(-1, 1))

xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)})
xtransformed = transformation._transform_coordinates(xdata)
transformed = data.drop(columns=list(axes)).copy()
# dummy transformation that will be replaced by _adjust_transformation()
default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()}
transformed.attrs[TRANSFORM_KEY] = default_cs
xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(sum(lengths)), "dim": list(axes)})
xtransformed = transformation._transform_coordinates(xdata)

# Compute non-axis columns while the partition structure is still stable; preserves original index.
transformed_pd = data.drop(columns=list(axes)).compute()

for ax in axes:
indices = xtransformed["dim"] == ax
new_ax = xtransformed[:, indices]
# TODO: discuss with dask team
# This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing
# a getattr missing dependency of dependent from_dask_array.
new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index)
transformed[ax] = new_col
# Assigning a numpy array is positional (no index alignment), so the original index is preserved.
transformed_pd[ax] = new_ax.data.flatten().compute()

# Reconstruct as a dask DataFrame via delayed partitions so that:
# (a) row order matches the original (dd.from_pandas sorts by index, which scrambles rows for
# non-monotonic or duplicate indices such as those produced by multi-file parquet reads), and
# (b) the original index is preserved exactly.
offsets = np.cumsum([0] + lengths)
delayed_parts = [dask.delayed(transformed_pd.iloc[offsets[i] : offsets[i + 1]]) for i in range(len(lengths))]
transformed = dd.from_delayed(delayed_parts, meta=transformed_pd.iloc[:0])
# Preserve spatialdata_attrs (feature_key, instance_key, …) from the original element;
# dd.from_delayed starts with empty attrs so we must copy them explicitly.
for k, v in data.attrs.items():
if k != TRANSFORM_KEY:
transformed.attrs[k] = v
# dummy transformation that will be replaced by _adjust_transformation()
default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()}
transformed.attrs[TRANSFORM_KEY] = default_cs

old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True))

Expand Down
48 changes: 48 additions & 0 deletions tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from dask import config
from geopandas.testing import geom_almost_equals
Expand Down Expand Up @@ -590,6 +591,53 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa
_ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning)


def test_transform_points_duplicate_index_gh1105(tmp_path: str):
"""Regression test for https://github.com/scverse/spatialdata/issues/1105.

Points loaded from multiple parquet files (e.g. Xenium transcripts) have a per-file 0-based
index, so the global dask DataFrame index has duplicate labels. The old implementation passed
``index=transformed.index`` to ``pd.Series``, which materialised the duplicate dask Index and
caused ``ValueError: cannot reindex on an axis with duplicate labels`` when assigning back.
"""
import dask.dataframe as dd

n_per_partition = 50
n_partitions = 4
rng = np.random.default_rng(0)

# Simulate multi-file parquet: each partition's index starts at 0
parts = [
pd.DataFrame(
{
"x": rng.random(n_per_partition).astype("float32"),
"y": rng.random(n_per_partition).astype("float32"),
"gene": [f"gene_{j}" for j in range(n_per_partition)],
}
)
for _ in range(n_partitions)
]
# test also the case of non-contiguous indices
for part in parts:
part.index = part.index.to_list()[:-1] + [100]
ddf = dd.from_map(lambda df: df, parts)
assert not ddf.index.compute().is_unique, "test setup: index must have duplicates"

scale_factor = 4
points = PointsModel.parse(ddf)
set_transformation(points, Scale([scale_factor, scale_factor], axes=("x", "y")), to_coordinate_system="global")

result = transform(points, to_coordinate_system="global")
result_pd = result.compute()

# Index must be preserved as-is (duplicate [0..49] × 4)
assert list(result_pd.index) == list(ddf.compute().index)
# Non-axis column must survive unchanged
assert list(result_pd["gene"]) == list(ddf.compute()["gene"])
# Axis values must be correctly scaled
expected_x = ddf.compute()["x"].values * scale_factor
np.testing.assert_allclose(result_pd["x"].values, expected_x, rtol=1e-5)


def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str):
tmpdir = Path(tmp_path) / "tmp.zarr"
points_memory = full_sdata["points_0"].compute()
Expand Down
Loading