Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
40326e9
implement selection
giovp Aug 21, 2024
aa339aa
update
giovp Aug 21, 2024
4ac406e
Merge branch 'main' into giovp/dataloader3
giovp Sep 2, 2024
92d578f
vectorize adjust_bounding_box_to_real_axes
giovp Sep 2, 2024
2bb5c35
update
giovp Sep 2, 2024
c89dcdf
replace append with insert
giovp Sep 2, 2024
5bf0b43
add comment
giovp Sep 2, 2024
a60bf6f
vectorize
giovp Sep 2, 2024
017967b
update to handle multiple boxes
giovp Sep 2, 2024
ab774b7
vectorize with numba
giovp Sep 2, 2024
804b30a
Merge branch 'giovp/parallel-transform' into giovp/dataloader3
giovp Sep 2, 2024
38dba25
fix corner len
giovp Sep 2, 2024
df80902
Merge branch 'giovp/parallel-transform' into giovp/dataloader3
giovp Sep 2, 2024
b27607e
update
giovp Sep 3, 2024
a934e21
fix validation
giovp Sep 3, 2024
5bdd9df
Merge branch 'giovp/parallel-transform' into giovp/dataloader3
giovp Sep 3, 2024
77f73f4
refactor
giovp Sep 3, 2024
3adfea8
refactor
giovp Sep 3, 2024
dfdfdbf
add test for query with multiple bounding boxes
giovp Sep 3, 2024
5c5560d
fix typing
giovp Sep 3, 2024
dd2c573
vectorize bounding box query on polygons
giovp Sep 3, 2024
be95358
add test to cover no polygon overlap (None)
giovp Sep 3, 2024
fad9b1a
vectorize bounding box query on points and tests
giovp Sep 4, 2024
9b977d6
fix type
giovp Sep 4, 2024
f3f3d27
Merge branch 'giovp/parallel-transform' into giovp/dataloader3
giovp Sep 5, 2024
208e217
Merge branch 'main' into giovp/dataloader3
giovp Feb 24, 2025
df21672
Merge branch 'main' into giovp/dataloader3
LucaMarconato May 15, 2026
e87d318
Fix rasterize path and bugs in PR #687 dataloader; add benchmark
LucaMarconato May 15, 2026
a51cb95
add asv benchmark for dataloader performance
LucaMarconato May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@ git checkout - && git stash pop
asv compare main HEAD
```

### Dataloader benchmarks

Dataloader benchmarks live in `benchmarks/benchmark_dataloader.py`. They use a synthetic in-memory `SpatialData` (2048×2048 image, 500 circle regions) and compute two metrics:

- `time_init` — constructing `ImageTilesDataset` (includes bounding-box pre-computation).
- `time_fetch` — iterating over all 500 tiles once (pure `__getitem__` calls, no `DataLoader` overhead).

Run both in your current environment:

```bash
asv run --python=same --show-stderr -b TimeDataloader
```

Or a single method:

```bash
asv run --python=same --show-stderr -b TimeDataloader.time_init
asv run --python=same --show-stderr -b TimeDataloader.time_fetch
```

Compare against `main` in one shot:

```bash
asv continuous --show-stderr -v -b TimeDataloader main HEAD
```

### Querying benchmarks

Querying using a bounding box without a spatial index is highly impacted by large amounts of points (transcripts), more than table rows (cells).
Expand Down
75 changes: 75 additions & 0 deletions benchmarks/benchmark_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# type: ignore
"""Benchmarks for ImageTilesDataset: init time and iteration time."""

from __future__ import annotations

import anndata as ad
import geopandas as gpd
import numpy as np
import pandas as pd
from shapely.geometry import Point

import spatialdata as sd
from spatialdata.dataloader import ImageTilesDataset
from spatialdata.models import Image2DModel, ShapesModel, TableModel
from spatialdata.transformations import Identity

_RNG = np.random.default_rng(42)

_IMAGE_SIZE = 2048
_N_CIRCLES = 500
_N_CHANNELS = 3

_DATASET_KWARGS = {
"regions_to_images": {"circles": "image"},
"regions_to_coordinate_systems": {"circles": "global"},
"table_name": "table",
"return_annotations": "instance_id",
}


def _make_sdata() -> sd.SpatialData:
img_data = _RNG.integers(0, 256, size=(_N_CHANNELS, _IMAGE_SIZE, _IMAGE_SIZE), dtype=np.uint8).astype(np.float32)
image = Image2DModel.parse(img_data, dims=["c", "y", "x"], transformations={"global": Identity()})

radius = 32.0
cx = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES)
cy = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES)
geom = gpd.GeoDataFrame({"geometry": [Point(x, y) for x, y in zip(cx, cy, strict=True)]})
geom["radius"] = radius
circles = ShapesModel.parse(geom, transformations={"global": Identity()})

table = ad.AnnData(
_RNG.random((_N_CIRCLES, 10)).astype(np.float32),
obs=pd.DataFrame(
{
"region": pd.Categorical(["circles"] * _N_CIRCLES),
"instance_id": np.arange(_N_CIRCLES, dtype=np.int64),
},
index=[str(i) for i in range(_N_CIRCLES)],
),
)
table = TableModel.parse(table, region="circles", region_key="region", instance_key="instance_id")

return sd.SpatialData(images={"image": image}, shapes={"circles": circles}, tables={"table": table})


class TimeDataloader:
"""Time ImageTilesDataset construction and tile iteration."""

def setup(self):
self.sdata = _make_sdata()
self.ds = ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS)

def teardown(self):
del self.ds
del self.sdata

def time_init(self):
"""Time constructing ImageTilesDataset (bounding-box pre-computation)."""
ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS)

def time_fetch(self):
"""Time iterating over every tile once."""
for i in range(len(self.ds)):
_ = self.ds[i]
8 changes: 4 additions & 4 deletions src/spatialdata/_core/query/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def get_bounding_box_corners(
return output.squeeze().drop_vars("box")


@nb.jit(parallel=False, nopython=True)
@nb.njit(parallel=False)
def _create_slices_and_translation(
min_values: nb.types.Array,
max_values: nb.types.Array,
) -> tuple[nb.types.Array, nb.types.Array]:
min_values: np.ndarray,
max_values: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
n_boxes, n_dims = min_values.shape
slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max])
translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _(
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

# for triggering validation
# for triggering validation (handles both 1-D single-box and 2-D multi-box arrays)
_ = BoundingBoxRequest(
target_coordinate_system=target_coordinate_system,
axes=axes,
Expand Down
52 changes: 35 additions & 17 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def __init__(
from spatialdata import bounding_box_query
from spatialdata._core.operations.rasterize import rasterize as rasterize_fn

self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name)
self.sdata = sdata
self._rasterize = rasterize
self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name)
self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name)

if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False:
Expand All @@ -151,14 +153,12 @@ def __init__(

def _validate(
self,
sdata: SpatialData,
regions_to_images: dict[str, str],
regions_to_coordinate_systems: dict[str, str],
return_annotations: str | list[str] | None,
table_name: str | None,
) -> None:
"""Validate input parameters."""
self.sdata = sdata
if return_annotations is not None and table_name is None:
raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.")

Expand All @@ -173,8 +173,8 @@ def _validate(
image_name = regions_to_images[region_name]

# get elements
region_elem = sdata[region_name]
image_elem = sdata[image_name]
region_elem = self.sdata[region_name]
image_elem = self.sdata[image_name]

# check that the elements are supported
if get_model(region_elem) == PointsModel:
Expand All @@ -199,13 +199,13 @@ def _validate(
)

if table_name is not None:
_, region_key, instance_key = get_table_keys(sdata.tables[table_name])
_, region_key, instance_key = get_table_keys(self.sdata.tables[table_name])
if get_model(region_elem) in [Labels2DModel, Labels3DModel]:
indices = get_element_instances(region_elem).tolist()
else:
indices = region_elem.index.tolist()
table = sdata.tables[table_name]
if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype):
table = self.sdata.tables[table_name]
if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype):
raise TypeError(
f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. "
f"Please convert it."
Expand All @@ -228,8 +228,10 @@ def _preprocess(
table_name: str | None,
) -> None:
"""Preprocess the dataset."""
from spatialdata import bounding_box_query

if table_name is not None:
_, region_key, instance_key = get_table_keys(self.sdata.tables[table_name])
_, region_key, _ = get_table_keys(self.sdata.tables[table_name])
filtered_table = self.sdata.tables[table_name][
self.sdata.tables[table_name].obs[region_key].isin(self.regions)
] # filtered table for the data loader
Expand All @@ -249,6 +251,18 @@ def _preprocess(
tile_scale=tile_scale,
tile_dim_in_units=tile_dim_in_units,
)
if not rasterize:
# Pre-compute all per-tile slice selections in a single vectorized call.
# Passing 2-D min/max arrays triggers the multi-box path in bounding_box_query,
# which returns a list of {axis: slice} dicts — one per tile.
tile_coords["selection"] = bounding_box_query(
self.sdata[image_name],
("x", "y"),
min_coordinate=tile_coords[["minx", "miny"]].values,
max_coordinate=tile_coords[["maxx", "maxy"]].values,
target_coordinate_system=cs,
return_request_only=True,
)
tile_coords_df.append(tile_coords)

inst = circles.index.values
Expand Down Expand Up @@ -276,7 +290,7 @@ def _preprocess(
self.dataset_index = pd.concat(index_df).reset_index(drop=True)
assert len(self.tiles_coords) == len(self.dataset_index)
if table_name:
self.dataset_table = ad.concat(*tables_l)
self.dataset_table = ad.concat(tables_l)
assert len(self.tiles_coords) == len(self.dataset_table)

dims_ = set(chain(*dims_l))
Expand Down Expand Up @@ -356,13 +370,17 @@ def __getitem__(self, idx: int) -> Any | SpatialData:
t_coords = self.tiles_coords.iloc[idx]

image = self.sdata[row["image"]]
tile = self._crop_image(
image,
axes=tuple(self.dims),
min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values,
max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values,
target_coordinate_system=row["cs"],
)
if self._rasterize:
tile = self._crop_image(
image,
axes=tuple(self.dims),
min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values,
max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values,
target_coordinate_system=row["cs"],
)
else:
# Use pre-computed slice selection (vectorized at init time).
tile = image.sel(t_coords["selection"])
if self.transform is not None:
out = self._return(idx, tile)
return self.transform(out)
Expand Down
Loading