From f51d357473e66613202ac166e0fc72877a8af5d4 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 23 Mar 2026 21:18:25 -0700 Subject: [PATCH 1/4] Add rechunk_no_shuffle utility (#1067) Computes integer multiplier per dimension so new chunks are exact multiples of source chunks, avoiding the shuffle dask triggers when it has to split and recombine partial blocks. Available as xrspatial.rechunk_no_shuffle() and on the .xrs accessor. --- xrspatial/__init__.py | 1 + xrspatial/accessor.py | 6 ++++ xrspatial/utils.py | 74 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/xrspatial/__init__.py b/xrspatial/__init__.py index bcf1f814..3f7d4f72 100644 --- a/xrspatial/__init__.py +++ b/xrspatial/__init__.py @@ -134,6 +134,7 @@ from xrspatial.zonal import suggest_zonal_canvas as suggest_zonal_canvas # noqa from xrspatial.reproject import merge # noqa from xrspatial.reproject import reproject # noqa +from xrspatial.utils import rechunk_no_shuffle # noqa import xrspatial.mcda # noqa: F401 — exposes xrspatial.mcda subpackage import xrspatial.accessor # noqa: F401 — registers .xrs accessors diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 1632432a..3db3e048 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -494,6 +494,12 @@ def to_geotiff(self, path, **kwargs): from .geotiff import to_geotiff return to_geotiff(self._obj, path, **kwargs) + # ---- Chunking ---- + + def rechunk_no_shuffle(self, **kwargs): + from .utils import rechunk_no_shuffle + return rechunk_no_shuffle(self._obj, **kwargs) + @xr.register_dataset_accessor("xrs") class XrsSpatialDatasetAccessor: diff --git a/xrspatial/utils.py b/xrspatial/utils.py index 26046edd..5e72de4b 100644 --- a/xrspatial/utils.py +++ b/xrspatial/utils.py @@ -1026,3 +1026,77 @@ def _sample_windows_min_max( # numpy scalars return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float))) + + +def rechunk_no_shuffle(agg, target_mb=128): + """Rechunk a dask-backed DataArray without triggering a shuffle. + + Computes an integer multiplier per dimension so that each new chunk + is an exact multiple of the original chunk size. This lets dask + merge whole source chunks in-place instead of splitting and + recombining partial blocks (which is effectively a shuffle). + + Parameters + ---------- + agg : xr.DataArray + Input raster. If not backed by a dask array the input is + returned unchanged. + target_mb : int or float + Target chunk size in megabytes. The actual chunk size will be + the closest multiple of the source chunk that does not exceed + this target. Default 128. + + Returns + ------- + xr.DataArray + Rechunked DataArray. Coordinates and attributes are preserved. + + Raises + ------ + TypeError + If *agg* is not an ``xr.DataArray``. + ValueError + If *target_mb* is not positive. + + Examples + -------- + >>> import dask.array as da + >>> import xarray as xr + >>> arr = xr.DataArray(da.zeros((4096, 4096), chunks=256)) + >>> big = rechunk_no_shuffle(arr, target_mb=64) + >>> big.chunks # multiples of 256 + """ + if not isinstance(agg, xr.DataArray): + raise TypeError( + f"rechunk_no_shuffle(): expected xr.DataArray, " + f"got {type(agg).__name__}" + ) + if target_mb <= 0: + raise ValueError( + f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" + ) + + if not has_dask_array() or not isinstance(agg.data, da.Array): + return agg + + chunks = agg.chunks # tuple of tuples + base = tuple(c[0] for c in chunks) + + current_bytes = agg.dtype.itemsize + for b in base: + current_bytes *= b + + target_bytes = target_mb * 1024 * 1024 + + if current_bytes >= target_bytes: + return agg + + ndim = len(base) + ratio = target_bytes / current_bytes + multiplier = max(1, int(ratio ** (1.0 / ndim))) + + if multiplier <= 1: + return agg + + new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)} + return agg.chunk(new_chunks) From 7abe264fbc83d24bcc826c7a54e87fef7c597317 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 23 Mar 2026 21:19:06 -0700 Subject: [PATCH 2/4] Add tests for rechunk_no_shuffle (#1067) --- xrspatial/tests/test_accessor.py | 1 + xrspatial/tests/test_rechunk_no_shuffle.py | 123 +++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 xrspatial/tests/test_rechunk_no_shuffle.py diff --git a/xrspatial/tests/test_accessor.py b/xrspatial/tests/test_accessor.py index 88263aad..671c3f9f 100644 --- a/xrspatial/tests/test_accessor.py +++ b/xrspatial/tests/test_accessor.py @@ -89,6 +89,7 @@ def test_dataarray_accessor_has_expected_methods(elevation): 'generate_terrain', 'perlin', 'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi', 'rasterize', + 'rechunk_no_shuffle', ] for name in expected: assert name in names, f"Missing method: {name}" diff --git a/xrspatial/tests/test_rechunk_no_shuffle.py b/xrspatial/tests/test_rechunk_no_shuffle.py new file mode 100644 index 00000000..be6faa93 --- /dev/null +++ b/xrspatial/tests/test_rechunk_no_shuffle.py @@ -0,0 +1,123 @@ +"""Tests for rechunk_no_shuffle.""" + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.utils import rechunk_no_shuffle + +da = pytest.importorskip("dask.array") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_dask_raster(shape=(1024, 1024), chunks=256, dtype=np.float32): + data = da.zeros(shape, chunks=chunks, dtype=dtype) + dims = ['y', 'x'] if len(shape) == 2 else [f'd{i}' for i in range(len(shape))] + return xr.DataArray(data, dims=dims) + + +# --------------------------------------------------------------------------- +# Basic behaviour +# --------------------------------------------------------------------------- + +def test_chunks_are_exact_multiples(): + """New chunks must be an integer multiple of original chunks.""" + raster = _make_dask_raster(shape=(2048, 2048), chunks=128) + result = rechunk_no_shuffle(raster, target_mb=16) + + for orig, new in zip(raster.chunks, result.chunks): + base = orig[0] + # every chunk (except possibly the last) should be a multiple + for c in new[:-1]: + assert c % base == 0, f"chunk {c} is not a multiple of {base}" + + +def test_chunks_grow(): + """Output chunks should be larger than input when target is larger.""" + raster = _make_dask_raster(shape=(2048, 2048), chunks=64) + result = rechunk_no_shuffle(raster, target_mb=16) + assert result.chunks[0][0] > raster.chunks[0][0] + + +def test_already_large_returns_unchanged(): + """If chunks already meet or exceed target, return as-is.""" + raster = _make_dask_raster(shape=(512, 512), chunks=512) + result = rechunk_no_shuffle(raster, target_mb=0.5) + assert result.chunks == raster.chunks + + +def test_3d_input(): + """Works with 3-D arrays (e.g. stacked bands).""" + raster = _make_dask_raster(shape=(4, 512, 512), chunks=(1, 128, 128)) + result = rechunk_no_shuffle(raster, target_mb=16) + for orig, new in zip(raster.chunks, result.chunks): + base = orig[0] + for c in new[:-1]: + assert c % base == 0 + + +def test_preserves_values(): + """Rechunked array should contain identical data.""" + np.random.seed(1067) + data = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64) + raster = xr.DataArray(data, dims=['y', 'x']) + result = rechunk_no_shuffle(raster, target_mb=1) + np.testing.assert_array_equal(raster.values, result.values) + + +def test_preserves_coords_and_attrs(): + """Coordinates and attributes must survive rechunking.""" + data = da.zeros((256, 256), chunks=64, dtype=np.float32) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.arange(256), 'x': np.arange(256)}, + attrs={'crs': 'EPSG:4326'}, + ) + result = rechunk_no_shuffle(raster, target_mb=1) + assert result.attrs == raster.attrs + xr.testing.assert_equal(result.coords.to_dataset(), raster.coords.to_dataset()) + + +# --------------------------------------------------------------------------- +# Non-dask passthrough +# --------------------------------------------------------------------------- + +def test_numpy_passthrough(): + """Numpy-backed DataArray should be returned unchanged.""" + raster = xr.DataArray(np.zeros((100, 100)), dims=['y', 'x']) + result = rechunk_no_shuffle(raster, target_mb=1) + assert result is raster + + +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- + +def test_rejects_non_dataarray(): + with pytest.raises(TypeError, match="expected xr.DataArray"): + rechunk_no_shuffle(np.zeros((10, 10))) + + +def test_rejects_nonpositive_target(): + raster = _make_dask_raster() + with pytest.raises(ValueError, match="target_mb must be > 0"): + rechunk_no_shuffle(raster, target_mb=0) + with pytest.raises(ValueError, match="target_mb must be > 0"): + rechunk_no_shuffle(raster, target_mb=-1) + + +# --------------------------------------------------------------------------- +# Accessor integration +# --------------------------------------------------------------------------- + +def test_accessor(): + """The .xrs.rechunk_no_shuffle() accessor delegates correctly.""" + import xrspatial # noqa: F401 — registers accessor + raster = _make_dask_raster(shape=(1024, 1024), chunks=128) + direct = rechunk_no_shuffle(raster, target_mb=16) + via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16) + assert direct.chunks == via_accessor.chunks From e49ba4121bdc9ec9f619433af235c73884a49cab Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 23 Mar 2026 21:19:30 -0700 Subject: [PATCH 3/4] Add rechunk_no_shuffle to docs and README (#1067) --- README.md | 1 + docs/source/reference/utilities.rst | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/README.md b/README.md index 3888bb15..3026a44d 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,7 @@ Same-CRS tiles skip reprojection entirely and are placed by direct coordinate al | [Preview](xrspatial/preview.py) | Downsamples a raster to target pixel dimensions for visualization | Custom | ✅️ | ✅️ | ✅️ | 🔄 | | [Rescale](xrspatial/normalize.py) | Min-max normalization to a target range (default [0, 1]) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | | [Standardize](xrspatial/normalize.py) | Z-score normalization (subtract mean, divide by std) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [rechunk_no_shuffle](xrspatial/utils.py) | Rechunk dask arrays using whole-chunk multiples (no shuffle) | Custom | 🔄 | ✅️ | 🔄 | ✅️ | ----------- diff --git a/docs/source/reference/utilities.rst b/docs/source/reference/utilities.rst index a7b15c61..4bec052d 100644 --- a/docs/source/reference/utilities.rst +++ b/docs/source/reference/utilities.rst @@ -54,6 +54,13 @@ Normalization xrspatial.normalize.rescale xrspatial.normalize.standardize +Rechunking +========== +.. autosummary:: + :toctree: _autosummary + + xrspatial.utils.rechunk_no_shuffle + Diagnostics =========== .. autosummary:: From 0531d164326f8d735c110ee8110f4e60ae997be6 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 23 Mar 2026 21:19:59 -0700 Subject: [PATCH 4/4] Add rechunk_no_shuffle user guide notebook (#1067) --- .../user_guide/36_Rechunk_No_Shuffle.ipynb | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 examples/user_guide/36_Rechunk_No_Shuffle.ipynb diff --git a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb new file mode 100644 index 00000000..f890229d --- /dev/null +++ b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rechunk Without Shuffling\n", + "\n", + "When working with large dask-backed rasters, rechunking to bigger blocks can\n", + "speed up downstream operations like `slope()` or `focal_mean()` that use\n", + "`map_overlap`. But if the new chunk size is not an exact multiple of the\n", + "original, dask has to split and recombine blocks — essentially a shuffle —\n", + "which tanks performance.\n", + "\n", + "`rechunk_no_shuffle` picks the largest whole-chunk multiple that fits your\n", + "target size, so dask can merge blocks in place with zero shuffle overhead." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import dask.array as da\n", + "import xarray as xr\n", + "import xrspatial\n", + "from xrspatial.utils import rechunk_no_shuffle" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a synthetic dask-backed raster\n", + "\n", + "Start with a 4096 x 4096 raster chunked at 256 x 256 (about 0.25 MB per\n", + "chunk for float32)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "raw = np.random.rand(4096, 4096).astype(np.float32) * 1000\n", + "dem = xr.DataArray(\n", + " da.from_array(raw, chunks=256),\n", + " dims=['y', 'x'],\n", + " coords={\n", + " 'y': np.linspace(40.0, 41.0, 4096),\n", + " 'x': np.linspace(-105.0, -104.0, 4096),\n", + " },\n", + ")\n", + "print(f'Original chunks: {dem.chunks}')\n", + "print(f'Chunks per axis: {len(dem.chunks[0])} x {len(dem.chunks[1])}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rechunk to ~64 MB target\n", + "\n", + "Each new chunk will be an exact multiple of 256, so dask just groups\n", + "existing blocks together." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "big = rechunk_no_shuffle(dem, target_mb=64)\n", + "print(f'New chunks: {big.chunks}')\n", + "print(f'Chunks per axis: {len(big.chunks[0])} x {len(big.chunks[1])}')\n", + "print(f'Block size: {big.chunks[0][0]} x {big.chunks[1][0]}')\n", + "print(f'Multiple of 256: {big.chunks[0][0] // 256}x')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the .xrs accessor\n", + "\n", + "The same function is available directly on any DataArray." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "big_via_accessor = dem.xrs.rechunk_no_shuffle(target_mb=64)\n", + "print(f'Accessor chunks: {big_via_accessor.chunks}')\n", + "assert big.chunks == big_via_accessor.chunks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare task graph sizes\n", + "\n", + "Fewer, larger chunks means a smaller task graph for downstream operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from xrspatial.slope import slope\n", + "\n", + "slope_small = slope(dem)\n", + "slope_big = slope(big)\n", + "\n", + "print(f'slope() graph with original chunks: {len(dict(slope_small.data.__dask_graph__())):,} tasks')\n", + "print(f'slope() graph with rechunked: {len(dict(slope_big.data.__dask_graph__())):,} tasks')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Non-dask arrays pass through unchanged\n", + "\n", + "If the input is a plain numpy-backed DataArray, the function returns it\n", + "as-is — no copy, no error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "numpy_dem = xr.DataArray(raw, dims=['y', 'x'])\n", + "result = rechunk_no_shuffle(numpy_dem, target_mb=64)\n", + "assert result is numpy_dem\n", + "print('Numpy passthrough: OK')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}