Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 179 additions & 78 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,37 +574,31 @@ def reproject(
else:
is_cupy = is_cupy_array(data)

# For very large datasets, estimate whether a dask graph would fit
# in memory. Each dask task uses ~1KB of graph metadata. If the
# graph itself would exceed available memory, use a streaming
# approach instead of dask (process tiles sequentially, no graph).
# For large in-memory datasets, wrap in dask for chunked processing.
# map_blocks generates an O(1) HighLevelGraph (single blockwise layer)
# so graph metadata is no longer a concern -- the streaming fallback
# is only needed when dask itself is unavailable.
_use_streaming = False
if not is_dask and not is_cupy:
nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize
if data.ndim == 3:
nbytes *= data.shape[2]
_OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB
if nbytes > _OOM_THRESHOLD:
# Estimate graph size for the output
cs = chunk_size or 2048
if isinstance(cs, int):
cs = (cs, cs)
n_out_chunks = (math.ceil(out_shape[0] / cs[0])
* math.ceil(out_shape[1] / cs[1]))
graph_bytes = n_out_chunks * 1024 # ~1KB per task

if graph_bytes > 1024 * 1024 * 1024: # > 1GB graph
# Graph too large for dask -- use streaming
_use_streaming = True
else:
# Graph fits -- use dask with large chunks
try:
import dask.array as _da
data = _da.from_array(data, chunks=cs)
raster = xr.DataArray(
data, dims=raster.dims, coords=raster.coords,
name=raster.name, attrs=raster.attrs,
)
is_dask = True
except ImportError:
# dask not available -- fall back to streaming
_use_streaming = True

# Serialize CRS for pickle safety
src_wkt = src_crs.to_wkt()
Expand Down Expand Up @@ -1125,51 +1119,136 @@ def _reproject_dask_cupy(
return result


def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt):
"""Compute an approximate bounding box of the source raster in target CRS.

Transforms corners and edge midpoints (12 points) to handle non-linear
projections. Returns ``(left, bottom, right, top)`` in target CRS, or
*None* if the transform fails (e.g. out-of-domain).
"""
try:
from ._crs_utils import _require_pyproj
pyproj = _require_pyproj()
src_crs = pyproj.CRS(src_wkt)
tgt_crs = pyproj.CRS(tgt_wkt)
transformer = pyproj.Transformer.from_crs(
src_crs, tgt_crs, always_xy=True
)
except Exception:
return None

sl, sb, sr, st = src_bounds
mx = (sl + sr) / 2
my = (sb + st) / 2
xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]
ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]
try:
tx, ty = transformer.transform(xs, ys)
tx = [v for v in tx if np.isfinite(v)]
ty = [v for v in ty if np.isfinite(v)]
if not tx or not ty:
return None
return (min(tx), min(ty), max(tx), max(ty))
except Exception:
return None


def _bounds_overlap(a, b):
"""Return True if bounding boxes *a* and *b* overlap."""
return a[0] < b[2] and a[2] > b[0] and a[1] < b[3] and a[3] > b[1]


def _reproject_block_adapter(
block, block_info, source_data,
src_bounds, src_shape, y_desc,
src_wkt, tgt_wkt,
out_bounds, out_shape,
resampling, nodata, precision,
is_cupy, src_footprint_tgt,
):
"""``map_blocks`` adapter for reprojection.

Derives chunk bounds from *block_info* and delegates to the
per-chunk worker.
"""
info = block_info[0]
(row_start, row_end), (col_start, col_end) = info['array-location']
chunk_shape = (row_end - row_start, col_end - col_start)
cb = _chunk_bounds(out_bounds, out_shape,
row_start, row_end, col_start, col_end)

# Skip chunks that don't overlap the source footprint
if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt):
return np.full(chunk_shape, nodata, dtype=np.float64)

chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
return chunk_fn(
source_data, src_bounds, src_shape, y_desc,
src_wkt, tgt_wkt,
cb, chunk_shape,
resampling, nodata, precision,
)


def _reproject_dask(
raster, src_bounds, src_shape, y_desc,
src_wkt, tgt_wkt,
out_bounds, out_shape,
resampling, nodata, precision,
chunk_size, is_cupy,
):
"""Dask+NumPy backend: build output as ``da.block`` of delayed chunks."""
import dask
"""Dask+NumPy backend: ``map_blocks`` over a template array.

Uses a single ``blockwise`` layer in the HighLevelGraph instead of
O(N) ``dask.delayed`` nodes, keeping graph metadata O(1).

The source dask array is bound to the adapter via ``functools.partial``
rather than passed as a ``map_blocks`` kwarg. This prevents dask from
adding the full source as a dependency of every output block (which
would cause a MemoryError on distributed schedulers when the source
exceeds the worker memory limit).
"""
import functools

import dask.array as da

row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size)
n_row = len(row_chunks)
n_col = len(col_chunks)

chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
dtype = np.float64
# Precompute source footprint in target CRS for empty-chunk skipping
src_footprint_tgt = _source_footprint_in_target(
src_bounds, src_wkt, tgt_wkt
)

