diff --git a/README.md b/README.md index 3888bb15..3026a44d 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,7 @@ Same-CRS tiles skip reprojection entirely and are placed by direct coordinate al | [Preview](xrspatial/preview.py) | Downsamples a raster to target pixel dimensions for visualization | Custom | ✅️ | ✅️ | ✅️ | 🔄 | | [Rescale](xrspatial/normalize.py) | Min-max normalization to a target range (default [0, 1]) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | | [Standardize](xrspatial/normalize.py) | Z-score normalization (subtract mean, divide by std) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [rechunk_no_shuffle](xrspatial/utils.py) | Rechunk dask arrays using whole-chunk multiples (no shuffle) | Custom | 🔄 | ✅️ | 🔄 | ✅️ | ----------- diff --git a/docs/source/reference/utilities.rst b/docs/source/reference/utilities.rst index a7b15c61..4bec052d 100644 --- a/docs/source/reference/utilities.rst +++ b/docs/source/reference/utilities.rst @@ -54,6 +54,13 @@ Normalization xrspatial.normalize.rescale xrspatial.normalize.standardize +Rechunking +========== +.. autosummary:: + :toctree: _autosummary + + xrspatial.utils.rechunk_no_shuffle + Diagnostics =========== .. autosummary:: diff --git a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb new file mode 100644 index 00000000..f890229d --- /dev/null +++ b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rechunk Without Shuffling\n", + "\n", + "When working with large dask-backed rasters, rechunking to bigger blocks can\n", + "speed up downstream operations like `slope()` or `focal_mean()` that use\n", + "`map_overlap`. But if the new chunk size is not an exact multiple of the\n", + "original, dask has to split and recombine blocks — essentially a shuffle —\n", + "which tanks performance.\n", + "\n", + "`rechunk_no_shuffle` picks the largest whole-chunk multiple that fits your\n", + "target size, so dask can merge blocks in place with zero shuffle overhead." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import dask.array as da\n", + "import xarray as xr\n", + "import xrspatial\n", + "from xrspatial.utils import rechunk_no_shuffle" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create a synthetic dask-backed raster\n", + "\n", + "Start with a 4096 x 4096 raster chunked at 256 x 256 (about 0.25 MB per\n", + "chunk for float32)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "raw = np.random.rand(4096, 4096).astype(np.float32) * 1000\n", + "dem = xr.DataArray(\n", + " da.from_array(raw, chunks=256),\n", + " dims=['y', 'x'],\n", + " coords={\n", + " 'y': np.linspace(40.0, 41.0, 4096),\n", + " 'x': np.linspace(-105.0, -104.0, 4096),\n", + " },\n", + ")\n", + "print(f'Original chunks: {dem.chunks}')\n", + "print(f'Chunks per axis: {len(dem.chunks[0])} x {len(dem.chunks[1])}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rechunk to ~64 MB target\n", + "\n", + "Each new chunk will be an exact multiple of 256, so dask just groups\n", + "existing blocks together." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "big = rechunk_no_shuffle(dem, target_mb=64)\n", + "print(f'New chunks: {big.chunks}')\n", + "print(f'Chunks per axis: {len(big.chunks[0])} x {len(big.chunks[1])}')\n", + "print(f'Block size: {big.chunks[0][0]} x {big.chunks[1][0]}')\n", + "print(f'Multiple of 256: {big.chunks[0][0] // 256}x')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the .xrs accessor\n", + "\n", + "The same function is available directly on any DataArray." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "big_via_accessor = dem.xrs.rechunk_no_shuffle(target_mb=64)\n", + "print(f'Accessor chunks: {big_via_accessor.chunks}')\n", + "assert big.chunks == big_via_accessor.chunks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare task graph sizes\n", + "\n", + "Fewer, larger chunks means a smaller task graph for downstream operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from xrspatial.slope import slope\n", + "\n", + "slope_small = slope(dem)\n", + "slope_big = slope(big)\n", + "\n", + "print(f'slope() graph with original chunks: {len(dict(slope_small.data.__dask_graph__())):,} tasks')\n", + "print(f'slope() graph with rechunked: {len(dict(slope_big.data.__dask_graph__())):,} tasks')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Non-dask arrays pass through unchanged\n", + "\n", + "If the input is a plain numpy-backed DataArray, the function returns it\n", + "as-is — no copy, no error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "numpy_dem = xr.DataArray(raw, dims=['y', 'x'])\n", + "result = rechunk_no_shuffle(numpy_dem, target_mb=64)\n", + "assert result is numpy_dem\n", + "print('Numpy passthrough: OK')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/xrspatial/__init__.py b/xrspatial/__init__.py index bcf1f814..3f7d4f72 100644 --- a/xrspatial/__init__.py +++ b/xrspatial/__init__.py @@ -134,6 +134,7 @@ from xrspatial.zonal import suggest_zonal_canvas as suggest_zonal_canvas # noqa from xrspatial.reproject import merge # noqa from xrspatial.reproject import reproject # noqa +from xrspatial.utils import rechunk_no_shuffle # noqa import xrspatial.mcda # noqa: F401 — exposes xrspatial.mcda subpackage import xrspatial.accessor # noqa: F401 — registers .xrs accessors diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 1632432a..3db3e048 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -494,6 +494,12 @@ def to_geotiff(self, path, **kwargs): from .geotiff import to_geotiff return to_geotiff(self._obj, path, **kwargs) + # ---- Chunking ---- + + def rechunk_no_shuffle(self, **kwargs): + from .utils import rechunk_no_shuffle + return rechunk_no_shuffle(self._obj, **kwargs) + @xr.register_dataset_accessor("xrs") class XrsSpatialDatasetAccessor: diff --git a/xrspatial/tests/test_accessor.py b/xrspatial/tests/test_accessor.py index 88263aad..671c3f9f 100644 --- a/xrspatial/tests/test_accessor.py +++ b/xrspatial/tests/test_accessor.py @@ -89,6 +89,7 @@ def test_dataarray_accessor_has_expected_methods(elevation): 'generate_terrain', 'perlin', 'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi', 'rasterize', + 'rechunk_no_shuffle', ] for name in expected: assert name in names, f"Missing method: {name}" diff --git a/xrspatial/tests/test_rechunk_no_shuffle.py b/xrspatial/tests/test_rechunk_no_shuffle.py new file mode 100644 index 00000000..be6faa93 --- /dev/null +++ b/xrspatial/tests/test_rechunk_no_shuffle.py @@ -0,0 +1,123 @@ +"""Tests for rechunk_no_shuffle.""" + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.utils import rechunk_no_shuffle + +da = pytest.importorskip("dask.array") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_dask_raster(shape=(1024, 1024), chunks=256, dtype=np.float32): + data = da.zeros(shape, chunks=chunks, dtype=dtype) + dims = ['y', 'x'] if len(shape) == 2 else [f'd{i}' for i in range(len(shape))] + return xr.DataArray(data, dims=dims) + + +# --------------------------------------------------------------------------- +# Basic behaviour +# --------------------------------------------------------------------------- + +def test_chunks_are_exact_multiples(): + """New chunks must be an integer multiple of original chunks.""" + raster = _make_dask_raster(shape=(2048, 2048), chunks=128) + result = rechunk_no_shuffle(raster, target_mb=16) + + for orig, new in zip(raster.chunks, result.chunks): + base = orig[0] + # every chunk (except possibly the last) should be a multiple + for c in new[:-1]: + assert c % base == 0, f"chunk {c} is not a multiple of {base}" + + +def test_chunks_grow(): + """Output chunks should be larger than input when target is larger.""" + raster = _make_dask_raster(shape=(2048, 2048), chunks=64) + result = rechunk_no_shuffle(raster, target_mb=16) + assert result.chunks[0][0] > raster.chunks[0][0] + + +def test_already_large_returns_unchanged(): + """If chunks already meet or exceed target, return as-is.""" + raster = _make_dask_raster(shape=(512, 512), chunks=512) + result = rechunk_no_shuffle(raster, target_mb=0.5) + assert result.chunks == raster.chunks + + +def test_3d_input(): + """Works with 3-D arrays (e.g. stacked bands).""" + raster = _make_dask_raster(shape=(4, 512, 512), chunks=(1, 128, 128)) + result = rechunk_no_shuffle(raster, target_mb=16) + for orig, new in zip(raster.chunks, result.chunks): + base = orig[0] + for c in new[:-1]: + assert c % base == 0 + + +def test_preserves_values(): + """Rechunked array should contain identical data.""" + np.random.seed(1067) + data = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64) + raster = xr.DataArray(data, dims=['y', 'x']) + result = rechunk_no_shuffle(raster, target_mb=1) + np.testing.assert_array_equal(raster.values, result.values) + + +def test_preserves_coords_and_attrs(): + """Coordinates and attributes must survive rechunking.""" + data = da.zeros((256, 256), chunks=64, dtype=np.float32) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.arange(256), 'x': np.arange(256)}, + attrs={'crs': 'EPSG:4326'}, + ) + result = rechunk_no_shuffle(raster, target_mb=1) + assert result.attrs == raster.attrs + xr.testing.assert_equal(result.coords.to_dataset(), raster.coords.to_dataset()) + + +# --------------------------------------------------------------------------- +# Non-dask passthrough +# --------------------------------------------------------------------------- + +def test_numpy_passthrough(): + """Numpy-backed DataArray should be returned unchanged.""" + raster = xr.DataArray(np.zeros((100, 100)), dims=['y', 'x']) + result = rechunk_no_shuffle(raster, target_mb=1) + assert result is raster + + +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- + +def test_rejects_non_dataarray(): + with pytest.raises(TypeError, match="expected xr.DataArray"): + rechunk_no_shuffle(np.zeros((10, 10))) + + +def test_rejects_nonpositive_target(): + raster = _make_dask_raster() + with pytest.raises(ValueError, match="target_mb must be > 0"): + rechunk_no_shuffle(raster, target_mb=0) + with pytest.raises(ValueError, match="target_mb must be > 0"): + rechunk_no_shuffle(raster, target_mb=-1) + + +# --------------------------------------------------------------------------- +# Accessor integration +# --------------------------------------------------------------------------- + +def test_accessor(): + """The .xrs.rechunk_no_shuffle() accessor delegates correctly.""" + import xrspatial # noqa: F401 — registers accessor + raster = _make_dask_raster(shape=(1024, 1024), chunks=128) + direct = rechunk_no_shuffle(raster, target_mb=16) + via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16) + assert direct.chunks == via_accessor.chunks diff --git a/xrspatial/utils.py b/xrspatial/utils.py index 26046edd..5e72de4b 100644 --- a/xrspatial/utils.py +++ b/xrspatial/utils.py @@ -1026,3 +1026,77 @@ def _sample_windows_min_max( # numpy scalars return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float))) + + +def rechunk_no_shuffle(agg, target_mb=128): + """Rechunk a dask-backed DataArray without triggering a shuffle. + + Computes an integer multiplier per dimension so that each new chunk + is an exact multiple of the original chunk size. This lets dask + merge whole source chunks in-place instead of splitting and + recombining partial blocks (which is effectively a shuffle). + + Parameters + ---------- + agg : xr.DataArray + Input raster. If not backed by a dask array the input is + returned unchanged. + target_mb : int or float + Target chunk size in megabytes. The actual chunk size will be + the closest multiple of the source chunk that does not exceed + this target. Default 128. + + Returns + ------- + xr.DataArray + Rechunked DataArray. Coordinates and attributes are preserved. + + Raises + ------ + TypeError + If *agg* is not an ``xr.DataArray``. + ValueError + If *target_mb* is not positive. + + Examples + -------- + >>> import dask.array as da + >>> import xarray as xr + >>> arr = xr.DataArray(da.zeros((4096, 4096), chunks=256)) + >>> big = rechunk_no_shuffle(arr, target_mb=64) + >>> big.chunks # multiples of 256 + """ + if not isinstance(agg, xr.DataArray): + raise TypeError( + f"rechunk_no_shuffle(): expected xr.DataArray, " + f"got {type(agg).__name__}" + ) + if target_mb <= 0: + raise ValueError( + f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" + ) + + if not has_dask_array() or not isinstance(agg.data, da.Array): + return agg + + chunks = agg.chunks # tuple of tuples + base = tuple(c[0] for c in chunks) + + current_bytes = agg.dtype.itemsize + for b in base: + current_bytes *= b + + target_bytes = target_mb * 1024 * 1024 + + if current_bytes >= target_bytes: + return agg + + ndim = len(base) + ratio = target_bytes / current_bytes + multiplier = max(1, int(ratio ** (1.0 / ndim))) + + if multiplier <= 1: + return agg + + new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)} + return agg.chunk(new_chunks)