Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion examples/user_guide/36_Rechunk_No_Shuffle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,25 @@
"assert result is numpy_dem\n",
"print('Numpy passthrough: OK')"
]
},
{
"cell_type": "markdown",
"source": "## Dataset support\n\n`rechunk_no_shuffle` also accepts `xr.Dataset`. Each dask-backed variable\nis rechunked independently; numpy-backed variables pass through unchanged.",
"metadata": {}
},
{
"cell_type": "code",
"source": "ds = xr.Dataset({\n 'elevation': dem,\n 'slope': xr.DataArray(\n da.from_array(np.random.rand(4096, 4096).astype(np.float32), chunks=256),\n dims=['y', 'x'],\n ),\n 'mask': xr.DataArray(np.ones((4096, 4096), dtype=np.uint8), dims=['y', 'x']),\n})\n\nds_big = rechunk_no_shuffle(ds, target_mb=64)\n\nfor name in ds.data_vars:\n if hasattr(ds[name].data, 'dask'):\n print(f'{name}: {ds[name].chunks[0][0]} -> {ds_big[name].chunks[0][0]}')\n else:\n print(f'{name}: numpy (unchanged)')",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": "# Works on the Dataset accessor too\nds_big_acc = ds.xrs.rechunk_no_shuffle(target_mb=64)\nassert ds_big['elevation'].chunks == ds_big_acc['elevation'].chunks\nprint('Dataset accessor: OK')",
"metadata": {},
"execution_count": null,
"outputs": []
}
],
"metadata": {
Expand All @@ -162,4 +181,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
6 changes: 6 additions & 0 deletions xrspatial/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,9 @@ def open_geotiff(self, source, **kwargs):
y_min, y_max, x_min, x_max)
kwargs.pop('window', None)
return open_geotiff(source, window=window, **kwargs)

# ---- Chunking ----

def rechunk_no_shuffle(self, **kwargs):
from .utils import rechunk_no_shuffle
return rechunk_no_shuffle(self._obj, **kwargs)
1 change: 1 addition & 0 deletions xrspatial/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_dataset_accessor_has_expected_methods():
'proximity', 'allocation', 'direction', 'cost_distance',
'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi',
'rasterize',
'rechunk_no_shuffle',
]
for name in expected:
assert name in names, f"Missing method: {name}"
Expand Down
76 changes: 74 additions & 2 deletions xrspatial/tests/test_rechunk_no_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_numpy_passthrough():
# Input validation
# ---------------------------------------------------------------------------

def test_rejects_non_dataarray():
with pytest.raises(TypeError, match="expected xr.DataArray"):
def test_rejects_non_dataarray_or_dataset():
with pytest.raises(TypeError, match="expected xr.DataArray or xr.Dataset"):
rechunk_no_shuffle(np.zeros((10, 10)))


Expand All @@ -121,3 +121,75 @@ def test_accessor():
direct = rechunk_no_shuffle(raster, target_mb=16)
via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16)
assert direct.chunks == via_accessor.chunks


# ---------------------------------------------------------------------------
# Dataset support
# ---------------------------------------------------------------------------

def _make_dask_dataset(chunks=128):
"""Dataset with two dask variables and one numpy variable."""
dask_a = xr.DataArray(
da.zeros((512, 512), chunks=chunks, dtype=np.float32),
dims=['y', 'x'], name='a',
)
dask_b = xr.DataArray(
da.ones((512, 512), chunks=chunks, dtype=np.float64),
dims=['y', 'x'], name='b',
)
numpy_c = xr.DataArray(
np.zeros((512, 512), dtype=np.float32),
dims=['y', 'x'], name='c',
)
return xr.Dataset({'a': dask_a, 'b': dask_b, 'c': numpy_c},
attrs={'crs': 'EPSG:32610'})


def test_dataset_rechunks_all_dask_vars():
"""Both dask variables should get bigger chunks."""
ds = _make_dask_dataset(chunks=64)
result = rechunk_no_shuffle(ds, target_mb=16)
assert isinstance(result, xr.Dataset)
for name in ['a', 'b']:
orig_chunk = ds[name].chunks[0][0]
new_chunk = result[name].chunks[0][0]
assert new_chunk > orig_chunk
assert new_chunk % orig_chunk == 0


def test_dataset_numpy_var_unchanged():
"""Numpy-backed variable passes through without modification."""
ds = _make_dask_dataset()
result = rechunk_no_shuffle(ds, target_mb=16)
# 'c' is numpy-backed, should still be numpy
assert not hasattr(result['c'].data, 'dask')
np.testing.assert_array_equal(result['c'].values, ds['c'].values)


