diff --git a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb index f890229d..ee70c000 100644 --- a/examples/user_guide/36_Rechunk_No_Shuffle.ipynb +++ b/examples/user_guide/36_Rechunk_No_Shuffle.ipynb @@ -147,6 +147,25 @@ "assert result is numpy_dem\n", "print('Numpy passthrough: OK')" ] + }, + { + "cell_type": "markdown", + "source": "## Dataset support\n\n`rechunk_no_shuffle` also accepts `xr.Dataset`. Each dask-backed variable\nis rechunked independently; numpy-backed variables pass through unchanged.", + "metadata": {} + }, + { + "cell_type": "code", + "source": "ds = xr.Dataset({\n 'elevation': dem,\n 'slope': xr.DataArray(\n da.from_array(np.random.rand(4096, 4096).astype(np.float32), chunks=256),\n dims=['y', 'x'],\n ),\n 'mask': xr.DataArray(np.ones((4096, 4096), dtype=np.uint8), dims=['y', 'x']),\n})\n\nds_big = rechunk_no_shuffle(ds, target_mb=64)\n\nfor name in ds.data_vars:\n if hasattr(ds[name].data, 'dask'):\n print(f'{name}: {ds[name].chunks[0][0]} -> {ds_big[name].chunks[0][0]}')\n else:\n print(f'{name}: numpy (unchanged)')", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": "# Works on the Dataset accessor too\nds_big_acc = ds.xrs.rechunk_no_shuffle(target_mb=64)\nassert ds_big['elevation'].chunks == ds_big_acc['elevation'].chunks\nprint('Dataset accessor: OK')", + "metadata": {}, + "execution_count": null, + "outputs": [] } ], "metadata": { @@ -162,4 +181,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 3db3e048..0c7672dc 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -910,3 +910,9 @@ def open_geotiff(self, source, **kwargs): y_min, y_max, x_min, x_max) kwargs.pop('window', None) return open_geotiff(source, window=window, **kwargs) + + # ---- Chunking ---- + + def rechunk_no_shuffle(self, **kwargs): + from .utils import rechunk_no_shuffle + return rechunk_no_shuffle(self._obj, **kwargs) diff --git a/xrspatial/tests/test_accessor.py b/xrspatial/tests/test_accessor.py index 671c3f9f..c3173c0a 100644 --- a/xrspatial/tests/test_accessor.py +++ b/xrspatial/tests/test_accessor.py @@ -106,6 +106,7 @@ def test_dataset_accessor_has_expected_methods(): 'proximity', 'allocation', 'direction', 'cost_distance', 'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi', 'rasterize', + 'rechunk_no_shuffle', ] for name in expected: assert name in names, f"Missing method: {name}" diff --git a/xrspatial/tests/test_rechunk_no_shuffle.py b/xrspatial/tests/test_rechunk_no_shuffle.py index be6faa93..875c2751 100644 --- a/xrspatial/tests/test_rechunk_no_shuffle.py +++ b/xrspatial/tests/test_rechunk_no_shuffle.py @@ -97,8 +97,8 @@ def test_numpy_passthrough(): # Input validation # --------------------------------------------------------------------------- -def test_rejects_non_dataarray(): - with pytest.raises(TypeError, match="expected xr.DataArray"): +def test_rejects_non_dataarray_or_dataset(): + with pytest.raises(TypeError, match="expected xr.DataArray or xr.Dataset"): rechunk_no_shuffle(np.zeros((10, 10))) @@ -121,3 +121,75 @@ def test_accessor(): direct = rechunk_no_shuffle(raster, target_mb=16) via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16) assert direct.chunks == via_accessor.chunks + + +# --------------------------------------------------------------------------- +# Dataset support +# --------------------------------------------------------------------------- + +def _make_dask_dataset(chunks=128): + """Dataset with two dask variables and one numpy variable.""" + dask_a = xr.DataArray( + da.zeros((512, 512), chunks=chunks, dtype=np.float32), + dims=['y', 'x'], name='a', + ) + dask_b = xr.DataArray( + da.ones((512, 512), chunks=chunks, dtype=np.float64), + dims=['y', 'x'], name='b', + ) + numpy_c = xr.DataArray( + np.zeros((512, 512), dtype=np.float32), + dims=['y', 'x'], name='c', + ) + return xr.Dataset({'a': dask_a, 'b': dask_b, 'c': numpy_c}, + attrs={'crs': 'EPSG:32610'}) + + +def test_dataset_rechunks_all_dask_vars(): + """Both dask variables should get bigger chunks.""" + ds = _make_dask_dataset(chunks=64) + result = rechunk_no_shuffle(ds, target_mb=16) + assert isinstance(result, xr.Dataset) + for name in ['a', 'b']: + orig_chunk = ds[name].chunks[0][0] + new_chunk = result[name].chunks[0][0] + assert new_chunk > orig_chunk + assert new_chunk % orig_chunk == 0 + + +def test_dataset_numpy_var_unchanged(): + """Numpy-backed variable passes through without modification.""" + ds = _make_dask_dataset() + result = rechunk_no_shuffle(ds, target_mb=16) + # 'c' is numpy-backed, should still be numpy + assert not hasattr(result['c'].data, 'dask') + np.testing.assert_array_equal(result['c'].values, ds['c'].values) + + +def test_dataset_preserves_attrs_and_coords(): + """Dataset attributes and coordinates survive rechunking.""" + ds = _make_dask_dataset() + ds = ds.assign_coords(y=np.arange(512), x=np.arange(512)) + result = rechunk_no_shuffle(ds, target_mb=16) + assert result.attrs == ds.attrs + xr.testing.assert_equal(result.coords.to_dataset(), ds.coords.to_dataset()) + + +def test_dataset_preserves_values(): + """Data values are identical after rechunking.""" + np.random.seed(1069) + arr = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64) + ds = xr.Dataset({'v': xr.DataArray(arr, dims=['y', 'x'])}) + result = rechunk_no_shuffle(ds, target_mb=1) + np.testing.assert_array_equal(ds['v'].values, result['v'].values) + + +def test_dataset_accessor(): + """The Dataset .xrs.rechunk_no_shuffle() accessor works.""" + import xrspatial # noqa: F401 + ds = _make_dask_dataset(chunks=64) + direct = rechunk_no_shuffle(ds, target_mb=16) + via_accessor = ds.xrs.rechunk_no_shuffle(target_mb=16) + for name in ds.data_vars: + if hasattr(ds[name].data, 'dask'): + assert direct[name].chunks == via_accessor[name].chunks diff --git a/xrspatial/utils.py b/xrspatial/utils.py index 5e72de4b..155789bf 100644 --- a/xrspatial/utils.py +++ b/xrspatial/utils.py @@ -1028,8 +1028,34 @@ def _sample_windows_min_max( return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float))) +def _rechunk_dataarray(agg, target_bytes): + """Rechunk a single dask-backed DataArray. Returns unchanged if not dask.""" + if not has_dask_array() or not isinstance(agg.data, da.Array): + return agg + + chunks = agg.chunks # tuple of tuples + base = tuple(c[0] for c in chunks) + + current_bytes = agg.dtype.itemsize + for b in base: + current_bytes *= b + + if current_bytes >= target_bytes: + return agg + + ndim = len(base) + ratio = target_bytes / current_bytes + multiplier = max(1, int(ratio ** (1.0 / ndim))) + + if multiplier <= 1: + return agg + + new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)} + return agg.chunk(new_chunks) + + def rechunk_no_shuffle(agg, target_mb=128): - """Rechunk a dask-backed DataArray without triggering a shuffle. + """Rechunk dask-backed data without triggering a shuffle. Computes an integer multiplier per dimension so that each new chunk is an exact multiple of the original chunk size. This lets dask @@ -1038,9 +1064,9 @@ def rechunk_no_shuffle(agg, target_mb=128): Parameters ---------- - agg : xr.DataArray - Input raster. If not backed by a dask array the input is - returned unchanged. + agg : xr.DataArray or xr.Dataset + Input raster or collection of rasters. Non-dask variables + pass through unchanged. target_mb : int or float Target chunk size in megabytes. The actual chunk size will be the closest multiple of the source chunk that does not exceed @@ -1048,13 +1074,13 @@ def rechunk_no_shuffle(agg, target_mb=128): Returns ------- - xr.DataArray - Rechunked DataArray. Coordinates and attributes are preserved. + xr.DataArray or xr.Dataset + Rechunked object. Coordinates and attributes are preserved. Raises ------ TypeError - If *agg* is not an ``xr.DataArray``. + If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``. ValueError If *target_mb* is not positive. @@ -1066,9 +1092,9 @@ def rechunk_no_shuffle(agg, target_mb=128): >>> big = rechunk_no_shuffle(arr, target_mb=64) >>> big.chunks # multiples of 256 """ - if not isinstance(agg, xr.DataArray): + if not isinstance(agg, (xr.DataArray, xr.Dataset)): raise TypeError( - f"rechunk_no_shuffle(): expected xr.DataArray, " + f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, " f"got {type(agg).__name__}" ) if target_mb <= 0: @@ -1076,27 +1102,13 @@ def rechunk_no_shuffle(agg, target_mb=128): f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" ) - if not has_dask_array() or not isinstance(agg.data, da.Array): - return agg - - chunks = agg.chunks # tuple of tuples - base = tuple(c[0] for c in chunks) - - current_bytes = agg.dtype.itemsize - for b in base: - current_bytes *= b - target_bytes = target_mb * 1024 * 1024 - if current_bytes >= target_bytes: - return agg - - ndim = len(base) - ratio = target_bytes / current_bytes - multiplier = max(1, int(ratio ** (1.0 / ndim))) + if isinstance(agg, xr.DataArray): + return _rechunk_dataarray(agg, target_bytes) - if multiplier <= 1: - return agg - - new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)} - return agg.chunk(new_chunks) + # Dataset: rechunk each variable independently + new_vars = {} + for name, var in agg.data_vars.items(): + new_vars[name] = _rechunk_dataarray(var, target_bytes) + return agg.assign(new_vars)