Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions xrspatial/reproject/_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,48 @@ def _resample_cubic_cuda(src, row_coords, col_coords, out, nodata):
rv3 += sv * wc3
val += rv3 * wr3

if has_nan:
out[i, j] = nodata
else:
if not has_nan:
out[i, j] = val
else:
# Fall back to bilinear with weight renormalization
r1b = r0 + 1
c1b = c0 + 1
dr = r - r0
dc = c - c0

w00 = (1.0 - dr) * (1.0 - dc)
w01 = (1.0 - dr) * dc
w10 = dr * (1.0 - dc)
w11 = dr * dc

accum = 0.0
wsum = 0.0

if 0 <= r0 < sh and 0 <= c0 < sw:
v = src[r0, c0]
if v == v:
accum += w00 * v
wsum += w00
if 0 <= r0 < sh and 0 <= c1b < sw:
v = src[r0, c1b]
if v == v:
accum += w01 * v
wsum += w01
if 0 <= r1b < sh and 0 <= c0 < sw:
v = src[r1b, c0]
if v == v:
accum += w10 * v
wsum += w10
if 0 <= r1b < sh and 0 <= c1b < sw:
v = src[r1b, c1b]
if v == v:
accum += w11 * v
wsum += w11

if wsum > 1e-10:
out[i, j] = accum / wsum
else:
out[i, j] = nodata


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -722,19 +760,32 @@ def _resample_cupy_native(source_window, src_row_coords, src_col_coords,
work, rc, cc, out, nd
)
if is_integer:
out = cp.round(out).astype(source_window.dtype)
info = cp.iinfo(source_window.dtype)
out = cp.clip(cp.round(out), info.min, info.max).astype(
source_window.dtype
)
return out

if order == 1:
_resample_bilinear_cuda[blocks_per_grid, threads_per_block](
work, rc, cc, out, nd
)
if is_integer:
info = cp.iinfo(source_window.dtype)
out = cp.clip(cp.round(out), info.min, info.max).astype(
source_window.dtype
)
return out

# Cubic
_resample_cubic_cuda[blocks_per_grid, threads_per_block](
work, rc, cc, out, nd
)
if is_integer:
info = cp.iinfo(source_window.dtype)
out = cp.clip(cp.round(out), info.min, info.max).astype(
source_window.dtype
)
return out


Expand Down Expand Up @@ -797,7 +848,10 @@ def _resample_cupy(source_window, src_row_coords, src_col_coords,

result[oob] = nodata

if is_integer and resampling == 'nearest':
result = cp.round(result).astype(source_window.dtype)
if is_integer:
info = cp.iinfo(source_window.dtype)
result = cp.clip(cp.round(result), info.min, info.max).astype(
source_window.dtype
)

return result
150 changes: 150 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,153 @@ def test_merge_single_tile(self):
)
r = merge([t])
assert np.any(np.isfinite(r.values))


# ---------------------------------------------------------------------------
# CuPy resampler unit tests (integer clipping + cubic NaN fallback)
# ---------------------------------------------------------------------------

@pytest.mark.skipif(not HAS_CUPY, reason="CuPy not installed")
class TestCuPyResamplerClipping:
"""Verify uint8 overflow protection in CuPy resampling paths."""

def _sharp_edge_inputs(self):
"""Build a uint8 source with a sharp 0->255 edge and coordinate grids
that place sample points right at the transition (where cubic ringing
produces out-of-range values)."""
src = np.zeros((16, 16), dtype=np.float64)
src[:, 8:] = 255.0

# Sample at half-pixel offsets across the edge
rows, cols = np.meshgrid(
np.linspace(2, 13, 24), np.linspace(6.5, 9.5, 24), indexing='ij'
)
return src, rows.astype(np.float64), cols.astype(np.float64)

def test_cupy_native_nearest_uint8_clamp(self):
from xrspatial.reproject._interpolate import _resample_cupy_native
src, rows, cols = self._sharp_edge_inputs()
src_gpu = cp.asarray(np.zeros((16, 16), dtype=np.uint8))
src_gpu[:, 8:] = 255
result = _resample_cupy_native(src_gpu, rows, cols,
resampling='nearest', nodata=np.nan)
assert result.dtype == np.uint8
vals = cp.asnumpy(result)
assert np.all((vals == 0) | (vals == 255) | np.isnan(vals.astype(float)))

def test_cupy_native_bilinear_uint8_clamp(self):
from xrspatial.reproject._interpolate import _resample_cupy_native
src_gpu = cp.zeros((16, 16), dtype=np.uint8)
src_gpu[:, 8:] = 255
_, rows, cols = self._sharp_edge_inputs()
result = _resample_cupy_native(src_gpu, rows, cols,
resampling='bilinear', nodata=np.nan)
assert result.dtype == np.uint8
vals = cp.asnumpy(result)
assert np.all(vals <= 255)
assert np.all(vals >= 0)