def test_dataset_preserves_attrs_and_coords():
"""Dataset attributes and coordinates survive rechunking."""
ds = _make_dask_dataset()
ds = ds.assign_coords(y=np.arange(512), x=np.arange(512))
result = rechunk_no_shuffle(ds, target_mb=16)
assert result.attrs == ds.attrs
xr.testing.assert_equal(result.coords.to_dataset(), ds.coords.to_dataset())


def test_dataset_preserves_values():
"""Data values are identical after rechunking."""
np.random.seed(1069)
arr = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64)
ds = xr.Dataset({'v': xr.DataArray(arr, dims=['y', 'x'])})
result = rechunk_no_shuffle(ds, target_mb=1)
np.testing.assert_array_equal(ds['v'].values, result['v'].values)


def test_dataset_accessor():
"""The Dataset .xrs.rechunk_no_shuffle() accessor works."""
import xrspatial # noqa: F401
ds = _make_dask_dataset(chunks=64)
direct = rechunk_no_shuffle(ds, target_mb=16)
via_accessor = ds.xrs.rechunk_no_shuffle(target_mb=16)
for name in ds.data_vars:
if hasattr(ds[name].data, 'dask'):
assert direct[name].chunks == via_accessor[name].chunks
72 changes: 42 additions & 30 deletions xrspatial/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,34 @@ def _sample_windows_min_max(
return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float)))


def _rechunk_dataarray(agg, target_bytes):
"""Rechunk a single dask-backed DataArray. Returns unchanged if not dask."""
if not has_dask_array() or not isinstance(agg.data, da.Array):
return agg

chunks = agg.chunks # tuple of tuples
base = tuple(c[0] for c in chunks)

current_bytes = agg.dtype.itemsize
for b in base:
current_bytes *= b

if current_bytes >= target_bytes:
return agg

ndim = len(base)
ratio = target_bytes / current_bytes
multiplier = max(1, int(ratio ** (1.0 / ndim)))

if multiplier <= 1:
return agg

new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)}
return agg.chunk(new_chunks)


def rechunk_no_shuffle(agg, target_mb=128):
"""Rechunk a dask-backed DataArray without triggering a shuffle.
"""Rechunk dask-backed data without triggering a shuffle.

Computes an integer multiplier per dimension so that each new chunk
is an exact multiple of the original chunk size. This lets dask
Expand All @@ -1038,23 +1064,23 @@ def rechunk_no_shuffle(agg, target_mb=128):

Parameters
----------
agg : xr.DataArray
Input raster. If not backed by a dask array the input is
returned unchanged.
agg : xr.DataArray or xr.Dataset
Input raster or collection of rasters. Non-dask variables
pass through unchanged.
target_mb : int or float
Target chunk size in megabytes. The actual chunk size will be
the closest multiple of the source chunk that does not exceed
this target. Default 128.

Returns
-------
xr.DataArray
Rechunked DataArray. Coordinates and attributes are preserved.
xr.DataArray or xr.Dataset
Rechunked object. Coordinates and attributes are preserved.

Raises
------
TypeError
If *agg* is not an ``xr.DataArray``.
If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``.
ValueError
If *target_mb* is not positive.

Expand All @@ -1066,37 +1092,23 @@ def rechunk_no_shuffle(agg, target_mb=128):
>>> big = rechunk_no_shuffle(arr, target_mb=64)
>>> big.chunks # multiples of 256
"""
if not isinstance(agg, xr.DataArray):
if not isinstance(agg, (xr.DataArray, xr.Dataset)):
raise TypeError(
f"rechunk_no_shuffle(): expected xr.DataArray, "
f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, "
f"got {type(agg).__name__}"
)
if target_mb <= 0:
raise ValueError(
f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}"
)

if not has_dask_array() or not isinstance(agg.data, da.Array):
return agg

chunks = agg.chunks # tuple of tuples
base = tuple(c[0] for c in chunks)

current_bytes = agg.dtype.itemsize
for b in base:
current_bytes *= b

target_bytes = target_mb * 1024 * 1024

if current_bytes >= target_bytes:
return agg

ndim = len(base)
ratio = target_bytes / current_bytes
multiplier = max(1, int(ratio ** (1.0 / ndim)))
if isinstance(agg, xr.DataArray):
return _rechunk_dataarray(agg, target_bytes)

if multiplier <= 1:
return agg

new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)}
return agg.chunk(new_chunks)
# Dataset: rechunk each variable independently
new_vars = {}
for name, var in agg.data_vars.items():
new_vars[name] = _rechunk_dataarray(var, target_bytes)
return agg.assign(new_vars)
Loading