Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
overview_resampling: str = 'mean',
bigtiff: bool | None = None,
gpu: bool | None = None,
streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None:
streaming_buffer_bytes: int = 256 * 1024 * 1024,
max_z_error: float = 0.0) -> None:
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.

Dask-backed DataArrays are written in streaming mode: one tile-row
Expand Down Expand Up @@ -658,6 +659,13 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
streaming a dask-backed DataArray. Defaults to 256 MB. Wide
rasters whose tile-row exceeds this budget are split into
horizontal segments. Ignored for numpy / CuPy / COG paths.
max_z_error : float
Per-pixel error budget for LERC compression. ``0.0`` (default)
is lossless; larger values let the encoder approximate values
within the bound, producing smaller files at the cost of accuracy
bounded by ``abs(decoded - original) <= max_z_error``. Only used
when ``compression='lerc'``; passing a non-zero value with any
other codec raises ``ValueError``.
"""
# Up-front validation: catch bad compression names before they reach
# any of the deeper write paths (streaming, GPU, VRT, COG) where the
Expand All @@ -668,6 +676,18 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
f"Unknown compression {compression!r}. "
f"Valid options: {list(_VALID_COMPRESSIONS)}.")

# max_z_error only applies to LERC; reject negative values and reject
# non-zero values paired with any other codec so the caller learns the
# parameter was ignored before bytes hit disk.
if max_z_error < 0:
raise ValueError(
f"max_z_error must be >= 0, got {max_z_error}")
if max_z_error != 0 and (
not isinstance(compression, str)
or compression.lower() != 'lerc'):
raise ValueError(
"max_z_error is only valid with compression='lerc'")

# tile_size only applies to tiled output; warn if the caller passed a
# non-default size alongside strip mode (it would otherwise be silently
# ignored).
Expand Down Expand Up @@ -696,12 +716,20 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
compression_level=compression_level,
tile_size=tile_size,
predictor=predictor,
bigtiff=bigtiff)
bigtiff=bigtiff,
max_z_error=max_z_error)
return

# Auto-detect GPU data and dispatch to write_geotiff_gpu
use_gpu = gpu if gpu is not None else _is_gpu_data(data)
if use_gpu:
# GPU writer uses nvCOMP and does not support LERC; refuse rather
# than silently dropping the requested error budget.
if max_z_error != 0:
raise ValueError(
"max_z_error is not supported on the GPU writer "
"(nvCOMP has no LERC backend). Use the CPU path "
"(gpu=False) or omit max_z_error.")
try:
write_geotiff_gpu(data, path, crs=crs, nodata=nodata,
compression=compression,
Expand Down Expand Up @@ -824,6 +852,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
extra_tags=extra_tags_list,
bigtiff=bigtiff,
streaming_buffer_bytes=streaming_buffer_bytes,
max_z_error=max_z_error,
)
return

Expand Down Expand Up @@ -893,12 +922,14 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
bigtiff=bigtiff,
max_z_error=max_z_error,
)


def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt,
nodata, compression, compression_level,
tile_size, predictor, bigtiff):
tile_size, predictor, bigtiff,
max_z_error: float = 0.0):
"""Write a single tile GeoTIFF. Used by _write_vrt_tiled."""
if hasattr(chunk_data, 'compute'):
chunk_data = chunk_data.compute()
Expand Down Expand Up @@ -930,13 +961,14 @@ def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt,
tile_size=tile_size,
predictor=predictor,
compression_level=compression_level,
bigtiff=bigtiff)
bigtiff=bigtiff,
max_z_error=max_z_error)


def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
compression='zstd', compression_level=None,
tile_size=256, predictor: bool | int = False,
bigtiff=None):
bigtiff=None, max_z_error: float = 0.0):
"""Write a DataArray as a directory of tiled GeoTIFFs with a VRT index.

This enables streaming dask arrays to disk without materializing the
Expand Down Expand Up @@ -1075,7 +1107,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
task = dask.delayed(_write_single_tile)(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff)
tile_size, predictor, bigtiff, max_z_error)
delayed_tasks.append(task)
else:
# Numpy: slice and write directly
Expand All @@ -1084,7 +1116,7 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
_write_single_tile(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff)
tile_size, predictor, bigtiff, max_z_error)

