diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index 13fc8b88..a9558b5a 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -866,15 +866,59 @@ def packbits_compress(data: bytes) -> bytes: pass +def _splice_jpeg_tables(tile_data: bytes, + jpeg_tables: bytes | None) -> bytes: + """Splice a JPEGTables stream into a tile's JPEG fragment. + + GDAL-style tiled JPEG TIFFs store DQT/DHT tables once in tag 347 + (an abbreviated JPEG: SOI + tables + EOI) and each tile is a JPEG + fragment whose own DQT/DHT segments were stripped. To make a tile + self-contained, drop the tables stream's leading SOI and trailing + EOI and insert what remains after the tile's SOI marker. + + Both buffers must start with SOI (FF D8). If either does not, the + tile data is returned unchanged so libjpeg sees its original input + and raises a meaningful error. + """ + if not jpeg_tables: + return tile_data + if len(tile_data) < 2 or tile_data[0] != 0xFF or tile_data[1] != 0xD8: + return tile_data + if len(jpeg_tables) < 4: + return tile_data + if jpeg_tables[0] != 0xFF or jpeg_tables[1] != 0xD8: + return tile_data + # Strip SOI from the tables stream, and EOI if present at the end. + tables_body = jpeg_tables[2:] + if len(tables_body) >= 2 and tables_body[-2] == 0xFF and tables_body[-1] == 0xD9: + tables_body = tables_body[:-2] + return tile_data[:2] + tables_body + tile_data[2:] + + def jpeg_decompress(data: bytes, width: int = 0, height: int = 0, - samples: int = 1) -> bytes: - """Decompress JPEG tile/strip data. Requires Pillow.""" + samples: int = 1, jpeg_tables: bytes | None = None) -> bytes: + """Decompress JPEG tile/strip data. Requires Pillow. + + Parameters + ---------- + data : bytes + Raw JPEG bytes from one TIFF strip or tile. May be a fragment + when ``jpeg_tables`` is supplied (GDAL tiled JPEG). + jpeg_tables : bytes, optional + Contents of TIFF tag 347 (JPEGTables). If supplied, the shared + DQT/DHT segments are spliced into ``data`` before decoding so + the resulting stream is a complete JPEG. + """ if not JPEG_AVAILABLE: raise ImportError( "Pillow is required to read JPEG-compressed TIFFs. " "Install it with: pip install Pillow") import io + if jpeg_tables: + data = _splice_jpeg_tables(data, jpeg_tables) img = Image.open(io.BytesIO(data)) + # libjpeg already converts YCbCr->RGB during decode, so rely on the + # mode Pillow returns. Calling .convert() unnecessarily would copy. return np.asarray(img).tobytes() @@ -1089,7 +1133,8 @@ def lz4_compress(data: bytes, level: int = 0) -> bytes: def decompress(data, compression: int, expected_size: int = 0, - width: int = 0, height: int = 0, samples: int = 1) -> np.ndarray: + width: int = 0, height: int = 0, samples: int = 1, + jpeg_tables: bytes | None = None) -> np.ndarray: """Decompress tile/strip data based on TIFF compression tag. Parameters @@ -1116,8 +1161,10 @@ def decompress(data, compression: int, expected_size: int = 0, elif compression == COMPRESSION_PACKBITS: return np.frombuffer(packbits_decompress(data), dtype=np.uint8) elif compression == COMPRESSION_JPEG: - return np.frombuffer(jpeg_decompress(data, width, height, samples), - dtype=np.uint8) + return np.frombuffer( + jpeg_decompress(data, width, height, samples, + jpeg_tables=jpeg_tables), + dtype=np.uint8) elif compression == COMPRESSION_ZSTD: return np.frombuffer(zstd_decompress(data), dtype=np.uint8) elif compression == COMPRESSION_JPEG2000: diff --git a/xrspatial/geotiff/_header.py b/xrspatial/geotiff/_header.py index 5f6f6a1b..290b2407 100644 --- a/xrspatial/geotiff/_header.py +++ b/xrspatial/geotiff/_header.py @@ -38,6 +38,7 @@ TAG_COLORMAP = 320 TAG_EXTRA_SAMPLES = 338 TAG_SAMPLE_FORMAT = 339 +TAG_JPEG_TABLES = 347 TAG_GDAL_METADATA = 42112 TAG_GDAL_NODATA = 42113 @@ -166,6 +167,34 @@ def photometric(self) -> int: def planar_config(self) -> int: return self.get_value(TAG_PLANAR_CONFIG, 1) + @property + def jpeg_tables(self) -> bytes | None: + """JPEGTables tag (347): shared DQT/DHT segments for tiled JPEG. + + GDAL-tiled ``compress=JPEG`` TIFFs store the quantization and + Huffman tables once in this tag; each tile's payload is a JPEG + fragment that needs the tables spliced in before libjpeg can + decode it. Returns the raw bytes of the abbreviated JPEG stream + (SOI ... DQT/DHT ... EOI), or None if absent. + """ + v = self.get_value(TAG_JPEG_TABLES) + if v is None: + return None + if isinstance(v, (bytes, bytearray)): + return bytes(v) + # BYTE arrays may surface as a tuple/list of ints + if isinstance(v, (tuple, list)): + return bytes(v) + # A single-byte tag value comes back as an int; wrap it in a + # one-element bytes object. Plain ``bytes(v)`` would (incorrectly) + # allocate v zero bytes -- a malformed file with a huge int here + # could otherwise blow up memory. + if isinstance(v, int): + return bytes([v & 0xFF]) + raise TypeError( + f"unexpected JPEGTables tag value type: {type(v).__name__}" + ) + @property def x_resolution(self) -> float | None: """XResolution tag (282), or None if absent.""" diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index e7fec46e..1bedfd61 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -520,7 +520,7 @@ def _packed_byte_count(pixel_count: int, bps: int) -> int: def _decode_strip_or_tile(data_slice, compression, width, height, samples, bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order='<'): + byte_order='<', jpeg_tables=None): """Decompress, apply predictor, unpack sub-byte, and reshape a strip/tile. Parameters @@ -529,6 +529,12 @@ def _decode_strip_or_tile(data_slice, compression, width, height, samples, '<' for little-endian, '>' for big-endian. When the file byte order differs from the system's native order, pixel data is byte-swapped after decompression. + jpeg_tables : bytes or None + Raw bytes of the file's JPEGTables tag (347), or None if the file + doesn't have one. GDAL-style tiled JPEG TIFFs store DQT/DHT tables + once in this tag and each tile is a JPEG fragment that depends on + them; the JPEG decoder splices the tables in before handing the + tile to libjpeg. Ignored for non-JPEG compressions. Returns an array shaped (height, width) or (height, width, samples). """ @@ -539,7 +545,8 @@ def _decode_strip_or_tile(data_slice, compression, width, height, samples, expected = pixel_count * bytes_per_sample chunk = decompress(data_slice, compression, expected, - width=width, height=height, samples=samples) + width=width, height=height, samples=samples, + jpeg_tables=jpeg_tables) # Validate the decompressed byte count. A truncated deflate stream or a # buggy compressor can produce fewer or more bytes than expected. Without @@ -654,6 +661,7 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader, bps = bps[0] bytes_per_sample = bps // 8 is_sub_byte = bps in SUB_BYTE_BPS + jpeg_tables = ifd.jpeg_tables if offsets is None or byte_counts is None: raise ValueError("Missing strip offsets or byte counts") @@ -713,7 +721,8 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader, strip_pixels = _decode_strip_or_tile( strip_data, compression, width, strip_rows, 1, bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order=header.byte_order) + byte_order=header.byte_order, + jpeg_tables=jpeg_tables) src_r0 = max(r0 - strip_row, 0) src_r1 = min(r1 - strip_row, strip_rows) @@ -738,7 +747,8 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader, strip_pixels = _decode_strip_or_tile( strip_data, compression, width, strip_rows, samples, bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order=header.byte_order) + byte_order=header.byte_order, + jpeg_tables=jpeg_tables) src_r0 = max(r0 - strip_row, 0) src_r1 = min(r1 - strip_row, strip_rows) @@ -790,6 +800,7 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, bps = bps[0] bytes_per_sample = bps // 8 is_sub_byte = bps in SUB_BYTE_BPS + jpeg_tables = ifd.jpeg_tables offsets = ifd.tile_offsets byte_counts = ifd.tile_byte_counts @@ -885,7 +896,8 @@ def _decode_one(job): return _decode_strip_or_tile( tile_data, compression, tw, th, tile_samples, bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order=header.byte_order) + byte_order=header.byte_order, + jpeg_tables=jpeg_tables) if use_parallel: from concurrent.futures import ThreadPoolExecutor @@ -1001,6 +1013,7 @@ def _read_cog_http(url: str, overview_level: int | None = None, pred = ifd.predictor bytes_per_sample = bps // 8 is_sub_byte = bps in SUB_BYTE_BPS + jpeg_tables = ifd.jpeg_tables offsets = ifd.tile_offsets byte_counts = ifd.tile_byte_counts @@ -1067,7 +1080,8 @@ def _read_cog_http(url: str, overview_level: int | None = None, tile_pixels = _decode_strip_or_tile( tile_data, compression, tw, th, samples, bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order=header.byte_order) + byte_order=header.byte_order, + jpeg_tables=jpeg_tables) y0 = tr * th x0 = tc * tw diff --git a/xrspatial/geotiff/tests/test_jpeg.py b/xrspatial/geotiff/tests/test_jpeg.py index b6a3f08d..66309ecf 100644 --- a/xrspatial/geotiff/tests/test_jpeg.py +++ b/xrspatial/geotiff/tests/test_jpeg.py @@ -1,12 +1,15 @@ """Tests for JPEG compression support (issue #1050).""" from __future__ import annotations +import importlib.util + import numpy as np import pytest import xarray as xr from xrspatial.geotiff._compression import ( COMPRESSION_JPEG, + _splice_jpeg_tables, jpeg_compress, jpeg_decompress, ) @@ -154,3 +157,138 @@ def test_to_geotiff_jpeg_rejected(self, tmp_path): path = str(tmp_path / 'api_1050.tif') with pytest.raises(ValueError, match="JPEGTables"): to_geotiff(da, path, compression='jpeg', tile_size=16) + + +class TestJpegTablesSplice: + """Verify the JPEGTables splice helper used for tiled JPEG TIFFs.""" + + def test_splice_reconstructs_complete_jpeg(self): + # Build a complete JPEG, then split it into a tables stream + a + # tile fragment. Splicing should recover a decodable stream. + from PIL import Image + import io + + rng = np.random.RandomState(1502) + arr = rng.randint(50, 200, (16, 16, 3), dtype=np.uint8) + img = Image.fromarray(arr, mode='RGB') + buf = io.BytesIO() + img.save(buf, format='JPEG', quality=85) + full = buf.getvalue() + + # Find the SOS marker (FF DA): everything before is tables. + sos = full.index(b'\xff\xda') + tables = b'\xff\xd8' + full[2:sos] + b'\xff\xd9' + tile_fragment = b'\xff\xd8' + full[sos:] + + spliced = _splice_jpeg_tables(tile_fragment, tables) + decoded = Image.open(io.BytesIO(spliced)) + decoded.load() + assert decoded.size == (16, 16) + + def test_splice_passthrough_on_empty_tables(self): + payload = b'\xff\xd8\xff\xd9' + assert _splice_jpeg_tables(payload, b'') == payload + assert _splice_jpeg_tables(payload, None) == payload + + def test_splice_passthrough_on_invalid_input(self): + # No SOI -> return unchanged so libjpeg's own error surfaces. + assert _splice_jpeg_tables(b'no soi', b'\xff\xd8\xff\xd9') == b'no soi' + + def test_jpeg_decompress_accepts_jpeg_tables_kwarg(self): + from PIL import Image + import io + + rng = np.random.RandomState(1502) + arr = rng.randint(50, 200, (16, 16, 3), dtype=np.uint8) + img = Image.fromarray(arr, mode='RGB') + buf = io.BytesIO() + img.save(buf, format='JPEG', quality=85) + full = buf.getvalue() + sos = full.index(b'\xff\xda') + tables = b'\xff\xd8' + full[2:sos] + b'\xff\xd9' + fragment = b'\xff\xd8' + full[sos:] + + out = jpeg_decompress(fragment, 16, 16, samples=3, jpeg_tables=tables) + assert len(out) == 16 * 16 * 3 + + +# rasterio-driven tests for issue #1502: GDAL writes tiled JPEG TIFFs +# whose per-tile fragments share DQT/DHT tables in tag 347. Skip the +# class -- not the whole module -- when rasterio is missing so the +# codec/splice unit tests above still run. + + +@pytest.mark.skipif( + importlib.util.find_spec('rasterio') is None, + reason='rasterio is required to write GDAL-style tiled JPEG TIFFs', +) +class TestGdalTiledJpegRead: + """Read GDAL-style tiled JPEG TIFFs that use the JPEGTables tag.""" + + def _gradient_rgb(self, size=128): + # Smooth content keeps JPEG error low and detection of bugs easy. + y = np.linspace(20, 240, size, dtype=np.uint8) + x = np.linspace(20, 240, size, dtype=np.uint8) + r = np.broadcast_to(y[:, None], (size, size)).astype(np.uint8) + g = np.broadcast_to(x[None, :], (size, size)).astype(np.uint8) + b = np.full((size, size), 128, dtype=np.uint8) + return np.stack([r, g, b], axis=0) # rasterio wants (bands, H, W) + + def test_tiled_ycbcr_jpeg(self, tmp_path): + import rasterio as rio + from xrspatial.geotiff._header import ( + parse_header, parse_all_ifds, TAG_JPEG_TABLES, + ) + + size = 128 + data = self._gradient_rgb(size) + path = str(tmp_path / 'tiled_jpeg_ycbcr_1502.tif') + with rio.open( + path, 'w', driver='GTiff', height=size, width=size, count=3, + dtype='uint8', tiled=True, blockxsize=64, blockysize=64, + compress='JPEG', photometric='YCBCR', + ) as dst: + dst.write(data) + + # Sanity: the file actually carries JPEGTables (tag 347). + with open(path, 'rb') as f: + blob = f.read() + hdr = parse_header(blob) + ifds = parse_all_ifds(blob, hdr) + assert TAG_JPEG_TABLES in ifds[0].entries + assert ifds[0].jpeg_tables is not None + assert ifds[0].jpeg_tables[:2] == b'\xff\xd8' + + arr, _ = read_to_array(path) + assert arr.shape == (size, size, 3) + assert arr.dtype == np.uint8 + + # Compare to rasterio's own decode. JPEG at quality 75 + 4:2:0 + # chroma subsampling shows ~1-3 absolute mean error on smooth + # gradients; allow a generous 5. + with rio.open(path) as src: + ref = src.read() # (bands, H, W) + ref = np.transpose(ref, (1, 2, 0)) + assert np.abs(arr.astype(int) - ref.astype(int)).mean() < 5 + + def test_tiled_grayscale_jpeg(self, tmp_path): + import rasterio as rio + + size = 96 + y = np.linspace(20, 240, size, dtype=np.uint8) + gray = np.broadcast_to(y[:, None], (size, size)).astype(np.uint8) + + path = str(tmp_path / 'tiled_jpeg_gray_1502.tif') + with rio.open( + path, 'w', driver='GTiff', height=size, width=size, count=1, + dtype='uint8', tiled=True, blockxsize=32, blockysize=32, + compress='JPEG', + ) as dst: + dst.write(gray, 1) + + arr, _ = read_to_array(path) + assert arr.shape == (size, size) + + with rio.open(path) as src: + ref = src.read(1) + assert np.abs(arr.astype(int) - ref.astype(int)).mean() < 5