blocks = [[None] * n_col for _ in range(n_row)]
# Bind the source dask array and all scalar params via partial so
# map_blocks doesn't detect them as dask Array kwargs (which would
# add the full source as a dependency of every output block).
bound_adapter = functools.partial(
_reproject_block_adapter,
source_data=raster.data,
src_bounds=src_bounds,
src_shape=src_shape,
y_desc=y_desc,
src_wkt=src_wkt,
tgt_wkt=tgt_wkt,
out_bounds=out_bounds,
out_shape=out_shape,
resampling=resampling,
nodata=nodata,
precision=precision,
is_cupy=is_cupy,
src_footprint_tgt=src_footprint_tgt,
)

row_offset = 0
for i in range(n_row):
col_offset = 0
for j in range(n_col):
rchunk = row_chunks[i]
cchunk = col_chunks[j]
cb = _chunk_bounds(
out_bounds, out_shape,
row_offset, row_offset + rchunk,
col_offset, col_offset + cchunk,
)
delayed_chunk = dask.delayed(chunk_fn)(
raster.data,
src_bounds, src_shape, y_desc,
src_wkt, tgt_wkt,
cb, (rchunk, cchunk),
resampling, nodata, precision,
)
blocks[i][j] = da.from_delayed(
delayed_chunk, shape=(rchunk, cchunk), dtype=dtype
)
col_offset += cchunk
row_offset += rchunk
template = da.empty(
out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks)
)

return da.block(blocks)
return da.map_blocks(
bound_adapter,
template,
dtype=np.float64,
meta=np.array((), dtype=np.float64),
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1434,37 +1513,51 @@ def _merge_inmemory(
return _merge_arrays_numpy(arrays, nodata, strategy)


def _merge_chunk_worker(
def _merge_block_adapter(
block, block_info,
raster_data_list, src_bounds_list, src_shape_list, y_desc_list,
src_wkt_list, tgt_wkt,
chunk_bounds_tuple, chunk_shape,
out_bounds, out_shape,
resampling, nodata, strategy, precision,
src_footprints_tgt,
):
"""Worker for a single merge chunk."""
"""``map_blocks`` adapter for merge."""
info = block_info[0]
(row_start, row_end), (col_start, col_end) = info['array-location']
chunk_shape = (row_end - row_start, col_end - col_start)
cb = _chunk_bounds(out_bounds, out_shape,
row_start, row_end, col_start, col_end)

# Only reproject rasters whose footprint overlaps this chunk
arrays = []
for i in range(len(raster_data_list)):
if (src_footprints_tgt[i] is not None
and not _bounds_overlap(cb, src_footprints_tgt[i])):
continue
reprojected = _reproject_chunk_numpy(
raster_data_list[i],
src_bounds_list[i], src_shape_list[i], y_desc_list[i],
src_wkt_list[i], tgt_wkt,
chunk_bounds_tuple, chunk_shape,
cb, chunk_shape,
resampling, nodata, precision,
)
arrays.append(reprojected)

if not arrays:
return np.full(chunk_shape, nodata, dtype=np.float64)
return _merge_arrays_numpy(arrays, nodata, strategy)


def _merge_dask(
raster_infos, tgt_wkt, out_bounds, out_shape,
resampling, nodata, strategy, chunk_size,
):
"""Dask merge backend."""
import dask
"""Dask merge backend using ``map_blocks``."""
import functools

import dask.array as da

row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size)
n_row = len(row_chunks)
n_col = len(col_chunks)

# Prepare lists for the worker
data_list = [info['raster'].data for info in raster_infos]
Expand All @@ -1473,30 +1566,38 @@ def _merge_dask(
ydesc_list = [info['y_desc'] for info in raster_infos]
wkt_list = [info['src_wkt'] for info in raster_infos]

dtype = np.float64
blocks = [[None] * n_col for _ in range(n_row)]
# Precompute source footprints in target CRS
footprints = [
_source_footprint_in_target(bounds_list[i], wkt_list[i], tgt_wkt)
for i in range(len(raster_infos))
]

# Bind via partial to prevent map_blocks from adding dask arrays
# in data_list as whole-array dependencies.
bound_adapter = functools.partial(
_merge_block_adapter,
raster_data_list=data_list,
src_bounds_list=bounds_list,
src_shape_list=shape_list,
y_desc_list=ydesc_list,
src_wkt_list=wkt_list,
tgt_wkt=tgt_wkt,
out_bounds=out_bounds,
out_shape=out_shape,
resampling=resampling,
nodata=nodata,
strategy=strategy,
precision=16,
src_footprints_tgt=footprints,
)

row_offset = 0
for i in range(n_row):
col_offset = 0
for j in range(n_col):
rchunk = row_chunks[i]
cchunk = col_chunks[j]
cb = _chunk_bounds(
out_bounds, out_shape,
row_offset, row_offset + rchunk,
col_offset, col_offset + cchunk,
)
delayed_chunk = dask.delayed(_merge_chunk_worker)(
data_list, bounds_list, shape_list, ydesc_list,
wkt_list, tgt_wkt,
cb, (rchunk, cchunk),
resampling, nodata, strategy, 16,
)
blocks[i][j] = da.from_delayed(
delayed_chunk, shape=(rchunk, cchunk), dtype=dtype
)
col_offset += cchunk
row_offset += rchunk
template = da.empty(
out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks)
)

return da.block(blocks)
return da.map_blocks(
bound_adapter,
template,
dtype=np.float64,
meta=np.array((), dtype=np.float64),
)
Loading
Loading