col_offset += chunk_w
row_offset += chunk_h
Expand Down
46 changes: 29 additions & 17 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def _build_ifd(tags: list[tuple], overflow_base: int,

def _write_stripped(data: np.ndarray, compression: int, predictor: int,
rows_per_strip: int = 256,
compression_level: int | None = None) -> tuple[list, list, list]:
compression_level: int | None = None,
max_z_error: float = 0.0) -> tuple[list, list, list]:
"""Compress data as strips.

Returns
Expand Down Expand Up @@ -401,7 +402,8 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,
elif compression == COMPRESSION_LERC:
from ._compression import lerc_compress
compressed = lerc_compress(
strip_data, width, strip_rows, samples=samples, dtype=dtype)
strip_data, width, strip_rows, samples=samples, dtype=dtype,
max_z_error=max_z_error)
elif compression_level is None:
compressed = compress(strip_data, compression)
else:
Expand All @@ -421,7 +423,7 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,

def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
bytes_per_sample, predictor: int, compression,
compression_level=None):
compression_level=None, max_z_error: float = 0.0):
"""Extract, pad, and compress a single tile. Thread-safe."""
r0 = tr * th
c0 = tc * tw
Expand Down Expand Up @@ -464,15 +466,17 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
if compression == COMPRESSION_LERC:
from ._compression import lerc_compress
return lerc_compress(
tile_data, tw, th, samples=samples, dtype=dtype)
tile_data, tw, th, samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(tile_data, compression)
return compress(tile_data, compression, level=compression_level)


def _write_tiled(data: np.ndarray, compression: int, predictor: int,
tile_size: int = 256,
compression_level: int | None = None) -> tuple[list, list, list]:
compression_level: int | None = None,
max_z_error: float = 0.0) -> tuple[list, list, list]:
"""Compress data as tiles, using parallel compression.

For compressed formats (deflate, lzw, zstd), tiles are compressed
Expand Down Expand Up @@ -545,7 +549,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int,
compressed = _prepare_tile(
data, tr, tc, th, tw, height, width,
samples, dtype, bytes_per_sample, predictor, compression,
compression_level,
compression_level, max_z_error,
)
rel_offsets.append(current_offset)
byte_counts.append(len(compressed))
Expand All @@ -566,7 +570,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int,
pool.submit(
_prepare_tile, data, tr, tc, th, tw, height, width,
samples, dtype, bytes_per_sample, predictor, compression,
compression_level,
compression_level, max_z_error,
)
for tr, tc in tile_indices
]
Expand Down Expand Up @@ -976,7 +980,8 @@ def write(data: np.ndarray, path: str, *,
resolution_unit: int | None = None,
gdal_metadata_xml: str | None = None,
extra_tags: list | None = None,
bigtiff: bool | None = None) -> None:
bigtiff: bool | None = None,
max_z_error: float = 0.0) -> None:
"""Write a numpy array as a GeoTIFF or COG.

Parameters
Expand Down Expand Up @@ -1027,10 +1032,12 @@ def write(data: np.ndarray, path: str, *,
# Full resolution
if tiled:
rel_off, bc, comp_data = _write_tiled(data, comp_tag, pred_int, tile_size,
compression_level=compression_level)
compression_level=compression_level,
max_z_error=max_z_error)
else:
rel_off, bc, comp_data = _write_stripped(data, comp_tag, pred_int,
compression_level=compression_level)
compression_level=compression_level,
max_z_error=max_z_error)

h, w = data.shape[:2]
parts.append((data, w, h, rel_off, bc, comp_data))
Expand All @@ -1057,10 +1064,12 @@ def write(data: np.ndarray, path: str, *,
if tiled:
o_off, o_bc, o_data = _write_tiled(current, comp_tag, pred_int,
tile_size,
compression_level=compression_level)
compression_level=compression_level,
max_z_error=max_z_error)
else:
o_off, o_bc, o_data = _write_stripped(current, comp_tag, pred_int,
compression_level=compression_level)
compression_level=compression_level,
max_z_error=max_z_error)
parts.append((current, ow, oh, o_off, o_bc, o_data))

file_bytes = _assemble_tiff(
Expand All @@ -1086,7 +1095,8 @@ def write(data: np.ndarray, path: str, *,


def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample,
predictor: int, compression, compression_level=None):
predictor: int, compression, compression_level=None,
max_z_error: float = 0.0):
"""Compress a tile or strip. *arr* must be contiguous and correctly sized."""
if compression == COMPRESSION_JPEG:
return jpeg_compress(arr.tobytes(), block_w, block_h, samples)
Expand All @@ -1106,7 +1116,8 @@ def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample,
if compression == COMPRESSION_LERC:
from ._compression import lerc_compress
return lerc_compress(raw_data, block_w, block_h,
samples=samples, dtype=dtype)
samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(raw_data, compression)
return compress(raw_data, compression, level=compression_level)
Expand All @@ -1133,7 +1144,8 @@ def write_streaming(dask_data, path: str, *,
gdal_metadata_xml: str | None = None,
extra_tags: list | None = None,
bigtiff: bool | None = None,
streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None:
streaming_buffer_bytes: int = 256 * 1024 * 1024,
max_z_error: float = 0.0) -> None:
"""Write a dask array as a GeoTIFF by streaming pixel data.

For tiled output, each tile-row is computed in horizontal segments
Expand Down Expand Up @@ -1411,7 +1423,7 @@ def write_streaming(dask_data, path: str, *,
compressed = _compress_block(
tile_arr, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level)
compression_level, max_z_error)

actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
Expand Down Expand Up @@ -1448,7 +1460,7 @@ def write_streaming(dask_data, path: str, *,
np.ascontiguousarray(strip_np),
width, strip_rows, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level)
compression_level, max_z_error)

actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
Expand Down
Loading
Loading