Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _read_geo_info(source, *, overview_level: int | None = None):
overview_level : int or None
Overview IFD index (0 = full resolution).
"""
from ._dtypes import tiff_dtype_to_numpy
from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import extract_geo_info
from ._header import parse_all_ifds, parse_header, select_overview_ifd
from ._reader import _coerce_path, _is_file_like
Expand Down Expand Up @@ -230,9 +230,7 @@ def _read_geo_info(source, *, overview_level: int | None = None):
raise ValueError("No IFDs found in TIFF file")
ifd = select_overview_ifd(ifds, overview_level)
geo_info = extract_geo_info(ifd, data, header.byte_order)
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
n_bands = ifd.samples_per_pixel if ifd.samples_per_pixel > 1 else 0
return geo_info, ifd.height, ifd.width, file_dtype, n_bands
Expand Down Expand Up @@ -1446,7 +1444,7 @@ def read_geotiff_gpu(source: str, *,
from ._header import (
parse_header, parse_all_ifds, select_overview_ifd, validate_tile_layout,
)
from ._dtypes import tiff_dtype_to_numpy
from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import extract_geo_info
from ._gpu_decode import gpu_decode_tiles

Expand All @@ -1469,9 +1467,7 @@ def read_geotiff_gpu(source: str, *,
# Skip mask IFDs (NewSubfileType bit 2)
ifd = select_overview_ifd(ifds, overview_level)

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, data, header.byte_order)

Expand Down
74 changes: 74 additions & 0 deletions xrspatial/geotiff/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,80 @@ def tiff_dtype_to_numpy(bits_per_sample: int, sample_format: int = 1) -> np.dtyp
SUB_BYTE_BPS = {1, 2, 4, 12}


_GDAL_OT_FOR_BPS = {
8: 'Byte',
16: 'UInt16',
32: 'UInt32',
64: 'Float64',
}


def _suggest_gdal_ot(bps_values, sample_format=None) -> str:
"""Pick a sensible ``gdal_translate -ot`` value for a mixed-bps file.

Returns a real GDAL type name (``Byte``, ``UInt16`` etc.) when the
widest band has a recognised mapping, or ``<desired_type>`` as a
placeholder otherwise. ``sample_format`` (TIFF SampleFormat: 1=uint,
2=int, 3=float) refines the integer choice when known.
"""
if not bps_values:
return '<desired_type>'
widest = max(bps_values)
if sample_format == 3 and widest in (32, 64):
return 'Float32' if widest == 32 else 'Float64'
if sample_format == 2 and widest in (8, 16, 32):
return {8: 'Int8', 16: 'Int16', 32: 'Int32'}[widest]
return _GDAL_OT_FOR_BPS.get(widest, '<desired_type>')


def resolve_bits_per_sample(bps, sample_format=None) -> int:
"""Resolve a TIFF ``BitsPerSample`` tag value to a single integer.

The TIFF spec allows ``BitsPerSample`` to be either a scalar or a
sequence with one entry per sample. xarray-spatial decodes a whole
IFD with one numpy dtype, so the per-sample widths must agree.

Parameters
----------
bps : int or sequence of int
Raw value from ``IFD.bits_per_sample``. Accepts ``int``, ``tuple``,
or ``list``.
sample_format : int, optional
TIFF SampleFormat (1=uint, 2=int, 3=float). Used only to make the
``gdal_translate`` hint in the error message more accurate when
the entries don't agree; not consulted when they do.

Comment on lines +154 to +163
Returns
-------
int
The shared bits-per-sample value.