def test_cupy_native_cubic_uint8_clamp(self):
from xrspatial.reproject._interpolate import _resample_cupy_native
src_gpu = cp.zeros((16, 16), dtype=np.uint8)
src_gpu[:, 8:] = 255
_, rows, cols = self._sharp_edge_inputs()
result = _resample_cupy_native(src_gpu, rows, cols,
resampling='cubic', nodata=np.nan)
assert result.dtype == np.uint8
vals = cp.asnumpy(result)
assert np.all(vals <= 255)
assert np.all(vals >= 0)

def test_cupy_map_coords_bilinear_uint8_clamp(self):
from xrspatial.reproject._interpolate import _resample_cupy
src_gpu = cp.zeros((16, 16), dtype=np.uint8)
src_gpu[:, 8:] = 255
_, rows, cols = self._sharp_edge_inputs()
result = _resample_cupy(src_gpu, rows, cols,
resampling='bilinear', nodata=np.nan)
assert result.dtype == np.uint8
vals = cp.asnumpy(result)
assert np.all(vals <= 255)
assert np.all(vals >= 0)

def test_cupy_map_coords_cubic_uint8_clamp(self):
from xrspatial.reproject._interpolate import _resample_cupy
src_gpu = cp.zeros((16, 16), dtype=np.uint8)
src_gpu[:, 8:] = 255
_, rows, cols = self._sharp_edge_inputs()
result = _resample_cupy(src_gpu, rows, cols,
resampling='cubic', nodata=np.nan)
assert result.dtype == np.uint8
vals = cp.asnumpy(result)
assert np.all(vals <= 255)
assert np.all(vals >= 0)


@pytest.mark.skipif(not HAS_CUPY, reason="CuPy not installed")
class TestCudaCubicNanFallback:
"""Verify _resample_cubic_cuda falls back to bilinear near NaN instead
of writing nodata."""

def test_cubic_nan_fallback_produces_valid_values(self):
"""Cubic with a few NaN neighbors should interpolate from valid
neighbors (bilinear fallback), not produce nodata everywhere."""
from xrspatial.reproject._interpolate import _resample_cupy_native

# 16x16 source with value 100.0, a few NaN pixels scattered
src = np.full((16, 16), 100.0, dtype=np.float64)
src[5, 5] = np.nan
src[10, 10] = np.nan

src_gpu = cp.asarray(src)

# Sample at points near (but not on) NaN pixels
rows = np.array([[5.3, 6.0, 10.3, 8.0]], dtype=np.float64)
cols = np.array([[5.3, 6.0, 10.3, 8.0]], dtype=np.float64)

result = _resample_cupy_native(src_gpu, rows, cols,
resampling='cubic', nodata=np.nan)
vals = cp.asnumpy(result).ravel()

# Points near NaN should get valid interpolated values (bilinear
# fallback), not NaN. Point (6.0, 6.0) and (8.0, 8.0) are far
# enough from any NaN that cubic should succeed directly.
assert np.isfinite(vals[1]), "point far from NaN should be finite"
assert np.isfinite(vals[3]), "point far from NaN should be finite"
# Points adjacent to NaN should also be finite via bilinear fallback
assert np.isfinite(vals[0]), "bilinear fallback should produce finite value near NaN"
assert np.isfinite(vals[2]), "bilinear fallback should produce finite value near NaN"

def test_cubic_nan_fallback_matches_cpu(self):
"""CUDA cubic NaN fallback should produce values close to the CPU
Numba JIT version."""
from xrspatial.reproject._interpolate import (
_resample_cupy_native,
_resample_numpy,
)

src = np.full((16, 16), 50.0, dtype=np.float64)
src[4, 4] = np.nan
src[7, 12] = np.nan

# Sample grid covering the whole raster
rows, cols = np.meshgrid(
np.linspace(1, 14, 12), np.linspace(1, 14, 12), indexing='ij'
)
rows = rows.astype(np.float64)
cols = cols.astype(np.float64)

cpu_result = _resample_numpy(src, rows, cols,
resampling='cubic', nodata=np.nan)
gpu_result = _resample_cupy_native(
cp.asarray(src), rows, cols,
resampling='cubic', nodata=np.nan
)
gpu_np = cp.asnumpy(gpu_result)

# Both should have the same NaN pattern
np.testing.assert_array_equal(np.isnan(cpu_result), np.isnan(gpu_np))
# Finite values should match closely
finite = np.isfinite(cpu_result)
np.testing.assert_allclose(cpu_result[finite], gpu_np[finite],
rtol=1e-10)
Loading