diff --git a/docs/superpowers/plans/2026-03-24-dask-graph-utilities.md b/docs/superpowers/plans/2026-03-24-dask-graph-utilities.md new file mode 100644 index 00000000..759211c1 --- /dev/null +++ b/docs/superpowers/plans/2026-03-24-dask-graph-utilities.md @@ -0,0 +1,1057 @@ +# Dask Graph Utilities Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add `fused_overlap` and `multi_overlap` utilities that reduce dask graph size by fusing multiple `map_overlap` calls into single passes. + +**Architecture:** Two standalone functions in `xrspatial/utils.py`. `fused_overlap` composes sequential overlap stages into one `map_overlap` call with summed depth. `multi_overlap` runs a multi-output kernel via `da.overlap.overlap()` + `da.map_blocks()` with `new_axis=0`. Both are exposed on the DataArray `.xrs` accessor. + +**Tech Stack:** dask.array (map_overlap, overlap.overlap, map_blocks), numpy, xarray + +**Spec:** `docs/superpowers/specs/2026-03-24-dask-graph-utilities-design.md` + +--- + +## File structure + +| File | Role | +|------|------| +| `xrspatial/utils.py` | Add `_normalize_depth`, `_pad_nan`, `fused_overlap`, `multi_overlap` after existing `rechunk_no_shuffle` (line 1102) | +| `xrspatial/accessor.py` | Add `fused_overlap`, `multi_overlap` to DataArray accessor after `rechunk_no_shuffle` (line 501) | +| `xrspatial/__init__.py` | Export both functions | +| `xrspatial/tests/test_fused_overlap.py` | New test file | +| `xrspatial/tests/test_multi_overlap.py` | New test file | +| `xrspatial/tests/test_accessor.py` | Add to expected methods list (line 93) | +| `docs/source/reference/utilities.rst` | Add API entries | +| `README.md` | Add rows to Utilities table (line 289) | +| `examples/user_guide/37_Fused_Overlap.ipynb` | New notebook | + +--- + +### Task 1: `_normalize_depth` and `_pad_nan` helpers + +**Files:** +- Modify: `xrspatial/utils.py` (append after line 1102) +- Test: `xrspatial/tests/test_fused_overlap.py` (new) + +- [ ] **Step 1: Write failing tests for `_normalize_depth`** + +Create `xrspatial/tests/test_fused_overlap.py`: + +```python +"""Tests for fused_overlap and helpers.""" + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.utils import _normalize_depth, _pad_nan + + +class TestNormalizeDepth: + def test_int_input(self): + assert _normalize_depth(2, ndim=2) == {0: 2, 1: 2} + + def test_tuple_input(self): + assert _normalize_depth((3, 1), ndim=2) == {0: 3, 1: 1} + + def test_dict_input(self): + assert _normalize_depth({0: 2, 1: 4}, ndim=2) == {0: 2, 1: 4} + + def test_dict_missing_axis_raises(self): + with pytest.raises(ValueError, match="missing axes"): + _normalize_depth({0: 1}, ndim=2) + + def test_dict_extra_axis_raises(self): + with pytest.raises(ValueError, match="extra axes"): + _normalize_depth({0: 1, 1: 1, 2: 1}, ndim=2) + + def test_negative_depth_raises(self): + with pytest.raises(ValueError, match="non-negative"): + _normalize_depth(-1, ndim=2) + + def test_tuple_wrong_length_raises(self): + with pytest.raises(ValueError, match="length"): + _normalize_depth((1, 2, 3), ndim=2) + + +class TestPadNan: + def test_2d_pads_with_nan(self): + data = np.ones((4, 4), dtype=np.float32) + result = _pad_nan(data, depth=(1, 1)) + assert result.shape == (6, 6) + assert np.isnan(result[0, 0]) + np.testing.assert_array_equal(result[1:-1, 1:-1], data) + + def test_asymmetric_depth(self): + data = np.ones((4, 4), dtype=np.float32) + result = _pad_nan(data, depth=(2, 1)) + assert result.shape == (8, 6) + + def test_integer_dtype_promotes_to_float(self): + data = np.ones((4, 4), dtype=np.int32) + result = _pad_nan(data, depth=(1, 1)) + assert np.issubdtype(result.dtype, np.floating) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest xrspatial/tests/test_fused_overlap.py -v` +Expected: FAIL with ImportError + +- [ ] **Step 3: Implement `_normalize_depth` and `_pad_nan`** + +Append to `xrspatial/utils.py` after line 1102: + +```python +def _normalize_depth(depth, ndim): + """Normalize depth to a dict {axis: int}. + + Accepts int, tuple, or dict. Validates completeness and + non-negativity. + """ + if isinstance(depth, dict): + expected = set(range(ndim)) + got = set(depth.keys()) + missing = expected - got + extra = got - expected + if missing: + raise ValueError( + f"_normalize_depth: missing axes {sorted(missing)} " + f"for ndim={ndim}" + ) + if extra: + raise ValueError( + f"_normalize_depth: extra axes {sorted(extra)} " + f"for ndim={ndim}" + ) + for v in depth.values(): + if v < 0: + raise ValueError( + f"_normalize_depth: depth must be non-negative, got {v}" + ) + return dict(depth) + + if isinstance(depth, int): + if depth < 0: + raise ValueError( + f"_normalize_depth: depth must be non-negative, got {depth}" + ) + return {ax: depth for ax in range(ndim)} + + if isinstance(depth, tuple): + if len(depth) != ndim: + raise ValueError( + f"_normalize_depth: tuple length {len(depth)} != ndim {ndim}" + ) + for v in depth: + if v < 0: + raise ValueError( + f"_normalize_depth: depth must be non-negative, got {v}" + ) + return {ax: d for ax, d in enumerate(depth)} + + raise TypeError( + f"_normalize_depth: expected int, tuple, or dict, got {type(depth).__name__}" + ) + + +def _pad_nan(data, depth): + """Pad a 2-D numpy or cupy array with NaN on each side. + + Parameters + ---------- + data : numpy or cupy array + depth : tuple of int + ``(d0, d1)`` cells to pad per axis. + """ + pad_width = tuple((d, d) for d in depth) + if is_cupy_array(data): + if np.issubdtype(data.dtype, np.integer): + data = data.astype(cupy.float64) + out = cupy.pad(data, pad_width, mode='constant', + constant_values=np.nan) + else: + # Promote integer dtypes so NaN fill works + if np.issubdtype(data.dtype, np.integer): + data = data.astype(np.float64) + out = np.pad(data, pad_width, mode='constant', + constant_values=np.nan) + return out +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `pytest xrspatial/tests/test_fused_overlap.py -v` +Expected: all 10 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add xrspatial/utils.py xrspatial/tests/test_fused_overlap.py +git commit -m "Add _normalize_depth and _pad_nan helpers" +``` + +--- + +### Task 2: `fused_overlap` implementation + +**Files:** +- Modify: `xrspatial/utils.py` (append after `_pad_nan`) +- Modify: `xrspatial/tests/test_fused_overlap.py` (add tests) + +- [ ] **Step 1: Write failing tests for `fused_overlap`** + +Append to `xrspatial/tests/test_fused_overlap.py`: + +```python +da = pytest.importorskip("dask.array") + + +def _increment_interior(chunk): + """Stage func: adds 1 to every cell. Returns interior only.""" + # depth=1 means chunk is (H+2, W+2), interior is (H, W) + return chunk[1:-1, 1:-1] + 1 + + +def _double_interior(chunk): + """Stage func: doubles every cell. Returns interior only.""" + return chunk[1:-1, 1:-1] * 2 + + +def _make_dask_raster(shape=(64, 64), chunks=16, dtype=np.float32): + data = da.from_array( + np.random.RandomState(42).rand(*shape).astype(dtype), chunks=chunks + ) + return xr.DataArray(data, dims=['y', 'x']) + + +class TestFusedOverlapDask: + def test_single_stage_matches_map_overlap(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + + fused = fused_overlap(raster, (_increment_interior, 1)) + + # Sequential reference + ref = raster.data.map_overlap( + _increment_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + np.testing.assert_array_equal(fused.values, ref.compute()) + + def test_two_stages_match_sequential(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + + fused = fused_overlap( + raster, + (_increment_interior, 1), + (_double_interior, 1), + ) + + # Sequential reference + step1 = raster.data.map_overlap( + _increment_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + ref = step1.map_overlap( + _double_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + np.testing.assert_array_equal(fused.values, ref.compute()) + + def test_three_stages(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + + fused = fused_overlap( + raster, + (_increment_interior, 1), + (_double_interior, 1), + (_increment_interior, 1), + ) + + step1 = raster.data.map_overlap( + _increment_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + step2 = step1.map_overlap( + _double_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + ref = step2.map_overlap( + _increment_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + np.testing.assert_array_equal(fused.values, ref.compute()) + + def test_nonsquare_depth(self): + from xrspatial.utils import fused_overlap + + def _stage_2_1(chunk): + return chunk[2:-2, 1:-1] + 1 + + raster = _make_dask_raster(shape=(64, 64), chunks=32) + fused = fused_overlap(raster, (_stage_2_1, (2, 1))) + + ref = raster.data.map_overlap( + _stage_2_1, depth=(2, 1), boundary=np.nan, + meta=np.array(()), + ) + np.testing.assert_array_equal(fused.values, ref.compute()) + + def test_returns_dataarray(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + result = fused_overlap(raster, (_increment_interior, 1)) + assert isinstance(result, xr.DataArray) + + def test_fewer_graph_layers_than_sequential(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + + fused = fused_overlap( + raster, + (_increment_interior, 1), + (_double_interior, 1), + ) + + step1 = raster.data.map_overlap( + _increment_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + sequential = step1.map_overlap( + _double_interior, depth=1, boundary=np.nan, + meta=np.array(()), + ) + assert len(dict(fused.data.__dask_graph__())) < len( + dict(sequential.__dask_graph__()) + ) + + +class TestFusedOverlapNumpy: + def test_numpy_fallback_matches_dask(self): + from xrspatial.utils import fused_overlap + np_raster = xr.DataArray( + np.random.RandomState(42).rand(64, 64).astype(np.float32), + dims=['y', 'x'], + ) + dask_raster = np_raster.chunk(16) + + np_result = fused_overlap( + np_raster, + (_increment_interior, 1), + (_double_interior, 1), + ) + dask_result = fused_overlap( + dask_raster, + (_increment_interior, 1), + (_double_interior, 1), + ) + # Compare interior (edges may differ due to NaN propagation, + # but interior cells should match) + np.testing.assert_array_equal( + np_result.values[2:-2, 2:-2], + dask_result.values[2:-2, 2:-2], + ) + + +class TestFusedOverlapValidation: + def test_rejects_non_nan_boundary(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + with pytest.raises(ValueError, match="boundary.*nan"): + fused_overlap(raster, (_increment_interior, 1), boundary='nearest') + + def test_rejects_empty_stages(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + with pytest.raises(ValueError, match="at least one stage"): + fused_overlap(raster) + + def test_rejects_non_dataarray(self): + from xrspatial.utils import fused_overlap + with pytest.raises(TypeError): + fused_overlap(np.zeros((10, 10)), (_increment_interior, 1)) + + def test_rejects_chunks_smaller_than_total_depth(self): + from xrspatial.utils import fused_overlap + raster = _make_dask_raster(shape=(32, 32), chunks=4) + # total_depth = 5, chunks = 4 + def _big_depth(chunk): + return chunk[5:-5, 5:-5] + 1 + with pytest.raises(ValueError, match="[Cc]hunk size"): + fused_overlap(raster, (_big_depth, 5)) + + def test_small_chunks_barely_above_total_depth(self): + """Chunks just barely larger than total_depth should work.""" + from xrspatial.utils import fused_overlap + # chunks=6, total_depth=2 (two stages of depth 1) + raster = _make_dask_raster(shape=(24, 24), chunks=6) + result = fused_overlap( + raster, + (_increment_interior, 1), + (_double_interior, 1), + ) + assert result.shape == (24, 24) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest xrspatial/tests/test_fused_overlap.py::TestFusedOverlapDask -v` +Expected: FAIL with ImportError (`fused_overlap` not found) + +- [ ] **Step 3: Implement `fused_overlap`** + +Append to `xrspatial/utils.py` after `_pad_nan`: + +```python +def fused_overlap(agg, *stages, boundary='nan'): + """Run multiple overlap operations in a single map_overlap call. + + Each stage is a ``(func, depth)`` pair. ``func`` receives a padded + chunk and returns the unpadded interior result. Stages are fused + into one ``map_overlap`` call with the sum of all depths, producing + one blockwise graph layer instead of N. + + Parameters + ---------- + agg : xr.DataArray + Input raster. + *stages : tuple of (callable, depth) + Each ``func`` takes array ``(H+2*d, W+2*d)`` -> ``(H, W)``. + ``depth`` is int, tuple, or dict. + boundary : str + Must be ``'nan'``. + + Returns + ------- + xr.DataArray + """ + if not isinstance(agg, xr.DataArray): + raise TypeError( + f"fused_overlap(): expected xr.DataArray, " + f"got {type(agg).__name__}" + ) + if not stages: + raise ValueError("fused_overlap(): need at least one stage") + if boundary != 'nan': + raise ValueError( + f"fused_overlap(): boundary must be 'nan', got {boundary!r}" + ) + + ndim = agg.ndim + + # Normalize and sum depths + stage_depths = [_normalize_depth(d, ndim) for _, d in stages] + total_depth = {ax: sum(sd[ax] for sd in stage_depths) + for ax in range(ndim)} + + # --- non-dask path --- + if not has_dask_array() or not isinstance(agg.data, da.Array): + result = agg.data + for i, (func, _) in enumerate(stages): + depth_tuple = tuple(stage_depths[i][ax] for ax in range(ndim)) + padded = _pad_nan(result, depth_tuple) + result = func(padded) + return agg.copy(data=result) + + # --- dask path --- + # Validate chunk sizes + for ax, d in total_depth.items(): + for cs in agg.chunks[ax]: + if cs < d: + raise ValueError( + f"Chunk size {cs} on axis {ax} is smaller than " + f"total depth {d}. Rechunk first." + ) + + funcs = [f for f, _ in stages] + + def _fused_wrapper(block): + result = block + for func in funcs: + result = func(result) + # result is now the interior; it still has enough valid + # overlap for the remaining stages + # re-pad to original block shape so map_overlap can crop + pad_width = tuple((total_depth[ax], total_depth[ax]) + for ax in range(block.ndim)) + if is_cupy_array(result): + padded = cupy.pad(result, pad_width, mode='constant', + constant_values=np.nan) + else: + padded = np.pad(result, pad_width, mode='constant', + constant_values=np.nan) + return padded + + out = agg.data.map_overlap( + _fused_wrapper, + depth=total_depth, + boundary=np.nan, + meta=np.array(()), + ) + + return agg.copy(data=out) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `pytest xrspatial/tests/test_fused_overlap.py -v` +Expected: all tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add xrspatial/utils.py xrspatial/tests/test_fused_overlap.py +git commit -m "Add fused_overlap utility" +``` + +--- + +### Task 3: `multi_overlap` implementation + +**Files:** +- Modify: `xrspatial/utils.py` (append after `fused_overlap`) +- Create: `xrspatial/tests/test_multi_overlap.py` + +- [ ] **Step 1: Write failing tests for `multi_overlap`** + +Create `xrspatial/tests/test_multi_overlap.py`: + +```python +"""Tests for multi_overlap.""" + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.utils import multi_overlap + +da = pytest.importorskip("dask.array") + + +def _triple_kernel(chunk): + """Return 3 bands from a padded (H+2, W+2) chunk.""" + interior = chunk[1:-1, 1:-1] + return np.stack([interior + 1, interior * 2, interior - 1], axis=0) + + +def _make_dask_raster(shape=(64, 64), chunks=16, dtype=np.float32): + data = da.from_array( + np.random.RandomState(99).rand(*shape).astype(dtype), chunks=chunks + ) + return xr.DataArray( + data, dims=['y', 'x'], + coords={'y': np.arange(shape[0]), 'x': np.arange(shape[1])}, + attrs={'crs': 'EPSG:4326'}, + ) + + +class TestMultiOverlapDask: + def test_matches_sequential_stack(self): + raster = _make_dask_raster() + + multi = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + + # Sequential reference: run kernel 3 times extracting one band each + def _band_i(chunk, i=0): + return _triple_kernel(chunk)[i] + + bands = [] + for i in range(3): + from functools import partial + b = raster.data.map_overlap( + partial(_band_i, i=i), depth=1, boundary=np.nan, + meta=np.array(()), + ) + bands.append(b) + ref = da.stack(bands, axis=0).compute() + + np.testing.assert_array_equal(multi.values, ref) + + def test_output_shape(self): + raster = _make_dask_raster(shape=(32, 32), chunks=16) + result = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + assert result.shape == (3, 32, 32) + + def test_returns_dataarray_with_band_dim(self): + raster = _make_dask_raster() + result = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + assert isinstance(result, xr.DataArray) + assert result.dims[0] == 'band' + assert result.dims[1] == 'y' + assert result.dims[2] == 'x' + + def test_preserves_coords_and_attrs(self): + raster = _make_dask_raster() + result = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + assert result.attrs == raster.attrs + xr.testing.assert_equal( + result.coords['x'], raster.coords['x'] + ) + + def test_explicit_dtype(self): + raster = _make_dask_raster() + result = multi_overlap( + raster, _triple_kernel, n_outputs=3, depth=1, + dtype=np.float64, + ) + assert result.dtype == np.float64 + + def test_fewer_graph_tasks_than_sequential(self): + raster = _make_dask_raster() + multi = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + + from functools import partial + def _band_i(chunk, i=0): + return _triple_kernel(chunk)[i] + + bands = [] + for i in range(3): + b = raster.data.map_overlap( + partial(_band_i, i=i), depth=1, boundary=np.nan, + meta=np.array(()), + ) + bands.append(b) + sequential = da.stack(bands, axis=0) + + assert len(dict(multi.data.__dask_graph__())) < len( + dict(sequential.__dask_graph__()) + ) + + def test_single_output_matches_map_overlap(self): + def _single_kernel(chunk): + return (chunk[1:-1, 1:-1] + 1)[np.newaxis, :] + + raster = _make_dask_raster() + multi = multi_overlap(raster, _single_kernel, n_outputs=1, depth=1) + + def _ref_func(chunk): + return chunk[1:-1, 1:-1] + 1 + ref = raster.data.map_overlap( + _ref_func, depth=1, boundary=np.nan, meta=np.array(()), + ) + np.testing.assert_array_equal(multi.values[0], ref.compute()) + + def test_dtype_inference_defaults_to_input(self): + raster = _make_dask_raster(dtype=np.float32) + result = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + assert result.dtype == np.float32 + + def test_non_nan_boundary(self): + """multi_overlap supports all boundary modes.""" + raster = _make_dask_raster() + result = multi_overlap( + raster, _triple_kernel, n_outputs=3, depth=1, + boundary='nearest', + ) + assert result.shape == (3, 64, 64) + assert not np.any(np.isnan(result.values)) + + +class TestMultiOverlapNumpy: + def test_numpy_fallback(self): + raster = xr.DataArray( + np.random.RandomState(99).rand(32, 32).astype(np.float32), + dims=['y', 'x'], + ) + result = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + assert isinstance(result, xr.DataArray) + assert result.shape == (3, 32, 32) + + +class TestMultiOverlapValidation: + def test_rejects_non_2d(self): + raster = xr.DataArray(da.zeros((4, 32, 32), chunks=16), dims=['z', 'y', 'x']) + with pytest.raises(ValueError, match="2-D"): + multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + + def test_rejects_n_outputs_zero(self): + raster = _make_dask_raster() + with pytest.raises(ValueError, match="n_outputs.*>= 1"): + multi_overlap(raster, _triple_kernel, n_outputs=0, depth=1) + + def test_rejects_depth_zero(self): + raster = _make_dask_raster() + with pytest.raises(ValueError, match="depth.*>= 1"): + multi_overlap(raster, _triple_kernel, n_outputs=3, depth=0) + + def test_rejects_chunks_smaller_than_depth(self): + raster = _make_dask_raster(shape=(32, 32), chunks=4) + with pytest.raises(ValueError, match="[Cc]hunk size"): + multi_overlap(raster, _triple_kernel, n_outputs=3, depth=5) + + def test_rejects_non_dataarray(self): + with pytest.raises(TypeError): + multi_overlap(np.zeros((10, 10)), _triple_kernel, 3, 1) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `pytest xrspatial/tests/test_multi_overlap.py -v` +Expected: FAIL with ImportError + +- [ ] **Step 3: Implement `multi_overlap`** + +Append to `xrspatial/utils.py` after `fused_overlap`: + +```python +def multi_overlap(agg, func, n_outputs, depth, boundary='nan', dtype=None): + """Run a multi-output kernel via a single overlap + map_blocks call. + + ``func`` receives a padded 2-D chunk and returns + ``(n_outputs, H, W)`` — the unpadded interior for each output band. + The result is a 3-D DataArray with a leading ``band`` dimension. + + Parameters + ---------- + agg : xr.DataArray + 2-D input raster. + func : callable + ``(H+2*dy, W+2*dx) -> (n_outputs, H, W)`` + n_outputs : int + Number of output bands (>= 1). + depth : int or tuple of int + Per-axis overlap (>= 1 on each axis). + boundary : str + Boundary mode: 'nan', 'nearest', 'reflect', or 'wrap'. + dtype : numpy dtype, optional + Output dtype. Defaults to input dtype. + + Returns + ------- + xr.DataArray + Shape ``(n_outputs, H, W)`` with ``band`` leading dimension. + """ + if not isinstance(agg, xr.DataArray): + raise TypeError( + f"multi_overlap(): expected xr.DataArray, " + f"got {type(agg).__name__}" + ) + if agg.ndim != 2: + raise ValueError( + f"multi_overlap(): input must be 2-D, got {agg.ndim}-D" + ) + if n_outputs < 1: + raise ValueError( + f"multi_overlap(): n_outputs must be >= 1, got {n_outputs}" + ) + + _validate_boundary(boundary) + + depth_dict = _normalize_depth(depth, agg.ndim) + for ax, d in depth_dict.items(): + if d < 1: + raise ValueError( + f"multi_overlap(): depth must be >= 1, got {d} on axis {ax}" + ) + + dtype = dtype or agg.dtype + + # --- non-dask path --- + if not has_dask_array() or not isinstance(agg.data, da.Array): + if boundary == 'nan': + depth_tuple = tuple(depth_dict[ax] for ax in range(agg.ndim)) + padded = _pad_nan(agg.data, depth_tuple) + else: + depth_tuple = tuple(depth_dict[ax] for ax in range(agg.ndim)) + padded = _pad_array(agg.data, depth_tuple, boundary) + result_data = func(padded).astype(dtype) + return xr.DataArray( + result_data, + dims=['band'] + list(agg.dims), + coords=agg.coords, + attrs=agg.attrs, + ) + + # --- dask path --- + import dask.array.overlap as _dask_overlap + + _validate_boundary(boundary) + boundary_val = _boundary_to_dask(boundary, is_cupy=is_cupy_backed(agg)) + + # Validate chunk sizes + for ax, d in depth_dict.items(): + for cs in agg.chunks[ax]: + if cs < d: + raise ValueError( + f"Chunk size {cs} on axis {ax} is smaller than " + f"depth {d}. Rechunk first." + ) + + # Step 1: pad with overlap + padded = _dask_overlap.overlap( + agg.data, depth=depth_dict, boundary=boundary_val + ) + + # Step 2: map_blocks — func returns (n_outputs, H, W) per block + out = da.map_blocks( + func, + padded, + dtype=dtype, + new_axis=0, + chunks=((n_outputs,),) + agg.data.chunks, + ) + + return xr.DataArray( + out, + dims=['band'] + list(agg.dims), + coords=agg.coords, + attrs=agg.attrs, + ) +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `pytest xrspatial/tests/test_multi_overlap.py -v` +Expected: all tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add xrspatial/utils.py xrspatial/tests/test_multi_overlap.py +git commit -m "Add multi_overlap utility" +``` + +--- + +### Task 4: Accessor and exports + +**Files:** +- Modify: `xrspatial/accessor.py` (line 501, after `rechunk_no_shuffle`) +- Modify: `xrspatial/__init__.py` (after `rechunk_no_shuffle` import) +- Modify: `xrspatial/tests/test_accessor.py` (line 93, expected methods list) + +- [ ] **Step 1: Write failing accessor tests** + +Append to `xrspatial/tests/test_fused_overlap.py`: + +```python +class TestFusedOverlapAccessor: + def test_accessor_delegates(self): + import xrspatial # noqa: F401 + from xrspatial.utils import fused_overlap + raster = _make_dask_raster() + direct = fused_overlap(raster, (_increment_interior, 1)) + via_acc = raster.xrs.fused_overlap((_increment_interior, 1)) + np.testing.assert_array_equal(direct.values, via_acc.values) +``` + +Append to `xrspatial/tests/test_multi_overlap.py`: + +```python +class TestMultiOverlapAccessor: + def test_accessor_delegates(self): + import xrspatial # noqa: F401 + raster = _make_dask_raster() + direct = multi_overlap(raster, _triple_kernel, n_outputs=3, depth=1) + via_acc = raster.xrs.multi_overlap(_triple_kernel, n_outputs=3, depth=1) + np.testing.assert_array_equal(direct.values, via_acc.values) +``` + +- [ ] **Step 2: Run accessor tests to verify they fail** + +Run: `pytest xrspatial/tests/test_fused_overlap.py::TestFusedOverlapAccessor -v` +Expected: FAIL (method not found) + +- [ ] **Step 3: Add accessor methods** + +In `xrspatial/accessor.py`, after the `rechunk_no_shuffle` method (line 501), add: + +```python + def fused_overlap(self, *stages, **kwargs): + from .utils import fused_overlap + return fused_overlap(self._obj, *stages, **kwargs) + + def multi_overlap(self, func, n_outputs, **kwargs): + from .utils import multi_overlap + return multi_overlap(self._obj, func, n_outputs, **kwargs) +``` + +- [ ] **Step 4: Add exports to `__init__.py`** + +After the existing `rechunk_no_shuffle` import line, add: + +```python +from xrspatial.utils import fused_overlap # noqa +from xrspatial.utils import multi_overlap # noqa +``` + +- [ ] **Step 5: Update expected methods list in test_accessor.py** + +In `xrspatial/tests/test_accessor.py`, find the `test_dataarray_accessor_has_expected_methods` function's expected list (around line 93). Add `'fused_overlap'` and `'multi_overlap'` to the list. + +- [ ] **Step 6: Run all tests** + +Run: `pytest xrspatial/tests/test_fused_overlap.py xrspatial/tests/test_multi_overlap.py xrspatial/tests/test_accessor.py -v` +Expected: all PASS + +- [ ] **Step 7: Commit** + +```bash +git add xrspatial/accessor.py xrspatial/__init__.py xrspatial/tests/test_accessor.py xrspatial/tests/test_fused_overlap.py xrspatial/tests/test_multi_overlap.py +git commit -m "Add fused_overlap and multi_overlap to accessor and exports" +``` + +--- + +### Task 5: Documentation and README + +**Files:** +- Modify: `docs/source/reference/utilities.rst` +- Modify: `README.md` (line 289, Utilities table) + +- [ ] **Step 1: Add API entries to utilities.rst** + +In `docs/source/reference/utilities.rst`, before the `Rechunking` section, add: + +```rst +Overlap Fusion +============== +.. autosummary:: + :toctree: _autosummary + + xrspatial.utils.fused_overlap + xrspatial.utils.multi_overlap +``` + +- [ ] **Step 2: Add README rows** + +In `README.md`, in the Utilities table (after the `rechunk_no_shuffle` row around line 289), add: + +```markdown +| [fused_overlap](xrspatial/utils.py) | Fuse sequential map_overlap calls into a single pass | Custom | 🔄 | ✅️ | 🔄 | ✅️ | +| [multi_overlap](xrspatial/utils.py) | Run multi-output kernel in a single overlap pass | Custom | 🔄 | ✅️ | 🔄 | ✅️ | +``` + +- [ ] **Step 3: Commit** + +```bash +git add docs/source/reference/utilities.rst README.md +git commit -m "Add fused_overlap and multi_overlap to docs and README" +``` + +--- + +### Task 6: User guide notebook + +**Files:** +- Create: `examples/user_guide/37_Fused_Overlap.ipynb` + +- [ ] **Step 1: Create the notebook** + +Create `examples/user_guide/37_Fused_Overlap.ipynb` with these cells: + +**Cell 1 (markdown):** +``` +# Fusing Overlap Operations + +When you chain spatial operations like erode then dilate, each one adds a +blockwise layer to the dask graph. `fused_overlap` runs them in a single +`map_overlap` call, and `multi_overlap` does the same for kernels that +produce multiple output bands. +``` + +**Cell 2 (code):** +```python +import numpy as np +import dask.array as da +import xarray as xr +import xrspatial +from xrspatial.utils import fused_overlap, multi_overlap +``` + +**Cell 3 (markdown):** +``` +## fused_overlap: chained operations in one pass + +Define two stage functions. Each takes a padded chunk and returns the +unpadded interior. +``` + +**Cell 4 (code):** +```python +def smooth_interior(chunk): + """3x3 mean filter. Takes (H+2, W+2), returns (H, W).""" + from numpy.lib.stride_tricks import sliding_window_view + windows = sliding_window_view(chunk, (3, 3)) + return np.nanmean(windows, axis=(-2, -1)) + +def threshold_interior(chunk): + """Binary threshold. Takes (H+2, W+2), returns (H, W).""" + interior = chunk[1:-1, 1:-1] + return (interior > 0.5).astype(np.float32) + +np.random.seed(42) +raw = np.random.rand(512, 512).astype(np.float32) +dem = xr.DataArray(da.from_array(raw, chunks=128), dims=['y', 'x']) +``` + +**Cell 5 (code):** +```python +# Fused: one map_overlap call +fused = fused_overlap(dem, (smooth_interior, 1), (threshold_interior, 1)) + +# Sequential: two map_overlap calls +step1 = dem.data.map_overlap(smooth_interior, depth=1, boundary=np.nan, meta=np.array(())) +sequential = step1.map_overlap(threshold_interior, depth=1, boundary=np.nan, meta=np.array(())) + +print(f'Fused graph: {len(dict(fused.data.__dask_graph__())):,} tasks') +print(f'Sequential graph: {len(dict(sequential.__dask_graph__())):,} tasks') +``` + +**Cell 6 (markdown):** +``` +## multi_overlap: N outputs in one pass +``` + +**Cell 7 (code):** +```python +def gradient_kernel(chunk): + """Compute dx and dy gradients. Takes (H+2, W+2), returns (2, H, W).""" + dx = (chunk[1:-1, 2:] - chunk[1:-1, :-2]) / 2.0 + dy = (chunk[2:, 1:-1] - chunk[:-2, 1:-1]) / 2.0 + return np.stack([dx, dy], axis=0) + +result = multi_overlap(dem, gradient_kernel, n_outputs=2, depth=1) +print(f'Output shape: {result.shape}') +print(f'Dimensions: {result.dims}') +print(f'Graph tasks: {len(dict(result.data.__dask_graph__())):,}') +``` + +**Cell 8 (code):** +```python +# Accessor syntax works too +fused_acc = dem.xrs.fused_overlap((smooth_interior, 1), (threshold_interior, 1)) +multi_acc = dem.xrs.multi_overlap(gradient_kernel, n_outputs=2, depth=1) +print('Accessor: OK') +``` + +- [ ] **Step 2: Commit** + +```bash +git add examples/user_guide/37_Fused_Overlap.ipynb +git commit -m "Add fused_overlap and multi_overlap user guide notebook" +``` + +--- + +### Task 7: Final verification + +- [ ] **Step 1: Run all new tests** + +```bash +pytest xrspatial/tests/test_fused_overlap.py xrspatial/tests/test_multi_overlap.py xrspatial/tests/test_accessor.py -v +``` + +Expected: all PASS + +- [ ] **Step 2: Run existing test suite to check for regressions** + +```bash +pytest xrspatial/tests/test_rechunk_no_shuffle.py -v +``` + +Expected: all PASS (no regressions in utils.py) diff --git a/docs/superpowers/specs/2026-03-24-dask-graph-utilities-design.md b/docs/superpowers/specs/2026-03-24-dask-graph-utilities-design.md new file mode 100644 index 00000000..b28a6ee3 --- /dev/null +++ b/docs/superpowers/specs/2026-03-24-dask-graph-utilities-design.md @@ -0,0 +1,314 @@ +# Dask Graph Utilities: fused_overlap and multi_overlap + +**Date:** 2026-03-24 +**Status:** Draft +**Issue:** TBD (to be created during implementation) + +## Problem + +Several xrspatial operations run multiple `map_overlap` passes over the same data when a single pass would produce the same result with a smaller dask graph: + +- **Morphological opening/closing** chains erode + dilate as two separate `map_overlap` calls (2 blockwise layers). +- **Flow direction MFD** runs the same 3x3 kernel 8 times to extract 8 output bands, then stacks them (8 blockwise layers + 1 stack). +- **GLCM texture** does the same per-metric extraction pattern. + +Each extra `map_overlap` call adds a blockwise layer to the dask graph. For large rasters with many chunks, this inflates task counts and scheduler overhead. + +**Not in scope:** Iterative operations like diffusion (N steps of depth-1) are a poor fit for fusion because `total_depth = N`, and for large N the overlap region dominates chunk data. Those are better handled by the existing iterative approach. + +## Solution + +Two new utilities in `xrspatial/utils.py`, both exposed on the DataArray `.xrs` accessor. + +--- + +### 1. `fused_overlap` + +Fuses a sequence of overlap operations into a single `map_overlap` call with combined depth. + +**Stage function contract:** Each stage function takes a padded array and returns the **unpadded interior result only**. That is, given input of shape `(H + 2*dy, W + 2*dx)`, the function returns shape `(H, W)`. This is different from `map_overlap`'s built-in convention (same-shape return). Existing chunk functions that follow the same-shape convention need a one-line adapter: `lambda chunk: func(chunk)[dy:-dy, dx:-dx]` (where `dy`, `dx` are the per-axis depths). + +**Boundary restriction:** Only `boundary='nan'` is supported. For non-NaN boundary modes (`nearest`, `reflect`, `wrap`), the fused result would differ from sequential execution at chunk/array edges because boundary fill happens once on the original data rather than after each stage. Restricting to NaN avoids this correctness gap. NaN boundaries cover the vast majority of spatial raster operations in this codebase. + +**Signature:** + +```python +def fused_overlap(agg, *stages, boundary='nan'): + """Run multiple overlap operations in a single map_overlap call. + + Parameters + ---------- + agg : xr.DataArray + Input raster. If not dask-backed, stages are applied + sequentially with numpy/cupy padding. + *stages : tuple of (func, depth) + Each stage is a ``(callable, depth)`` pair. ``func`` takes a + padded numpy/cupy array of shape ``(H + 2*d, W + 2*d)`` and + returns the interior result of shape ``(H, W)``. ``depth`` + is an int or tuple of ints (per-axis overlap). + boundary : str + Must be ``'nan'``. + + Returns + ------- + xr.DataArray + Result of applying all stages in sequence. + + Raises + ------ + ValueError + If no stages are provided, boundary is not 'nan', or any + chunk dimension is smaller than total_depth. + """ +``` + +**Usage:** + +```python +from xrspatial.utils import fused_overlap + +# morphological opening in one pass instead of two +result = fused_overlap( + data, + (erode_interior, (1, 1)), + (dilate_interior, (1, 1)), + boundary='nan', +) + +# via accessor +result = data.xrs.fused_overlap( + (erode_interior, (1, 1)), + (dilate_interior, (1, 1)), + boundary='nan', +) +``` + +**How it works (dask path):** + +1. Normalize each stage's depth to a per-axis dict via `_normalize_depth`. +2. Compute `total_depth` by summing depths across all stages. +3. Validate that every chunk dimension exceeds `total_depth`. +4. Build a wrapper function that operates on the chunk padded with `total_depth`: + - Stage 0 receives the full `(H + 2*T, W + 2*T)` block, returns `(H + 2*(T - d0), W + 2*(T - d0))` interior. + - Stage 1 receives that `(H + 2*(T - d0), W + 2*(T - d0))` block (which has exactly `T - d0` cells of valid overlap remaining). It returns `(H + 2*(T - d0 - d1), W + 2*(T - d0 - d1))`. + - This continues until the final stage, which has exactly `d_last` overlap and returns `(H, W)`. + - The wrapper then re-pads the `(H, W)` result back to `(H + 2*T, W + 2*T)` using NaN fill so that `map_overlap` can crop its expected `total_depth`. +5. Call `data.map_overlap(wrapper, depth=total_depth, boundary=np.nan)` once. + +**Worked example (two stages, each depth 1, total_depth 2):** + +``` +map_overlap gives wrapper a chunk of shape (H+4, W+4). + +Stage 0: receives (H+4, W+4), returns interior (H+2, W+2). + - The (H+2, W+2) block has 1 cell of valid overlap on each side. + +Stage 1: receives (H+2, W+2), returns interior (H, W). + +Wrapper: pads (H, W) back to (H+4, W+4) with NaN fill. +map_overlap crops total_depth=2 from each side -> final (H, W). Correct. +``` + +**Non-dask fallback:** For each stage in sequence: pad with `np.pad(..., mode='constant', constant_values=np.nan)` (or `cupy.pad` for cupy arrays), apply `func`, take interior. Note: the existing `_pad_array` helper does not support `boundary='nan'` directly, so the implementation must use `np.pad` with constant NaN fill for this path. + +**Result:** N stages produce 1 blockwise layer instead of N. + +--- + +### 2. `multi_overlap` + +Runs a multi-output kernel via a single overlap + map_blocks call. + +**Signature:** + +```python +def multi_overlap(agg, func, n_outputs, depth, boundary='nan', dtype=None): + """Run a multi-output kernel via a single map_overlap call. + + Parameters + ---------- + agg : xr.DataArray + 2-D input raster. + func : callable + Takes a padded numpy/cupy chunk of shape + ``(H + 2*dy, W + 2*dx)`` and returns an array of shape + ``(n_outputs, H, W)`` -- the interior result per output band. + n_outputs : int + Number of output bands (must be >= 1). + depth : int or tuple of int + Per-axis overlap. Must be >= 1 on each spatial axis. + boundary : str + Boundary mode: 'nan', 'nearest', 'reflect', or 'wrap'. + dtype : numpy dtype, optional + Output dtype. If None, uses the input dtype. + + Returns + ------- + xr.DataArray + 3-D DataArray of shape ``(n_outputs, H, W)`` with a leading + ``band`` dimension. + + Raises + ------ + ValueError + If n_outputs < 1, depth < 1, input is not 2-D, or any chunk + dimension is smaller than depth. + """ +``` + +**Usage:** + +```python +from xrspatial.utils import multi_overlap + +# flow direction MFD: one pass instead of 8 +bands = multi_overlap(data, mfd_kernel, n_outputs=8, depth=(1, 1)) + +# via accessor +bands = data.xrs.multi_overlap(mfd_kernel, n_outputs=8, depth=(1, 1)) +``` + +**How it works (dask path):** + +```python +import dask.array.overlap as _overlap + +def multi_overlap(agg, func, n_outputs, depth, boundary='nan', dtype=None): + depth_dict = _normalize_depth(depth, agg.ndim) + boundary_val = _boundary_to_dask(boundary) + dtype = dtype or agg.dtype + + # Validate depth >= 1 on each axis + for ax, d in depth_dict.items(): + if d < 1: + raise ValueError(f"depth must be >= 1, got {d} on axis {ax}") + + # Validate chunk sizes exceed depth + for ax, d in depth_dict.items(): + for cs in agg.chunks[ax]: + if cs < d: + raise ValueError( + f"Chunk size {cs} on axis {ax} is smaller than " + f"depth {d}. Rechunk first." + ) + + # Step 1: pad the dask array with overlap + padded = _overlap.overlap(agg.data, depth=depth_dict, boundary=boundary_val) + + # Step 2: map_blocks with new output axis + def _wrapper(block): + # func returns (n_outputs, H, W) from padded (H+2dy, W+2dx) + return func(block) + + out = da.map_blocks( + _wrapper, + padded, + dtype=dtype, + new_axis=0, + chunks=((n_outputs,),) + agg.data.chunks, + ) + + # Step 3: wrap in DataArray with band dimension + result = xr.DataArray( + out, + dims=['band'] + list(agg.dims), + coords=agg.coords, + attrs=agg.attrs, + ) + return result +``` + +This produces 1 overlap layer + 1 map_blocks layer = 2 layers total, versus N+1 for the current N-separate-calls + stack approach. + +**Non-dask fallback:** Pad the numpy/cupy array (using `_pad_array` for non-NaN boundaries, or `np.pad`/`cupy.pad` with `constant_values=np.nan` for NaN boundary), call `func` (returns `(n_outputs, H, W)`), wrap in DataArray. + +--- + +## Helpers + +### `_normalize_depth(depth, ndim)` + +Accepts `int`, `tuple`, or `dict` and returns a dict `{axis: int}` for all axes. Follows dask's conventions: + +- `int` -> same depth on all axes +- `tuple` -> one depth per axis +- `dict` -> passed through, validated that all axes `0..ndim-1` are present, all values are non-negative ints, and no extra axes exist + +--- + +## Accessor integration + +Both functions go on `XrsSpatialDataArrayAccessor` only. Not on the Dataset accessor -- these are chunk-level operations that don't generalize to "apply to every variable." + +```python +# DataArray accessor +def fused_overlap(self, *stages, **kwargs): + from .utils import fused_overlap + return fused_overlap(self._obj, *stages, **kwargs) + +def multi_overlap(self, func, n_outputs, **kwargs): + from .utils import multi_overlap + return multi_overlap(self._obj, func, n_outputs, **kwargs) +``` + +## Backend support + +- **numpy:** Direct application with `_pad_array` for overlap simulation. +- **dask+numpy:** Primary target. One `map_overlap` or `overlap` + `map_blocks` call. +- **cupy:** Works if the user's `func` handles cupy arrays. `_pad_array` already supports cupy. +- **dask+cupy:** Same as dask+numpy, with `is_cupy=True` passed to `_boundary_to_dask`. + +No `ArrayTypeFunctionMapping` needed. These are dask wrappers, not spatial operations. + +## What this does NOT include + +- **Non-NaN boundaries for `fused_overlap`.** Sequential boundary fill between stages gives different results than a single outer fill. NaN is the only mode where fusion is equivalent to sequential execution. +- **Diffusion / high-iteration fusion.** When `total_depth = N` for large N, the overlap dominates chunk data. The existing iterative approach is better for those cases. Practical limit: 2-5 stages. +- **Auto-rechunk between stages.** Separate concern -- `rechunk_no_shuffle` exists for that. +- **Dataset accessor methods.** These are per-array operations. +- **Refactoring existing call sites** (MFD, GLCM, morphology) to use the new utilities. That's follow-up work after the utilities ship. + +## File changes + +| File | Change | +|------|--------| +| `xrspatial/utils.py` | Add `fused_overlap`, `multi_overlap`, `_normalize_depth` helper | +| `xrspatial/accessor.py` | Add `fused_overlap`, `multi_overlap` to DataArray accessor | +| `xrspatial/__init__.py` | Export both functions | +| `xrspatial/tests/test_fused_overlap.py` | New test file | +| `xrspatial/tests/test_multi_overlap.py` | New test file | +| `xrspatial/tests/test_accessor.py` | Add to expected methods list | +| `docs/source/reference/utilities.rst` | Add API entries | +| `README.md` | Add rows to Utilities table | +| `examples/user_guide/37_Fused_Overlap.ipynb` | New notebook | + +## Testing strategy + +**fused_overlap:** +- Single stage produces same result as plain `map_overlap` +- Two stages (erode + dilate) matches sequential `map_overlap` calls +- Three+ stages work correctly +- Depth accumulation is correct (depth 1+1 = 2 total overlap) +- Non-square depth (e.g. `(2, 1)`) works +- Small chunks (barely larger than total_depth) produce correct results +- Rejects non-NaN boundary modes with clear error +- Rejects chunks smaller than total_depth with clear error +- Non-dask fallback (numpy) works +- Non-dask fallback (cupy) works +- Accessor delegates correctly +- Input validation (empty stages, non-DataArray, etc.) + +**multi_overlap:** +- Single output matches plain `map_overlap` +- N outputs match N separate `map_overlap` calls + `da.stack` +- Output is an xr.DataArray with `band` leading dimension +- Output shape is `(n_outputs, H, W)` +- Values are identical to the sequential approach +- dtype inference works (None -> input dtype) +- Explicit dtype parameter is respected +- Rejects n_outputs < 1, depth < 1, non-2D input +- Rejects chunks smaller than depth +- Non-dask fallback (numpy) works +- Non-dask fallback (cupy) works +- Accessor delegates correctly +- Coords and attrs are preserved diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index fe14d74c..a10ac560 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -192,11 +192,10 @@ def _reproject_chunk_numpy( Called inside ``dask.delayed`` for the dask path, or directly for numpy. CRS objects are passed as WKT strings for pickle safety. """ - from ._crs_utils import _require_pyproj + from ._crs_utils import _crs_from_wkt - pyproj = _require_pyproj() - src_crs = pyproj.CRS.from_wkt(src_wkt) - tgt_crs = pyproj.CRS.from_wkt(tgt_wkt) + src_crs = _crs_from_wkt(src_wkt) + tgt_crs = _crs_from_wkt(tgt_wkt) # Try Numba fast path first (avoids creating pyproj Transformer) numba_result = None @@ -212,6 +211,8 @@ def _reproject_chunk_numpy( src_y, src_x = numba_result else: # Fallback: create pyproj Transformer (expensive) + from ._crs_utils import _require_pyproj + pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs( tgt_crs, src_crs, always_xy=True ) @@ -321,15 +322,10 @@ def _reproject_chunk_cupy( """CuPy variant of ``_reproject_chunk_numpy``.""" import cupy as cp - from ._crs_utils import _require_pyproj + from ._crs_utils import _crs_from_wkt - pyproj = _require_pyproj() - src_crs = pyproj.CRS.from_wkt(src_wkt) - tgt_crs = pyproj.CRS.from_wkt(tgt_wkt) - - transformer = pyproj.Transformer.from_crs( - tgt_crs, src_crs, always_xy=True - ) + src_crs = _crs_from_wkt(src_wkt) + tgt_crs = _crs_from_wkt(tgt_wkt) # Try CUDA transform first (keeps coordinates on-device) cuda_result = None @@ -371,6 +367,11 @@ def _reproject_chunk_cupy( _use_native_cuda = True else: # CPU fallback (Numba JIT or pyproj) + from ._crs_utils import _require_pyproj + pyproj = _require_pyproj() + transformer = pyproj.Transformer.from_crs( + tgt_crs, src_crs, always_xy=True + ) src_y, src_x = _transform_coords( transformer, chunk_bounds_tuple, chunk_shape, transform_precision, src_crs=src_crs, tgt_crs=tgt_crs, @@ -513,8 +514,6 @@ def reproject( If vertical transformation was applied, ``attrs['vertical_crs']`` records the target vertical datum. """ - from ._crs_utils import _require_pyproj - if not isinstance(raster, xr.DataArray): raise TypeError( f"reproject(): raster must be an xr.DataArray, " @@ -522,7 +521,6 @@ def reproject( ) _validate_resampling(resampling) - _require_pyproj() # Resolve CRS src_crs = _resolve_crs(source_crs) @@ -984,11 +982,10 @@ def _reproject_dask_cupy( """ import cupy as cp - from ._crs_utils import _require_pyproj + from ._crs_utils import _crs_from_wkt - pyproj = _require_pyproj() - src_crs = pyproj.CRS.from_wkt(src_wkt) - tgt_crs = pyproj.CRS.from_wkt(tgt_wkt) + src_crs = _crs_from_wkt(src_wkt) + tgt_crs = _crs_from_wkt(tgt_wkt) # Use larger chunks for GPU to amortize kernel launch overhead gpu_chunk = chunk_size or 2048 @@ -1048,6 +1045,8 @@ def _reproject_dask_cupy( c_max = int(np.ceil(c_max_val)) + 3 else: # CPU fallback for this chunk + from ._crs_utils import _require_pyproj + pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs( tgt_crs, src_crs, always_xy=True ) @@ -1120,30 +1119,44 @@ def _reproject_dask_cupy( def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt): - """Compute an approximate bounding box of the source raster in target CRS. - - Transforms corners and edge midpoints (12 points) to handle non-linear - projections. Returns ``(left, bottom, right, top)`` in target CRS, or - *None* if the transform fails (e.g. out-of-domain). - """ + """Compute approximate bounding box of source raster in target CRS.""" try: - from ._crs_utils import _require_pyproj - pyproj = _require_pyproj() - src_crs = pyproj.CRS(src_wkt) - tgt_crs = pyproj.CRS(tgt_wkt) - transformer = pyproj.Transformer.from_crs( - src_crs, tgt_crs, always_xy=True - ) + from ._crs_utils import _crs_from_wkt, _resolve_crs + try: + src_crs = _crs_from_wkt(src_wkt) + except Exception: + src_crs = _resolve_crs(src_wkt) + try: + tgt_crs = _crs_from_wkt(tgt_wkt) + except Exception: + tgt_crs = _resolve_crs(tgt_wkt) except Exception: return None sl, sb, sr, st = src_bounds mx = (sl + sr) / 2 my = (sb + st) / 2 - xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx] - ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb] + xs = np.array([sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]) + ys = np.array([sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]) + try: - tx, ty = transformer.transform(xs, ys) + from ._projections import transform_points + result = transform_points(src_crs, tgt_crs, xs, ys) + if result is not None: + tx, ty = result + tx = [v for v in tx if np.isfinite(v)] + ty = [v for v in ty if np.isfinite(v)] + if not tx or not ty: + return None + return (min(tx), min(ty), max(tx), max(ty)) + except (ImportError, ModuleNotFoundError): + pass + + try: + from ._crs_utils import _require_pyproj + pyproj = _require_pyproj() + transformer = pyproj.Transformer.from_crs(src_crs, tgt_crs, always_xy=True) + tx, ty = transformer.transform(xs.tolist(), ys.tolist()) tx = [v for v in tx if np.isfinite(v)] ty = [v for v in ty if np.isfinite(v)] if not tx or not ty: @@ -1298,14 +1311,11 @@ def merge( ------- xr.DataArray """ - from ._crs_utils import _require_pyproj - if not rasters: raise ValueError("merge(): rasters list must not be empty") _validate_resampling(resampling) _validate_strategy(strategy) - pyproj = _require_pyproj() # Resolve target CRS tgt_crs = _resolve_crs(target_crs) @@ -1485,9 +1495,8 @@ def _merge_inmemory( Detects same-CRS tiles and uses fast direct placement instead of reprojection. """ - from ._crs_utils import _require_pyproj - pyproj = _require_pyproj() - tgt_crs = pyproj.CRS.from_wkt(tgt_wkt) + from ._crs_utils import _crs_from_wkt + tgt_crs = _crs_from_wkt(tgt_wkt) arrays = [] for info in raster_infos: diff --git a/xrspatial/reproject/_crs_utils.py b/xrspatial/reproject/_crs_utils.py index fa5d699d..c4ebb511 100644 --- a/xrspatial/reproject/_crs_utils.py +++ b/xrspatial/reproject/_crs_utils.py @@ -1,36 +1,86 @@ -"""CRS detection utilities and optional pyproj import guard.""" +"""CRS detection utilities and optional pyproj import guard. + +Uses a two-tier strategy: try the lightweight built-in CRS first, +then fall back to pyproj for codes/formats not in the built-in table. +""" from __future__ import annotations +from xrspatial.reproject._lite_crs import CRS as LiteCRS -def _require_pyproj(): - """Import and return the pyproj module, raising a clear error if missing.""" + +def _try_import_pyproj(): + """Try to import pyproj, returning the module or None.""" try: import pyproj return pyproj except ImportError: + return None + + +def _require_pyproj(): + """Import and return the pyproj module, raising a clear error if missing.""" + pyproj = _try_import_pyproj() + if pyproj is None: raise ImportError( "pyproj is required for CRS reprojection. " "Install it with: pip install pyproj " "or: pip install xarray-spatial[reproject]" ) + return pyproj def _resolve_crs(crs_input): - """Convert *crs_input* to a ``pyproj.CRS`` object. - - Accepts anything ``pyproj.CRS()`` accepts: EPSG int, authority string, - WKT, proj4 dict, or an existing ``pyproj.CRS`` instance. - - Returns None if *crs_input* is None. + """Convert *crs_input* to a CRS object. + + Resolution order: + + 1. ``None`` passes through as ``None``. + 2. An existing ``LiteCRS`` instance passes through unchanged. + 3. An existing ``pyproj.CRS`` instance passes through unchanged + (only checked when pyproj is importable). + 4. Try ``LiteCRS(crs_input)`` -- covers EPSG ints and ``"EPSG:XXXX"`` + strings for codes in the built-in table. + 5. Fall back to ``pyproj.CRS(crs_input)`` -- raises ``ImportError`` + if pyproj is not installed. """ if crs_input is None: return None - pyproj = _require_pyproj() - if isinstance(crs_input, pyproj.CRS): + + # Pass through existing LiteCRS + if isinstance(crs_input, LiteCRS): + return crs_input + + # Pass through existing pyproj.CRS (if pyproj available) + pyproj = _try_import_pyproj() + if pyproj is not None and isinstance(crs_input, pyproj.CRS): return crs_input + + # Try lite CRS first + try: + return LiteCRS(crs_input) + except (ValueError, TypeError): + pass + + # Fall back to pyproj + pyproj = _require_pyproj() return pyproj.CRS(crs_input) +def _crs_from_wkt(wkt): + """Build a CRS from an OGC WKT string. + + Tries ``LiteCRS.from_wkt`` first (extracts the AUTHORITY tag), + then falls back to ``pyproj.CRS.from_wkt``. + """ + try: + return LiteCRS.from_wkt(wkt) + except (ValueError, TypeError): + pass + + pyproj = _require_pyproj() + return pyproj.CRS.from_wkt(wkt) + + def _detect_source_crs(raster): """Auto-detect the CRS of a DataArray. @@ -47,7 +97,7 @@ def _detect_source_crs(raster): crs_wkt = raster.attrs.get('crs_wkt') if crs_wkt is not None: - return _resolve_crs(crs_wkt) + return _crs_from_wkt(crs_wkt) # rioxarray fallback try: diff --git a/xrspatial/reproject/_grid.py b/xrspatial/reproject/_grid.py index 9cc2adf2..70918c9a 100644 --- a/xrspatial/reproject/_grid.py +++ b/xrspatial/reproject/_grid.py @@ -4,6 +4,38 @@ import numpy as np +def _transform_boundary(source_crs, target_crs, xs, ys): + """Transform coordinate arrays, preferring Numba fast path over pyproj. + + Parameters + ---------- + source_crs, target_crs : CRS-like + Source and target coordinate reference systems. + xs, ys : ndarray + 1-D arrays of x and y coordinates in *source_crs*. + + Returns + ------- + tx, ty : ndarray + Transformed coordinates as numpy arrays. + """ + from ._projections import transform_points + + result = transform_points(source_crs, target_crs, xs, ys) + if result is not None: + return result + + # Fall back to pyproj + from ._crs_utils import _require_pyproj + + pyproj = _require_pyproj() + transformer = pyproj.Transformer.from_crs( + source_crs, target_crs, always_xy=True + ) + tx, ty = transformer.transform(xs, ys) + return np.asarray(tx), np.asarray(ty) + + def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, resolution=None, bounds=None, width=None, height=None): """Compute the output raster grid parameters. @@ -14,7 +46,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, (left, bottom, right, top) in source CRS. source_shape : tuple (height, width) of source raster. - source_crs, target_crs : pyproj.CRS + source_crs, target_crs : CRS-like Source and target coordinate reference systems. resolution : float or tuple or None Target resolution. If tuple, (x_res, y_res). @@ -27,13 +59,6 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, ------- dict with keys: bounds, shape, res_x, res_y """ - from ._crs_utils import _require_pyproj - - pyproj = _require_pyproj() - transformer = pyproj.Transformer.from_crs( - source_crs, target_crs, always_xy=True - ) - if bounds is None: # Transform source corners and edges to target CRS src_left, src_bottom, src_right, src_top = source_bounds @@ -76,7 +101,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, ixx, iyy = np.meshgrid(ix, iy) xs = np.concatenate([edge_xs, ixx.ravel()]) ys = np.concatenate([edge_ys, iyy.ravel()]) - tx, ty = transformer.transform(xs, ys) + tx, ty = _transform_boundary(source_crs, target_crs, xs, ys) tx = np.asarray(tx) ty = np.asarray(ty) # Filter out inf/nan from failed transforms @@ -110,7 +135,9 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, ix = np.linspace(src_left, src_right, n_dense) iy = np.linspace(src_bottom, src_top, n_dense) ixx, iyy = np.meshgrid(ix, iy) - itx, ity = transformer.transform(ixx.ravel(), iyy.ravel()) + itx, ity = _transform_boundary( + source_crs, target_crs, ixx.ravel(), iyy.ravel() + ) itx = np.asarray(itx) ity = np.asarray(ity) ivalid = np.isfinite(itx) & np.isfinite(ity) @@ -150,13 +177,15 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs, src_res_y = (src_top - src_bottom) / src_h center_x = (src_left + src_right) / 2 center_y = (src_bottom + src_top) / 2 - tc_x, tc_y = transformer.transform(center_x, center_y) - # Step along x only - tx_x, tx_y = transformer.transform(center_x + src_res_x, center_y) - dx = np.hypot(float(tx_x) - float(tc_x), float(tx_y) - float(tc_y)) - # Step along y only - ty_x, ty_y = transformer.transform(center_x, center_y + src_res_y) - dy = np.hypot(float(ty_x) - float(tc_x), float(ty_y) - float(tc_y)) + # Batch the three resolution-estimation points into one call + pts_x = np.array([center_x, center_x + src_res_x, center_x]) + pts_y = np.array([center_y, center_y, center_y + src_res_y]) + tp_x, tp_y = _transform_boundary(source_crs, target_crs, pts_x, pts_y) + tc_x, tc_y = float(tp_x[0]), float(tp_y[0]) + tx_x, tx_y = float(tp_x[1]), float(tp_y[1]) + ty_x, ty_y = float(tp_x[2]), float(tp_y[2]) + dx = np.hypot(tx_x - tc_x, tx_y - tc_y) + dy = np.hypot(ty_x - tc_x, ty_y - tc_y) if dx == 0 or dy == 0: res_x = (right - left) / src_w res_y = (top - bottom) / src_h diff --git a/xrspatial/reproject/_lite_crs.py b/xrspatial/reproject/_lite_crs.py new file mode 100644 index 00000000..135898ff --- /dev/null +++ b/xrspatial/reproject/_lite_crs.py @@ -0,0 +1,406 @@ +"""Lightweight CRS class with an embedded EPSG lookup table. + +Provides a drop-in subset of ``pyproj.CRS`` for the most common EPSG +codes so that reprojection can work without pyproj installed. +""" +from __future__ import annotations + +import re +from typing import Tuple + +# ------------------------------------------------------------------- +# Ellipsoid definitions: (a, f) +# ------------------------------------------------------------------- +_ELLIPSOIDS = { + "WGS84": (6378137.0, 1.0 / 298.257223563), + "GRS80": (6378137.0, 1.0 / 298.257222101), + "clrk66": (6378206.4, 1.0 / 294.9786982), + "bessel": (6377397.155, 1.0 / 299.1528128), +} + +# ------------------------------------------------------------------- +# Named EPSG table (internal keys prefixed with _ are stripped from +# to_dict output) +# ------------------------------------------------------------------- +_EPSG_TABLE: dict[int, dict] = { + # Geographic --------------------------------------------------- + 4326: { + "proj": "longlat", + "datum": "WGS84", + "ellps": "WGS84", + "_is_geographic": True, + "_name": "WGS 84", + }, + 4269: { + "proj": "longlat", + "datum": "NAD83", + "ellps": "GRS80", + "_is_geographic": True, + "_name": "NAD83", + }, + 4267: { + "proj": "longlat", + "datum": "NAD27", + "ellps": "clrk66", + "_is_geographic": True, + "_name": "NAD27", + }, + # Web Mercator ------------------------------------------------- + 3857: { + "proj": "merc", + "datum": "WGS84", + "ellps": "WGS84", + "lat_ts": 0, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / Pseudo-Mercator", + }, + # Ellipsoidal Mercator ----------------------------------------- + 3395: { + "proj": "merc", + "datum": "WGS84", + "ellps": "WGS84", + "lat_ts": 0, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / World Mercator", + }, + # Lambert Conformal Conic -------------------------------------- + 2154: { + "proj": "lcc", + "ellps": "GRS80", + "lon_0": 3, + "lat_0": 46.5, + "lat_1": 49, + "lat_2": 44, + "x_0": 700000, + "y_0": 6600000, + "_is_geographic": False, + "_name": "RGF93 / Lambert-93", + }, + # Albers Equal Area -------------------------------------------- + 5070: { + "proj": "aea", + "datum": "NAD83", + "ellps": "GRS80", + "lon_0": -96, + "lat_0": 23, + "lat_1": 29.5, + "lat_2": 45.5, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "NAD83 / Conus Albers", + }, + # Lambert Azimuthal Equal Area --------------------------------- + 3035: { + "proj": "laea", + "ellps": "GRS80", + "lon_0": 10, + "lat_0": 52, + "x_0": 4321000, + "y_0": 3210000, + "_is_geographic": False, + "_name": "ETRS89-extended / LAEA Europe", + }, + # Polar Stereographic ------------------------------------------ + 3031: { + "proj": "stere", + "datum": "WGS84", + "ellps": "WGS84", + "lon_0": 0, + "lat_0": -90, + "lat_ts": -71, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / Antarctic Polar Stereographic", + }, + 3413: { + "proj": "stere", + "datum": "WGS84", + "ellps": "WGS84", + "lon_0": -45, + "lat_0": 90, + "lat_ts": 70, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / NSIDC Sea Ice Polar Stereographic North", + }, + 3996: { + "proj": "stere", + "datum": "WGS84", + "ellps": "WGS84", + "lon_0": 0, + "lat_0": 90, + "lat_ts": 75, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / IBCAO Polar Stereographic", + }, + # Oblique Stereographic ---------------------------------------- + 28992: { + "proj": "sterea", + "ellps": "bessel", + "lon_0": 5.38763888889, + "lat_0": 52.15616055556, + "k_0": 0.9999079, + "x_0": 155000, + "y_0": 463000, + "_is_geographic": False, + "_name": "Amersfoort / RD New", + }, + # Cylindrical Equal Area --------------------------------------- + 6933: { + "proj": "cea", + "datum": "WGS84", + "ellps": "WGS84", + "lon_0": 0, + "lat_ts": 30, + "x_0": 0, + "y_0": 0, + "_is_geographic": False, + "_name": "WGS 84 / NSIDC EASE-Grid 2.0 Global", + }, +} + + +def _make_utm_entry(zone: int, south: bool, datum: str, ellps: str) -> dict: + """Build an EPSG table entry for a UTM zone.""" + d: dict = { + "proj": "utm", + "zone": zone, + "datum": datum, + "ellps": ellps, + "x_0": 500000, + "y_0": 10000000 if south else 0, + "_is_geographic": False, + } + if south: + d["south"] = True + hemi = "S" if south else "N" + d["_name"] = f"{datum} / UTM zone {zone}{hemi}" + return d + + +def _populate_utm(): + """Add WGS84 and NAD83 UTM zones to the EPSG table.""" + # WGS84 UTM North: 32601-32660 + for zone in range(1, 61): + _EPSG_TABLE[32600 + zone] = _make_utm_entry( + zone, south=False, datum="WGS84", ellps="WGS84" + ) + # WGS84 UTM South: 32701-32760 + for zone in range(1, 61): + _EPSG_TABLE[32700 + zone] = _make_utm_entry( + zone, south=True, datum="WGS84", ellps="WGS84" + ) + # NAD83 UTM: 26901-26923 + for zone in range(1, 24): + _EPSG_TABLE[26900 + zone] = _make_utm_entry( + zone, south=False, datum="NAD83", ellps="GRS80" + ) + + +_populate_utm() + +# WKT projection name mapping +_PROJ_TO_WKT_METHOD = { + "longlat": None, # geographic, no projection method + "merc": "Mercator_1SP", + "utm": "Transverse_Mercator", + "lcc": "Lambert_Conformal_Conic_2SP", + "aea": "Albers_Conic_Equal_Area", + "laea": "Lambert_Azimuthal_Equal_Area", + "stere": "Polar_Stereographic", + "sterea": "Oblique_Stereographic", + "cea": "Cylindrical_Equal_Area", +} + +# Regex to extract AUTHORITY["EPSG","XXXX"] from WKT +_AUTHORITY_RE = re.compile(r'AUTHORITY\["EPSG"\s*,\s*"(\d+)"\]') + + +class CRS: + """Lightweight coordinate reference system, compatible with a subset + of the ``pyproj.CRS`` API. + + Parameters + ---------- + value : int or str + EPSG code (``4326``) or authority string (``"EPSG:4326"``). + """ + + __slots__ = ("_epsg", "_entry") + + def __init__(self, value: int | str): + epsg = self._parse_input(value) + if epsg not in _EPSG_TABLE: + raise ValueError( + f"EPSG:{epsg} is not in the built-in table. " + f"Install pyproj for full CRS support: " + f"pip install pyproj or: pip install xarray-spatial[reproject]" + ) + self._epsg = epsg + self._entry = _EPSG_TABLE[epsg] + + # -- construction helpers ------------------------------------------ + + @staticmethod + def _parse_input(value: int | str) -> int: + if isinstance(value, int): + return value + if isinstance(value, str): + m = re.match(r"^EPSG:(\d+)$", value, re.IGNORECASE) + if m: + return int(m.group(1)) + raise ValueError(f"Cannot parse CRS string: {value!r}") + raise TypeError(f"Expected int or str, got {type(value).__name__}") + + @classmethod + def from_epsg(cls, code: int) -> "CRS": + """Construct from an integer EPSG code.""" + return cls(code) + + @classmethod + def from_wkt(cls, wkt: str) -> "CRS": + """Construct from an OGC WKT string by extracting the AUTHORITY tag. + + Uses the last AUTHORITY["EPSG","..."] match, which in OGC WKT1 + is the outermost (root) authority. + """ + matches = _AUTHORITY_RE.findall(wkt) + if not matches: + raise ValueError("No AUTHORITY[\"EPSG\",...] found in WKT string") + epsg = int(matches[-1]) + return cls(epsg) + + # -- properties ---------------------------------------------------- + + @property + def is_geographic(self) -> bool: + """True if this CRS is geographic (lat/lon), False if projected.""" + return self._entry["_is_geographic"] + + # -- output methods ------------------------------------------------ + + def to_epsg(self) -> int: + """Return the EPSG code as an integer.""" + return self._epsg + + def to_authority(self) -> Tuple[str, str]: + """Return ``("EPSG", "")`` tuple.""" + return ("EPSG", str(self._epsg)) + + def to_dict(self) -> dict: + """Return a proj4-style parameter dictionary. + + Internal keys (prefixed with ``_``) are stripped. + """ + return {k: v for k, v in self._entry.items() if not k.startswith("_")} + + def to_wkt(self) -> str: + """Generate a minimal OGC WKT1 string with AUTHORITY tag.""" + entry = self._entry + name = entry.get("_name", f"EPSG:{self._epsg}") + ellps_name = entry.get("ellps", "WGS84") + a, f = _ELLIPSOIDS.get(ellps_name, _ELLIPSOIDS["WGS84"]) + rf = 1.0 / f if f != 0 else 0.0 + + # Datum name + datum_name = entry.get("datum", name) + + # SPHEROID + spheroid = ( + f'SPHEROID["{ellps_name}",{a},{rf}]' + ) + + if self.is_geographic: + # GEOGCS + wkt = ( + f'GEOGCS["{name}",' + f'DATUM["{datum_name}",{spheroid}],' + f'PRIMEM["Greenwich",0],' + f'UNIT["degree",0.0174532925199433],' + f'AUTHORITY["EPSG","{self._epsg}"]]' + ) + else: + # PROJCS wrapping a GEOGCS + proj = entry["proj"] + method = _PROJ_TO_WKT_METHOD.get(proj, proj) + + # Build GEOGCS for the datum + geogcs = ( + f'GEOGCS["{datum_name}",' + f'DATUM["{datum_name}",{spheroid}],' + f'PRIMEM["Greenwich",0],' + f'UNIT["degree",0.0174532925199433]]' + ) + + # Projection parameters + params = self._wkt_parameters(entry) + param_str = "," + ",".join(params) if params else "" + + wkt = ( + f'PROJCS["{name}",{geogcs},' + f'PROJECTION["{method}"]{param_str},' + f'UNIT["metre",1],' + f'AUTHORITY["EPSG","{self._epsg}"]]' + ) + + return wkt + + @staticmethod + def _wkt_parameters(entry: dict) -> list[str]: + """Build WKT PARAMETER[] entries from a proj dict.""" + # For UTM entries, expand zone into explicit TM parameters so that + # parsers (including pyproj) get the correct central meridian and + # scale factor rather than defaulting to 0 / 1. + if entry.get("proj") == "utm" and "zone" in entry: + zone = entry["zone"] + lon_0 = zone * 6 - 183 + k_0 = 0.9996 + lat_0 = 0 + x_0 = entry.get("x_0", 500000) + y_0 = entry.get("y_0", 0) + return [ + f'PARAMETER["latitude_of_origin",{lat_0}]', + f'PARAMETER["central_meridian",{lon_0}]', + f'PARAMETER["scale_factor",{k_0}]', + f'PARAMETER["false_easting",{x_0}]', + f'PARAMETER["false_northing",{y_0}]', + ] + + # Map from proj keys to WKT parameter names + key_map = { + "lat_0": "latitude_of_origin", + "lon_0": "central_meridian", + "lat_1": "standard_parallel_1", + "lat_2": "standard_parallel_2", + "lat_ts": "latitude_of_true_scale", + "k_0": "scale_factor", + "x_0": "false_easting", + "y_0": "false_northing", + } + params = [] + for proj_key, wkt_name in key_map.items(): + if proj_key in entry and not proj_key.startswith("_"): + params.append(f'PARAMETER["{wkt_name}",{entry[proj_key]}]') + return params + + # -- equality / hashing ------------------------------------------- + + def __eq__(self, other: object) -> bool: + if isinstance(other, CRS): + return self._epsg == other._epsg + return NotImplemented + + def __hash__(self) -> int: + return hash(self._epsg) + + def __repr__(self) -> str: + return f"CRS(epsg={self._epsg})" diff --git a/xrspatial/reproject/_projections.py b/xrspatial/reproject/_projections.py index 4d73a4f9..0b3739b6 100644 --- a/xrspatial/reproject/_projections.py +++ b/xrspatial/reproject/_projections.py @@ -2099,6 +2099,209 @@ def try_numba_transform(src_crs, tgt_crs, chunk_bounds, chunk_shape): return None +def transform_points(src_crs, tgt_crs, xs, ys): + """Transform scatter points from *src_crs* to *tgt_crs* using Numba kernels. + + Parameters + ---------- + src_crs, tgt_crs : CRS-like + Source and target coordinate reference systems (pyproj CRS or lite CRS). + xs, ys : array-like + 1-D arrays of x and y coordinates in *src_crs*. + + Returns + ------- + (tx, ty) : tuple of numpy arrays, or None + Transformed coordinates in *tgt_crs*, or ``None`` if no fast path + exists for this CRS pair. + + Notes + ----- + Intentional omissions (fall back to pyproj for these): + + * No datum-shift wrapping -- metre-level error is sub-pixel for the + boundary-estimation use case this function targets. + * Sinusoidal and Generic Transverse Mercator are not covered here; + those projections are dispatched via ``to_dict()['proj']`` which + requires a full pyproj CRS. + """ + src_epsg = _get_epsg(src_crs) + tgt_epsg = _get_epsg(tgt_crs) + if src_epsg is None and tgt_epsg is None: + return None + + src_is_geo = _is_supported_geographic(src_epsg) + tgt_is_geo = _is_supported_geographic(tgt_epsg) + if not src_is_geo and not tgt_is_geo: + return None + + xs = np.asarray(xs, dtype=np.float64).ravel() + ys = np.asarray(ys, dtype=np.float64).ravel() + n = xs.shape[0] + tx = np.empty(n, dtype=np.float64) + ty = np.empty(n, dtype=np.float64) + + # --- Geographic -> Web Mercator (3857) --- + if src_is_geo and tgt_epsg == 3857: + merc_forward(xs, ys, tx, ty) + return tx, ty + + if src_epsg == 3857 and tgt_is_geo: + merc_inverse(xs, ys, tx, ty) + return tx, ty + + # --- Geographic -> UTM --- + if src_is_geo: + utm = _utm_params(tgt_epsg) + if utm is not None: + lon0, k0, fe, fn = utm + Qn = k0 * _A_RECT + tmerc_forward(xs, ys, tx, ty, + lon0, k0, fe, fn, Qn, _ALPHA, _CBG) + return tx, ty + + # --- UTM -> Geographic --- + utm_src = _utm_params(src_epsg) + if utm_src is not None and tgt_is_geo: + lon0, k0, fe, fn = utm_src + Qn = k0 * _A_RECT + tmerc_inverse(xs, ys, tx, ty, + lon0, k0, fe, fn, Qn, _BETA, _CGB) + return tx, ty + + # --- Geographic -> Ellipsoidal Mercator (3395) --- + if src_is_geo and tgt_epsg == 3395: + emerc_forward(xs, ys, tx, ty, 1.0, _WGS84_E) + return tx, ty + + if src_epsg == 3395 and tgt_is_geo: + emerc_inverse(xs, ys, tx, ty, 1.0, _WGS84_E) + return tx, ty + + # --- Geographic -> LCC --- + if src_is_geo: + params = _lcc_params(tgt_crs) + if params is not None: + lon0, nn, c, rho0, k0, fe, fn, to_m = params + lcc_forward(xs, ys, tx, ty, + lon0, nn, c, rho0, k0, fe, fn, _WGS84_E, _WGS84_A) + if to_m != 1.0: + tx /= to_m + ty /= to_m + return tx, ty + + # --- LCC -> Geographic --- + if tgt_is_geo: + params = _lcc_params(src_crs) + if params is not None: + lon0, nn, c, rho0, k0, fe, fn, to_m = params + # lcc_inverse does NOT take a to_m param; pre-multiply if needed + if to_m != 1.0: + xs = xs * to_m + ys = ys * to_m + lcc_inverse(xs, ys, tx, ty, + lon0, nn, c, rho0, k0, fe, fn, _WGS84_E, _WGS84_A) + return tx, ty + + # --- Geographic -> AEA --- + if src_is_geo: + params = _aea_params(tgt_crs) + if params is not None: + lon0, nn, C, rho0, fe, fn = params + aea_forward(xs, ys, tx, ty, + lon0, nn, C, rho0, fe, fn, + _WGS84_E, _WGS84_A) + return tx, ty + + # --- AEA -> Geographic --- + if tgt_is_geo: + params = _aea_params(src_crs) + if params is not None: + lon0, nn, C, rho0, fe, fn = params + aea_inverse(xs, ys, tx, ty, + lon0, nn, C, rho0, fe, fn, + _WGS84_E, _WGS84_A, _QP, _APA) + return tx, ty + + # --- Geographic -> CEA --- + if src_is_geo: + params = _cea_params(tgt_crs) + if params is not None: + lon0, k0, fe, fn = params + cea_forward(xs, ys, tx, ty, + lon0, k0, fe, fn, + _WGS84_E, _WGS84_A, _QP) + return tx, ty + + # --- CEA -> Geographic --- + if tgt_is_geo: + params = _cea_params(src_crs) + if params is not None: + lon0, k0, fe, fn = params + cea_inverse(xs, ys, tx, ty, + lon0, k0, fe, fn, + _WGS84_E, _WGS84_A, _QP, _APA) + return tx, ty + + # --- Geographic -> LAEA --- + if src_is_geo: + params = _laea_params(tgt_crs) + if params is not None: + lon0, lat0, sinb1, cosb1, dd, xmf, ymf, rq, qp, fe, fn, mode = params + laea_forward(xs, ys, tx, ty, + lon0, sinb1, cosb1, xmf, ymf, rq, qp, + fe, fn, _WGS84_E, _WGS84_A, _WGS84_E2, mode) + return tx, ty + + # --- LAEA -> Geographic --- + if tgt_is_geo: + params = _laea_params(src_crs) + if params is not None: + lon0, lat0, sinb1, cosb1, dd, xmf, ymf, rq, qp, fe, fn, mode = params + laea_inverse(xs, ys, tx, ty, + lon0, sinb1, cosb1, xmf, ymf, rq, qp, + fe, fn, _WGS84_E, _WGS84_A, _WGS84_E2, mode, _APA) + return tx, ty + + # --- Geographic -> Polar Stereographic --- + if src_is_geo: + params = _stere_params(tgt_crs) + if params is not None: + lon0, k0, akm1, fe, fn, is_south = params + stere_forward(xs, ys, tx, ty, + lon0, akm1, fe, fn, _WGS84_E, is_south) + return tx, ty + + # --- Polar Stereographic -> Geographic --- + if tgt_is_geo: + params = _stere_params(src_crs) + if params is not None: + lon0, k0, akm1, fe, fn, is_south = params + stere_inverse(xs, ys, tx, ty, + lon0, akm1, fe, fn, _WGS84_E, is_south) + return tx, ty + + # --- Geographic -> Oblique Stereographic --- + if src_is_geo: + params = _sterea_params(tgt_crs) + if params is not None: + lon0, sinc0, cosc0, R2, C, K, ratexp, fe, fn, e = params + sterea_forward(xs, ys, tx, ty, + lon0, sinc0, cosc0, R2, C, K, ratexp, fe, fn, e) + return tx, ty + + # --- Oblique Stereographic -> Geographic --- + if tgt_is_geo: + params = _sterea_params(src_crs) + if params is not None: + lon0, sinc0, cosc0, R2, C, K, ratexp, fe, fn, e = params + sterea_inverse(xs, ys, tx, ty, + lon0, sinc0, cosc0, R2, C, K, ratexp, fe, fn, e) + return tx, ty + + return None + + # Wrap try_numba_transform with datum shift support _try_numba_transform_inner = try_numba_transform diff --git a/xrspatial/tests/test_lite_crs.py b/xrspatial/tests/test_lite_crs.py new file mode 100644 index 00000000..4c8d1e3d --- /dev/null +++ b/xrspatial/tests/test_lite_crs.py @@ -0,0 +1,379 @@ +"""Tests for the lightweight CRS class (xrspatial.reproject._lite_crs).""" +from __future__ import annotations + +import pytest + +from xrspatial.reproject._lite_crs import CRS + + +# ----------------------------------------------------------------------- +# Construction +# ----------------------------------------------------------------------- +class TestCRSConstruction: + def test_from_epsg_int(self): + crs = CRS(4326) + assert crs.to_epsg() == 4326 + + def test_from_epsg_classmethod(self): + crs = CRS.from_epsg(3857) + assert crs.to_epsg() == 3857 + + def test_from_authority_string(self): + crs = CRS("EPSG:4326") + assert crs.to_epsg() == 4326 + + def test_unknown_epsg_raises(self): + with pytest.raises(ValueError, match="not in the built-in table"): + CRS(99999) + + def test_to_authority(self): + crs = CRS(4326) + assert crs.to_authority() == ("EPSG", "4326") + + def test_is_geographic_true(self): + crs = CRS(4326) + assert crs.is_geographic is True + + def test_is_geographic_false(self): + crs = CRS(3857) + assert crs.is_geographic is False + + +# ----------------------------------------------------------------------- +# Equality & hashing +# ----------------------------------------------------------------------- +class TestCRSEquality: + def test_equal_same_epsg(self): + assert CRS(4326) == CRS("EPSG:4326") + + def test_not_equal_different_epsg(self): + assert CRS(4326) != CRS(3857) + + def test_hash_equal(self): + assert hash(CRS(4326)) == hash(CRS.from_epsg(4326)) + + def test_hash_in_set(self): + s = {CRS(4326), CRS(3857), CRS("EPSG:4326")} + assert len(s) == 2 + + +# ----------------------------------------------------------------------- +# to_dict for every named EPSG code +# ----------------------------------------------------------------------- +class TestCRSToDict: + def test_geographic_4326(self): + d = CRS(4326).to_dict() + assert d["proj"] == "longlat" + assert d["datum"] == "WGS84" + assert not any(k.startswith("_") for k in d) + + def test_utm_north_zone_32(self): + d = CRS(32632).to_dict() + assert d["proj"] == "utm" + assert d["zone"] == 32 + assert d.get("south") is None or d.get("south") is False + + def test_utm_south_zone_55(self): + d = CRS(32755).to_dict() + assert d["proj"] == "utm" + assert d["zone"] == 55 + assert d["south"] is True + + def test_utm_nad83_zone_10(self): + d = CRS(26910).to_dict() + assert d["proj"] == "utm" + assert d["zone"] == 10 + assert d["datum"] == "NAD83" + + def test_lcc_2154(self): + d = CRS(2154).to_dict() + assert d["proj"] == "lcc" + assert d["lon_0"] == 3 + assert d["lat_1"] == 49 + assert d["lat_2"] == 44 + + def test_aea_5070(self): + d = CRS(5070).to_dict() + assert d["proj"] == "aea" + assert d["lon_0"] == -96 + assert d["lat_1"] == 29.5 + assert d["lat_2"] == 45.5 + + def test_web_mercator_3857(self): + d = CRS(3857).to_dict() + assert d["proj"] == "merc" + assert d["datum"] == "WGS84" + + def test_laea_3035(self): + d = CRS(3035).to_dict() + assert d["proj"] == "laea" + assert d["lon_0"] == 10 + assert d["lat_0"] == 52 + + def test_stere_3031(self): + d = CRS(3031).to_dict() + assert d["proj"] == "stere" + assert d["lat_0"] == -90 + assert d["lat_ts"] == -71 + + def test_sterea_28992(self): + d = CRS(28992).to_dict() + assert d["proj"] == "sterea" + assert d["k_0"] == 0.9999079 + + def test_cea_6933(self): + d = CRS(6933).to_dict() + assert d["proj"] == "cea" + assert d["lat_ts"] == 30 + + +# ----------------------------------------------------------------------- +# WKT round-trip +# ----------------------------------------------------------------------- +class TestCRSWktRoundTrip: + def test_roundtrip_geographic(self): + crs = CRS(4326) + wkt = crs.to_wkt() + crs2 = CRS.from_wkt(wkt) + assert crs2.to_epsg() == 4326 + + def test_roundtrip_projected(self): + crs = CRS(32632) + wkt = crs.to_wkt() + crs2 = CRS.from_wkt(wkt) + assert crs2.to_epsg() == 32632 + + def test_roundtrip_all_named(self): + named_codes = [ + 4326, 4269, 4267, + 3857, 3395, + 2154, 5070, 3035, + 3031, 3413, 3996, + 28992, 6933, + ] + for code in named_codes: + crs = CRS(code) + wkt = crs.to_wkt() + crs2 = CRS.from_wkt(wkt) + assert crs2.to_epsg() == code, f"round-trip failed for EPSG:{code}" + + def test_wkt_contains_authority(self): + crs = CRS(4326) + wkt = crs.to_wkt() + assert 'AUTHORITY["EPSG","4326"]' in wkt + + +# ----------------------------------------------------------------------- +# Two-tier CRS resolution (_crs_utils integration) +# ----------------------------------------------------------------------- +try: + import pyproj as _pyproj_mod + _HAS_PYPROJ = True +except ImportError: + _HAS_PYPROJ = False + +from xrspatial.reproject._crs_utils import _resolve_crs, _crs_from_wkt + + +class TestTwoTierResolution: + def test_resolve_crs_int_uses_lite(self): + result = _resolve_crs(4326) + assert isinstance(result, CRS) + assert result.to_epsg() == 4326 + + def test_resolve_crs_string_uses_lite(self): + result = _resolve_crs("EPSG:32632") + assert isinstance(result, CRS) + assert result.to_epsg() == 32632 + + @pytest.mark.skipif(not _HAS_PYPROJ, reason="pyproj not installed") + def test_resolve_crs_unknown_falls_back(self): + result = _resolve_crs(2193) + assert not isinstance(result, CRS) + assert hasattr(result, "to_epsg") + assert result.to_epsg() == 2193 + + def test_resolve_crs_none_returns_none(self): + assert _resolve_crs(None) is None + + def test_resolve_crs_passes_through_lite_crs(self): + lite = CRS(4326) + result = _resolve_crs(lite) + assert result is lite + + @pytest.mark.skipif(not _HAS_PYPROJ, reason="pyproj not installed") + def test_resolve_crs_passes_through_pyproj(self): + pp = _pyproj_mod.CRS.from_epsg(4326) + result = _resolve_crs(pp) + assert result is pp + + def test_crs_from_wkt_lite(self): + wkt = CRS(4326).to_wkt() + result = _crs_from_wkt(wkt) + assert isinstance(result, CRS) + assert result.to_epsg() == 4326 + + +# ----------------------------------------------------------------------- +# Grid computation with lite CRS (no pyproj needed) +# ----------------------------------------------------------------------- +class TestGridWithoutPyproj: + def test_compute_output_grid_with_lite_crs(self): + from xrspatial.reproject._grid import _compute_output_grid + from xrspatial.reproject._lite_crs import CRS + + src_crs = CRS(4326) + tgt_crs = CRS(32632) + source_bounds = (6.0, 47.0, 12.0, 55.0) + source_shape = (64, 64) + grid = _compute_output_grid( + source_bounds, source_shape, src_crs, tgt_crs + ) + assert 'bounds' in grid + assert 'shape' in grid + h, w = grid['shape'] + assert h > 0 and w > 0 + left, bottom, right, top = grid['bounds'] + assert right > left + assert top > bottom + + +# ----------------------------------------------------------------------- +# Integration: CRS resolution works when pyproj is blocked +# ----------------------------------------------------------------------- +class TestNoPyproj: + """Verify CRS resolution works for supported codes when pyproj is absent.""" + + def test_resolve_without_pyproj(self, monkeypatch): + """_resolve_crs and _crs_from_wkt work without pyproj for known EPSG codes.""" + import sys + from xrspatial.reproject._lite_crs import CRS + + # Block pyproj import + monkeypatch.setitem(sys.modules, 'pyproj', None) + + from xrspatial.reproject._crs_utils import _resolve_crs, _crs_from_wkt + + src_crs = _resolve_crs(4326) + assert isinstance(src_crs, CRS) + assert src_crs.to_epsg() == 4326 + + tgt_crs = _resolve_crs("EPSG:32632") + assert isinstance(tgt_crs, CRS) + assert tgt_crs.is_geographic is False + + # _crs_from_wkt round-trips without pyproj + wkt = src_crs.to_wkt() + restored = _crs_from_wkt(wkt) + assert isinstance(restored, CRS) + assert restored.to_epsg() == 4326 + + def test_unknown_epsg_without_pyproj_raises(self, monkeypatch): + """Unknown EPSG codes raise clear error when pyproj is absent.""" + import sys + monkeypatch.setitem(sys.modules, 'pyproj', None) + from xrspatial.reproject._crs_utils import _resolve_crs + with pytest.raises((ImportError, ValueError)): + _resolve_crs(2193) + + +# ----------------------------------------------------------------------- +# Validate against pyproj (skipped when pyproj not installed) +# ----------------------------------------------------------------------- +pyproj = pytest.importorskip("pyproj", reason="pyproj not installed") + +# All named codes + a selection of UTM zones +_VALIDATE_CODES = [ + 4326, 4269, 4267, + 3857, 3395, + 2154, 5070, 3035, + 3031, 3413, 3996, + 28992, 6933, + 32601, 32632, 32660, + 32701, 32755, 32760, + 26901, 26910, 26923, +] + + +# ----------------------------------------------------------------------- +# transform_points +# ----------------------------------------------------------------------- +import numpy as np +from xrspatial.reproject._projections import transform_points + + +class TestTransformPoints: + def test_wgs84_to_web_mercator(self): + xs = np.array([0.0, 10.0, -75.5]) + ys = np.array([0.0, 45.0, 40.0]) + src = pyproj.CRS.from_epsg(4326) + tgt = pyproj.CRS.from_epsg(3857) + result = transform_points(src, tgt, xs, ys) + assert result is not None + tx, ty = result + # lon=0 lat=0 -> (0, 0) in Web Mercator + assert abs(tx[0]) < 1.0 + assert abs(ty[0]) < 1.0 + # lon=10 -> positive easting + assert tx[1] > 1e6 + + def test_wgs84_to_utm_zone32(self): + # Central meridian of UTM zone 32 is 9 deg E + xs = np.array([9.0]) + ys = np.array([0.0]) + src = pyproj.CRS.from_epsg(4326) + tgt = pyproj.CRS.from_epsg(32632) + result = transform_points(src, tgt, xs, ys) + assert result is not None + tx, ty = result + # On the central meridian, easting should be 500000 + assert abs(tx[0] - 500000.0) < 1.0 + + def test_unsupported_pair_returns_none(self): + # Two projected CRS -> no fast path + src = pyproj.CRS.from_epsg(32632) + tgt = pyproj.CRS.from_epsg(32633) + result = transform_points(src, tgt, [500000], [0]) + assert result is None + + def test_matches_pyproj_transformer(self): + # WGS84 -> Albers Equal Area (EPSG:5070), 20 random points + rng = np.random.RandomState(42) + xs = rng.uniform(-120, -70, 20) + ys = rng.uniform(25, 50, 20) + src = pyproj.CRS.from_epsg(4326) + tgt = pyproj.CRS.from_epsg(5070) + result = transform_points(src, tgt, xs, ys) + assert result is not None + tx, ty = result + # Compare against pyproj + transformer = pyproj.Transformer.from_crs(src, tgt, always_xy=True) + ref_x, ref_y = transformer.transform(xs, ys) + # Tolerance is ~1 m because transform_points skips datum shifts + # (metre-level error is sub-pixel for boundary estimation). + np.testing.assert_allclose(tx, ref_x, atol=1.0) + np.testing.assert_allclose(ty, ref_y, atol=1.0) + + +class TestCRSMatchesPyproj: + @pytest.mark.parametrize("epsg", _VALIDATE_CODES) + def test_to_dict_proj_key_matches(self, epsg): + lite = CRS(epsg).to_dict() + ref = pyproj.CRS.from_epsg(epsg).to_dict() + assert lite["proj"] == ref["proj"], ( + f"EPSG:{epsg} proj mismatch: {lite['proj']} != {ref['proj']}" + ) + + @pytest.mark.parametrize("epsg", _VALIDATE_CODES) + def test_is_geographic_matches(self, epsg): + lite_geo = CRS(epsg).is_geographic + ref_geo = pyproj.CRS.from_epsg(epsg).is_geographic + assert lite_geo == ref_geo, ( + f"EPSG:{epsg} is_geographic mismatch: {lite_geo} != {ref_geo}" + ) + + @pytest.mark.parametrize("epsg", _VALIDATE_CODES) + def test_equality_with_pyproj(self, epsg): + a = CRS(epsg) + b = CRS.from_epsg(epsg) + assert a == b diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 303e9365..d883d9ef 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1313,3 +1313,22 @@ def test_bounds_overlap(self): assert _bounds_overlap(a, (0, 0, 10, 10)) # identical assert not _bounds_overlap(a, (11, 0, 20, 10)) # no overlap x assert not _bounds_overlap(a, (0, 11, 10, 20)) # no overlap y + + +class TestReprojWithLiteCRS: + def test_reproject_wgs84_to_utm_with_lite_crs(self): + import xarray as xr + from xrspatial.reproject import reproject + import numpy as np + h, w = 32, 32 + y = np.linspace(49, 47, h) + x = np.linspace(8, 10, w) + data = np.random.default_rng(42).random((h, w)) + raster = xr.DataArray( + data, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + ) + result = reproject(raster, target_crs=32632) + assert result.attrs['crs'] is not None + assert result.shape[0] > 0 and result.shape[1] > 0