Raises
------
ValueError
If ``bps`` is a sequence whose entries are not all equal. Files
with per-band bit depths (e.g. RGB+8-bit-alpha with
``(16, 16, 16, 8)``) are not supported; convert with GDAL or
rasterio first.
"""
if isinstance(bps, (tuple, list)):
if len(bps) == 0:
raise ValueError("BitsPerSample tuple is empty")
first = bps[0]
for v in bps[1:]:
if v != first:
ot = _suggest_gdal_ot(bps, sample_format)
raise ValueError(
f"Mixed BitsPerSample per band is not supported: "
f"{tuple(bps)}. xarray-spatial decodes all bands with "
f"a single dtype. Convert the file to a uniform bit "
f"depth first, e.g. "
f"`gdal_translate -ot {ot} in.tif out.tif`."
)
Comment on lines +184 to +190
return int(first)
return int(bps)


def numpy_to_tiff_dtype(dt: np.dtype) -> tuple[int, int]:
"""Convert a numpy dtype to (bits_per_sample, sample_format).

Expand Down
5 changes: 2 additions & 3 deletions xrspatial/geotiff/_geotags.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TAG_MODEL_TRANSFORMATION,
TAG_GEO_KEY_DIRECTORY, TAG_GEO_DOUBLE_PARAMS, TAG_GEO_ASCII_PARAMS,
)
from ._dtypes import resolve_bits_per_sample

# ImageDescription tag (270). Captured for round-trip but not managed
# by the writer -- it flows through extra_tags pass-through.
Expand Down Expand Up @@ -514,9 +515,7 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview,
if ifd.photometric == 3:
raw_cmap = ifd.colormap
if raw_cmap is not None:
bps_val = ifd.bits_per_sample
if isinstance(bps_val, tuple):
bps_val = bps_val[0]
bps_val = resolve_bits_per_sample(ifd.bits_per_sample)
n_colors = 1 << bps_val # 2^BitsPerSample
# TIFF ColorMap: 3 * n_colors uint16 values
# Layout: [R0..R_{n-1}, G0..G_{n-1}, B0..B_{n-1}]
Expand Down
18 changes: 5 additions & 13 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
predictor_decode,
unpack_bits,
)
from ._dtypes import SUB_BYTE_BPS, tiff_dtype_to_numpy
from ._dtypes import SUB_BYTE_BPS, resolve_bits_per_sample, tiff_dtype_to_numpy
from ._geotags import GeoInfo, GeoTransform, extract_geo_info
from ._header import (
IFD,
Expand Down Expand Up @@ -663,9 +663,7 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader,
offsets = ifd.strip_offsets
byte_counts = ifd.strip_byte_counts
pred = ifd.predictor
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS
jpeg_tables = ifd.jpeg_tables
Expand Down Expand Up @@ -802,9 +800,7 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader,
samples = ifd.samples_per_pixel
compression = ifd.compression
pred = ifd.predictor
bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS
jpeg_tables = ifd.jpeg_tables
Expand Down Expand Up @@ -994,9 +990,7 @@ def _read_cog_http(url: str, overview_level: int | None = None,
# Select IFD based on overview level, skipping any mask IFDs
ifd = select_overview_ifd(ifds, overview_level)

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, header_bytes, header.byte_order)

Expand Down Expand Up @@ -1211,9 +1205,7 @@ def read_to_array(source, *, window=None, overview_level: int | None = None,
# Select IFD, skipping any mask IFDs
ifd = select_overview_ifd(ifds, overview_level)

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)
dtype = tiff_dtype_to_numpy(bps, ifd.sample_format)
geo_info = extract_geo_info(ifd, data, header.byte_order)

Expand Down
5 changes: 2 additions & 3 deletions xrspatial/geotiff/_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
from ._header import parse_header, parse_all_ifds
from ._geotags import extract_geo_info
from ._reader import _FileSource
from ._dtypes import resolve_bits_per_sample

if not source_files:
raise ValueError("source_files must not be empty")
Expand All @@ -409,9 +410,7 @@ def write_vrt(vrt_path: str, source_files: list[str], *,
geo = extract_geo_info(ifd, data, header.byte_order)
src.close()

bps = ifd.bits_per_sample
if isinstance(bps, tuple):
bps = bps[0]
bps = resolve_bits_per_sample(ifd.bits_per_sample)

sources_meta.append({
'path': src_path,
Expand Down
Loading
Loading