From d182cc216d8489ac3a0e0586a079377f85d9f46e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 13 Mar 2026 10:13:13 +0100 Subject: [PATCH 1/4] claude's draft --- docs/devnotes/feat-v3-scale-offset-cast.md | 137 ++++ src/zarr/codecs/__init__.py | 6 + src/zarr/codecs/cast_value.py | 371 +++++++++ src/zarr/codecs/scale_offset.py | 119 +++ src/zarr/core/chunk_grids/rectilinear.py | 802 ++++++++++++++++++++ tests/test_codecs/test_scale_offset_cast.py | 598 +++++++++++++++ 6 files changed, 2033 insertions(+) create mode 100644 docs/devnotes/feat-v3-scale-offset-cast.md create mode 100644 src/zarr/codecs/cast_value.py create mode 100644 src/zarr/codecs/scale_offset.py create mode 100644 src/zarr/core/chunk_grids/rectilinear.py create mode 100644 tests/test_codecs/test_scale_offset_cast.py diff --git a/docs/devnotes/feat-v3-scale-offset-cast.md b/docs/devnotes/feat-v3-scale-offset-cast.md new file mode 100644 index 0000000000..a7dcb504fd --- /dev/null +++ b/docs/devnotes/feat-v3-scale-offset-cast.md @@ -0,0 +1,137 @@ +# scale_offset and cast_value codecs + +Source: https://github.com/zarr-developers/zarr-extensions/pull/43 + +## Overview + +Two array-to-array codecs for zarr v3, designed to work together for the +common pattern of storing floating-point data as compressed integers. + +--- + +## scale_offset + +**Type:** array -> array (does NOT change dtype) + +**Encode:** `out = (in - offset) * scale` +**Decode:** `out = (in / scale) + offset` + +### Parameters +- `offset` (optional): scalar subtracted during encoding. Default: 0 (additive identity). + Serialized in JSON using the zarr v3 fill-value encoding for the array's dtype. +- `scale` (optional): scalar multiplied during encoding (after offset subtraction). Default: 1 + (multiplicative identity). Same JSON encoding as offset. + +### Key rules +- Arithmetic MUST use the input array's own data type semantics (no implicit promotion). +- If any intermediate or final value is unrepresentable in that dtype, error. +- If neither scale nor offset is given, `configuration` may be omitted (codec is a no-op). +- Fill value MUST be transformed through the codec (encode direction). +- Only valid for real-number data types (int/uint/float families). + +### JSON +```json +{"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} +``` + +--- + +## cast_value + +**Type:** array -> array (CHANGES dtype) + +**Purpose:** Value-convert (not binary-reinterpret) array elements to a new data type. + +### Parameters +- `data_type` (required): target zarr v3 data type. +- `rounding` (optional): how to round when exact representation is impossible. + Values: `"nearest-even"` (default), `"towards-zero"`, `"towards-positive"`, + `"towards-negative"`, `"nearest-away"`. +- `out_of_range` (optional): what to do when a value is outside the target's range. + Values: `"clamp"`, `"wrap"`. If absent, out-of-range MUST error. + `"wrap"` only valid for integral two's-complement types. +- `scalar_map` (optional): explicit value overrides. + `{"encode": [[input, output], ...], "decode": [[input, output], ...]}`. + Evaluated BEFORE rounding/out_of_range. + +### Casting procedure (same for encode and decode, swapping source/target) +1. Check scalar_map — if input matches a key, use mapped value. +2. Check exact representability — if yes, use directly. +3. Apply rounding and out_of_range rules. +4. If none apply, MUST error. + +### Special values +- NaN propagates between IEEE 754 types unless scalar_map overrides. +- Signed zero preserved between IEEE 754 types. +- If target doesn't support NaN/infinity and input has them, MUST error + unless scalar_map provides a mapping. + +### Fill value +- MUST be cast using same semantics as elements. +- Implementations SHOULD validate fill value survives round-trip at metadata + construction time. + +### JSON +```json +{ + "name": "cast_value", + "configuration": { + "data_type": "uint8", + "rounding": "nearest-even", + "out_of_range": "clamp", + "scalar_map": { + "encode": [["NaN", 0], ["+Infinity", 0], ["-Infinity", 0]], + "decode": [[0, "NaN"]] + } + } +} +``` + +--- + +## Typical combined usage + +```json +{ + "data_type": "float64", + "fill_value": "NaN", + "codecs": [ + {"name": "scale_offset", "configuration": {"offset": -10, "scale": 0.1}}, + {"name": "cast_value", "configuration": { + "data_type": "uint8", + "rounding": "nearest-even", + "scalar_map": {"encode": [["NaN", 0]], "decode": [[0, "NaN"]]} + }}, + "bytes" + ] +} +``` + +--- + +## Implementation notes for zarr-python + +### scale_offset +- Subclass `ArrayArrayCodec`. +- `resolve_metadata`: transform fill_value via `(fill - offset) * scale`, keep dtype. +- `_encode_single`: `(array - offset) * scale` using numpy with same dtype. +- `_decode_single`: `(array / scale) + offset` using numpy with same dtype. +- `is_fixed_size = True`. + +### cast_value +- Subclass `ArrayArrayCodec`. +- `resolve_metadata`: change dtype to target dtype, cast fill_value. +- `_encode_single`: cast array from input dtype to target dtype. +- `_decode_single`: cast array from target dtype back to input dtype. +- Needs the input dtype stored (from `evolve_from_array_spec` or `resolve_metadata`). +- `is_fixed_size = True` (for fixed-size types). +- Initial implementation: support `rounding` and `out_of_range` for common cases. + `scalar_map` adds complexity but is needed for NaN handling. + +### Key design decisions from PR review +1. Encode = `(in - offset) * scale` (subtract, not add) — matches HDF5 and numcodecs. +2. No implicit precision promotion — arithmetic stays in the input dtype. +3. `out_of_range` defaults to error (not clamp). +4. `scalar_map` was added specifically to handle NaN-to-integer mappings. +5. Fill value must round-trip exactly through the codec chain. +6. Name uses underscore: `scale_offset`, `cast_value`. diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 4c621290e7..27ba6778da 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -2,6 +2,7 @@ from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian +from zarr.codecs.cast_value import CastValueCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.gzip import GzipCodec from zarr.codecs.numcodecs import ( @@ -27,6 +28,7 @@ Zlib, Zstd, ) +from zarr.codecs.scale_offset import ScaleOffsetCodec from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec @@ -38,9 +40,11 @@ "BloscCodec", "BloscShuffle", "BytesCodec", + "CastValueCodec", "Crc32cCodec", "Endian", "GzipCodec", + "ScaleOffsetCodec", "ShardingCodec", "ShardingCodecIndexLocation", "TransposeCodec", @@ -61,6 +65,8 @@ register_codec("vlen-utf8", VLenUTF8Codec) register_codec("vlen-bytes", VLenBytesCodec) register_codec("transpose", TransposeCodec) +register_codec("scale_offset", ScaleOffsetCodec) +register_codec("cast_value", CastValueCodec) # Register all the codecs formerly contained in numcodecs.zarr3 diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py new file mode 100644 index 0000000000..df0d13c079 --- /dev/null +++ b/src/zarr/codecs/cast_value.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Literal + +import numpy as np + +from zarr.abc.codec import ArrayArrayCodec +from zarr.core.array_spec import ArraySpec +from zarr.core.common import JSON, parse_named_configuration +from zarr.core.dtype import get_data_type_from_json + +if TYPE_CHECKING: + from typing import Self + + from zarr.core.buffer import NDBuffer + from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + +RoundingMode = Literal[ + "nearest-even", + "towards-zero", + "towards-positive", + "towards-negative", + "nearest-away", +] + +OutOfRangeMode = Literal["clamp", "wrap"] + +ScalarMapJSON = dict[str, list[list[JSON]]] + +# Pre-parsed scalar map entry: (source_float, target_float, source_is_nan) +_MapEntry = tuple[float, float, bool] + + +def _special_float(s: str) -> float: + """Convert special float string representations to float values.""" + if s == "NaN": + return float("nan") + if s in ("+Infinity", "Infinity"): + return float("inf") + if s == "-Infinity": + return float("-inf") + return float(s) + + +def _parse_map_entries(mapping: dict[str, str]) -> list[_MapEntry]: + """Pre-parse a scalar map dict into a list of (src, tgt, src_is_nan) tuples.""" + entries: list[_MapEntry] = [] + for src_str, tgt_str in mapping.items(): + src = _special_float(src_str) + tgt = _special_float(tgt_str) + entries.append((src, tgt, np.isnan(src))) + return entries + + +def _apply_scalar_map(work: np.ndarray, entries: list[_MapEntry]) -> None: + """Apply scalar map entries in-place. Single pass per entry.""" + for src, tgt, src_is_nan in entries: + if src_is_nan: + mask = np.isnan(work) + else: + mask = work == src + work[mask] = tgt + + +def _round_inplace(arr: np.ndarray, mode: RoundingMode) -> np.ndarray: + """Round array, returning result (may or may not be a new array). + + For nearest-away, requires 3 numpy ops. All others are a single op. + """ + if mode == "nearest-even": + return np.rint(arr) + elif mode == "towards-zero": + return np.trunc(arr) + elif mode == "towards-positive": + return np.ceil(arr) + elif mode == "towards-negative": + return np.floor(arr) + elif mode == "nearest-away": + return np.sign(arr) * np.floor(np.abs(arr) + 0.5) + raise ValueError(f"Unknown rounding mode: {mode}") + + +def _cast_array( + arr: np.ndarray, + target_dtype: np.dtype, + rounding: RoundingMode, + out_of_range: OutOfRangeMode | None, + scalar_map_entries: list[_MapEntry] | None, +) -> np.ndarray: + """Cast an array to target_dtype with rounding, out-of-range, and scalar_map handling. + + Optimized to minimize allocations and passes over the data. + For the simple case (no scalar_map, no rounding needed, no out-of-range), + this is essentially just ``arr.astype(target_dtype)``. + """ + src_is_int = np.issubdtype(arr.dtype, np.integer) + src_is_float = np.issubdtype(arr.dtype, np.floating) + tgt_is_int = np.issubdtype(target_dtype, np.integer) + tgt_is_float = np.issubdtype(target_dtype, np.floating) + + # Fast path: float→float with no scalar_map — single astype + if src_is_float and tgt_is_float and not scalar_map_entries: + return arr.astype(target_dtype) + + # Fast path: int→float with no scalar_map — single astype + if src_is_int and tgt_is_float and not scalar_map_entries: + return arr.astype(target_dtype) + + # Fast path: int→int with no scalar_map — check range then astype + if src_is_int and tgt_is_int and not scalar_map_entries: + # Check if source range could exceed target range + if arr.dtype.itemsize > target_dtype.itemsize or arr.dtype != target_dtype: + info = np.iinfo(target_dtype) + lo, hi = int(info.min), int(info.max) + arr_min, arr_max = int(arr.min()), int(arr.max()) + if arr_min >= lo and arr_max <= hi: + return arr.astype(target_dtype) + if out_of_range == "clamp": + return np.clip(arr, lo, hi).astype(target_dtype) + elif out_of_range == "wrap": + range_size = hi - lo + 1 + return ((arr.astype(np.int64) - lo) % range_size + lo).astype(target_dtype) + else: + raise ValueError( + f"Values out of range for {target_dtype} and no out_of_range policy set" + ) + return arr.astype(target_dtype) + + # float→int: needs rounding, range check, possibly scalar_map + if src_is_float and tgt_is_int: + # Work in float64 for the arithmetic + if arr.dtype != np.float64: + work = arr.astype(np.float64) + else: + work = arr.copy() + + if scalar_map_entries: + _apply_scalar_map(work, scalar_map_entries) + + # Check for unmapped NaN/Inf + bad = np.isnan(work) | np.isinf(work) + if bad.any(): + raise ValueError("Cannot cast NaN or Infinity to integer type without scalar_map") + + work = _round_inplace(work, rounding) + + info = np.iinfo(target_dtype) + lo, hi = float(info.min), float(info.max) + if out_of_range == "clamp": + np.clip(work, lo, hi, out=work) + elif out_of_range == "wrap": + range_size = int(info.max) - int(info.min) + 1 + oor = (work < lo) | (work > hi) + if oor.any(): + work[oor] = (work[oor].astype(np.int64) - int(info.min)) % range_size + int( + info.min + ) + elif (work.min() < lo) or (work.max() > hi): + raise ValueError( + f"Values out of range for {target_dtype} and no out_of_range policy set" + ) + + return work.astype(target_dtype) + + # int→float with scalar_map + if src_is_int and tgt_is_float and scalar_map_entries: + work = arr.astype(np.float64) + _apply_scalar_map(work, scalar_map_entries) + return work.astype(target_dtype) + + # float→float with scalar_map + if src_is_float and tgt_is_float and scalar_map_entries: + work = arr.copy() + _apply_scalar_map(work, scalar_map_entries) + return work.astype(target_dtype) + + # int→int with scalar_map + if src_is_int and tgt_is_int and scalar_map_entries: + work = arr.astype(np.int64) + _apply_scalar_map(work, scalar_map_entries) + info = np.iinfo(target_dtype) + lo, hi = int(info.min), int(info.max) + w_min, w_max = int(work.min()), int(work.max()) + if w_min < lo or w_max > hi: + if out_of_range == "clamp": + np.clip(work, lo, hi, out=work) + elif out_of_range == "wrap": + range_size = hi - lo + 1 + oor = (work < lo) | (work > hi) + work[oor] = (work[oor] - lo) % range_size + lo + else: + raise ValueError( + f"Values out of range for {target_dtype} and no out_of_range policy set" + ) + return work.astype(target_dtype) + + # Fallback + return arr.astype(target_dtype) + + +def _parse_scalar_map( + data: ScalarMapJSON | None, +) -> tuple[list[_MapEntry] | None, list[_MapEntry] | None]: + """Parse scalar_map JSON into pre-parsed encode and decode entry lists. + + Returns (encode_entries, decode_entries). Either may be None. + """ + if data is None: + return None, None + encode_raw: dict[str, str] = {} + decode_raw: dict[str, str] = {} + for src, tgt in data.get("encode", []): + encode_raw[str(src)] = str(tgt) + for src, tgt in data.get("decode", []): + decode_raw[str(src)] = str(tgt) + return ( + _parse_map_entries(encode_raw) if encode_raw else None, + _parse_map_entries(decode_raw) if decode_raw else None, + ) + + +@dataclass(frozen=True) +class CastValueCodec(ArrayArrayCodec): + """Cast-value array-to-array codec. + + Value-converts array elements to a new data type during encoding, + and back to the original data type during decoding. + + Parameters + ---------- + data_type : str + Target zarr v3 data type name (e.g. "uint8", "float32"). + rounding : RoundingMode + How to round when exact representation is impossible. Default is "nearest-even". + out_of_range : OutOfRangeMode or None + What to do when a value is outside the target's range. + None means error. "clamp" clips to range. "wrap" uses modular arithmetic + (only valid for integer types). + scalar_map : dict or None + Explicit value overrides as JSON: {"encode": [[src, tgt], ...], "decode": [[src, tgt], ...]}. + """ + + is_fixed_size = True + + data_type: str + rounding: RoundingMode + out_of_range: OutOfRangeMode | None + scalar_map: ScalarMapJSON | None + + def __init__( + self, + *, + data_type: str, + rounding: RoundingMode = "nearest-even", + out_of_range: OutOfRangeMode | None = None, + scalar_map: ScalarMapJSON | None = None, + ) -> None: + object.__setattr__(self, "data_type", data_type) + object.__setattr__(self, "rounding", rounding) + object.__setattr__(self, "out_of_range", out_of_range) + object.__setattr__(self, "scalar_map", scalar_map) + + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "cast_value", require_configuration=True + ) + return cls(**configuration_parsed) # type: ignore[arg-type] + + def to_dict(self) -> dict[str, JSON]: + config: dict[str, JSON] = {"data_type": self.data_type} + if self.rounding != "nearest-even": + config["rounding"] = self.rounding + if self.out_of_range is not None: + config["out_of_range"] = self.out_of_range + if self.scalar_map is not None: + config["scalar_map"] = self.scalar_map + return {"name": "cast_value", "configuration": config} + + def _target_zdtype(self) -> ZDType[TBaseDType, TBaseScalar]: + return get_data_type_from_json(self.data_type, zarr_format=3) + + def validate( + self, + *, + shape: tuple[int, ...], + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: + source_native = dtype.to_native_dtype() + target_native = self._target_zdtype().to_native_dtype() + for label, dt in [("source", source_native), ("target", target_native)]: + if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): + raise ValueError( + f"cast_value codec only supports integer and floating-point data types. " + f"Got {label} dtype {dt}." + ) + if self.out_of_range == "wrap": + if not np.issubdtype(target_native, np.integer): + raise ValueError("out_of_range='wrap' is only valid for integer target types.") + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + target_zdtype = self._target_zdtype() + target_native = target_zdtype.to_native_dtype() + source_native = chunk_spec.dtype.to_native_dtype() + + fill = chunk_spec.fill_value + fill_arr = np.array([fill], dtype=source_native) + + encode_entries, _ = _parse_scalar_map(self.scalar_map) + + new_fill_arr = _cast_array( + fill_arr, target_native, self.rounding, self.out_of_range, encode_entries + ) + new_fill = target_native.type(new_fill_arr[0]) + + return replace(chunk_spec, dtype=target_zdtype, fill_value=new_fill) + + def _encode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + arr = chunk_array.as_ndarray_like() + target_native = self._target_zdtype().to_native_dtype() + + encode_entries, _ = _parse_scalar_map(self.scalar_map) + + result = _cast_array( + np.asarray(arr), target_native, self.rounding, self.out_of_range, encode_entries + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_array, chunk_spec) + + def _decode_sync( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + arr = chunk_array.as_ndarray_like() + target_native = chunk_spec.dtype.to_native_dtype() + + _, decode_entries = _parse_scalar_map(self.scalar_map) + + result = _cast_array( + np.asarray(arr), target_native, self.rounding, self.out_of_range, decode_entries + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _decode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_array, chunk_spec) + + def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: + source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize + target_itemsize = self._target_zdtype().to_native_dtype().itemsize + if source_itemsize == 0: + return 0 + num_elements = input_byte_length // source_itemsize + return num_elements * target_itemsize diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py new file mode 100644 index 0000000000..f4cd95ed52 --- /dev/null +++ b/src/zarr/codecs/scale_offset.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING + +import numpy as np + +from zarr.abc.codec import ArrayArrayCodec +from zarr.core.array_spec import ArraySpec +from zarr.core.common import JSON, parse_named_configuration + +if TYPE_CHECKING: + from typing import Self + + from zarr.core.buffer import NDBuffer + from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + + +@dataclass(frozen=True) +class ScaleOffsetCodec(ArrayArrayCodec): + """Scale-offset array-to-array codec. + + Encodes values by subtracting an offset and multiplying by a scale factor. + Decodes by dividing by the scale and adding the offset. + + All arithmetic uses the input array's data type semantics. + + Parameters + ---------- + offset : float + Value subtracted during encoding. Default is 0. + scale : float + Value multiplied during encoding (after offset subtraction). Default is 1. + """ + + is_fixed_size = True + + offset: float + scale: float + + def __init__(self, *, offset: float = 0, scale: float = 1) -> None: + object.__setattr__(self, "offset", offset) + object.__setattr__(self, "scale", scale) + + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "scale_offset", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) # type: ignore[arg-type] + + def to_dict(self) -> dict[str, JSON]: + if self.offset == 0 and self.scale == 1: + return {"name": "scale_offset"} + config: dict[str, JSON] = {} + if self.offset != 0: + config["offset"] = self.offset + if self.scale != 1: + config["scale"] = self.scale + return {"name": "scale_offset", "configuration": config} + + def validate( + self, + *, + shape: tuple[int, ...], + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: + native = dtype.to_native_dtype() + if not np.issubdtype(native, np.integer) and not np.issubdtype(native, np.floating): + raise ValueError( + f"scale_offset codec only supports integer and floating-point data types. " + f"Got {dtype}." + ) + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + native_dtype = chunk_spec.dtype.to_native_dtype() + fill = chunk_spec.fill_value + new_fill = ( + np.dtype(native_dtype).type(fill) - native_dtype.type(self.offset) + ) * native_dtype.type(self.scale) + return replace(chunk_spec, fill_value=new_fill) + + def _encode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + arr = chunk_array.as_ndarray_like() + result = (arr - arr.dtype.type(self.offset)) * arr.dtype.type(self.scale) + return chunk_array.__class__.from_ndarray_like(result) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_array, chunk_spec) + + def _decode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer: + arr = chunk_array.as_ndarray_like() + result = (arr / arr.dtype.type(self.scale)) + arr.dtype.type(self.offset) + return chunk_array.__class__.from_ndarray_like(result) + + async def _decode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_array, chunk_spec) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + return input_byte_length diff --git a/src/zarr/core/chunk_grids/rectilinear.py b/src/zarr/core/chunk_grids/rectilinear.py new file mode 100644 index 0000000000..3985613244 --- /dev/null +++ b/src/zarr/core/chunk_grids/rectilinear.py @@ -0,0 +1,802 @@ +from __future__ import annotations + +import bisect +import itertools +import operator +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Literal, Self, TypedDict + +import numpy as np +import numpy.typing as npt +from zarr.core.chunk_grids.common import ChunkEdgeLength, ChunkGrid + +from zarr.core.common import ( + JSON, + ShapeLike, + parse_named_configuration, + parse_shapelike, +) + +if TYPE_CHECKING: + from collections.abc import Iterator + + +class RectilinearChunkGridConfigurationDict(TypedDict): + """TypedDict for rectilinear chunk grid configuration""" + + kind: Literal["inline"] + chunk_shapes: Sequence[Sequence[ChunkEdgeLength]] + + +def _parse_chunk_shapes( + data: Sequence[Sequence[ChunkEdgeLength]], +) -> tuple[tuple[int, ...], ...]: + """ + Parse and expand chunk_shapes from metadata. + + Parameters + ---------- + data : Sequence[Sequence[ChunkEdgeLength]] + The chunk_shapes specification from metadata + + Returns + ------- + tuple[tuple[int, ...], ...] + Tuple of expanded chunk edge lengths for each axis + """ + # Runtime validation - strings are sequences but we don't want them + # Type annotation is for static typing, this validates actual JSON data + if isinstance(data, str) or not isinstance(data, Sequence): # type: ignore[redundant-expr,unreachable] + raise TypeError(f"chunk_shapes must be a sequence, got {type(data)}") + + result = [] + for i, axis_spec in enumerate(data): + # Runtime validation for each axis spec + if isinstance(axis_spec, str) or not isinstance(axis_spec, Sequence): # type: ignore[redundant-expr,unreachable] + raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_spec)}") + expanded = _expand_run_length_encoding(axis_spec) + result.append(expanded) + + return tuple(result) + + +def _expand_run_length_encoding(spec: Sequence[ChunkEdgeLength]) -> tuple[int, ...]: + """ + Expand a chunk edge length specification into a tuple of integers. + + The specification can contain: + - integers: representing explicit edge lengths + - tuples [value, count]: representing run-length encoded sequences + + Parameters + ---------- + spec : Sequence[ChunkEdgeLength] + The chunk edge length specification for one axis + + Returns + ------- + tuple[int, ...] + Expanded sequence of chunk edge lengths + + Examples + -------- + >>> _expand_run_length_encoding([2, 3]) + (2, 3) + >>> _expand_run_length_encoding([[2, 3]]) + (2, 2, 2) + >>> _expand_run_length_encoding([1, [2, 1], 3]) + (1, 2, 3) + >>> _expand_run_length_encoding([[1, 3], 3]) + (1, 1, 1, 3) + """ + result: list[int] = [] + for item in spec: + if isinstance(item, int): + # Explicit edge length + result.append(item) + elif isinstance(item, list | tuple): + # Run-length encoded: [value, count] + if len(item) != 2: + raise TypeError( + f"Run-length encoded items must be [int, int], got list of length {len(item)}" + ) + value, count = item + # Runtime validation of JSON data + if not isinstance(value, int) or not isinstance(count, int): # type: ignore[redundant-expr] + raise TypeError( + f"Run-length encoded items must be [int, int], got [{type(value).__name__}, {type(count).__name__}]" + ) + if count < 0: + raise ValueError(f"Run-length count must be non-negative, got {count}") + result.extend([value] * count) + else: + raise TypeError( + f"Chunk edge length must be int or [int, int] for run-length encoding, " + f"got {type(item)}" + ) + return tuple(result) + + +def _compress_run_length_encoding(chunks: tuple[int, ...]) -> list[int | list[int]]: + """ + Compress a sequence of chunk sizes to RLE format where beneficial. + + This function automatically detects runs of identical values and compresses them + using the [value, count] format. Single values or short runs are kept as-is. + + Parameters + ---------- + chunks : tuple[int, ...] + Sequence of chunk sizes along one dimension + + Returns + ------- + list[int | list[int]] + Compressed representation using RLE where beneficial + + Examples + -------- + >>> _compress_run_length_encoding((10, 10, 10, 10, 10, 10)) + [[10, 6]] + >>> _compress_run_length_encoding((10, 20, 30)) + [10, 20, 30] + >>> _compress_run_length_encoding((10, 10, 10, 20, 20, 30)) + [[10, 3], [20, 2], 30] + >>> _compress_run_length_encoding((5, 5, 10, 10, 10, 10, 15)) + [[5, 2], [10, 4], 15] + """ + if not chunks: + return [] + + result: list[int | list[int]] = [] + current_value = chunks[0] + current_count = 1 + + for value in chunks[1:]: + if value == current_value: + current_count += 1 + else: + # Decide whether to use RLE or explicit value + # Use RLE if count >= 3 to save space (tradeoff: [v,c] vs v,v,v) + if current_count >= 3: + result.append([current_value, current_count]) + elif current_count == 2: + # For count=2, RLE doesn't save space, but use it for consistency + result.append([current_value, current_count]) + else: + result.append(current_value) + + current_value = value + current_count = 1 + + # Handle the last run + if current_count >= 3 or current_count == 2: + result.append([current_value, current_count]) + else: + result.append(current_value) + + return result + + +def _normalize_rectilinear_chunks( + chunks: Sequence[Sequence[int | Sequence[int]]], shape: tuple[int, ...] +) -> tuple[tuple[int, ...], ...]: + """ + Normalize and validate variable chunks for RectilinearChunkGrid. + + Supports both explicit chunk sizes and run-length encoding (RLE). + RLE format: [[value, count]] expands to 'count' repetitions of 'value'. + + Parameters + ---------- + chunks : Sequence[Sequence[int | Sequence[int]]] + Nested sequence where each element is a sequence of chunk sizes along that dimension. + Each chunk size can be: + - An integer: explicit chunk size + - A sequence [value, count]: RLE format (expands to 'count' chunks of size 'value') + shape : tuple[int, ...] + The shape of the array. + + Returns + ------- + tuple[tuple[int, ...], ...] + Normalized chunk shapes as tuple of tuples. + + Raises + ------ + ValueError + If chunks don't match shape or sum incorrectly. + TypeError + If chunk specification format is invalid. + + Examples + -------- + >>> _normalize_rectilinear_chunks([[10, 20, 30], [25, 25]], (60, 50)) + ((10, 20, 30), (25, 25)) + >>> _normalize_rectilinear_chunks([[[10, 6]], [[10, 5]]], (60, 50)) + ((10, 10, 10, 10, 10, 10), (10, 10, 10, 10, 10)) + """ + # Expand RLE for each dimension + try: + chunk_shapes = tuple( + _expand_run_length_encoding(dim) # type: ignore[arg-type] + for dim in chunks + ) + except (TypeError, ValueError) as e: + raise TypeError( + f"Invalid variable chunks: {chunks}. Expected nested sequence of integers " + f"or RLE format [[value, count]]." + ) from e + + # Validate dimensionality + if len(chunk_shapes) != len(shape): + raise ValueError( + f"Variable chunks dimensionality ({len(chunk_shapes)}) " + f"must match array shape dimensionality ({len(shape)})" + ) + + # Validate that chunks sum to shape for each dimension + for i, (dim_chunks, dim_size) in enumerate(zip(chunk_shapes, shape, strict=False)): + chunk_sum = sum(dim_chunks) + if chunk_sum < dim_size: + raise ValueError( + f"Variable chunks along dimension {i} sum to {chunk_sum} " + f"but array shape is {dim_size}. " + f"Chunks must sum to be greater than or equal to the shape." + ) + if sum(dim_chunks[:-1]) >= dim_size: + raise ValueError( + f"Dimension {i} has more chunks than needed. " + f"The last chunk(s) would contain no valid data. " + f"Remove the extra chunk(s) or increase the array shape." + ) + + return chunk_shapes + + +@dataclass(frozen=True) +class RectilinearChunkGrid(ChunkGrid): + """ + A rectilinear chunk grid where chunk sizes vary along each axis. + + .. warning:: + This is an experimental feature and may change in future releases. + Expected to stabilize in Zarr version 3.3. + + Attributes + ---------- + chunk_shapes : tuple[tuple[int, ...], ...] + For each axis, a tuple of chunk edge lengths along that axis. + The sum of edge lengths must be >= the array shape along that axis. + """ + + _array_shape: tuple[int, ...] + chunk_shapes: tuple[tuple[int, ...], ...] + + def __init__(self, *, chunk_shapes: Sequence[Sequence[int]], array_shape: ShapeLike) -> None: + """ + Initialize a RectilinearChunkGrid. + + Parameters + ---------- + chunk_shapes : Sequence[Sequence[int]] + For each axis, a sequence of chunk edge lengths. + array_shape : ShapeLike + The shape of the array this chunk grid is bound to. + """ + array_shape_parsed = parse_shapelike(array_shape) + + # Convert to nested tuples and validate + parsed_shapes: list[tuple[int, ...]] = [] + for i, axis_chunks in enumerate(chunk_shapes): + if not isinstance(axis_chunks, Sequence): + raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_chunks)}") + # Validate all are positive integers + axis_tuple = tuple(axis_chunks) + for j, size in enumerate(axis_tuple): + if not isinstance(size, int): + raise TypeError( + f"chunk_shapes[{i}][{j}] must be an int, got {type(size).__name__}" + ) + if size <= 0: + raise ValueError(f"chunk_shapes[{i}][{j}] must be positive, got {size}") + parsed_shapes.append(axis_tuple) + + chunk_shapes_parsed = tuple(parsed_shapes) + object.__setattr__(self, "chunk_shapes", chunk_shapes_parsed) + object.__setattr__(self, "_array_shape", array_shape_parsed) + + # Validate array_shape is compatible with chunk_shapes + self._validate_array_shape() + + @property + def array_shape(self) -> tuple[int, ...]: + return self._array_shape + + @classmethod + def from_dict( # type: ignore[override] + cls, data: dict[str, JSON], *, array_shape: ShapeLike + ) -> Self: + """ + Parse a RectilinearChunkGrid from metadata dict. + + Parameters + ---------- + data : dict[str, JSON] + Metadata dictionary with 'name' and 'configuration' keys + array_shape : ShapeLike + The shape of the array this chunk grid is bound to. + + Returns + ------- + Self + A RectilinearChunkGrid instance + """ + _, configuration = parse_named_configuration(data, "rectilinear") + + if not isinstance(configuration, dict): + raise TypeError(f"configuration must be a dict, got {type(configuration)}") + + # Validate kind field + kind = configuration.get("kind") + if kind != "inline": + raise ValueError(f"Only 'inline' kind is supported, got {kind!r}") + + # Parse chunk_shapes with run-length encoding support + chunk_shapes_raw = configuration.get("chunk_shapes") + if chunk_shapes_raw is None: + raise ValueError("configuration must contain 'chunk_shapes'") + + # Type ignore: JSON data validated at runtime by _parse_chunk_shapes + chunk_shapes_expanded = _parse_chunk_shapes(chunk_shapes_raw) # type: ignore[arg-type] + + return cls(chunk_shapes=chunk_shapes_expanded, array_shape=array_shape) + + def to_dict(self) -> dict[str, JSON]: + """ + Convert to metadata dict format with automatic RLE compression. + + This method automatically compresses chunk shapes using run-length encoding + where beneficial (runs of 2 or more identical values). This reduces metadata + size for arrays with many uniform chunks. + + Returns + ------- + dict[str, JSON] + Metadata dictionary with 'name' and 'configuration' keys + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 10, 10, 10, 10, 10], [5, 5, 5, 5, 5]], array_shape=(60, 25)) + >>> grid.to_dict()['configuration']['chunk_shapes'] + [[[10, 6]], [[5, 5]]] + """ + # Compress each dimension using RLE where beneficial + chunk_shapes_compressed = [ + _compress_run_length_encoding(axis_chunks) for axis_chunks in self.chunk_shapes + ] + + return { + "name": "rectilinear", + "configuration": { + "kind": "inline", + "chunk_shapes": chunk_shapes_compressed, + }, + } + + def update_shape(self, new_shape: tuple[int, ...]) -> Self: + """ + Update the RectilinearChunkGrid to accommodate a new array shape. + + When resizing an array, this method adjusts the chunk grid to match the new shape. + For dimensions that grow, a new chunk is added with size equal to the size difference. + For dimensions that shrink, chunks are truncated or removed to fit the new shape. + + Parameters + ---------- + new_shape : tuple[int, ...] + The new shape of the array. Must have the same number of dimensions as the + chunk grid. + + Returns + ------- + Self + A new RectilinearChunkGrid instance with updated chunk shapes and array_shape + + Raises + ------ + ValueError + If the number of dimensions in new_shape doesn't match the number of dimensions + in the chunk grid + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20], [15, 15]], array_shape=(30, 30)) + >>> grid.update_shape((50, 40)) # Grow both dimensions + RectilinearChunkGrid(_array_shape=(50, 40), chunk_shapes=((10, 20, 20), (15, 15, 10))) + + >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20, 30], [25, 25]], array_shape=(60, 50)) + >>> grid.update_shape((25, 30)) # Shrink first dimension + RectilinearChunkGrid(_array_shape=(25, 30), chunk_shapes=((10, 20), (25, 25))) + + Notes + ----- + This method is automatically called when an array is resized. The chunk size + strategy for growing dimensions adds a single new chunk with size equal to + the growth amount. This may not be optimal for all use cases, and users may + want to manually adjust chunk shapes after resizing. + """ + + if len(new_shape) != len(self.chunk_shapes): + raise ValueError( + f"new_shape has {len(new_shape)} dimensions but " + f"chunk_shapes has {len(self.chunk_shapes)} dimensions" + ) + + new_chunk_shapes: list[tuple[int, ...]] = [] + for dim in range(len(new_shape)): + old_dim_length = sum(self.chunk_shapes[dim]) + new_dim_chunks: tuple[int, ...] + if new_shape[dim] == old_dim_length: + new_dim_chunks = self.chunk_shapes[dim] # no changes + + elif new_shape[dim] > old_dim_length: + new_dim_chunks = (*self.chunk_shapes[dim], new_shape[dim] - old_dim_length) + else: + # drop chunk sizes that are not inside the shape anymore + total = 0 + i = 0 + for c in self.chunk_shapes[dim]: + i += 1 + total += c + if total >= new_shape[dim]: + break + # keep the last chunk (it may be too long) + new_dim_chunks = self.chunk_shapes[dim][:i] + + new_chunk_shapes.append(new_dim_chunks) + + return type(self)(chunk_shapes=tuple(new_chunk_shapes), array_shape=new_shape) + + def all_chunk_coords(self) -> Iterator[tuple[int, ...]]: + """ + Generate all chunk coordinates. + + Yields + ------ + tuple[int, ...] + Chunk coordinates + """ + nchunks_per_axis = [len(axis_chunks) for axis_chunks in self.chunk_shapes] + return itertools.product(*(range(n) for n in nchunks_per_axis)) + + def get_nchunks(self) -> int: + """ + Get the total number of chunks. + + Returns + ------- + int + Total number of chunks + """ + return reduce(operator.mul, (len(axis_chunks) for axis_chunks in self.chunk_shapes), 1) + + def _validate_array_shape(self) -> None: + """ + Validate that array_shape is compatible with chunk_shapes. + + Raises + ------ + ValueError + If array_shape is incompatible with chunk_shapes + """ + if len(self._array_shape) != len(self.chunk_shapes): + raise ValueError( + f"array_shape has {len(self._array_shape)} dimensions but " + f"chunk_shapes has {len(self.chunk_shapes)} dimensions" + ) + + for axis, (arr_size, axis_chunks) in enumerate( + zip(self._array_shape, self.chunk_shapes, strict=False) + ): + chunk_sum = sum(axis_chunks) + if chunk_sum < arr_size: + raise ValueError( + f"Sum of chunk sizes along axis {axis} is {chunk_sum} " + f"but array shape is {arr_size}. This is invalid for the " + "RectilinearChunkGrid." + ) + + @cached_property + def _cumulative_sizes(self) -> tuple[tuple[int, ...], ...]: + """ + Compute cumulative sizes for each axis. + + Returns a tuple of tuples where each inner tuple contains cumulative + chunk sizes for an axis. Used for efficient chunk boundary calculations. + + Returns + ------- + tuple[tuple[int, ...], ...] + Cumulative sizes for each axis + + Examples + -------- + For chunk_shapes = [[2, 3, 1], [4, 2]]: + Returns ((0, 2, 5, 6), (0, 4, 6)) + """ + result = [] + for axis_chunks in self.chunk_shapes: + cumsum = [0] + for size in axis_chunks: + cumsum.append(cumsum[-1] + size) + result.append(tuple(cumsum)) + return tuple(result) + + def get_chunk_start(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """ + Get the starting position (offset) of a chunk in the array. + + Parameters + ---------- + chunk_coord : tuple[int, ...] + Chunk coordinates (indices into the chunk grid) + + Returns + ------- + tuple[int, ...] + Starting index of the chunk in the array + + Raises + ------ + IndexError + If chunk_coord is out of bounds + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) + >>> grid.get_chunk_start((0, 0)) + (0, 0) + >>> grid.get_chunk_start((1, 1)) + (2, 3) + """ + # Validate chunk coordinates are in bounds + for axis, (coord, axis_chunks) in enumerate( + zip(chunk_coord, self.chunk_shapes, strict=False) + ): + if not (0 <= coord < len(axis_chunks)): + raise IndexError( + f"chunk_coord[{axis}] = {coord} is out of bounds [0, {len(axis_chunks)})" + ) + + # Use cumulative sizes to get start position + return tuple(self._cumulative_sizes[axis][coord] for axis, coord in enumerate(chunk_coord)) + + def get_chunk_shape(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: + """ + Get the shape of a specific chunk. + + Parameters + ---------- + chunk_coord : tuple[int, ...] + Chunk coordinates (indices into the chunk grid) + + Returns + ------- + tuple[int, ...] + Shape of the chunk + + Raises + ------ + IndexError + If chunk_coord is out of bounds + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 3, 1], [4, 2]], array_shape=(6, 6)) + >>> grid.get_chunk_shape((0, 0)) + (2, 4) + >>> grid.get_chunk_shape((1, 0)) + (3, 4) + """ + # Validate chunk coordinates are in bounds + for axis, (coord, axis_chunks) in enumerate( + zip(chunk_coord, self.chunk_shapes, strict=False) + ): + if not (0 <= coord < len(axis_chunks)): + raise IndexError( + f"chunk_coord[{axis}] = {coord} is out of bounds [0, {len(axis_chunks)})" + ) + + # Get shape directly from chunk_shapes + return tuple( + axis_chunks[coord] + for axis_chunks, coord in zip(self.chunk_shapes, chunk_coord, strict=False) + ) + + def get_chunk_slice(self, chunk_coord: tuple[int, ...]) -> tuple[slice, ...]: + """ + Get the slice for indexing into an array for a specific chunk. + + Parameters + ---------- + chunk_coord : tuple[int, ...] + Chunk coordinates (indices into the chunk grid) + + Returns + ------- + tuple[slice, ...] + Slice tuple for indexing the array + + Raises + ------ + IndexError + If chunk_coord is out of bounds + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) + >>> grid.get_chunk_slice((0, 0)) + (slice(0, 2, None), slice(0, 3, None)) + >>> grid.get_chunk_slice((1, 1)) + (slice(2, 4, None), slice(3, 6, None)) + """ + start = self.get_chunk_start(chunk_coord) + shape = self.get_chunk_shape(chunk_coord) + + return tuple(slice(s, s + length) for s, length in zip(start, shape, strict=False)) + + def get_chunk_grid_shape(self) -> tuple[int, ...]: + """ + Get the shape of the chunk grid (number of chunks per axis). + + Returns + ------- + tuple[int, ...] + Number of chunks along each axis + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) + >>> grid.get_chunk_grid_shape() + (3, 2) + """ + return tuple(len(axis_chunks) for axis_chunks in self.chunk_shapes) + + def array_index_to_chunk_coord(self, array_index: tuple[int, ...]) -> tuple[int, ...]: + """ + Find which chunk contains a given array index. + + Parameters + ---------- + array_index : tuple[int, ...] + Index into the array + + Returns + ------- + tuple[int, ...] + Chunk coordinates containing the array index + + Raises + ------ + IndexError + If array_index is out of bounds + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 3, 1], [4, 2]], array_shape=(6, 6)) + >>> grid.array_index_to_chunk_coord((0, 0)) + (0, 0) + >>> grid.array_index_to_chunk_coord((2, 0)) + (1, 0) + >>> grid.array_index_to_chunk_coord((5, 5)) + (2, 1) + """ + # Validate array index is in bounds + for axis, (idx, size) in enumerate(zip(array_index, self._array_shape, strict=False)): + if not (0 <= idx < size): + raise IndexError(f"array_index[{axis}] = {idx} is out of bounds [0, {size})") + + # Use binary search in cumulative sizes to find chunk coordinate + result = [] + for axis, idx in enumerate(array_index): + cumsum = self._cumulative_sizes[axis] + # bisect_right gives us the chunk index + 1, so subtract 1 + chunk_idx = bisect.bisect_right(cumsum, idx) - 1 + result.append(chunk_idx) + + return tuple(result) + + def array_indices_to_chunk_dim( + self, dim: int, indices: npt.NDArray[np.intp] + ) -> npt.NDArray[np.intp]: + """ + Vectorized mapping of array indices to chunk coordinates along one dimension. + + For RectilinearChunkGrid, uses np.searchsorted on cumulative sizes. + """ + cumsum = np.asarray(self._cumulative_sizes[dim]) + return np.searchsorted(cumsum, indices, side="right").astype(np.intp) - 1 + + def chunks_in_selection(self, selection: tuple[slice, ...]) -> Iterator[tuple[int, ...]]: + """ + Get all chunks that intersect with a given selection. + + Parameters + ---------- + selection : tuple[slice, ...] + Selection (slices) into the array + + Yields + ------ + tuple[int, ...] + Chunk coordinates that intersect with the selection + + Raises + ------ + ValueError + If selection is invalid + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) + >>> selection = (slice(1, 5), slice(2, 5)) + >>> list(grid.chunks_in_selection(selection)) + [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] + """ + # Normalize slices and find chunk ranges for each axis + chunk_ranges = [] + for axis, (sel, size) in enumerate(zip(selection, self._array_shape, strict=False)): + if not isinstance(sel, slice): + raise TypeError(f"selection[{axis}] must be a slice, got {type(sel)}") + + # Normalize slice with array size + start, stop, step = sel.indices(size) + + if step != 1: + raise ValueError(f"selection[{axis}] has step={step}, only step=1 is supported") + + if start >= stop: + # Empty selection + return + + # Find first and last chunk that intersect with [start, stop) + start_chunk = self.array_index_to_chunk_coord( + tuple(start if i == axis else 0 for i in range(len(self._array_shape))) + )[axis] + + # stop-1 is the last index we need + end_chunk = self.array_index_to_chunk_coord( + tuple(stop - 1 if i == axis else 0 for i in range(len(self._array_shape))) + )[axis] + + chunk_ranges.append(range(start_chunk, end_chunk + 1)) + + # Generate all combinations of chunk coordinates + yield from itertools.product(*chunk_ranges) + + def chunks_per_dim(self, dim: int) -> int: + """ + Get the number of chunks along a specific dimension. + + Parameters + ---------- + dim : int + Dimension index + + Returns + ------- + int + Number of chunks along the dimension + + Examples + -------- + >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20], [5, 5, 5]], array_shape=(30, 15)) + >>> grid.chunks_per_dim(0) # 2 chunks along axis 0 + 2 + >>> grid.chunks_per_dim(1) # 3 chunks along axis 1 + 3 + """ + return len(self.chunk_shapes[dim]) diff --git a/tests/test_codecs/test_scale_offset_cast.py b/tests/test_codecs/test_scale_offset_cast.py new file mode 100644 index 0000000000..f4a842ed87 --- /dev/null +++ b/tests/test_codecs/test_scale_offset_cast.py @@ -0,0 +1,598 @@ +"""Tests for scale_offset and cast_value codecs.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import zarr +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.cast_value import CastValueCodec +from zarr.codecs.scale_offset import ScaleOffsetCodec +from zarr.storage import MemoryStore + + +class TestScaleOffsetCodec: + """Tests for the scale_offset codec.""" + + def test_identity(self) -> None: + """Default parameters (offset=0, scale=1) should be a no-op.""" + store = MemoryStore() + data = np.arange(20, dtype="float64").reshape(4, 5) + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4, 5), + codecs=[ScaleOffsetCodec(), BytesCodec()], + ) + arr[:] = data + np.testing.assert_array_equal(arr[:], data) + + def test_encode_decode_float64(self) -> None: + """Encode/decode round-trip with float64 data.""" + store = MemoryStore() + data = np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ScaleOffsetCodec(offset=10, scale=0.1), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data, rtol=1e-10) + + def test_encode_decode_float32(self) -> None: + """Round-trip with float32 data.""" + store = MemoryStore() + data = np.array([1.0, 2.0, 3.0], dtype="float32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffsetCodec(offset=1, scale=2), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data, rtol=1e-6) + + def test_encode_decode_integer(self) -> None: + """Round-trip with integer data (uses integer arithmetic semantics).""" + store = MemoryStore() + data = np.array([10, 20, 30, 40, 50], dtype="int32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ScaleOffsetCodec(offset=10, scale=1), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, data) + + def test_offset_only(self) -> None: + """Test with only offset (scale=1).""" + store = MemoryStore() + data = np.array([100.0, 200.0, 300.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffsetCodec(offset=100), BytesCodec()], + ) + arr[:] = data + np.testing.assert_allclose(arr[:], data) + + def test_scale_only(self) -> None: + """Test with only scale (offset=0).""" + store = MemoryStore() + data = np.array([1.0, 2.0, 3.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffsetCodec(scale=10), BytesCodec()], + ) + arr[:] = data + np.testing.assert_allclose(arr[:], data) + + def test_fill_value_transformed(self) -> None: + """Fill value should be transformed through the codec.""" + store = MemoryStore() + arr = zarr.create( + store=store, + shape=(5,), + dtype="float64", + chunks=(5,), + fill_value=100.0, + codecs=[ScaleOffsetCodec(offset=10, scale=2), BytesCodec()], + ) + # Without writing, reading should return the fill value + result = arr[:] + np.testing.assert_allclose(result, np.full(5, 100.0)) + + def test_validate_rejects_complex(self) -> None: + """Validate should reject complex dtypes.""" + with pytest.raises(ValueError, match="only supports integer and floating-point"): + zarr.create( + store=MemoryStore(), + shape=(5,), + dtype="complex128", + chunks=(5,), + codecs=[ScaleOffsetCodec(offset=1, scale=2), BytesCodec()], + ) + + def test_to_dict_no_config(self) -> None: + """Default codec should serialize without configuration.""" + codec = ScaleOffsetCodec() + assert codec.to_dict() == {"name": "scale_offset"} + + def test_to_dict_with_config(self) -> None: + """Non-default codec should include configuration.""" + codec = ScaleOffsetCodec(offset=5, scale=0.1) + d = codec.to_dict() + assert d == {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} + + def test_to_dict_offset_only(self) -> None: + """Only offset in config when scale is default.""" + codec = ScaleOffsetCodec(offset=5) + d = codec.to_dict() + assert d == {"name": "scale_offset", "configuration": {"offset": 5}} + + def test_from_dict_no_config(self) -> None: + """Parse codec from JSON with no configuration.""" + codec = ScaleOffsetCodec.from_dict({"name": "scale_offset"}) + assert codec.offset == 0 + assert codec.scale == 1 + + def test_from_dict_with_config(self) -> None: + """Parse codec from JSON with configuration.""" + codec = ScaleOffsetCodec.from_dict( + {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} + ) + assert codec.offset == 5 + assert codec.scale == 0.1 + + def test_roundtrip_json(self) -> None: + """to_dict -> from_dict should preserve parameters.""" + original = ScaleOffsetCodec(offset=3.14, scale=2.71) + restored = ScaleOffsetCodec.from_dict(original.to_dict()) + assert restored.offset == original.offset + assert restored.scale == original.scale + + +class TestCastValueCodec: + """Tests for the cast_value codec.""" + + def test_float64_to_float32(self) -> None: + """Cast float64 to float32 and back.""" + store = MemoryStore() + data = np.array([1.0, 2.0, 3.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[CastValueCodec(data_type="float32"), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data) + + def test_float64_to_int32_towards_zero(self) -> None: + """Cast float64 to int32 with towards-zero rounding.""" + store = MemoryStore() + data = np.array([1.7, -1.7, 2.3, -2.3], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[CastValueCodec(data_type="int32", rounding="towards-zero"), BytesCodec()], + ) + arr[:] = data + result = arr[:] + # After encoding to int32 with towards-zero: [1, -1, 2, -2] + # After decoding back to float64: [1.0, -1.0, 2.0, -2.0] + np.testing.assert_array_equal(result, [1.0, -1.0, 2.0, -2.0]) + + def test_float64_to_uint8_clamp(self) -> None: + """Cast float64 to uint8 with clamping out-of-range values.""" + store = MemoryStore() + data = np.array([0.0, 128.0, 300.0, -10.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="uint8", rounding="nearest-even", out_of_range="clamp"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [0.0, 128.0, 255.0, 0.0]) + + def test_float64_to_int8_wrap(self) -> None: + """Cast float64 to int8 with wrapping for out-of-range values.""" + store = MemoryStore() + data = np.array([200.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(1,), + codecs=[ + CastValueCodec(data_type="int8", rounding="nearest-even", out_of_range="wrap"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + # 200 wraps in int8 range [-128, 127]: (200 - (-128)) % 256 + (-128) = 328 % 256 - 128 = 72 - 128 = -56 + expected = np.array([200], dtype="float64") + expected_arr = np.array([200], dtype="float64") + # Encode: round(200) = 200, wrap: (200+128)%256-128 = 328%256-128 = 72-128 = -56 + # Decode: -56 cast back to float64 = -56.0 + np.testing.assert_array_equal(result, [-56.0]) + + def test_nan_to_integer_without_scalar_map_errors(self) -> None: + """NaN cast to integer without scalar_map should raise.""" + store = MemoryStore() + data = np.array([float("nan")], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(1,), + codecs=[CastValueCodec(data_type="uint8", out_of_range="clamp"), BytesCodec()], + ) + with pytest.raises(ValueError, match="Cannot cast NaN"): + arr[:] = data + + def test_nan_scalar_map(self) -> None: + """NaN should be mapped via scalar_map when provided.""" + store = MemoryStore() + data = np.array([1.0, float("nan"), 3.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ + CastValueCodec( + data_type="uint8", + out_of_range="clamp", + scalar_map={ + "encode": [["NaN", 0]], + "decode": [[0, "NaN"]], + }, + ), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + assert result[0] == 1.0 # 1.0 survives round-trip + assert np.isnan(result[1]) # NaN -> 0 -> NaN via scalar_map + assert result[2] == 3.0 + + def test_rounding_nearest_even(self) -> None: + """nearest-even rounding: 0.5 rounds to 0, 1.5 rounds to 2.""" + store = MemoryStore() + data = np.array([0.5, 1.5, 2.5, 3.5], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="int32", rounding="nearest-even"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [0.0, 2.0, 2.0, 4.0]) + + def test_rounding_towards_positive(self) -> None: + """towards-positive rounds up (ceil).""" + store = MemoryStore() + data = np.array([1.1, -1.1, 1.9, -1.9], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="int32", rounding="towards-positive"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [2.0, -1.0, 2.0, -1.0]) + + def test_rounding_towards_negative(self) -> None: + """towards-negative rounds down (floor).""" + store = MemoryStore() + data = np.array([1.1, -1.1, 1.9, -1.9], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="int32", rounding="towards-negative"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [1.0, -2.0, 1.0, -2.0]) + + def test_rounding_nearest_away(self) -> None: + """nearest-away rounds 0.5 away from zero.""" + store = MemoryStore() + data = np.array([0.5, 1.5, -0.5, -1.5], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="int32", rounding="nearest-away"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [1.0, 2.0, -1.0, -2.0]) + + def test_out_of_range_errors_by_default(self) -> None: + """Without out_of_range, values outside target range should error.""" + store = MemoryStore() + data = np.array([300.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(1,), + codecs=[CastValueCodec(data_type="uint8"), BytesCodec()], + ) + with pytest.raises(ValueError, match="out of range"): + arr[:] = data + + def test_wrap_only_valid_for_integers(self) -> None: + """wrap should be rejected for float target types.""" + with pytest.raises(ValueError, match="only valid for integer"): + zarr.create( + store=MemoryStore(), + shape=(5,), + dtype="float64", + chunks=(5,), + codecs=[ + CastValueCodec(data_type="float32", out_of_range="wrap"), + BytesCodec(), + ], + ) + + def test_validate_rejects_complex_source(self) -> None: + """Validate should reject complex source dtype.""" + with pytest.raises(ValueError, match="only supports integer and floating-point"): + zarr.create( + store=MemoryStore(), + shape=(5,), + dtype="complex128", + chunks=(5,), + codecs=[CastValueCodec(data_type="float64"), BytesCodec()], + ) + + def test_int32_to_int16_clamp(self) -> None: + """Integer-to-integer cast with clamping.""" + store = MemoryStore() + data = np.array([0, 100, 40000, -40000], dtype="int32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4,), + codecs=[ + CastValueCodec(data_type="int16", out_of_range="clamp"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, [0, 100, 32767, -32768]) + + def test_to_dict(self) -> None: + """Serialization to dict.""" + codec = CastValueCodec( + data_type="uint8", + rounding="towards-zero", + out_of_range="clamp", + scalar_map={"encode": [["NaN", 0]], "decode": [[0, "NaN"]]}, + ) + d = codec.to_dict() + assert d["name"] == "cast_value" + assert d["configuration"]["data_type"] == "uint8" + assert d["configuration"]["rounding"] == "towards-zero" + assert d["configuration"]["out_of_range"] == "clamp" + assert d["configuration"]["scalar_map"] == { + "encode": [["NaN", 0]], + "decode": [[0, "NaN"]], + } + + def test_to_dict_minimal(self) -> None: + """Only required fields in dict when defaults are used.""" + codec = CastValueCodec(data_type="float32") + d = codec.to_dict() + assert d == {"name": "cast_value", "configuration": {"data_type": "float32"}} + + def test_from_dict(self) -> None: + """Deserialization from dict.""" + codec = CastValueCodec.from_dict( + { + "name": "cast_value", + "configuration": { + "data_type": "uint8", + "rounding": "towards-zero", + "out_of_range": "clamp", + }, + } + ) + assert codec.data_type == "uint8" + assert codec.rounding == "towards-zero" + assert codec.out_of_range == "clamp" + + def test_roundtrip_json(self) -> None: + """to_dict -> from_dict should preserve all parameters.""" + original = CastValueCodec( + data_type="int16", + rounding="towards-negative", + out_of_range="clamp", + scalar_map={"encode": [["NaN", 0]]}, + ) + restored = CastValueCodec.from_dict(original.to_dict()) + assert restored.data_type == original.data_type + assert restored.rounding == original.rounding + assert restored.out_of_range == original.out_of_range + assert restored.scalar_map == original.scalar_map + + def test_fill_value_cast(self) -> None: + """Fill value should be cast to the target dtype.""" + store = MemoryStore() + arr = zarr.create( + store=store, + shape=(5,), + dtype="float64", + chunks=(5,), + fill_value=42.0, + codecs=[CastValueCodec(data_type="int32"), BytesCodec()], + ) + result = arr[:] + np.testing.assert_array_equal(result, np.full(5, 42.0)) + + def test_computed_encoded_size(self) -> None: + """Encoded size should reflect the target dtype's item size.""" + codec = CastValueCodec(data_type="uint8") + from zarr.core.array_spec import ArrayConfig, ArraySpec + from zarr.core.buffer.cpu import buffer_prototype + from zarr.core.dtype import parse_dtype + + spec = ArraySpec( + shape=(10,), + dtype=parse_dtype("float64", zarr_format=3), + fill_value=0.0, + config=ArrayConfig.from_dict({}), + prototype=buffer_prototype, + ) + # 10 float64 elements = 80 bytes input, 10 uint8 elements = 10 bytes output + assert codec.compute_encoded_size(80, spec) == 10 + + +class TestScaleOffsetAndCastValueCombined: + """Tests for the combined scale_offset + cast_value codec pipeline.""" + + def test_float64_to_uint8_roundtrip(self) -> None: + """Typical usage: float64 -> scale_offset -> cast_value(uint8) -> bytes.""" + store = MemoryStore() + # Data in range [0, 25.5] maps to [0, 255] with scale=10 + data = np.array([0.0, 1.0, 2.5, 10.0, 25.5], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ + ScaleOffsetCodec(offset=0, scale=10), + CastValueCodec(data_type="uint8", out_of_range="clamp"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data, atol=0.1) + + def test_temperature_storage_pattern(self) -> None: + """Realistic pattern: store temperature data as uint8. + + Temperature range: -10°C to 45°C + Encode: (temp - (-10)) * (255/55) = (temp + 10) * 4.636... + Use offset=-10, scale=255/55 + """ + store = MemoryStore() + offset = -10.0 + scale = 255.0 / 55.0 + data = np.array([-10.0, 0.0, 20.0, 37.5, 45.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ + ScaleOffsetCodec(offset=offset, scale=scale), + CastValueCodec(data_type="uint8", out_of_range="clamp"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + # Precision limited by uint8 quantization (~0.22°C step) + np.testing.assert_allclose(result, data, atol=0.25) + + def test_nan_handling_pipeline(self) -> None: + """NaN values should be handled via scalar_map in cast_value.""" + store = MemoryStore() + data = np.array([1.0, float("nan"), 3.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + fill_value=float("nan"), + codecs=[ + ScaleOffsetCodec(offset=0, scale=1), + CastValueCodec( + data_type="uint8", + out_of_range="clamp", + scalar_map={ + "encode": [["NaN", 0]], + "decode": [[0, "NaN"]], + }, + ), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + assert result[0] == 1.0 + assert np.isnan(result[1]) + assert result[2] == 3.0 + + def test_metadata_persistence(self) -> None: + """Array metadata should be correctly persisted and reloaded.""" + store = MemoryStore() + arr = zarr.create( + store=store, + shape=(10,), + dtype="float64", + chunks=(10,), + codecs=[ + ScaleOffsetCodec(offset=5, scale=0.5), + CastValueCodec(data_type="int16", out_of_range="clamp"), + BytesCodec(), + ], + ) + # Reopen from same store + arr2 = zarr.open_array(store, mode="r") + assert arr2.dtype == np.dtype("float64") + assert arr2.shape == (10,) From 1487e59d7bb67e3d7f80d782ee7d73ae5965fe70 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 13 Mar 2026 16:46:31 +0100 Subject: [PATCH 2/4] rename codecs, improve types, and re-organize tests --- src/zarr/codecs/__init__.py | 12 +- src/zarr/codecs/cast_value.py | 30 ++- src/zarr/codecs/scale_offset.py | 63 +++-- ...cale_offset_cast.py => test_cast_value.py} | 224 +++--------------- tests/test_codecs/test_scale_offset.py | 166 +++++++++++++ 5 files changed, 265 insertions(+), 230 deletions(-) rename tests/test_codecs/{test_scale_offset_cast.py => test_cast_value.py} (63%) create mode 100644 tests/test_codecs/test_scale_offset.py diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 27ba6778da..04b31d0d5f 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -2,7 +2,7 @@ from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian -from zarr.codecs.cast_value import CastValueCodec +from zarr.codecs.cast_value import CastValue from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.gzip import GzipCodec from zarr.codecs.numcodecs import ( @@ -28,7 +28,7 @@ Zlib, Zstd, ) -from zarr.codecs.scale_offset import ScaleOffsetCodec +from zarr.codecs.scale_offset import ScaleOffset from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec @@ -40,11 +40,11 @@ "BloscCodec", "BloscShuffle", "BytesCodec", - "CastValueCodec", + "CastValue", "Crc32cCodec", "Endian", "GzipCodec", - "ScaleOffsetCodec", + "ScaleOffset", "ShardingCodec", "ShardingCodecIndexLocation", "TransposeCodec", @@ -65,8 +65,8 @@ register_codec("vlen-utf8", VLenUTF8Codec) register_codec("vlen-bytes", VLenBytesCodec) register_codec("transpose", TransposeCodec) -register_codec("scale_offset", ScaleOffsetCodec) -register_codec("cast_value", CastValueCodec) +register_codec("scale_offset", ScaleOffset) +register_codec("cast_value", CastValue) # Register all the codecs formerly contained in numcodecs.zarr3 diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index df0d13c079..d5eb0c5122 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -1,18 +1,18 @@ from __future__ import annotations from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict import numpy as np from zarr.abc.codec import ArrayArrayCodec -from zarr.core.array_spec import ArraySpec from zarr.core.common import JSON, parse_named_configuration from zarr.core.dtype import get_data_type_from_json if TYPE_CHECKING: from typing import Self + from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -27,7 +27,11 @@ OutOfRangeMode = Literal["clamp", "wrap"] -ScalarMapJSON = dict[str, list[list[JSON]]] + +class ScalarMapJSON(TypedDict): + encode: NotRequired[tuple[tuple[object, object]]] + decode: NotRequired[tuple[tuple[object, object]]] + # Pre-parsed scalar map entry: (source_float, target_float, source_is_nan) _MapEntry = tuple[float, float, bool] @@ -123,8 +127,12 @@ def _cast_array( range_size = hi - lo + 1 return ((arr.astype(np.int64) - lo) % range_size + lo).astype(target_dtype) else: + oor_vals = arr[(arr < lo) | (arr > hi)] raise ValueError( - f"Values out of range for {target_dtype} and no out_of_range policy set" + f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " + f"got values in [{arr_min}, {arr_max}]. " + f"Out-of-range values: {oor_vals.ravel()!r}. " + f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." ) return arr.astype(target_dtype) @@ -158,8 +166,12 @@ def _cast_array( info.min ) elif (work.min() < lo) or (work.max() > hi): + oor_vals = work[(work < lo) | (work > hi)] raise ValueError( - f"Values out of range for {target_dtype} and no out_of_range policy set" + f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " + f"got values in [{work.min()}, {work.max()}]. " + f"Out-of-range values: {oor_vals.ravel()!r}. " + f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." ) return work.astype(target_dtype) @@ -191,8 +203,12 @@ def _cast_array( oor = (work < lo) | (work > hi) work[oor] = (work[oor] - lo) % range_size + lo else: + oor_vals = work[(work < lo) | (work > hi)] raise ValueError( - f"Values out of range for {target_dtype} and no out_of_range policy set" + f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " + f"got values in [{w_min}, {w_max}]. " + f"Out-of-range values: {oor_vals.ravel()!r}. " + f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." ) return work.astype(target_dtype) @@ -222,7 +238,7 @@ def _parse_scalar_map( @dataclass(frozen=True) -class CastValueCodec(ArrayArrayCodec): +class CastValue(ArrayArrayCodec): """Cast-value array-to-array codec. Value-converts array elements to a new data type during encoding, diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py index f4cd95ed52..55862137ae 100644 --- a/src/zarr/codecs/scale_offset.py +++ b/src/zarr/codecs/scale_offset.py @@ -1,24 +1,37 @@ from __future__ import annotations -from dataclasses import dataclass, replace -from typing import TYPE_CHECKING +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Literal, NotRequired import numpy as np +from typing_extensions import TypedDict from zarr.abc.codec import ArrayArrayCodec -from zarr.core.array_spec import ArraySpec -from zarr.core.common import JSON, parse_named_configuration +from zarr.core.common import JSON, NamedConfig if TYPE_CHECKING: from typing import Self + from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType -@dataclass(frozen=True) -class ScaleOffsetCodec(ArrayArrayCodec): +class ScaleOffsetConfig(TypedDict, closed=True): # type: ignore[call-arg] + scale: NotRequired[JSON] + offset: NotRequired[JSON] + + +ScaleOffsetName = Literal["scale_offset"] + + +class ScaleOffsetJSON(NamedConfig[ScaleOffsetName, ScaleOffsetConfig]): + """The JSON form(s) of the `scale_offset` codec""" + + +@dataclass(kw_only=True, frozen=True) +class ScaleOffset(ArrayArrayCodec): """Scale-offset array-to-array codec. Encodes values by subtracting an offset and multiplying by a scale factor. @@ -34,27 +47,21 @@ class ScaleOffsetCodec(ArrayArrayCodec): Value multiplied during encoding (after offset subtraction). Default is 1. """ - is_fixed_size = True - - offset: float - scale: float + is_fixed_size: bool = field(default=True, init=False) - def __init__(self, *, offset: float = 0, scale: float = 1) -> None: - object.__setattr__(self, "offset", offset) - object.__setattr__(self, "scale", scale) + offset: float = 0 + scale: float = 1 @classmethod - def from_dict(cls, data: dict[str, JSON]) -> Self: - _, configuration_parsed = parse_named_configuration( - data, "scale_offset", require_configuration=False - ) - configuration_parsed = configuration_parsed or {} - return cls(**configuration_parsed) # type: ignore[arg-type] - - def to_dict(self) -> dict[str, JSON]: + def from_dict(cls, data: ScaleOffsetJSON) -> Self: # type: ignore[override] + scale: float = data.get("configuration", {}).get("scale", 1) # type: ignore[assignment] + offset: float = data.get("configuration", {}).get("offset", 0) # type: ignore[assignment] + return cls(scale=scale, offset=offset) + + def to_dict(self) -> ScaleOffsetJSON: # type: ignore[override] if self.offset == 0 and self.scale == 1: return {"name": "scale_offset"} - config: dict[str, JSON] = {} + config: ScaleOffsetConfig = {} # if self.offset != 0: config["offset"] = self.offset if self.scale != 1: @@ -76,6 +83,10 @@ def validate( ) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + """ + Define the effect of this codec on the spec for an array. The only change is to update + the output fill value by applying the scale + offset transformation. + """ native_dtype = chunk_spec.dtype.to_native_dtype() fill = chunk_spec.fill_value new_fill = ( @@ -94,10 +105,10 @@ def _encode_sync( async def _encode_single( self, - chunk_array: NDBuffer, + chunk_data: NDBuffer, chunk_spec: ArraySpec, ) -> NDBuffer | None: - return self._encode_sync(chunk_array, chunk_spec) + return self._encode_sync(chunk_data, chunk_spec) def _decode_sync( self, @@ -110,10 +121,10 @@ def _decode_sync( async def _decode_single( self, - chunk_array: NDBuffer, + chunk_data: NDBuffer, chunk_spec: ArraySpec, ) -> NDBuffer: - return self._decode_sync(chunk_array, chunk_spec) + return self._decode_sync(chunk_data, chunk_spec) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/tests/test_codecs/test_scale_offset_cast.py b/tests/test_codecs/test_cast_value.py similarity index 63% rename from tests/test_codecs/test_scale_offset_cast.py rename to tests/test_codecs/test_cast_value.py index f4a842ed87..16b7bd3859 100644 --- a/tests/test_codecs/test_scale_offset_cast.py +++ b/tests/test_codecs/test_cast_value.py @@ -1,4 +1,4 @@ -"""Tests for scale_offset and cast_value codecs.""" +"""Tests for the cast_value codec.""" from __future__ import annotations @@ -7,166 +7,11 @@ import zarr from zarr.codecs.bytes import BytesCodec -from zarr.codecs.cast_value import CastValueCodec -from zarr.codecs.scale_offset import ScaleOffsetCodec +from zarr.codecs.cast_value import CastValue +from zarr.codecs.scale_offset import ScaleOffset from zarr.storage import MemoryStore -class TestScaleOffsetCodec: - """Tests for the scale_offset codec.""" - - def test_identity(self) -> None: - """Default parameters (offset=0, scale=1) should be a no-op.""" - store = MemoryStore() - data = np.arange(20, dtype="float64").reshape(4, 5) - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(4, 5), - codecs=[ScaleOffsetCodec(), BytesCodec()], - ) - arr[:] = data - np.testing.assert_array_equal(arr[:], data) - - def test_encode_decode_float64(self) -> None: - """Encode/decode round-trip with float64 data.""" - store = MemoryStore() - data = np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype="float64") - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(5,), - codecs=[ScaleOffsetCodec(offset=10, scale=0.1), BytesCodec()], - ) - arr[:] = data - result = arr[:] - np.testing.assert_allclose(result, data, rtol=1e-10) - - def test_encode_decode_float32(self) -> None: - """Round-trip with float32 data.""" - store = MemoryStore() - data = np.array([1.0, 2.0, 3.0], dtype="float32") - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(3,), - codecs=[ScaleOffsetCodec(offset=1, scale=2), BytesCodec()], - ) - arr[:] = data - result = arr[:] - np.testing.assert_allclose(result, data, rtol=1e-6) - - def test_encode_decode_integer(self) -> None: - """Round-trip with integer data (uses integer arithmetic semantics).""" - store = MemoryStore() - data = np.array([10, 20, 30, 40, 50], dtype="int32") - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(5,), - codecs=[ScaleOffsetCodec(offset=10, scale=1), BytesCodec()], - ) - arr[:] = data - result = arr[:] - np.testing.assert_array_equal(result, data) - - def test_offset_only(self) -> None: - """Test with only offset (scale=1).""" - store = MemoryStore() - data = np.array([100.0, 200.0, 300.0], dtype="float64") - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(3,), - codecs=[ScaleOffsetCodec(offset=100), BytesCodec()], - ) - arr[:] = data - np.testing.assert_allclose(arr[:], data) - - def test_scale_only(self) -> None: - """Test with only scale (offset=0).""" - store = MemoryStore() - data = np.array([1.0, 2.0, 3.0], dtype="float64") - arr = zarr.create( - store=store, - shape=data.shape, - dtype=data.dtype, - chunks=(3,), - codecs=[ScaleOffsetCodec(scale=10), BytesCodec()], - ) - arr[:] = data - np.testing.assert_allclose(arr[:], data) - - def test_fill_value_transformed(self) -> None: - """Fill value should be transformed through the codec.""" - store = MemoryStore() - arr = zarr.create( - store=store, - shape=(5,), - dtype="float64", - chunks=(5,), - fill_value=100.0, - codecs=[ScaleOffsetCodec(offset=10, scale=2), BytesCodec()], - ) - # Without writing, reading should return the fill value - result = arr[:] - np.testing.assert_allclose(result, np.full(5, 100.0)) - - def test_validate_rejects_complex(self) -> None: - """Validate should reject complex dtypes.""" - with pytest.raises(ValueError, match="only supports integer and floating-point"): - zarr.create( - store=MemoryStore(), - shape=(5,), - dtype="complex128", - chunks=(5,), - codecs=[ScaleOffsetCodec(offset=1, scale=2), BytesCodec()], - ) - - def test_to_dict_no_config(self) -> None: - """Default codec should serialize without configuration.""" - codec = ScaleOffsetCodec() - assert codec.to_dict() == {"name": "scale_offset"} - - def test_to_dict_with_config(self) -> None: - """Non-default codec should include configuration.""" - codec = ScaleOffsetCodec(offset=5, scale=0.1) - d = codec.to_dict() - assert d == {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} - - def test_to_dict_offset_only(self) -> None: - """Only offset in config when scale is default.""" - codec = ScaleOffsetCodec(offset=5) - d = codec.to_dict() - assert d == {"name": "scale_offset", "configuration": {"offset": 5}} - - def test_from_dict_no_config(self) -> None: - """Parse codec from JSON with no configuration.""" - codec = ScaleOffsetCodec.from_dict({"name": "scale_offset"}) - assert codec.offset == 0 - assert codec.scale == 1 - - def test_from_dict_with_config(self) -> None: - """Parse codec from JSON with configuration.""" - codec = ScaleOffsetCodec.from_dict( - {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} - ) - assert codec.offset == 5 - assert codec.scale == 0.1 - - def test_roundtrip_json(self) -> None: - """to_dict -> from_dict should preserve parameters.""" - original = ScaleOffsetCodec(offset=3.14, scale=2.71) - restored = ScaleOffsetCodec.from_dict(original.to_dict()) - assert restored.offset == original.offset - assert restored.scale == original.scale - - class TestCastValueCodec: """Tests for the cast_value codec.""" @@ -179,7 +24,7 @@ def test_float64_to_float32(self) -> None: shape=data.shape, dtype=data.dtype, chunks=(3,), - codecs=[CastValueCodec(data_type="float32"), BytesCodec()], + codecs=[CastValue(data_type="float32"), BytesCodec()], ) arr[:] = data result = arr[:] @@ -194,7 +39,7 @@ def test_float64_to_int32_towards_zero(self) -> None: shape=data.shape, dtype=data.dtype, chunks=(4,), - codecs=[CastValueCodec(data_type="int32", rounding="towards-zero"), BytesCodec()], + codecs=[CastValue(data_type="int32", rounding="towards-zero"), BytesCodec()], ) arr[:] = data result = arr[:] @@ -212,7 +57,7 @@ def test_float64_to_uint8_clamp(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="uint8", rounding="nearest-even", out_of_range="clamp"), + CastValue(data_type="uint8", rounding="nearest-even", out_of_range="clamp"), BytesCodec(), ], ) @@ -230,16 +75,13 @@ def test_float64_to_int8_wrap(self) -> None: dtype=data.dtype, chunks=(1,), codecs=[ - CastValueCodec(data_type="int8", rounding="nearest-even", out_of_range="wrap"), + CastValue(data_type="int8", rounding="nearest-even", out_of_range="wrap"), BytesCodec(), ], ) arr[:] = data result = arr[:] # 200 wraps in int8 range [-128, 127]: (200 - (-128)) % 256 + (-128) = 328 % 256 - 128 = 72 - 128 = -56 - expected = np.array([200], dtype="float64") - expected_arr = np.array([200], dtype="float64") - # Encode: round(200) = 200, wrap: (200+128)%256-128 = 328%256-128 = 72-128 = -56 # Decode: -56 cast back to float64 = -56.0 np.testing.assert_array_equal(result, [-56.0]) @@ -252,7 +94,7 @@ def test_nan_to_integer_without_scalar_map_errors(self) -> None: shape=data.shape, dtype=data.dtype, chunks=(1,), - codecs=[CastValueCodec(data_type="uint8", out_of_range="clamp"), BytesCodec()], + codecs=[CastValue(data_type="uint8", out_of_range="clamp"), BytesCodec()], ) with pytest.raises(ValueError, match="Cannot cast NaN"): arr[:] = data @@ -267,7 +109,7 @@ def test_nan_scalar_map(self) -> None: dtype=data.dtype, chunks=(3,), codecs=[ - CastValueCodec( + CastValue( data_type="uint8", out_of_range="clamp", scalar_map={ @@ -294,7 +136,7 @@ def test_rounding_nearest_even(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="int32", rounding="nearest-even"), + CastValue(data_type="int32", rounding="nearest-even"), BytesCodec(), ], ) @@ -312,7 +154,7 @@ def test_rounding_towards_positive(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="int32", rounding="towards-positive"), + CastValue(data_type="int32", rounding="towards-positive"), BytesCodec(), ], ) @@ -330,7 +172,7 @@ def test_rounding_towards_negative(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="int32", rounding="towards-negative"), + CastValue(data_type="int32", rounding="towards-negative"), BytesCodec(), ], ) @@ -348,7 +190,7 @@ def test_rounding_nearest_away(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="int32", rounding="nearest-away"), + CastValue(data_type="int32", rounding="nearest-away"), BytesCodec(), ], ) @@ -365,7 +207,7 @@ def test_out_of_range_errors_by_default(self) -> None: shape=data.shape, dtype=data.dtype, chunks=(1,), - codecs=[CastValueCodec(data_type="uint8"), BytesCodec()], + codecs=[CastValue(data_type="uint8"), BytesCodec()], ) with pytest.raises(ValueError, match="out of range"): arr[:] = data @@ -379,7 +221,7 @@ def test_wrap_only_valid_for_integers(self) -> None: dtype="float64", chunks=(5,), codecs=[ - CastValueCodec(data_type="float32", out_of_range="wrap"), + CastValue(data_type="float32", out_of_range="wrap"), BytesCodec(), ], ) @@ -392,7 +234,7 @@ def test_validate_rejects_complex_source(self) -> None: shape=(5,), dtype="complex128", chunks=(5,), - codecs=[CastValueCodec(data_type="float64"), BytesCodec()], + codecs=[CastValue(data_type="float64"), BytesCodec()], ) def test_int32_to_int16_clamp(self) -> None: @@ -405,7 +247,7 @@ def test_int32_to_int16_clamp(self) -> None: dtype=data.dtype, chunks=(4,), codecs=[ - CastValueCodec(data_type="int16", out_of_range="clamp"), + CastValue(data_type="int16", out_of_range="clamp"), BytesCodec(), ], ) @@ -415,7 +257,7 @@ def test_int32_to_int16_clamp(self) -> None: def test_to_dict(self) -> None: """Serialization to dict.""" - codec = CastValueCodec( + codec = CastValue( data_type="uint8", rounding="towards-zero", out_of_range="clamp", @@ -433,13 +275,13 @@ def test_to_dict(self) -> None: def test_to_dict_minimal(self) -> None: """Only required fields in dict when defaults are used.""" - codec = CastValueCodec(data_type="float32") + codec = CastValue(data_type="float32") d = codec.to_dict() assert d == {"name": "cast_value", "configuration": {"data_type": "float32"}} def test_from_dict(self) -> None: """Deserialization from dict.""" - codec = CastValueCodec.from_dict( + codec = CastValue.from_dict( { "name": "cast_value", "configuration": { @@ -455,13 +297,13 @@ def test_from_dict(self) -> None: def test_roundtrip_json(self) -> None: """to_dict -> from_dict should preserve all parameters.""" - original = CastValueCodec( + original = CastValue( data_type="int16", rounding="towards-negative", out_of_range="clamp", scalar_map={"encode": [["NaN", 0]]}, ) - restored = CastValueCodec.from_dict(original.to_dict()) + restored = CastValue.from_dict(original.to_dict()) assert restored.data_type == original.data_type assert restored.rounding == original.rounding assert restored.out_of_range == original.out_of_range @@ -476,14 +318,14 @@ def test_fill_value_cast(self) -> None: dtype="float64", chunks=(5,), fill_value=42.0, - codecs=[CastValueCodec(data_type="int32"), BytesCodec()], + codecs=[CastValue(data_type="int32"), BytesCodec()], ) result = arr[:] np.testing.assert_array_equal(result, np.full(5, 42.0)) def test_computed_encoded_size(self) -> None: """Encoded size should reflect the target dtype's item size.""" - codec = CastValueCodec(data_type="uint8") + codec = CastValue(data_type="uint8") from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer.cpu import buffer_prototype from zarr.core.dtype import parse_dtype @@ -513,8 +355,8 @@ def test_float64_to_uint8_roundtrip(self) -> None: dtype=data.dtype, chunks=(5,), codecs=[ - ScaleOffsetCodec(offset=0, scale=10), - CastValueCodec(data_type="uint8", out_of_range="clamp"), + ScaleOffset(offset=0, scale=10), + CastValue(data_type="uint8", out_of_range="clamp"), BytesCodec(), ], ) @@ -539,8 +381,8 @@ def test_temperature_storage_pattern(self) -> None: dtype=data.dtype, chunks=(5,), codecs=[ - ScaleOffsetCodec(offset=offset, scale=scale), - CastValueCodec(data_type="uint8", out_of_range="clamp"), + ScaleOffset(offset=offset, scale=scale), + CastValue(data_type="uint8", out_of_range="clamp"), BytesCodec(), ], ) @@ -560,8 +402,8 @@ def test_nan_handling_pipeline(self) -> None: chunks=(3,), fill_value=float("nan"), codecs=[ - ScaleOffsetCodec(offset=0, scale=1), - CastValueCodec( + ScaleOffset(offset=0, scale=1), + CastValue( data_type="uint8", out_of_range="clamp", scalar_map={ @@ -581,14 +423,14 @@ def test_nan_handling_pipeline(self) -> None: def test_metadata_persistence(self) -> None: """Array metadata should be correctly persisted and reloaded.""" store = MemoryStore() - arr = zarr.create( + zarr.create( store=store, shape=(10,), dtype="float64", chunks=(10,), codecs=[ - ScaleOffsetCodec(offset=5, scale=0.5), - CastValueCodec(data_type="int16", out_of_range="clamp"), + ScaleOffset(offset=5, scale=0.5), + CastValue(data_type="int16", out_of_range="clamp"), BytesCodec(), ], ) diff --git a/tests/test_codecs/test_scale_offset.py b/tests/test_codecs/test_scale_offset.py new file mode 100644 index 0000000000..ad131c6493 --- /dev/null +++ b/tests/test_codecs/test_scale_offset.py @@ -0,0 +1,166 @@ +"""Tests for the scale_offset codec.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import zarr +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.scale_offset import ScaleOffset +from zarr.storage import MemoryStore + + +class TestScaleOffsetCodec: + """Tests for the scale_offset codec.""" + + def test_identity(self) -> None: + """Default parameters (offset=0, scale=1) should be a no-op.""" + store = MemoryStore() + data = np.arange(20, dtype="float64").reshape(4, 5) + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(4, 5), + codecs=[ScaleOffset(), BytesCodec()], + ) + arr[:] = data + np.testing.assert_array_equal(arr[:], data) + + def test_encode_decode_float64(self) -> None: + """Encode/decode round-trip with float64 data.""" + store = MemoryStore() + data = np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ScaleOffset(offset=10, scale=0.1), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data, rtol=1e-10) + + def test_encode_decode_float32(self) -> None: + """Round-trip with float32 data.""" + store = MemoryStore() + data = np.array([1.0, 2.0, 3.0], dtype="float32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffset(offset=1, scale=2), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_allclose(result, data, rtol=1e-6) + + def test_encode_decode_integer(self) -> None: + """Round-trip with integer data (uses integer arithmetic semantics).""" + store = MemoryStore() + data = np.array([10, 20, 30, 40, 50], dtype="int32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(5,), + codecs=[ScaleOffset(offset=10, scale=1), BytesCodec()], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, data) + + def test_offset_only(self) -> None: + """Test with only offset (scale=1).""" + store = MemoryStore() + data = np.array([100.0, 200.0, 300.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffset(offset=100), BytesCodec()], + ) + arr[:] = data + np.testing.assert_allclose(arr[:], data) + + def test_scale_only(self) -> None: + """Test with only scale (offset=0).""" + store = MemoryStore() + data = np.array([1.0, 2.0, 3.0], dtype="float64") + arr = zarr.create( + store=store, + shape=data.shape, + dtype=data.dtype, + chunks=(3,), + codecs=[ScaleOffset(scale=10), BytesCodec()], + ) + arr[:] = data + np.testing.assert_allclose(arr[:], data) + + def test_fill_value_transformed(self) -> None: + """Fill value should be transformed through the codec.""" + store = MemoryStore() + arr = zarr.create( + store=store, + shape=(5,), + dtype="float64", + chunks=(5,), + fill_value=100.0, + codecs=[ScaleOffset(offset=10, scale=2), BytesCodec()], + ) + # Without writing, reading should return the fill value + result = arr[:] + np.testing.assert_allclose(result, np.full(5, 100.0)) + + def test_validate_rejects_complex(self) -> None: + """Validate should reject complex dtypes.""" + with pytest.raises(ValueError, match="only supports integer and floating-point"): + zarr.create( + store=MemoryStore(), + shape=(5,), + dtype="complex128", + chunks=(5,), + codecs=[ScaleOffset(offset=1, scale=2), BytesCodec()], + ) + + def test_to_dict_no_config(self) -> None: + """Default codec should serialize without configuration.""" + codec = ScaleOffset() + assert codec.to_dict() == {"name": "scale_offset"} + + def test_to_dict_with_config(self) -> None: + """Non-default codec should include configuration.""" + codec = ScaleOffset(offset=5, scale=0.1) + d = codec.to_dict() + assert d == {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} + + def test_to_dict_offset_only(self) -> None: + """Only offset in config when scale is default.""" + codec = ScaleOffset(offset=5) + d = codec.to_dict() + assert d == {"name": "scale_offset", "configuration": {"offset": 5}} + + def test_from_dict_no_config(self) -> None: + """Parse codec from JSON with no configuration.""" + codec = ScaleOffset.from_dict({"name": "scale_offset"}) + assert codec.offset == 0 + assert codec.scale == 1 + + def test_from_dict_with_config(self) -> None: + """Parse codec from JSON with configuration.""" + codec = ScaleOffset.from_dict( + {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} + ) + assert codec.offset == 5 + assert codec.scale == 0.1 + + def test_roundtrip_json(self) -> None: + """to_dict -> from_dict should preserve parameters.""" + original = ScaleOffset(offset=3.14, scale=2.71) + restored = ScaleOffset.from_dict(original.to_dict()) + assert restored.offset == original.offset + assert restored.scale == original.scale From 407ae82998cc173878fea34d2325c8a2f16808df Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 16 Mar 2026 13:17:20 +0100 Subject: [PATCH 3/4] code quality cleanup, handle lossy float -> int conversion, remove stray chunk grids file --- src/zarr/codecs/cast_value.py | 382 ++++++----- src/zarr/codecs/scale_offset.py | 15 +- src/zarr/core/chunk_grids/rectilinear.py | 802 ----------------------- tests/test_codecs/test_cast_value.py | 128 +++- tests/test_codecs/test_scale_offset.py | 16 +- 5 files changed, 329 insertions(+), 1014 deletions(-) delete mode 100644 src/zarr/core/chunk_grids/rectilinear.py diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index d5eb0c5122..ee1b2afbef 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict +from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypeAlias, TypedDict, cast import numpy as np @@ -17,6 +17,8 @@ from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType +NumericScalar: TypeAlias = np.integer[Any] | np.floating[Any] + RoundingMode = Literal[ "nearest-even", "towards-zero", @@ -33,208 +35,189 @@ class ScalarMapJSON(TypedDict): decode: NotRequired[tuple[tuple[object, object]]] -# Pre-parsed scalar map entry: (source_float, target_float, source_is_nan) -_MapEntry = tuple[float, float, bool] - +# Pre-parsed scalar map entry: (source_scalar, target_scalar) +_MapEntry = tuple[NumericScalar, NumericScalar] -def _special_float(s: str) -> float: - """Convert special float string representations to float values.""" - if s == "NaN": - return float("nan") - if s in ("+Infinity", "Infinity"): - return float("inf") - if s == "-Infinity": - return float("-inf") - return float(s) +def _parse_map_entries( + mapping: dict[str, str], + src_dtype: ZDType[TBaseDType, TBaseScalar], + tgt_dtype: ZDType[TBaseDType, TBaseScalar], +) -> list[_MapEntry]: + """Pre-parse a scalar map dict into a list of (src, tgt) tuples. -def _parse_map_entries(mapping: dict[str, str]) -> list[_MapEntry]: - """Pre-parse a scalar map dict into a list of (src, tgt, src_is_nan) tuples.""" + Each entry's source value is deserialized using ``src_dtype`` and its target + value using ``tgt_dtype``, preserving full precision for both data types. + """ entries: list[_MapEntry] = [] for src_str, tgt_str in mapping.items(): - src = _special_float(src_str) - tgt = _special_float(tgt_str) - entries.append((src, tgt, np.isnan(src))) + src = src_dtype.from_json_scalar(src_str, zarr_format=3) + tgt = tgt_dtype.from_json_scalar(tgt_str, zarr_format=3) + entries.append((src, tgt)) # type: ignore[arg-type] return entries -def _apply_scalar_map(work: np.ndarray, entries: list[_MapEntry]) -> None: +def _apply_scalar_map(work: np.ndarray[Any, np.dtype[Any]], entries: list[_MapEntry]) -> None: """Apply scalar map entries in-place. Single pass per entry.""" - for src, tgt, src_is_nan in entries: - if src_is_nan: + for src, tgt in entries: + if isinstance(src, (float, np.floating)) and np.isnan(src): mask = np.isnan(work) else: mask = work == src work[mask] = tgt -def _round_inplace(arr: np.ndarray, mode: RoundingMode) -> np.ndarray: +def _round_inplace( + arr: np.ndarray[Any, np.dtype[Any]], mode: RoundingMode +) -> np.ndarray[Any, np.dtype[Any]]: """Round array, returning result (may or may not be a new array). For nearest-away, requires 3 numpy ops. All others are a single op. """ - if mode == "nearest-even": - return np.rint(arr) - elif mode == "towards-zero": - return np.trunc(arr) - elif mode == "towards-positive": - return np.ceil(arr) - elif mode == "towards-negative": - return np.floor(arr) - elif mode == "nearest-away": - return np.sign(arr) * np.floor(np.abs(arr) + 0.5) + match mode: + case "nearest-even": + return np.rint(arr) # type: ignore [no-any-return] + case "towards-zero": + return np.trunc(arr) # type: ignore [no-any-return] + case "towards-positive": + return np.ceil(arr) # type: ignore [no-any-return] + case "towards-negative": + return np.floor(arr) # type: ignore [no-any-return] + case "nearest-away": + return np.sign(arr) * np.floor(np.abs(arr) + 0.5) # type: ignore [no-any-return] raise ValueError(f"Unknown rounding mode: {mode}") def _cast_array( - arr: np.ndarray, - target_dtype: np.dtype, - rounding: RoundingMode, - out_of_range: OutOfRangeMode | None, + arr: np.ndarray[Any, np.dtype[Any]], + *, + target_dtype: np.dtype[Any], + rounding_mode: RoundingMode, + out_of_range_mode: OutOfRangeMode | None, scalar_map_entries: list[_MapEntry] | None, -) -> np.ndarray: +) -> np.ndarray[Any, np.dtype[Any]]: """Cast an array to target_dtype with rounding, out-of-range, and scalar_map handling. Optimized to minimize allocations and passes over the data. For the simple case (no scalar_map, no rounding needed, no out-of-range), this is essentially just ``arr.astype(target_dtype)``. + + All casts are performed under ``np.errstate(over='raise', invalid='raise')`` + so that numpy overflow or invalid-value warnings become hard errors instead + of being silently swallowed. """ - src_is_int = np.issubdtype(arr.dtype, np.integer) - src_is_float = np.issubdtype(arr.dtype, np.floating) - tgt_is_int = np.issubdtype(target_dtype, np.integer) - tgt_is_float = np.issubdtype(target_dtype, np.floating) - - # Fast path: float→float with no scalar_map — single astype - if src_is_float and tgt_is_float and not scalar_map_entries: - return arr.astype(target_dtype) - - # Fast path: int→float with no scalar_map — single astype - if src_is_int and tgt_is_float and not scalar_map_entries: - return arr.astype(target_dtype) - - # Fast path: int→int with no scalar_map — check range then astype - if src_is_int and tgt_is_int and not scalar_map_entries: - # Check if source range could exceed target range - if arr.dtype.itemsize > target_dtype.itemsize or arr.dtype != target_dtype: - info = np.iinfo(target_dtype) - lo, hi = int(info.min), int(info.max) - arr_min, arr_max = int(arr.min()), int(arr.max()) - if arr_min >= lo and arr_max <= hi: - return arr.astype(target_dtype) - if out_of_range == "clamp": - return np.clip(arr, lo, hi).astype(target_dtype) - elif out_of_range == "wrap": - range_size = hi - lo + 1 - return ((arr.astype(np.int64) - lo) % range_size + lo).astype(target_dtype) - else: - oor_vals = arr[(arr < lo) | (arr > hi)] - raise ValueError( - f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " - f"got values in [{arr_min}, {arr_max}]. " - f"Out-of-range values: {oor_vals.ravel()!r}. " - f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." - ) - return arr.astype(target_dtype) + with np.errstate(over="raise", invalid="raise"): + return _cast_array_impl( + arr, + target_dtype=target_dtype, + rounding=rounding_mode, + out_of_range=out_of_range_mode, + scalar_map_entries=scalar_map_entries, + ) - # float→int: needs rounding, range check, possibly scalar_map - if src_is_float and tgt_is_int: - # Work in float64 for the arithmetic - if arr.dtype != np.float64: - work = arr.astype(np.float64) - else: - work = arr.copy() - if scalar_map_entries: - _apply_scalar_map(work, scalar_map_entries) - - # Check for unmapped NaN/Inf - bad = np.isnan(work) | np.isinf(work) - if bad.any(): - raise ValueError("Cannot cast NaN or Infinity to integer type without scalar_map") - - work = _round_inplace(work, rounding) - - info = np.iinfo(target_dtype) - lo, hi = float(info.min), float(info.max) - if out_of_range == "clamp": - np.clip(work, lo, hi, out=work) - elif out_of_range == "wrap": - range_size = int(info.max) - int(info.min) + 1 - oor = (work < lo) | (work > hi) - if oor.any(): - work[oor] = (work[oor].astype(np.int64) - int(info.min)) % range_size + int( - info.min - ) - elif (work.min() < lo) or (work.max() > hi): +def _check_int_range( + work: np.ndarray[Any, np.dtype[Any]], + *, + target_dtype: np.dtype[Any], + out_of_range: OutOfRangeMode | None, +) -> np.ndarray[Any, np.dtype[Any]]: + """Check integer range and apply out-of-range handling, then cast.""" + info = np.iinfo(target_dtype) + lo, hi = int(info.min), int(info.max) + w_min, w_max = int(work.min()), int(work.max()) + if w_min >= lo and w_max <= hi: + return work.astype(target_dtype) + match out_of_range: + case "clamp": + return np.clip(work, lo, hi).astype(target_dtype) + case "wrap": + range_size = hi - lo + 1 + return ((work.astype(np.int64) - lo) % range_size + lo).astype(target_dtype) + case None: oor_vals = work[(work < lo) | (work > hi)] raise ValueError( f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " - f"got values in [{work.min()}, {work.max()}]. " + f"got values in [{w_min}, {w_max}]. " f"Out-of-range values: {oor_vals.ravel()!r}. " f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." ) - return work.astype(target_dtype) - # int→float with scalar_map - if src_is_int and tgt_is_float and scalar_map_entries: - work = arr.astype(np.float64) - _apply_scalar_map(work, scalar_map_entries) - return work.astype(target_dtype) +def _cast_array_impl( + arr: np.ndarray[Any, np.dtype[Any]], + *, + target_dtype: np.dtype[Any], + rounding: RoundingMode, + out_of_range: OutOfRangeMode | None, + scalar_map_entries: list[_MapEntry] | None, +) -> np.ndarray[Any, np.dtype[Any]]: + src_type: Literal["int", "float"] = "int" if np.issubdtype(arr.dtype, np.integer) else "float" + tgt_type: Literal["int", "float"] = ( + "int" if np.issubdtype(target_dtype, np.integer) else "float" + ) + has_map = bool(scalar_map_entries) - # float→float with scalar_map - if src_is_float and tgt_is_float and scalar_map_entries: - work = arr.copy() - _apply_scalar_map(work, scalar_map_entries) - return work.astype(target_dtype) + match (src_type, tgt_type, has_map): + # float→float or int→float without scalar_map — single astype + case (_, "float", False): + return arr.astype(target_dtype) + + # int→float with scalar_map — widen to float64, apply map, cast + case ("int", "float", True): + work = arr.astype(np.float64) + _apply_scalar_map(work, scalar_map_entries) # type: ignore[arg-type] + return work.astype(target_dtype) - # int→int with scalar_map - if src_is_int and tgt_is_int and scalar_map_entries: - work = arr.astype(np.int64) - _apply_scalar_map(work, scalar_map_entries) - info = np.iinfo(target_dtype) - lo, hi = int(info.min), int(info.max) - w_min, w_max = int(work.min()), int(work.max()) - if w_min < lo or w_max > hi: - if out_of_range == "clamp": - np.clip(work, lo, hi, out=work) - elif out_of_range == "wrap": - range_size = hi - lo + 1 - oor = (work < lo) | (work > hi) - work[oor] = (work[oor] - lo) % range_size + lo + # float→float with scalar_map — copy, apply map, cast + case ("float", "float", True): + work = arr.copy() + _apply_scalar_map(work, scalar_map_entries) # type: ignore[arg-type] + return work.astype(target_dtype) + + # int→int without scalar_map — range check then astype + case ("int", "int", False): + if arr.dtype.itemsize > target_dtype.itemsize or arr.dtype != target_dtype: + return _check_int_range(arr, target_dtype=target_dtype, out_of_range=out_of_range) + return arr.astype(target_dtype) + + # int→int with scalar_map — widen to int64, apply map, range check + case ("int", "int", True): + work = arr.astype(np.int64) + _apply_scalar_map(work, scalar_map_entries) # type: ignore[arg-type] + return _check_int_range(work, target_dtype=target_dtype, out_of_range=out_of_range) + + # float→int (with or without scalar_map) — rounding + range check + case ("float", "int", _): + if arr.dtype != np.float64: + work = arr.astype(np.float64) else: - oor_vals = work[(work < lo) | (work > hi)] - raise ValueError( - f"Values out of range for {target_dtype} (valid range: [{lo}, {hi}]), " - f"got values in [{w_min}, {w_max}]. " - f"Out-of-range values: {oor_vals.ravel()!r}. " - f"Set out_of_range='clamp' or out_of_range='wrap' to handle this." - ) - return work.astype(target_dtype) + work = arr.copy() - # Fallback - return arr.astype(target_dtype) + if scalar_map_entries: + _apply_scalar_map(work, scalar_map_entries) + bad = np.isnan(work) | np.isinf(work) + if bad.any(): + raise ValueError("Cannot cast NaN or Infinity to integer type without scalar_map") -def _parse_scalar_map( - data: ScalarMapJSON | None, -) -> tuple[list[_MapEntry] | None, list[_MapEntry] | None]: - """Parse scalar_map JSON into pre-parsed encode and decode entry lists. + work = _round_inplace(work, rounding) + return _check_int_range(work, target_dtype=target_dtype, out_of_range=out_of_range) - Returns (encode_entries, decode_entries). Either may be None. - """ + raise AssertionError( + f"Unhandled type combination: src={src_type}, tgt={tgt_type}" + ) # pragma: no cover + + +def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None: + """Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'.""" if data is None: - return None, None - encode_raw: dict[str, str] = {} - decode_raw: dict[str, str] = {} - for src, tgt in data.get("encode", []): - encode_raw[str(src)] = str(tgt) - for src, tgt in data.get("decode", []): - decode_raw[str(src)] = str(tgt) - return ( - _parse_map_entries(encode_raw) if encode_raw else None, - _parse_map_entries(decode_raw) if decode_raw else None, - ) + return None + raw: dict[str, str] = {} + pairs = data.get(direction, []) + for src, tgt in pairs: # type: ignore[attr-defined] + raw[str(src)] = str(tgt) + return raw or None @dataclass(frozen=True) @@ -260,7 +243,7 @@ class CastValue(ArrayArrayCodec): is_fixed_size = True - data_type: str + dtype: ZDType[TBaseDType, TBaseScalar] rounding: RoundingMode out_of_range: OutOfRangeMode | None scalar_map: ScalarMapJSON | None @@ -268,12 +251,16 @@ class CastValue(ArrayArrayCodec): def __init__( self, *, - data_type: str, + data_type: str | ZDType[TBaseDType, TBaseScalar], rounding: RoundingMode = "nearest-even", out_of_range: OutOfRangeMode | None = None, scalar_map: ScalarMapJSON | None = None, ) -> None: - object.__setattr__(self, "data_type", data_type) + if isinstance(data_type, str): + dtype = get_data_type_from_json(data_type, zarr_format=3) + else: + dtype = data_type + object.__setattr__(self, "dtype", dtype) object.__setattr__(self, "rounding", rounding) object.__setattr__(self, "out_of_range", out_of_range) object.__setattr__(self, "scalar_map", scalar_map) @@ -286,18 +273,15 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: return cls(**configuration_parsed) # type: ignore[arg-type] def to_dict(self) -> dict[str, JSON]: - config: dict[str, JSON] = {"data_type": self.data_type} + config: dict[str, JSON] = {"data_type": cast(JSON, self.dtype.to_json(zarr_format=3))} if self.rounding != "nearest-even": config["rounding"] = self.rounding if self.out_of_range is not None: config["out_of_range"] = self.out_of_range if self.scalar_map is not None: - config["scalar_map"] = self.scalar_map + config["scalar_map"] = cast(JSON, self.scalar_map) return {"name": "cast_value", "configuration": config} - def _target_zdtype(self) -> ZDType[TBaseDType, TBaseScalar]: - return get_data_type_from_json(self.data_type, zarr_format=3) - def validate( self, *, @@ -306,29 +290,61 @@ def validate( chunk_grid: ChunkGrid, ) -> None: source_native = dtype.to_native_dtype() - target_native = self._target_zdtype().to_native_dtype() + target_native = self.dtype.to_native_dtype() for label, dt in [("source", source_native), ("target", target_native)]: if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): raise ValueError( f"cast_value codec only supports integer and floating-point data types. " f"Got {label} dtype {dt}." ) - if self.out_of_range == "wrap": - if not np.issubdtype(target_native, np.integer): - raise ValueError("out_of_range='wrap' is only valid for integer target types.") + if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): + raise ValueError("out_of_range='wrap' is only valid for integer target types.") + # Check that int→float casts won't silently lose precision. + # A float type with `m` mantissa bits can exactly represent all integers + # in [-2**m, 2**m]. If the integer type's range exceeds that, the cast is lossy. + if np.issubdtype(source_native, np.integer) and np.issubdtype(target_native, np.floating): + int_info = np.iinfo(source_native) # type: ignore[type-var] + mantissa_bits = np.finfo(target_native).nmant # type: ignore[arg-type] + max_exact_int = 2**mantissa_bits + if int_info.max > max_exact_int or int_info.min < -max_exact_int: + raise ValueError( + f"Casting {source_native} to {target_native} may silently lose precision. " + f"{target_native} can only exactly represent integers up to 2**{mantissa_bits} " + f"({max_exact_int}), but {source_native} has range " + f"[{int_info.min}, {int_info.max}]." + ) + # Same check for float→int decode direction + if np.issubdtype(target_native, np.integer) and np.issubdtype(source_native, np.floating): + int_info = np.iinfo(target_native) # type: ignore[type-var] + mantissa_bits = np.finfo(source_native).nmant # type: ignore[arg-type] + max_exact_int = 2**mantissa_bits + if int_info.max > max_exact_int or int_info.min < -max_exact_int: + raise ValueError( + f"Casting {source_native} to {target_native} may silently lose precision. " + f"{source_native} can only exactly represent integers up to 2**{mantissa_bits} " + f"({max_exact_int}), but {target_native} has range " + f"[{int_info.min}, {int_info.max}]." + ) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: - target_zdtype = self._target_zdtype() + target_zdtype = self.dtype target_native = target_zdtype.to_native_dtype() source_native = chunk_spec.dtype.to_native_dtype() fill = chunk_spec.fill_value fill_arr = np.array([fill], dtype=source_native) - encode_entries, _ = _parse_scalar_map(self.scalar_map) + encode_raw = _extract_raw_map(self.scalar_map, "encode") + encode_entries = ( + _parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None + ) new_fill_arr = _cast_array( - fill_arr, target_native, self.rounding, self.out_of_range, encode_entries + fill_arr, + target_dtype=target_native, + rounding_mode=self.rounding, + out_of_range_mode=self.out_of_range, + scalar_map_entries=encode_entries, ) new_fill = target_native.type(new_fill_arr[0]) @@ -340,12 +356,19 @@ def _encode_sync( _chunk_spec: ArraySpec, ) -> NDBuffer | None: arr = chunk_array.as_ndarray_like() - target_native = self._target_zdtype().to_native_dtype() + target_native = self.dtype.to_native_dtype() - encode_entries, _ = _parse_scalar_map(self.scalar_map) + encode_raw = _extract_raw_map(self.scalar_map, "encode") + encode_entries = ( + _parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None + ) result = _cast_array( - np.asarray(arr), target_native, self.rounding, self.out_of_range, encode_entries + np.asarray(arr), + target_dtype=target_native, + rounding_mode=self.rounding, + out_of_range_mode=self.out_of_range, + scalar_map_entries=encode_entries, ) return chunk_array.__class__.from_ndarray_like(result) @@ -364,10 +387,17 @@ def _decode_sync( arr = chunk_array.as_ndarray_like() target_native = chunk_spec.dtype.to_native_dtype() - _, decode_entries = _parse_scalar_map(self.scalar_map) + decode_raw = _extract_raw_map(self.scalar_map, "decode") + decode_entries = ( + _parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None + ) result = _cast_array( - np.asarray(arr), target_native, self.rounding, self.out_of_range, decode_entries + np.asarray(arr), + target_dtype=target_native, + rounding_mode=self.rounding, + out_of_range_mode=self.out_of_range, + scalar_map_entries=decode_entries, ) return chunk_array.__class__.from_ndarray_like(result) @@ -380,7 +410,7 @@ async def _decode_single( def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize - target_itemsize = self._target_zdtype().to_native_dtype().itemsize + target_itemsize = self.dtype.to_native_dtype().itemsize if source_itemsize == 0: return 0 num_elements = input_byte_length // source_itemsize diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py index 55862137ae..a5037fdbc0 100644 --- a/src/zarr/codecs/scale_offset.py +++ b/src/zarr/codecs/scale_offset.py @@ -7,7 +7,7 @@ from typing_extensions import TypedDict from zarr.abc.codec import ArrayArrayCodec -from zarr.core.common import JSON, NamedConfig +from zarr.core.common import JSON, NamedConfig, parse_named_configuration if TYPE_CHECKING: from typing import Self @@ -39,7 +39,7 @@ class ScaleOffset(ArrayArrayCodec): All arithmetic uses the input array's data type semantics. - Parameters + Attributes ---------- offset : float Value subtracted during encoding. Default is 0. @@ -54,9 +54,12 @@ class ScaleOffset(ArrayArrayCodec): @classmethod def from_dict(cls, data: ScaleOffsetJSON) -> Self: # type: ignore[override] - scale: float = data.get("configuration", {}).get("scale", 1) # type: ignore[assignment] - offset: float = data.get("configuration", {}).get("offset", 0) # type: ignore[assignment] - return cls(scale=scale, offset=offset) + _, configuration_parsed = parse_named_configuration( + data, "scale_offset", require_configuration=False + ) + if configuration_parsed is None: + return cls() + return cls(**configuration_parsed) # type: ignore[arg-type] def to_dict(self) -> ScaleOffsetJSON: # type: ignore[override] if self.offset == 0 and self.scale == 1: @@ -90,7 +93,7 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: native_dtype = chunk_spec.dtype.to_native_dtype() fill = chunk_spec.fill_value new_fill = ( - np.dtype(native_dtype).type(fill) - native_dtype.type(self.offset) + np.dtype(native_dtype).type(fill) - native_dtype.type(self.offset) # type: ignore[operator] ) * native_dtype.type(self.scale) return replace(chunk_spec, fill_value=new_fill) diff --git a/src/zarr/core/chunk_grids/rectilinear.py b/src/zarr/core/chunk_grids/rectilinear.py deleted file mode 100644 index 3985613244..0000000000 --- a/src/zarr/core/chunk_grids/rectilinear.py +++ /dev/null @@ -1,802 +0,0 @@ -from __future__ import annotations - -import bisect -import itertools -import operator -from collections.abc import Sequence -from dataclasses import dataclass -from functools import cached_property, reduce -from typing import TYPE_CHECKING, Literal, Self, TypedDict - -import numpy as np -import numpy.typing as npt -from zarr.core.chunk_grids.common import ChunkEdgeLength, ChunkGrid - -from zarr.core.common import ( - JSON, - ShapeLike, - parse_named_configuration, - parse_shapelike, -) - -if TYPE_CHECKING: - from collections.abc import Iterator - - -class RectilinearChunkGridConfigurationDict(TypedDict): - """TypedDict for rectilinear chunk grid configuration""" - - kind: Literal["inline"] - chunk_shapes: Sequence[Sequence[ChunkEdgeLength]] - - -def _parse_chunk_shapes( - data: Sequence[Sequence[ChunkEdgeLength]], -) -> tuple[tuple[int, ...], ...]: - """ - Parse and expand chunk_shapes from metadata. - - Parameters - ---------- - data : Sequence[Sequence[ChunkEdgeLength]] - The chunk_shapes specification from metadata - - Returns - ------- - tuple[tuple[int, ...], ...] - Tuple of expanded chunk edge lengths for each axis - """ - # Runtime validation - strings are sequences but we don't want them - # Type annotation is for static typing, this validates actual JSON data - if isinstance(data, str) or not isinstance(data, Sequence): # type: ignore[redundant-expr,unreachable] - raise TypeError(f"chunk_shapes must be a sequence, got {type(data)}") - - result = [] - for i, axis_spec in enumerate(data): - # Runtime validation for each axis spec - if isinstance(axis_spec, str) or not isinstance(axis_spec, Sequence): # type: ignore[redundant-expr,unreachable] - raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_spec)}") - expanded = _expand_run_length_encoding(axis_spec) - result.append(expanded) - - return tuple(result) - - -def _expand_run_length_encoding(spec: Sequence[ChunkEdgeLength]) -> tuple[int, ...]: - """ - Expand a chunk edge length specification into a tuple of integers. - - The specification can contain: - - integers: representing explicit edge lengths - - tuples [value, count]: representing run-length encoded sequences - - Parameters - ---------- - spec : Sequence[ChunkEdgeLength] - The chunk edge length specification for one axis - - Returns - ------- - tuple[int, ...] - Expanded sequence of chunk edge lengths - - Examples - -------- - >>> _expand_run_length_encoding([2, 3]) - (2, 3) - >>> _expand_run_length_encoding([[2, 3]]) - (2, 2, 2) - >>> _expand_run_length_encoding([1, [2, 1], 3]) - (1, 2, 3) - >>> _expand_run_length_encoding([[1, 3], 3]) - (1, 1, 1, 3) - """ - result: list[int] = [] - for item in spec: - if isinstance(item, int): - # Explicit edge length - result.append(item) - elif isinstance(item, list | tuple): - # Run-length encoded: [value, count] - if len(item) != 2: - raise TypeError( - f"Run-length encoded items must be [int, int], got list of length {len(item)}" - ) - value, count = item - # Runtime validation of JSON data - if not isinstance(value, int) or not isinstance(count, int): # type: ignore[redundant-expr] - raise TypeError( - f"Run-length encoded items must be [int, int], got [{type(value).__name__}, {type(count).__name__}]" - ) - if count < 0: - raise ValueError(f"Run-length count must be non-negative, got {count}") - result.extend([value] * count) - else: - raise TypeError( - f"Chunk edge length must be int or [int, int] for run-length encoding, " - f"got {type(item)}" - ) - return tuple(result) - - -def _compress_run_length_encoding(chunks: tuple[int, ...]) -> list[int | list[int]]: - """ - Compress a sequence of chunk sizes to RLE format where beneficial. - - This function automatically detects runs of identical values and compresses them - using the [value, count] format. Single values or short runs are kept as-is. - - Parameters - ---------- - chunks : tuple[int, ...] - Sequence of chunk sizes along one dimension - - Returns - ------- - list[int | list[int]] - Compressed representation using RLE where beneficial - - Examples - -------- - >>> _compress_run_length_encoding((10, 10, 10, 10, 10, 10)) - [[10, 6]] - >>> _compress_run_length_encoding((10, 20, 30)) - [10, 20, 30] - >>> _compress_run_length_encoding((10, 10, 10, 20, 20, 30)) - [[10, 3], [20, 2], 30] - >>> _compress_run_length_encoding((5, 5, 10, 10, 10, 10, 15)) - [[5, 2], [10, 4], 15] - """ - if not chunks: - return [] - - result: list[int | list[int]] = [] - current_value = chunks[0] - current_count = 1 - - for value in chunks[1:]: - if value == current_value: - current_count += 1 - else: - # Decide whether to use RLE or explicit value - # Use RLE if count >= 3 to save space (tradeoff: [v,c] vs v,v,v) - if current_count >= 3: - result.append([current_value, current_count]) - elif current_count == 2: - # For count=2, RLE doesn't save space, but use it for consistency - result.append([current_value, current_count]) - else: - result.append(current_value) - - current_value = value - current_count = 1 - - # Handle the last run - if current_count >= 3 or current_count == 2: - result.append([current_value, current_count]) - else: - result.append(current_value) - - return result - - -def _normalize_rectilinear_chunks( - chunks: Sequence[Sequence[int | Sequence[int]]], shape: tuple[int, ...] -) -> tuple[tuple[int, ...], ...]: - """ - Normalize and validate variable chunks for RectilinearChunkGrid. - - Supports both explicit chunk sizes and run-length encoding (RLE). - RLE format: [[value, count]] expands to 'count' repetitions of 'value'. - - Parameters - ---------- - chunks : Sequence[Sequence[int | Sequence[int]]] - Nested sequence where each element is a sequence of chunk sizes along that dimension. - Each chunk size can be: - - An integer: explicit chunk size - - A sequence [value, count]: RLE format (expands to 'count' chunks of size 'value') - shape : tuple[int, ...] - The shape of the array. - - Returns - ------- - tuple[tuple[int, ...], ...] - Normalized chunk shapes as tuple of tuples. - - Raises - ------ - ValueError - If chunks don't match shape or sum incorrectly. - TypeError - If chunk specification format is invalid. - - Examples - -------- - >>> _normalize_rectilinear_chunks([[10, 20, 30], [25, 25]], (60, 50)) - ((10, 20, 30), (25, 25)) - >>> _normalize_rectilinear_chunks([[[10, 6]], [[10, 5]]], (60, 50)) - ((10, 10, 10, 10, 10, 10), (10, 10, 10, 10, 10)) - """ - # Expand RLE for each dimension - try: - chunk_shapes = tuple( - _expand_run_length_encoding(dim) # type: ignore[arg-type] - for dim in chunks - ) - except (TypeError, ValueError) as e: - raise TypeError( - f"Invalid variable chunks: {chunks}. Expected nested sequence of integers " - f"or RLE format [[value, count]]." - ) from e - - # Validate dimensionality - if len(chunk_shapes) != len(shape): - raise ValueError( - f"Variable chunks dimensionality ({len(chunk_shapes)}) " - f"must match array shape dimensionality ({len(shape)})" - ) - - # Validate that chunks sum to shape for each dimension - for i, (dim_chunks, dim_size) in enumerate(zip(chunk_shapes, shape, strict=False)): - chunk_sum = sum(dim_chunks) - if chunk_sum < dim_size: - raise ValueError( - f"Variable chunks along dimension {i} sum to {chunk_sum} " - f"but array shape is {dim_size}. " - f"Chunks must sum to be greater than or equal to the shape." - ) - if sum(dim_chunks[:-1]) >= dim_size: - raise ValueError( - f"Dimension {i} has more chunks than needed. " - f"The last chunk(s) would contain no valid data. " - f"Remove the extra chunk(s) or increase the array shape." - ) - - return chunk_shapes - - -@dataclass(frozen=True) -class RectilinearChunkGrid(ChunkGrid): - """ - A rectilinear chunk grid where chunk sizes vary along each axis. - - .. warning:: - This is an experimental feature and may change in future releases. - Expected to stabilize in Zarr version 3.3. - - Attributes - ---------- - chunk_shapes : tuple[tuple[int, ...], ...] - For each axis, a tuple of chunk edge lengths along that axis. - The sum of edge lengths must be >= the array shape along that axis. - """ - - _array_shape: tuple[int, ...] - chunk_shapes: tuple[tuple[int, ...], ...] - - def __init__(self, *, chunk_shapes: Sequence[Sequence[int]], array_shape: ShapeLike) -> None: - """ - Initialize a RectilinearChunkGrid. - - Parameters - ---------- - chunk_shapes : Sequence[Sequence[int]] - For each axis, a sequence of chunk edge lengths. - array_shape : ShapeLike - The shape of the array this chunk grid is bound to. - """ - array_shape_parsed = parse_shapelike(array_shape) - - # Convert to nested tuples and validate - parsed_shapes: list[tuple[int, ...]] = [] - for i, axis_chunks in enumerate(chunk_shapes): - if not isinstance(axis_chunks, Sequence): - raise TypeError(f"chunk_shapes[{i}] must be a sequence, got {type(axis_chunks)}") - # Validate all are positive integers - axis_tuple = tuple(axis_chunks) - for j, size in enumerate(axis_tuple): - if not isinstance(size, int): - raise TypeError( - f"chunk_shapes[{i}][{j}] must be an int, got {type(size).__name__}" - ) - if size <= 0: - raise ValueError(f"chunk_shapes[{i}][{j}] must be positive, got {size}") - parsed_shapes.append(axis_tuple) - - chunk_shapes_parsed = tuple(parsed_shapes) - object.__setattr__(self, "chunk_shapes", chunk_shapes_parsed) - object.__setattr__(self, "_array_shape", array_shape_parsed) - - # Validate array_shape is compatible with chunk_shapes - self._validate_array_shape() - - @property - def array_shape(self) -> tuple[int, ...]: - return self._array_shape - - @classmethod - def from_dict( # type: ignore[override] - cls, data: dict[str, JSON], *, array_shape: ShapeLike - ) -> Self: - """ - Parse a RectilinearChunkGrid from metadata dict. - - Parameters - ---------- - data : dict[str, JSON] - Metadata dictionary with 'name' and 'configuration' keys - array_shape : ShapeLike - The shape of the array this chunk grid is bound to. - - Returns - ------- - Self - A RectilinearChunkGrid instance - """ - _, configuration = parse_named_configuration(data, "rectilinear") - - if not isinstance(configuration, dict): - raise TypeError(f"configuration must be a dict, got {type(configuration)}") - - # Validate kind field - kind = configuration.get("kind") - if kind != "inline": - raise ValueError(f"Only 'inline' kind is supported, got {kind!r}") - - # Parse chunk_shapes with run-length encoding support - chunk_shapes_raw = configuration.get("chunk_shapes") - if chunk_shapes_raw is None: - raise ValueError("configuration must contain 'chunk_shapes'") - - # Type ignore: JSON data validated at runtime by _parse_chunk_shapes - chunk_shapes_expanded = _parse_chunk_shapes(chunk_shapes_raw) # type: ignore[arg-type] - - return cls(chunk_shapes=chunk_shapes_expanded, array_shape=array_shape) - - def to_dict(self) -> dict[str, JSON]: - """ - Convert to metadata dict format with automatic RLE compression. - - This method automatically compresses chunk shapes using run-length encoding - where beneficial (runs of 2 or more identical values). This reduces metadata - size for arrays with many uniform chunks. - - Returns - ------- - dict[str, JSON] - Metadata dictionary with 'name' and 'configuration' keys - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 10, 10, 10, 10, 10], [5, 5, 5, 5, 5]], array_shape=(60, 25)) - >>> grid.to_dict()['configuration']['chunk_shapes'] - [[[10, 6]], [[5, 5]]] - """ - # Compress each dimension using RLE where beneficial - chunk_shapes_compressed = [ - _compress_run_length_encoding(axis_chunks) for axis_chunks in self.chunk_shapes - ] - - return { - "name": "rectilinear", - "configuration": { - "kind": "inline", - "chunk_shapes": chunk_shapes_compressed, - }, - } - - def update_shape(self, new_shape: tuple[int, ...]) -> Self: - """ - Update the RectilinearChunkGrid to accommodate a new array shape. - - When resizing an array, this method adjusts the chunk grid to match the new shape. - For dimensions that grow, a new chunk is added with size equal to the size difference. - For dimensions that shrink, chunks are truncated or removed to fit the new shape. - - Parameters - ---------- - new_shape : tuple[int, ...] - The new shape of the array. Must have the same number of dimensions as the - chunk grid. - - Returns - ------- - Self - A new RectilinearChunkGrid instance with updated chunk shapes and array_shape - - Raises - ------ - ValueError - If the number of dimensions in new_shape doesn't match the number of dimensions - in the chunk grid - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20], [15, 15]], array_shape=(30, 30)) - >>> grid.update_shape((50, 40)) # Grow both dimensions - RectilinearChunkGrid(_array_shape=(50, 40), chunk_shapes=((10, 20, 20), (15, 15, 10))) - - >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20, 30], [25, 25]], array_shape=(60, 50)) - >>> grid.update_shape((25, 30)) # Shrink first dimension - RectilinearChunkGrid(_array_shape=(25, 30), chunk_shapes=((10, 20), (25, 25))) - - Notes - ----- - This method is automatically called when an array is resized. The chunk size - strategy for growing dimensions adds a single new chunk with size equal to - the growth amount. This may not be optimal for all use cases, and users may - want to manually adjust chunk shapes after resizing. - """ - - if len(new_shape) != len(self.chunk_shapes): - raise ValueError( - f"new_shape has {len(new_shape)} dimensions but " - f"chunk_shapes has {len(self.chunk_shapes)} dimensions" - ) - - new_chunk_shapes: list[tuple[int, ...]] = [] - for dim in range(len(new_shape)): - old_dim_length = sum(self.chunk_shapes[dim]) - new_dim_chunks: tuple[int, ...] - if new_shape[dim] == old_dim_length: - new_dim_chunks = self.chunk_shapes[dim] # no changes - - elif new_shape[dim] > old_dim_length: - new_dim_chunks = (*self.chunk_shapes[dim], new_shape[dim] - old_dim_length) - else: - # drop chunk sizes that are not inside the shape anymore - total = 0 - i = 0 - for c in self.chunk_shapes[dim]: - i += 1 - total += c - if total >= new_shape[dim]: - break - # keep the last chunk (it may be too long) - new_dim_chunks = self.chunk_shapes[dim][:i] - - new_chunk_shapes.append(new_dim_chunks) - - return type(self)(chunk_shapes=tuple(new_chunk_shapes), array_shape=new_shape) - - def all_chunk_coords(self) -> Iterator[tuple[int, ...]]: - """ - Generate all chunk coordinates. - - Yields - ------ - tuple[int, ...] - Chunk coordinates - """ - nchunks_per_axis = [len(axis_chunks) for axis_chunks in self.chunk_shapes] - return itertools.product(*(range(n) for n in nchunks_per_axis)) - - def get_nchunks(self) -> int: - """ - Get the total number of chunks. - - Returns - ------- - int - Total number of chunks - """ - return reduce(operator.mul, (len(axis_chunks) for axis_chunks in self.chunk_shapes), 1) - - def _validate_array_shape(self) -> None: - """ - Validate that array_shape is compatible with chunk_shapes. - - Raises - ------ - ValueError - If array_shape is incompatible with chunk_shapes - """ - if len(self._array_shape) != len(self.chunk_shapes): - raise ValueError( - f"array_shape has {len(self._array_shape)} dimensions but " - f"chunk_shapes has {len(self.chunk_shapes)} dimensions" - ) - - for axis, (arr_size, axis_chunks) in enumerate( - zip(self._array_shape, self.chunk_shapes, strict=False) - ): - chunk_sum = sum(axis_chunks) - if chunk_sum < arr_size: - raise ValueError( - f"Sum of chunk sizes along axis {axis} is {chunk_sum} " - f"but array shape is {arr_size}. This is invalid for the " - "RectilinearChunkGrid." - ) - - @cached_property - def _cumulative_sizes(self) -> tuple[tuple[int, ...], ...]: - """ - Compute cumulative sizes for each axis. - - Returns a tuple of tuples where each inner tuple contains cumulative - chunk sizes for an axis. Used for efficient chunk boundary calculations. - - Returns - ------- - tuple[tuple[int, ...], ...] - Cumulative sizes for each axis - - Examples - -------- - For chunk_shapes = [[2, 3, 1], [4, 2]]: - Returns ((0, 2, 5, 6), (0, 4, 6)) - """ - result = [] - for axis_chunks in self.chunk_shapes: - cumsum = [0] - for size in axis_chunks: - cumsum.append(cumsum[-1] + size) - result.append(tuple(cumsum)) - return tuple(result) - - def get_chunk_start(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: - """ - Get the starting position (offset) of a chunk in the array. - - Parameters - ---------- - chunk_coord : tuple[int, ...] - Chunk coordinates (indices into the chunk grid) - - Returns - ------- - tuple[int, ...] - Starting index of the chunk in the array - - Raises - ------ - IndexError - If chunk_coord is out of bounds - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) - >>> grid.get_chunk_start((0, 0)) - (0, 0) - >>> grid.get_chunk_start((1, 1)) - (2, 3) - """ - # Validate chunk coordinates are in bounds - for axis, (coord, axis_chunks) in enumerate( - zip(chunk_coord, self.chunk_shapes, strict=False) - ): - if not (0 <= coord < len(axis_chunks)): - raise IndexError( - f"chunk_coord[{axis}] = {coord} is out of bounds [0, {len(axis_chunks)})" - ) - - # Use cumulative sizes to get start position - return tuple(self._cumulative_sizes[axis][coord] for axis, coord in enumerate(chunk_coord)) - - def get_chunk_shape(self, chunk_coord: tuple[int, ...]) -> tuple[int, ...]: - """ - Get the shape of a specific chunk. - - Parameters - ---------- - chunk_coord : tuple[int, ...] - Chunk coordinates (indices into the chunk grid) - - Returns - ------- - tuple[int, ...] - Shape of the chunk - - Raises - ------ - IndexError - If chunk_coord is out of bounds - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 3, 1], [4, 2]], array_shape=(6, 6)) - >>> grid.get_chunk_shape((0, 0)) - (2, 4) - >>> grid.get_chunk_shape((1, 0)) - (3, 4) - """ - # Validate chunk coordinates are in bounds - for axis, (coord, axis_chunks) in enumerate( - zip(chunk_coord, self.chunk_shapes, strict=False) - ): - if not (0 <= coord < len(axis_chunks)): - raise IndexError( - f"chunk_coord[{axis}] = {coord} is out of bounds [0, {len(axis_chunks)})" - ) - - # Get shape directly from chunk_shapes - return tuple( - axis_chunks[coord] - for axis_chunks, coord in zip(self.chunk_shapes, chunk_coord, strict=False) - ) - - def get_chunk_slice(self, chunk_coord: tuple[int, ...]) -> tuple[slice, ...]: - """ - Get the slice for indexing into an array for a specific chunk. - - Parameters - ---------- - chunk_coord : tuple[int, ...] - Chunk coordinates (indices into the chunk grid) - - Returns - ------- - tuple[slice, ...] - Slice tuple for indexing the array - - Raises - ------ - IndexError - If chunk_coord is out of bounds - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) - >>> grid.get_chunk_slice((0, 0)) - (slice(0, 2, None), slice(0, 3, None)) - >>> grid.get_chunk_slice((1, 1)) - (slice(2, 4, None), slice(3, 6, None)) - """ - start = self.get_chunk_start(chunk_coord) - shape = self.get_chunk_shape(chunk_coord) - - return tuple(slice(s, s + length) for s, length in zip(start, shape, strict=False)) - - def get_chunk_grid_shape(self) -> tuple[int, ...]: - """ - Get the shape of the chunk grid (number of chunks per axis). - - Returns - ------- - tuple[int, ...] - Number of chunks along each axis - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) - >>> grid.get_chunk_grid_shape() - (3, 2) - """ - return tuple(len(axis_chunks) for axis_chunks in self.chunk_shapes) - - def array_index_to_chunk_coord(self, array_index: tuple[int, ...]) -> tuple[int, ...]: - """ - Find which chunk contains a given array index. - - Parameters - ---------- - array_index : tuple[int, ...] - Index into the array - - Returns - ------- - tuple[int, ...] - Chunk coordinates containing the array index - - Raises - ------ - IndexError - If array_index is out of bounds - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 3, 1], [4, 2]], array_shape=(6, 6)) - >>> grid.array_index_to_chunk_coord((0, 0)) - (0, 0) - >>> grid.array_index_to_chunk_coord((2, 0)) - (1, 0) - >>> grid.array_index_to_chunk_coord((5, 5)) - (2, 1) - """ - # Validate array index is in bounds - for axis, (idx, size) in enumerate(zip(array_index, self._array_shape, strict=False)): - if not (0 <= idx < size): - raise IndexError(f"array_index[{axis}] = {idx} is out of bounds [0, {size})") - - # Use binary search in cumulative sizes to find chunk coordinate - result = [] - for axis, idx in enumerate(array_index): - cumsum = self._cumulative_sizes[axis] - # bisect_right gives us the chunk index + 1, so subtract 1 - chunk_idx = bisect.bisect_right(cumsum, idx) - 1 - result.append(chunk_idx) - - return tuple(result) - - def array_indices_to_chunk_dim( - self, dim: int, indices: npt.NDArray[np.intp] - ) -> npt.NDArray[np.intp]: - """ - Vectorized mapping of array indices to chunk coordinates along one dimension. - - For RectilinearChunkGrid, uses np.searchsorted on cumulative sizes. - """ - cumsum = np.asarray(self._cumulative_sizes[dim]) - return np.searchsorted(cumsum, indices, side="right").astype(np.intp) - 1 - - def chunks_in_selection(self, selection: tuple[slice, ...]) -> Iterator[tuple[int, ...]]: - """ - Get all chunks that intersect with a given selection. - - Parameters - ---------- - selection : tuple[slice, ...] - Selection (slices) into the array - - Yields - ------ - tuple[int, ...] - Chunk coordinates that intersect with the selection - - Raises - ------ - ValueError - If selection is invalid - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[2, 2, 2], [3, 3]], array_shape=(6, 6)) - >>> selection = (slice(1, 5), slice(2, 5)) - >>> list(grid.chunks_in_selection(selection)) - [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)] - """ - # Normalize slices and find chunk ranges for each axis - chunk_ranges = [] - for axis, (sel, size) in enumerate(zip(selection, self._array_shape, strict=False)): - if not isinstance(sel, slice): - raise TypeError(f"selection[{axis}] must be a slice, got {type(sel)}") - - # Normalize slice with array size - start, stop, step = sel.indices(size) - - if step != 1: - raise ValueError(f"selection[{axis}] has step={step}, only step=1 is supported") - - if start >= stop: - # Empty selection - return - - # Find first and last chunk that intersect with [start, stop) - start_chunk = self.array_index_to_chunk_coord( - tuple(start if i == axis else 0 for i in range(len(self._array_shape))) - )[axis] - - # stop-1 is the last index we need - end_chunk = self.array_index_to_chunk_coord( - tuple(stop - 1 if i == axis else 0 for i in range(len(self._array_shape))) - )[axis] - - chunk_ranges.append(range(start_chunk, end_chunk + 1)) - - # Generate all combinations of chunk coordinates - yield from itertools.product(*chunk_ranges) - - def chunks_per_dim(self, dim: int) -> int: - """ - Get the number of chunks along a specific dimension. - - Parameters - ---------- - dim : int - Dimension index - - Returns - ------- - int - Number of chunks along the dimension - - Examples - -------- - >>> grid = RectilinearChunkGrid(chunk_shapes=[[10, 20], [5, 5, 5]], array_shape=(30, 15)) - >>> grid.chunks_per_dim(0) # 2 chunks along axis 0 - 2 - >>> grid.chunks_per_dim(1) # 3 chunks along axis 1 - 3 - """ - return len(self.chunk_shapes[dim]) diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py index 16b7bd3859..e57772480a 100644 --- a/tests/test_codecs/test_cast_value.py +++ b/tests/test_codecs/test_cast_value.py @@ -9,6 +9,7 @@ from zarr.codecs.bytes import BytesCodec from zarr.codecs.cast_value import CastValue from zarr.codecs.scale_offset import ScaleOffset +from zarr.core.dtype import get_data_type_from_json from zarr.storage import MemoryStore @@ -28,7 +29,7 @@ def test_float64_to_float32(self) -> None: ) arr[:] = data result = arr[:] - np.testing.assert_allclose(result, data) + np.testing.assert_allclose(result, data) # type: ignore[arg-type] def test_float64_to_int32_towards_zero(self) -> None: """Cast float64 to int32 with towards-zero rounding.""" @@ -112,7 +113,7 @@ def test_nan_scalar_map(self) -> None: CastValue( data_type="uint8", out_of_range="clamp", - scalar_map={ + scalar_map={ # type: ignore[arg-type] "encode": [["NaN", 0]], "decode": [[0, "NaN"]], }, @@ -122,9 +123,91 @@ def test_nan_scalar_map(self) -> None: ) arr[:] = data result = arr[:] - assert result[0] == 1.0 # 1.0 survives round-trip - assert np.isnan(result[1]) # NaN -> 0 -> NaN via scalar_map - assert result[2] == 3.0 + assert result[0] == 1.0 # type: ignore[index] + assert np.isnan(result[1]) # type: ignore[index] + assert result[2] == 3.0 # type: ignore[index] + + def test_hex_nan_scalar_map(self) -> None: + """Hex-encoded NaN values in scalar_map should round-trip correctly. + + The hex string encoding is used for preserving specific NaN payloads + per the Zarr v3 spec. + """ + import struct + + # 0x7fc00001 is a NaN with a non-default payload in float32 + hex_nan = "0x7fc00001" + nan_bytes = bytes.fromhex("7fc00001") + nan_f32 = np.float32(struct.unpack(">f", nan_bytes)[0]) + assert np.isnan(nan_f32) + + store = MemoryStore() + data = np.array([1.0, nan_f32, 3.0], dtype="float32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype="float32", + chunks=(3,), + codecs=[ + CastValue( + data_type="uint8", + out_of_range="clamp", + scalar_map={ # type: ignore[arg-type] + "encode": [[hex_nan, 255]], + "decode": [[255, hex_nan]], + }, + ), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + assert result[0] == np.float32(1.0) # type: ignore[index] + assert np.isnan(result[1]) # type: ignore[index] + assert result[2] == np.float32(3.0) # type: ignore[index] + + # Verify the NaN payload is preserved by checking the bit pattern + result_bytes = struct.pack(">f", result[1]) # type: ignore[index] + assert result_bytes == nan_bytes + + def test_int64_to_float64_precision_loss_rejected(self) -> None: + """Casting int64 to float64 is rejected because float64 cannot + exactly represent all int64 values. + + float64 has a 52-bit mantissa, so it can only exactly represent + integers up to 2**52. int64.max is 2**63 - 1, far exceeding this. + """ + store = MemoryStore() + with pytest.raises(ValueError, match="may silently lose precision"): + zarr.create( + store=store, + shape=(1,), + dtype="int64", + chunks=(1,), + codecs=[ + CastValue(data_type="float64"), + BytesCodec(), + ], + ) + + def test_int32_to_float64_ok(self) -> None: + """Casting int32 to float64 is safe because float64 has enough + mantissa bits (52) to exactly represent all int32 values (up to 2**31 - 1).""" + store = MemoryStore() + data = np.array([np.iinfo(np.int32).max, np.iinfo(np.int32).min], dtype="int32") + arr = zarr.create( + store=store, + shape=data.shape, + dtype="int32", + chunks=(2,), + codecs=[ + CastValue(data_type="float64"), + BytesCodec(), + ], + ) + arr[:] = data + result = arr[:] + np.testing.assert_array_equal(result, data) def test_rounding_nearest_even(self) -> None: """nearest-even rounding: 0.5 rounds to 0, 1.5 rounds to 2.""" @@ -261,16 +344,17 @@ def test_to_dict(self) -> None: data_type="uint8", rounding="towards-zero", out_of_range="clamp", - scalar_map={"encode": [["NaN", 0]], "decode": [[0, "NaN"]]}, + scalar_map={"encode": [["NaN", 0]], "decode": [[0, "NaN"]]}, # type: ignore[arg-type] ) d = codec.to_dict() - assert d["name"] == "cast_value" - assert d["configuration"]["data_type"] == "uint8" - assert d["configuration"]["rounding"] == "towards-zero" - assert d["configuration"]["out_of_range"] == "clamp" - assert d["configuration"]["scalar_map"] == { - "encode": [["NaN", 0]], - "decode": [[0, "NaN"]], + assert d == { + "name": "cast_value", + "configuration": { + "data_type": "uint8", + "rounding": "towards-zero", + "out_of_range": "clamp", + "scalar_map": {"encode": [["NaN", 0]], "decode": [[0, "NaN"]]}, + }, } def test_to_dict_minimal(self) -> None: @@ -291,7 +375,7 @@ def test_from_dict(self) -> None: }, } ) - assert codec.data_type == "uint8" + assert codec.dtype == get_data_type_from_json("uint8", zarr_format=3) assert codec.rounding == "towards-zero" assert codec.out_of_range == "clamp" @@ -301,10 +385,10 @@ def test_roundtrip_json(self) -> None: data_type="int16", rounding="towards-negative", out_of_range="clamp", - scalar_map={"encode": [["NaN", 0]]}, + scalar_map={"encode": [["NaN", 0]]}, # type: ignore[arg-type] ) restored = CastValue.from_dict(original.to_dict()) - assert restored.data_type == original.data_type + assert restored.dtype == original.dtype assert restored.rounding == original.rounding assert restored.out_of_range == original.out_of_range assert restored.scalar_map == original.scalar_map @@ -362,7 +446,7 @@ def test_float64_to_uint8_roundtrip(self) -> None: ) arr[:] = data result = arr[:] - np.testing.assert_allclose(result, data, atol=0.1) + np.testing.assert_allclose(result, data, atol=0.1) # type: ignore[arg-type] def test_temperature_storage_pattern(self) -> None: """Realistic pattern: store temperature data as uint8. @@ -389,7 +473,7 @@ def test_temperature_storage_pattern(self) -> None: arr[:] = data result = arr[:] # Precision limited by uint8 quantization (~0.22°C step) - np.testing.assert_allclose(result, data, atol=0.25) + np.testing.assert_allclose(result, data, atol=0.25) # type: ignore[arg-type] def test_nan_handling_pipeline(self) -> None: """NaN values should be handled via scalar_map in cast_value.""" @@ -406,7 +490,7 @@ def test_nan_handling_pipeline(self) -> None: CastValue( data_type="uint8", out_of_range="clamp", - scalar_map={ + scalar_map={ # type: ignore[arg-type] "encode": [["NaN", 0]], "decode": [[0, "NaN"]], }, @@ -416,9 +500,9 @@ def test_nan_handling_pipeline(self) -> None: ) arr[:] = data result = arr[:] - assert result[0] == 1.0 - assert np.isnan(result[1]) - assert result[2] == 3.0 + assert result[0] == 1.0 # type: ignore[index] + assert np.isnan(result[1]) # type: ignore[index] + assert result[2] == 3.0 # type: ignore[index] def test_metadata_persistence(self) -> None: """Array metadata should be correctly persisted and reloaded.""" diff --git a/tests/test_codecs/test_scale_offset.py b/tests/test_codecs/test_scale_offset.py index ad131c6493..ace3445c10 100644 --- a/tests/test_codecs/test_scale_offset.py +++ b/tests/test_codecs/test_scale_offset.py @@ -41,7 +41,7 @@ def test_encode_decode_float64(self) -> None: ) arr[:] = data result = arr[:] - np.testing.assert_allclose(result, data, rtol=1e-10) + np.testing.assert_allclose(result, data, rtol=1e-10) # type: ignore[arg-type] def test_encode_decode_float32(self) -> None: """Round-trip with float32 data.""" @@ -56,7 +56,7 @@ def test_encode_decode_float32(self) -> None: ) arr[:] = data result = arr[:] - np.testing.assert_allclose(result, data, rtol=1e-6) + np.testing.assert_allclose(result, data, rtol=1e-6) # type: ignore[arg-type] def test_encode_decode_integer(self) -> None: """Round-trip with integer data (uses integer arithmetic semantics).""" @@ -85,7 +85,7 @@ def test_offset_only(self) -> None: codecs=[ScaleOffset(offset=100), BytesCodec()], ) arr[:] = data - np.testing.assert_allclose(arr[:], data) + np.testing.assert_allclose(arr[:], data) # type: ignore[arg-type] def test_scale_only(self) -> None: """Test with only scale (offset=0).""" @@ -99,7 +99,7 @@ def test_scale_only(self) -> None: codecs=[ScaleOffset(scale=10), BytesCodec()], ) arr[:] = data - np.testing.assert_allclose(arr[:], data) + np.testing.assert_allclose(arr[:], data) # type: ignore[arg-type] def test_fill_value_transformed(self) -> None: """Fill value should be transformed through the codec.""" @@ -114,7 +114,7 @@ def test_fill_value_transformed(self) -> None: ) # Without writing, reading should return the fill value result = arr[:] - np.testing.assert_allclose(result, np.full(5, 100.0)) + np.testing.assert_allclose(result, np.full(5, 100.0)) # type: ignore[arg-type] def test_validate_rejects_complex(self) -> None: """Validate should reject complex dtypes.""" @@ -130,19 +130,19 @@ def test_validate_rejects_complex(self) -> None: def test_to_dict_no_config(self) -> None: """Default codec should serialize without configuration.""" codec = ScaleOffset() - assert codec.to_dict() == {"name": "scale_offset"} + assert codec.to_dict() == {"name": "scale_offset"} # type: ignore[comparison-overlap] def test_to_dict_with_config(self) -> None: """Non-default codec should include configuration.""" codec = ScaleOffset(offset=5, scale=0.1) d = codec.to_dict() - assert d == {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} + assert d == {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} # type: ignore[comparison-overlap] def test_to_dict_offset_only(self) -> None: """Only offset in config when scale is default.""" codec = ScaleOffset(offset=5) d = codec.to_dict() - assert d == {"name": "scale_offset", "configuration": {"offset": 5}} + assert d == {"name": "scale_offset", "configuration": {"offset": 5}} # type: ignore[comparison-overlap] def test_from_dict_no_config(self) -> None: """Parse codec from JSON with no configuration.""" From 83fad887611a85503c55b751135c348b0d8b3078 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 16 Mar 2026 13:30:44 +0100 Subject: [PATCH 4/4] update dev notes --- docs/devnotes/feat-v3-scale-offset-cast.md | 132 ++++++++++++++------- 1 file changed, 87 insertions(+), 45 deletions(-) diff --git a/docs/devnotes/feat-v3-scale-offset-cast.md b/docs/devnotes/feat-v3-scale-offset-cast.md index a7dcb504fd..3236a56318 100644 --- a/docs/devnotes/feat-v3-scale-offset-cast.md +++ b/docs/devnotes/feat-v3-scale-offset-cast.md @@ -17,22 +17,20 @@ common pattern of storing floating-point data as compressed integers. **Decode:** `out = (in / scale) + offset` ### Parameters -- `offset` (optional): scalar subtracted during encoding. Default: 0 (additive identity). - Serialized in JSON using the zarr v3 fill-value encoding for the array's dtype. -- `scale` (optional): scalar multiplied during encoding (after offset subtraction). Default: 1 - (multiplicative identity). Same JSON encoding as offset. +- `offset` (optional, float): scalar subtracted during encoding. Default: 0. +- `scale` (optional, float): scalar multiplied during encoding (after offset subtraction). Default: 1. ### Key rules -- Arithmetic MUST use the input array's own data type semantics (no implicit promotion). -- If any intermediate or final value is unrepresentable in that dtype, error. +- Arithmetic uses the input array's own data type semantics (no implicit promotion). - If neither scale nor offset is given, `configuration` may be omitted (codec is a no-op). -- Fill value MUST be transformed through the codec (encode direction). -- Only valid for real-number data types (int/uint/float families). +- Fill value is transformed through the codec (encode direction). +- Only valid for real-number data types (int/uint/float families). Complex dtypes are rejected at validation time. ### JSON ```json {"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}} ``` +When both offset and scale are defaults: `{"name": "scale_offset"}` (no configuration key). --- @@ -43,33 +41,54 @@ common pattern of storing floating-point data as compressed integers. **Purpose:** Value-convert (not binary-reinterpret) array elements to a new data type. ### Parameters -- `data_type` (required): target zarr v3 data type. -- `rounding` (optional): how to round when exact representation is impossible. +- `data_type` (required): target zarr v3 data type name (e.g. `"uint8"`, `"float32"`). + Internally stored as a `ZDType` instance, resolved via `get_data_type_from_json`. +- `rounding` (optional): how to round when casting float to int. Values: `"nearest-even"` (default), `"towards-zero"`, `"towards-positive"`, `"towards-negative"`, `"nearest-away"`. - `out_of_range` (optional): what to do when a value is outside the target's range. - Values: `"clamp"`, `"wrap"`. If absent, out-of-range MUST error. - `"wrap"` only valid for integral two's-complement types. + Values: `"clamp"`, `"wrap"`. If absent, out-of-range values raise an error. + `"wrap"` is only valid for integer target types. - `scalar_map` (optional): explicit value overrides. `{"encode": [[input, output], ...], "decode": [[input, output], ...]}`. - Evaluated BEFORE rounding/out_of_range. - -### Casting procedure (same for encode and decode, swapping source/target) -1. Check scalar_map — if input matches a key, use mapped value. -2. Check exact representability — if yes, use directly. -3. Apply rounding and out_of_range rules. -4. If none apply, MUST error. + Applied BEFORE rounding/out_of_range. Each entry's source is deserialized using the + source dtype and target using the target dtype (via `ZDType.from_json_scalar`), + preserving full precision for both sides. + +### Cast procedure (`_cast_array_impl`) + +Dispatches on `(src_type, tgt_type, has_map)` where src/tgt are `"int"` or `"float"`: + +| Source | Target | scalar_map | Procedure | +|--------|--------|------------|-----------| +| any | float | no | `arr.astype(target_dtype)` | +| int | float | yes | widen to float64, apply map, cast | +| float | float | yes | copy, apply map, cast | +| int | int | no | range check, then astype | +| int | int | yes | widen to int64, apply map, range check | +| float | int | any | widen to float64, apply map (if any), reject NaN/Inf, round, range check | + +All casts are wrapped in `np.errstate(over='raise', invalid='raise')` to convert +numpy overflow/invalid warnings to hard errors. + +### Validation checks +- Only integer and floating-point dtypes are allowed (both source and target). +- `out_of_range='wrap'` is rejected for non-integer target types. +- Int-to-float casts are rejected if the float type's mantissa cannot exactly represent + the full integer range (e.g. int64 -> float64 is rejected because float64 has only + 52 mantissa bits, but int64 has values up to 2^63-1). Same check applies for the + float-to-int decode direction. ### Special values -- NaN propagates between IEEE 754 types unless scalar_map overrides. -- Signed zero preserved between IEEE 754 types. -- If target doesn't support NaN/infinity and input has them, MUST error - unless scalar_map provides a mapping. +- NaN: detected dynamically via `isinstance(src, (float, np.floating)) and np.isnan(src)`. + NaN-to-integer casts error unless `scalar_map` provides a mapping. + Hex-encoded NaN strings (e.g. `"0x7fc00001"`) preserve NaN payloads per the zarr v3 spec. +- `_check_int_range` handles out-of-range integer values with clamp (via `np.clip`) or + wrap (via modular arithmetic). ### Fill value -- MUST be cast using same semantics as elements. -- Implementations SHOULD validate fill value survives round-trip at metadata - construction time. +- Cast using the same `_cast_array` path as array elements, including scalar_map and rounding. +- Done in `resolve_metadata`, which also changes the chunk spec's dtype to the target. ### JSON ```json @@ -86,6 +105,7 @@ common pattern of storing floating-point data as compressed integers. } } ``` +Only non-default fields are serialized (rounding and out_of_range are omitted when default). --- @@ -109,29 +129,51 @@ common pattern of storing floating-point data as compressed integers. --- -## Implementation notes for zarr-python +## Implementation notes + +### Module structure +- `src/zarr/codecs/scale_offset.py` — `ScaleOffset` class +- `src/zarr/codecs/cast_value.py` — `CastValue` class and casting helpers +- `tests/test_codecs/test_scale_offset.py` — ScaleOffset tests +- `tests/test_codecs/test_cast_value.py` — CastValue tests + combined pipeline tests ### scale_offset -- Subclass `ArrayArrayCodec`. -- `resolve_metadata`: transform fill_value via `(fill - offset) * scale`, keep dtype. -- `_encode_single`: `(array - offset) * scale` using numpy with same dtype. -- `_decode_single`: `(array / scale) + offset` using numpy with same dtype. -- `is_fixed_size = True`. +- `@dataclass(kw_only=True, frozen=True)`, subclasses `ArrayArrayCodec`. +- Uses `ScaleOffsetJSON` (a `NamedConfig` TypedDict) for typed serialization. +- `from_dict` uses `parse_named_configuration(data, "scale_offset", require_configuration=False)`. +- `to_dict` omits the `configuration` key entirely when both offset=0 and scale=1. +- `resolve_metadata`: transforms fill_value via `(fill - offset) * scale`, dtype unchanged. +- `_encode_sync`: `(arr - offset) * scale` using the array's own dtype. +- `_decode_sync`: `(arr / scale) + offset` using the array's own dtype. +- `is_fixed_size = True`, `compute_encoded_size` returns input size unchanged. ### cast_value -- Subclass `ArrayArrayCodec`. -- `resolve_metadata`: change dtype to target dtype, cast fill_value. -- `_encode_single`: cast array from input dtype to target dtype. -- `_decode_single`: cast array from target dtype back to input dtype. -- Needs the input dtype stored (from `evolve_from_array_spec` or `resolve_metadata`). -- `is_fixed_size = True` (for fixed-size types). -- Initial implementation: support `rounding` and `out_of_range` for common cases. - `scalar_map` adds complexity but is needed for NaN handling. - -### Key design decisions from PR review +- `@dataclass(frozen=True)` with custom `__init__` (accepts `data_type: str | ZDType`). +- Stores `dtype: ZDType` (not a string). String data_type is resolved via `get_data_type_from_json`. +- `from_dict` uses `parse_named_configuration(data, "cast_value", require_configuration=True)`. +- `to_dict` serializes dtype via `self.dtype.to_json(zarr_format=3)`, only includes + non-default rounding/out_of_range/scalar_map. +- `resolve_metadata`: casts fill value, changes chunk spec dtype to target. +- `_encode_sync` / `_decode_sync`: delegate to `_cast_array`, threading the appropriate + scalar_map direction ("encode" or "decode") and the correct src/tgt ZDType pair for + scalar map deserialization. +- `compute_encoded_size`: scales by `target_itemsize / source_itemsize`. + +### Key helpers (cast_value.py) +- `_cast_array` — public entry point, wraps `_cast_array_impl` with `np.errstate`. +- `_cast_array_impl` — match-based dispatch on `(src_type, tgt_type, has_map)`. +- `_check_int_range` — integer range check with clamp/wrap/error. +- `_round_inplace` — rounding dispatch (rint, trunc, ceil, floor, nearest-away). +- `_apply_scalar_map` — in-place value remapping with NaN-aware matching. +- `_parse_map_entries` — deserializes scalar_map JSON using separate src/tgt ZDType instances. +- `_extract_raw_map` — extracts "encode" or "decode" direction from ScalarMapJSON. + +### Key design decisions 1. Encode = `(in - offset) * scale` (subtract, not add) — matches HDF5 and numcodecs. 2. No implicit precision promotion — arithmetic stays in the input dtype. 3. `out_of_range` defaults to error (not clamp). -4. `scalar_map` was added specifically to handle NaN-to-integer mappings. -5. Fill value must round-trip exactly through the codec chain. -6. Name uses underscore: `scale_offset`, `cast_value`. +4. `scalar_map` entries are typed: each side is deserialized with its own ZDType, + so int64 scalars don't lose precision through float64 intermediaries. +5. Fill value is cast through the same `_cast_array` path as data elements. +6. Int-to-float precision loss is caught at validate time (mantissa bit check). +7. Runtime overflow/invalid is caught via `np.errstate(over='raise', invalid='raise')`.