diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index c35fd89a..fe14d74c 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -574,10 +574,10 @@ def reproject( else: is_cupy = is_cupy_array(data) - # For very large datasets, estimate whether a dask graph would fit - # in memory. Each dask task uses ~1KB of graph metadata. If the - # graph itself would exceed available memory, use a streaming - # approach instead of dask (process tiles sequentially, no graph). + # For large in-memory datasets, wrap in dask for chunked processing. + # map_blocks generates an O(1) HighLevelGraph (single blockwise layer) + # so graph metadata is no longer a concern -- the streaming fallback + # is only needed when dask itself is unavailable. _use_streaming = False if not is_dask and not is_cupy: nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize @@ -585,19 +585,10 @@ def reproject( nbytes *= data.shape[2] _OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB if nbytes > _OOM_THRESHOLD: - # Estimate graph size for the output cs = chunk_size or 2048 if isinstance(cs, int): cs = (cs, cs) - n_out_chunks = (math.ceil(out_shape[0] / cs[0]) - * math.ceil(out_shape[1] / cs[1])) - graph_bytes = n_out_chunks * 1024 # ~1KB per task - - if graph_bytes > 1024 * 1024 * 1024: # > 1GB graph - # Graph too large for dask -- use streaming - _use_streaming = True - else: - # Graph fits -- use dask with large chunks + try: import dask.array as _da data = _da.from_array(data, chunks=cs) raster = xr.DataArray( @@ -605,6 +596,9 @@ def reproject( name=raster.name, attrs=raster.attrs, ) is_dask = True + except ImportError: + # dask not available -- fall back to streaming + _use_streaming = True # Serialize CRS for pickle safety src_wkt = src_crs.to_wkt() @@ -1125,6 +1119,77 @@ def _reproject_dask_cupy( return result +def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt): + """Compute an approximate bounding box of the source raster in target CRS. + + Transforms corners and edge midpoints (12 points) to handle non-linear + projections. Returns ``(left, bottom, right, top)`` in target CRS, or + *None* if the transform fails (e.g. out-of-domain). + """ + try: + from ._crs_utils import _require_pyproj + pyproj = _require_pyproj() + src_crs = pyproj.CRS(src_wkt) + tgt_crs = pyproj.CRS(tgt_wkt) + transformer = pyproj.Transformer.from_crs( + src_crs, tgt_crs, always_xy=True + ) + except Exception: + return None + + sl, sb, sr, st = src_bounds + mx = (sl + sr) / 2 + my = (sb + st) / 2 + xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx] + ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb] + try: + tx, ty = transformer.transform(xs, ys) + tx = [v for v in tx if np.isfinite(v)] + ty = [v for v in ty if np.isfinite(v)] + if not tx or not ty: + return None + return (min(tx), min(ty), max(tx), max(ty)) + except Exception: + return None + + +def _bounds_overlap(a, b): + """Return True if bounding boxes *a* and *b* overlap.""" + return a[0] < b[2] and a[2] > b[0] and a[1] < b[3] and a[3] > b[1] + + +def _reproject_block_adapter( + block, block_info, source_data, + src_bounds, src_shape, y_desc, + src_wkt, tgt_wkt, + out_bounds, out_shape, + resampling, nodata, precision, + is_cupy, src_footprint_tgt, +): + """``map_blocks`` adapter for reprojection. + + Derives chunk bounds from *block_info* and delegates to the + per-chunk worker. + """ + info = block_info[0] + (row_start, row_end), (col_start, col_end) = info['array-location'] + chunk_shape = (row_end - row_start, col_end - col_start) + cb = _chunk_bounds(out_bounds, out_shape, + row_start, row_end, col_start, col_end) + + # Skip chunks that don't overlap the source footprint + if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt): + return np.full(chunk_shape, nodata, dtype=np.float64) + + chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy + return chunk_fn( + source_data, src_bounds, src_shape, y_desc, + src_wkt, tgt_wkt, + cb, chunk_shape, + resampling, nodata, precision, + ) + + def _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, @@ -1132,44 +1197,58 @@ def _reproject_dask( resampling, nodata, precision, chunk_size, is_cupy, ): - """Dask+NumPy backend: build output as ``da.block`` of delayed chunks.""" - import dask + """Dask+NumPy backend: ``map_blocks`` over a template array. + + Uses a single ``blockwise`` layer in the HighLevelGraph instead of + O(N) ``dask.delayed`` nodes, keeping graph metadata O(1). + + The source dask array is bound to the adapter via ``functools.partial`` + rather than passed as a ``map_blocks`` kwarg. This prevents dask from + adding the full source as a dependency of every output block (which + would cause a MemoryError on distributed schedulers when the source + exceeds the worker memory limit). + """ + import functools + import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) - n_row = len(row_chunks) - n_col = len(col_chunks) - chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy - dtype = np.float64 + # Precompute source footprint in target CRS for empty-chunk skipping + src_footprint_tgt = _source_footprint_in_target( + src_bounds, src_wkt, tgt_wkt + ) - blocks = [[None] * n_col for _ in range(n_row)] + # Bind the source dask array and all scalar params via partial so + # map_blocks doesn't detect them as dask Array kwargs (which would + # add the full source as a dependency of every output block). + bound_adapter = functools.partial( + _reproject_block_adapter, + source_data=raster.data, + src_bounds=src_bounds, + src_shape=src_shape, + y_desc=y_desc, + src_wkt=src_wkt, + tgt_wkt=tgt_wkt, + out_bounds=out_bounds, + out_shape=out_shape, + resampling=resampling, + nodata=nodata, + precision=precision, + is_cupy=is_cupy, + src_footprint_tgt=src_footprint_tgt, + ) - row_offset = 0 - for i in range(n_row): - col_offset = 0 - for j in range(n_col): - rchunk = row_chunks[i] - cchunk = col_chunks[j] - cb = _chunk_bounds( - out_bounds, out_shape, - row_offset, row_offset + rchunk, - col_offset, col_offset + cchunk, - ) - delayed_chunk = dask.delayed(chunk_fn)( - raster.data, - src_bounds, src_shape, y_desc, - src_wkt, tgt_wkt, - cb, (rchunk, cchunk), - resampling, nodata, precision, - ) - blocks[i][j] = da.from_delayed( - delayed_chunk, shape=(rchunk, cchunk), dtype=dtype - ) - col_offset += cchunk - row_offset += rchunk + template = da.empty( + out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks) + ) - return da.block(blocks) + return da.map_blocks( + bound_adapter, + template, + dtype=np.float64, + meta=np.array((), dtype=np.float64), + ) # --------------------------------------------------------------------------- @@ -1434,23 +1513,38 @@ def _merge_inmemory( return _merge_arrays_numpy(arrays, nodata, strategy) -def _merge_chunk_worker( +def _merge_block_adapter( + block, block_info, raster_data_list, src_bounds_list, src_shape_list, y_desc_list, src_wkt_list, tgt_wkt, - chunk_bounds_tuple, chunk_shape, + out_bounds, out_shape, resampling, nodata, strategy, precision, + src_footprints_tgt, ): - """Worker for a single merge chunk.""" + """``map_blocks`` adapter for merge.""" + info = block_info[0] + (row_start, row_end), (col_start, col_end) = info['array-location'] + chunk_shape = (row_end - row_start, col_end - col_start) + cb = _chunk_bounds(out_bounds, out_shape, + row_start, row_end, col_start, col_end) + + # Only reproject rasters whose footprint overlaps this chunk arrays = [] for i in range(len(raster_data_list)): + if (src_footprints_tgt[i] is not None + and not _bounds_overlap(cb, src_footprints_tgt[i])): + continue reprojected = _reproject_chunk_numpy( raster_data_list[i], src_bounds_list[i], src_shape_list[i], y_desc_list[i], src_wkt_list[i], tgt_wkt, - chunk_bounds_tuple, chunk_shape, + cb, chunk_shape, resampling, nodata, precision, ) arrays.append(reprojected) + + if not arrays: + return np.full(chunk_shape, nodata, dtype=np.float64) return _merge_arrays_numpy(arrays, nodata, strategy) @@ -1458,13 +1552,12 @@ def _merge_dask( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, chunk_size, ): - """Dask merge backend.""" - import dask + """Dask merge backend using ``map_blocks``.""" + import functools + import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) - n_row = len(row_chunks) - n_col = len(col_chunks) # Prepare lists for the worker data_list = [info['raster'].data for info in raster_infos] @@ -1473,30 +1566,38 @@ def _merge_dask( ydesc_list = [info['y_desc'] for info in raster_infos] wkt_list = [info['src_wkt'] for info in raster_infos] - dtype = np.float64 - blocks = [[None] * n_col for _ in range(n_row)] + # Precompute source footprints in target CRS + footprints = [ + _source_footprint_in_target(bounds_list[i], wkt_list[i], tgt_wkt) + for i in range(len(raster_infos)) + ] + + # Bind via partial to prevent map_blocks from adding dask arrays + # in data_list as whole-array dependencies. + bound_adapter = functools.partial( + _merge_block_adapter, + raster_data_list=data_list, + src_bounds_list=bounds_list, + src_shape_list=shape_list, + y_desc_list=ydesc_list, + src_wkt_list=wkt_list, + tgt_wkt=tgt_wkt, + out_bounds=out_bounds, + out_shape=out_shape, + resampling=resampling, + nodata=nodata, + strategy=strategy, + precision=16, + src_footprints_tgt=footprints, + ) - row_offset = 0 - for i in range(n_row): - col_offset = 0 - for j in range(n_col): - rchunk = row_chunks[i] - cchunk = col_chunks[j] - cb = _chunk_bounds( - out_bounds, out_shape, - row_offset, row_offset + rchunk, - col_offset, col_offset + cchunk, - ) - delayed_chunk = dask.delayed(_merge_chunk_worker)( - data_list, bounds_list, shape_list, ydesc_list, - wkt_list, tgt_wkt, - cb, (rchunk, cchunk), - resampling, nodata, strategy, 16, - ) - blocks[i][j] = da.from_delayed( - delayed_chunk, shape=(rchunk, cchunk), dtype=dtype - ) - col_offset += cchunk - row_offset += rchunk + template = da.empty( + out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks) + ) - return da.block(blocks) + return da.map_blocks( + bound_adapter, + template, + dtype=np.float64, + meta=np.array((), dtype=np.float64), + ) diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 75e2b258..303e9365 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1159,3 +1159,157 @@ def test_cubic_nan_fallback_matches_cpu(self): finite = np.isfinite(cpu_result) np.testing.assert_allclose(cpu_result[finite], gpu_np[finite], rtol=1e-10) + + +# --------------------------------------------------------------------------- +# Dask graph optimization tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_DASK, reason="dask not installed") +class TestDaskGraphOptimization: + """Verify map_blocks conversion and empty-chunk skipping.""" + + def test_dask_reproject_uses_map_blocks(self): + """The dask path should produce a blockwise layer, not N delayed nodes.""" + from xrspatial.reproject import reproject + data = np.ones((64, 64), dtype=np.float64) + da_data = da.from_array(data, chunks=(32, 32)) + raster = xr.DataArray( + da_data, dims=['y', 'x'], + coords={'y': np.linspace(55, 45, 64), 'x': np.linspace(-5, 5, 64)}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + result = reproject(raster, 'EPSG:32633', chunk_size=32) + # Result should be a dask array + assert hasattr(result.data, 'dask') + # Should have few graph layers (map_blocks creates 1-2, not N) + graph = result.data.__dask_graph__() + assert len(graph.layers) <= 3 + + def test_source_not_whole_array_dependency(self): + """Source dask array should not be a dependency of every output block. + + When source_data is passed as a map_blocks kwarg, dask adds the + full source as a dependency of every output block -- this causes + MemoryError on distributed schedulers when the source exceeds + worker memory. Using functools.partial avoids this. + """ + from xrspatial.reproject import reproject + data = np.ones((64, 64), dtype=np.float64) + da_data = da.from_array(data, chunks=(32, 32)) + src_name = da_data.name # e.g. 'array-abc123' + raster = xr.DataArray( + da_data, dims=['y', 'x'], + coords={'y': np.linspace(55, 45, 64), 'x': np.linspace(-5, 5, 64)}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + result = reproject(raster, 'EPSG:32633', chunk_size=32) + graph = result.data.__dask_graph__() + # The source array's layer should NOT be in the output graph's + # dependencies (it's captured in the function closure instead). + assert src_name not in graph.layers, ( + f"source array '{src_name}' should not be a graph layer " + f"dependency -- use functools.partial to bind it" + ) + + def test_dask_reproject_matches_numpy(self): + """Dask map_blocks path should produce same values as numpy.""" + from xrspatial.reproject import reproject + data = np.random.RandomState(42).rand(64, 64).astype(np.float64) + coords = { + 'y': np.linspace(55, 45, 64), + 'x': np.linspace(-5, 5, 64), + } + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + + np_raster = xr.DataArray(data, dims=['y', 'x'], + coords=coords, attrs=attrs) + da_raster = xr.DataArray( + da.from_array(data, chunks=(32, 32)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + np_result = reproject(np_raster, 'EPSG:32633') + da_result = reproject(da_raster, 'EPSG:32633') + + np_vals = np_result.values + da_vals = da_result.values + # Same shape + assert np_vals.shape == da_vals.shape + # Same NaN pattern + np.testing.assert_array_equal(np.isnan(np_vals), np.isnan(da_vals)) + # Same finite values + finite = np.isfinite(np_vals) + if finite.any(): + np.testing.assert_allclose(np_vals[finite], da_vals[finite], + rtol=1e-10) + + def test_empty_chunk_skipping(self): + """Chunks outside the source footprint should be nodata-filled + without touching pyproj.""" + import dask + + from xrspatial.reproject import reproject + # Small raster in a corner of the output grid + data = np.ones((16, 16), dtype=np.float64) * 42.0 + raster = xr.DataArray( + da.from_array(data, chunks=(16, 16)), + dims=['y', 'x'], + coords={'y': np.linspace(50.1, 50.0, 16), + 'x': np.linspace(10.0, 10.1, 16)}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + # Force a large output grid with small chunks so many are empty. + # Use synchronous scheduler to avoid PROJ C library thread-safety + # crashes on macOS when many chunks call pyproj.CRS concurrently. + with dask.config.set(scheduler='synchronous'): + result = reproject(raster, 'EPSG:32633', chunk_size=64, + width=256, height=256) + vals = result.values + # Should have some valid pixels and some NaN (empty chunks) + assert np.any(np.isfinite(vals)) + assert np.any(np.isnan(vals)) + + def test_merge_dask_uses_map_blocks(self): + """The merge dask path should also use map_blocks.""" + from xrspatial.reproject import merge + t1 = xr.DataArray( + da.from_array(np.full((32, 32), 1.0), chunks=(16, 16)), + dims=['y', 'x'], + coords={'y': np.linspace(55, 50, 32), + 'x': np.linspace(-5, 0, 32)}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + t2 = xr.DataArray( + da.from_array(np.full((32, 32), 2.0), chunks=(16, 16)), + dims=['y', 'x'], + coords={'y': np.linspace(50, 45, 32), + 'x': np.linspace(0, 5, 32)}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + result = merge([t1, t2]) + vals = result.values + assert np.any(np.isfinite(vals)) + + def test_source_footprint_helper(self): + """_source_footprint_in_target should return a valid bbox.""" + from xrspatial.reproject import _source_footprint_in_target + src_bounds = (-5.0, 45.0, 5.0, 55.0) + fp = _source_footprint_in_target( + src_bounds, 'EPSG:4326', 'EPSG:32633' + ) + # Should return a tuple of 4 finite values + assert fp is not None + assert len(fp) == 4 + assert all(np.isfinite(v) for v in fp) + # left < right, bottom < top + assert fp[0] < fp[2] + assert fp[1] < fp[3] + + def test_bounds_overlap(self): + """_bounds_overlap should correctly detect overlap.""" + from xrspatial.reproject import _bounds_overlap + a = (0, 0, 10, 10) + assert _bounds_overlap(a, (5, 5, 15, 15)) # partial overlap + assert _bounds_overlap(a, (0, 0, 10, 10)) # identical + assert not _bounds_overlap(a, (11, 0, 20, 10)) # no overlap x + assert not _bounds_overlap(a, (0, 11, 10, 20)) # no overlap y