diff --git a/src/zarr/abc/serializable.py b/src/zarr/abc/serializable.py new file mode 100644 index 0000000000..f0798dd43a --- /dev/null +++ b/src/zarr/abc/serializable.py @@ -0,0 +1,27 @@ +from typing import Protocol, Self, TypeVar + +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class JSONSerializable(Protocol[T_co, T_contra]): + @classmethod + def from_json(cls, obj: T_contra) -> Self: + """ + Deserialize from an instance of T_contra. + """ + ... + + @classmethod + def try_from_json(cls, obj: object) -> Self: + """ + Deserialize from an unknown object. Details of any + deserialization failure should be conveyed via an exception. + """ + ... + + def to_json(self) -> T_co: + """ + Serialize to JSON. + """ + ... diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 6164cda957..66cf3bad7e 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -24,7 +24,7 @@ from zarr.core.common import ( JSON, AccessModeLiteral, - DimensionNames, + DimensionNamesLike, MemoryOrder, ZarrFormat, _default_zarr_format, @@ -914,7 +914,7 @@ async def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, config: ArrayConfigLike | None = None, **kwargs: Any, diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 1204eba3c9..4e718a234e 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -33,7 +33,7 @@ from zarr.core.common import ( JSON, AccessModeLiteral, - DimensionNames, + DimensionNamesLike, MemoryOrder, ShapeLike, ZarrFormat, @@ -649,7 +649,7 @@ def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, config: ArrayConfigLike | None = None, **kwargs: Any, @@ -832,7 +832,7 @@ def create_array( zarr_format: ZarrFormat | None = 3, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -1003,7 +1003,7 @@ def from_array( zarr_format: ZarrFormat | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 486216fa32..b82c77fa9c 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -53,7 +53,7 @@ ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, - DimensionNames, + DimensionNamesLike, MemoryOrder, ShapeLike, ZarrFormat, @@ -389,7 +389,7 @@ async def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # runtime overwrite: bool = False, data: npt.ArrayLike | None = None, @@ -417,7 +417,7 @@ async def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # runtime overwrite: bool = False, data: npt.ArrayLike | None = None, @@ -445,7 +445,7 @@ async def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # v2 only chunks: ShapeLike | None = None, dimension_separator: Literal[".", "/"] | None = None, @@ -479,7 +479,7 @@ async def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # v2 only chunks: ShapeLike | None = None, dimension_separator: Literal[".", "/"] | None = None, @@ -630,7 +630,7 @@ async def _create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # v2 only chunks: ShapeLike | None = None, dimension_separator: Literal[".", "/"] | None = None, @@ -742,7 +742,7 @@ def _create_metadata_v3( fill_value: Any | None = DEFAULT_FILL_VALUE, chunk_key_encoding: ChunkKeyEncodingLike | None = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, attributes: dict[str, JSON] | None = None, ) -> ArrayV3Metadata: """ @@ -803,7 +803,7 @@ async def _create_v3( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, attributes: dict[str, JSON] | None = None, overwrite: bool = False, ) -> AsyncArrayV3: @@ -1998,7 +1998,7 @@ def create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # v2 only chunks: tuple[int, ...] | None = None, dimension_separator: Literal[".", "/"] | None = None, @@ -2142,7 +2142,7 @@ def _create( | None ) = None, codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, # v2 only chunks: tuple[int, ...] | None = None, dimension_separator: Literal[".", "/"] | None = None, @@ -4266,7 +4266,7 @@ async def from_array( zarr_format: ZarrFormat | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -4535,7 +4535,7 @@ async def init_array( zarr_format: ZarrFormat | None = 3, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, overwrite: bool = False, config: ArrayConfigLike | None = None, ) -> AnyAsyncArray: @@ -4751,7 +4751,7 @@ async def create_array( zarr_format: ZarrFormat | None = 3, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -4935,7 +4935,7 @@ def _parse_keep_array_attr( order: MemoryOrder | None, zarr_format: ZarrFormat | None, chunk_key_encoding: ChunkKeyEncodingLike | None, - dimension_names: DimensionNames, + dimension_names: DimensionNamesLike, ) -> tuple[ tuple[int, ...] | Literal["auto"], ShardsLike | None, @@ -4946,7 +4946,7 @@ def _parse_keep_array_attr( MemoryOrder | None, ZarrFormat, ChunkKeyEncodingLike | None, - DimensionNames, + DimensionNamesLike, ]: if isinstance(data, Array): if chunks == "keep": diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 275d062eba..ebffbc4f2f 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -15,6 +15,7 @@ Generic, Literal, NotRequired, + TypeAlias, TypedDict, TypeVar, cast, @@ -43,11 +44,12 @@ ChunkCoords = tuple[int, ...] ZarrFormat = Literal[2, 3] NodeType = Literal["array", "group"] -JSON = str | int | float | bool | Mapping[str, "JSON"] | Sequence["JSON"] | None +JSON: TypeAlias = str | int | float | bool | Mapping[str, "JSON"] | Sequence["JSON"] | None MemoryOrder = Literal["C", "F"] AccessModeLiteral = Literal["r", "r+", "a", "w", "w-"] ANY_ACCESS_MODE: Final = "r", "r+", "a", "w", "w-" -DimensionNames = Iterable[str | None] | None +DimensionNamesLike = Iterable[str | None] | None +DimensionNames = DimensionNamesLike # backwards compatibility TName = TypeVar("TName", bound=str) TConfig = TypeVar("TConfig", bound=Mapping[str, object]) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 080e90ff0f..17b8b541b1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -40,7 +40,7 @@ ZATTRS_JSON, ZGROUP_JSON, ZMETADATA_V2_JSON, - DimensionNames, + DimensionNamesLike, NodeType, ShapeLike, ZarrFormat, @@ -1032,7 +1032,7 @@ async def create_array( order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -2483,7 +2483,7 @@ def create( order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -2627,7 +2627,7 @@ def create_array( order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, @@ -3025,7 +3025,7 @@ def array( order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, config: ArrayConfigLike | None = None, diff --git a/src/zarr/core/metadata/common.py b/src/zarr/core/metadata/common.py index 44d3eb292b..bd62e08ac6 100644 --- a/src/zarr/core/metadata/common.py +++ b/src/zarr/core/metadata/common.py @@ -1,13 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Mapping + from zarr.core.common import JSON -def parse_attributes(data: dict[str, JSON] | None) -> dict[str, JSON]: +def parse_attributes(data: Mapping[str, Any] | None) -> dict[str, JSON]: if data is None: return {} - return data + return dict(data) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 5ce155bd9a..531d495b19 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -4,8 +4,15 @@ from typing import TYPE_CHECKING, NotRequired, TypedDict, TypeGuard, cast from zarr.abc.metadata import Metadata +from zarr.abc.serializable import JSONSerializable from zarr.core.buffer.core import default_buffer_prototype -from zarr.core.dtype import VariableLengthUTF8, ZDType, get_data_type_from_json +from zarr.core.dtype import ( + VariableLengthUTF8, + ZDType, + ZDTypeLike, + get_data_type_from_json, + parse_dtype, +) from zarr.core.dtype.common import check_dtype_spec_v3 if TYPE_CHECKING: @@ -33,8 +40,9 @@ from zarr.core.common import ( JSON, ZARR_JSON, - DimensionNames, + DimensionNamesLike, NamedConfig, + ShapeLike, parse_named_configuration, parse_shapelike, ) @@ -156,27 +164,35 @@ def check_allowed_extra_field(data: object) -> TypeGuard[AllowedExtraField]: def parse_extra_fields( - data: Mapping[str, AllowedExtraField] | None, + data: Mapping[str, object] | None, ) -> dict[str, AllowedExtraField]: if data is None: return {} - else: - conflict_keys = ARRAY_METADATA_KEYS & set(data.keys()) - if len(conflict_keys) > 0: - msg = ( - "Invalid extra fields. " - "The following keys: " - f"{sorted(conflict_keys)} " - "are invalid because they collide with keys reserved for use by the " - "array metadata document." - ) - raise ValueError(msg) - return dict(data) + + conflict_keys = ARRAY_METADATA_KEYS & set(data.keys()) + if len(conflict_keys) > 0: + msg = ( + "Invalid extra fields. " + "The following keys: " + f"{sorted(conflict_keys)} " + "are invalid because they collide with keys reserved for use by the " + "array metadata document." + ) + raise ValueError(msg) + + disallowed = {k: v for k, v in data.items() if not check_allowed_extra_field(v)} + if disallowed: + raise MetadataValidationError( + f"Disallowed extra fields: {sorted(disallowed.keys())}. " + 'Extra fields must be a mapping with "must_understand" set to False.' + ) + + return dict(data) # type: ignore[arg-type] class ArrayMetadataJSON_V3(TypedDict): """ - A typed dictionary model for zarr v3 metadata. + A typed dictionary model for Zarr v3 array metadata. """ zarr_format: Literal[3] @@ -194,9 +210,130 @@ class ArrayMetadataJSON_V3(TypedDict): ARRAY_METADATA_KEYS = set(ArrayMetadataJSON_V3.__annotations__.keys()) +ChunkGridLike = dict[str, JSON] | ChunkGrid | NamedConfig[str, Any] +CodecLike = Codec | dict[str, JSON] | NamedConfig[str, Any] | str + +# Required keys in ArrayMetadataJSONLike_V3 (excludes zarr_format and node_type, +# which are identity fields consumed by the I/O layer before reaching this point). +_REQUIRED_JSONLIKE_KEYS = frozenset( + {"shape", "data_type", "chunk_grid", "chunk_key_encoding", "codecs", "fill_value"} +) + + +def narrow_array_metadata_json(data: object) -> ArrayMetadataJSONLike_V3: + """ + Narrow an untrusted object to ``ArrayMetadataJSONLike_V3``. + + Performs only structural type checking — verifies that ``data`` is a mapping + with the expected keys and that each value has an acceptable Python type. + Does **not** validate the semantic correctness of values (e.g. whether a + data type string is a recognized dtype, or whether extra fields satisfy + ``must_understand``). That validation is the responsibility of ``__init__``. + + This function allows ``zarr_format`` and ``node_type`` to be absent. The + expectation is that these values have already been validated as a + precondition for invoking this function. + """ + errors: list[str] = [] + + if not isinstance(data, Mapping): + raise MetadataValidationError(f"Expected a mapping, got {type(data).__name__}") + + # --- required keys --- + missing = _REQUIRED_JSONLIKE_KEYS - set(data.keys()) + if missing: + errors.append(f"Missing required keys: {sorted(missing)}") + + # --- shape: Iterable (but not str or Mapping) --- + shape = data.get("shape") + if shape is not None and (not isinstance(shape, Iterable) or isinstance(shape, str | Mapping)): + errors.append(f"Invalid shape: expected an iterable, got {type(shape).__name__}") + + # --- data_type: str, Mapping, or ZDType --- + data_type_json = data.get("data_type") + if data_type_json is not None and not isinstance(data_type_json, str | Mapping | ZDType): + errors.append( + f"Invalid data_type: expected a string, mapping, or ZDType, got {type(data_type_json).__name__}" + ) + + # --- chunk_grid: Mapping or ChunkGrid --- + chunk_grid = data.get("chunk_grid") + if chunk_grid is not None and not isinstance(chunk_grid, Mapping | ChunkGrid): + errors.append( + f"Invalid chunk_grid: expected a mapping or ChunkGrid, got {type(chunk_grid).__name__}" + ) + + # --- chunk_key_encoding: Mapping or ChunkKeyEncoding --- + chunk_key_encoding = data.get("chunk_key_encoding") + if chunk_key_encoding is not None and not isinstance( + chunk_key_encoding, Mapping | ChunkKeyEncoding + ): + errors.append( + f"Invalid chunk_key_encoding: expected a mapping or ChunkKeyEncoding, got {type(chunk_key_encoding).__name__}" + ) + + # --- codecs: Iterable (but not str or Mapping) --- + codecs = data.get("codecs") + if codecs is not None and ( + not isinstance(codecs, Iterable) or isinstance(codecs, str | Mapping) + ): + errors.append(f"Invalid codecs: expected an iterable, got {type(codecs).__name__}") + + # --- fill_value: any type is allowed, just must be present (checked via required keys) --- + + # --- attributes (optional): Mapping --- + attributes = data.get("attributes") + if attributes is not None and not isinstance(attributes, Mapping): + errors.append(f"Invalid attributes: expected a mapping, got {type(attributes).__name__}") + + # --- dimension_names (optional): Iterable (but not str) --- + dimension_names = data.get("dimension_names") + if dimension_names is not None and ( + not isinstance(dimension_names, Iterable) or isinstance(dimension_names, str) + ): + errors.append( + f"Invalid dimension_names: expected an iterable or None, got {type(dimension_names).__name__}" + ) + + # --- storage_transformers (optional): Iterable (but not str or Mapping) --- + storage_transformers = data.get("storage_transformers") + if storage_transformers is not None and ( + not isinstance(storage_transformers, Iterable) + or isinstance(storage_transformers, str | Mapping) + ): + errors.append( + f"Invalid storage_transformers: expected an iterable, got {type(storage_transformers).__name__}" + ) + + if errors: + raise MetadataValidationError( + "Cannot interpret input as Zarr v3 array metadata:\n" + + "\n".join(f" - {e}" for e in errors) + ) + + return cast(ArrayMetadataJSONLike_V3, data) + + +class ArrayMetadataJSONLike_V3(TypedDict): + """ + A typed dictionary model of JSON-like input that can be used to create ArrayV3Metadata + """ + + zarr_format: NotRequired[Literal[3]] + node_type: NotRequired[Literal["array"]] + shape: ShapeLike + data_type: ZDTypeLike + chunk_grid: ChunkGridLike + chunk_key_encoding: ChunkKeyEncodingLike + codecs: Iterable[CodecLike] + fill_value: object + attributes: NotRequired[dict[str, JSON]] + dimension_names: NotRequired[DimensionNamesLike] + storage_transformers: NotRequired[Iterable[dict[str, JSON]]] + @dataclass(frozen=True, kw_only=True) -class ArrayV3Metadata(Metadata): +class ArrayV3Metadata(Metadata, JSONSerializable[ArrayMetadataJSON_V3, ArrayMetadataJSONLike_V3]): shape: tuple[int, ...] data_type: ZDType[TBaseDType, TBaseScalar] chunk_grid: ChunkGrid @@ -213,14 +350,14 @@ class ArrayV3Metadata(Metadata): def __init__( self, *, - shape: Iterable[int], - data_type: ZDType[TBaseDType, TBaseScalar], + shape: ShapeLike, + data_type: ZDTypeLike, chunk_grid: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any], chunk_key_encoding: ChunkKeyEncodingLike, fill_value: object, codecs: Iterable[Codec | dict[str, JSON] | NamedConfig[str, Any] | str], attributes: dict[str, JSON] | None, - dimension_names: DimensionNames, + dimension_names: DimensionNamesLike, storage_transformers: Iterable[dict[str, JSON]] | None = None, extra_fields: Mapping[str, AllowedExtraField] | None = None, ) -> None: @@ -229,27 +366,27 @@ def __init__( """ shape_parsed = parse_shapelike(shape) + data_type_parsed = parse_dtype(data_type, zarr_format=3) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = parse_chunk_key_encoding(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - # Note: relying on a type method is numpy-specific - fill_value_parsed = data_type.cast_scalar(fill_value) + fill_value_parsed = data_type_parsed.cast_scalar(fill_value) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) extra_fields_parsed = parse_extra_fields(extra_fields) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type, + dtype=data_type_parsed, fill_value=fill_value_parsed, config=ArrayConfig.from_dict({}), # TODO: config is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = tuple(c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial) - validate_codecs(codecs_parsed_partial, data_type) + validate_codecs(codecs_parsed_partial, data_type_parsed) object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "data_type", data_type) + object.__setattr__(self, "data_type", data_type_parsed) object.__setattr__(self, "chunk_grid", chunk_grid_parsed) object.__setattr__(self, "chunk_key_encoding", chunk_key_encoding_parsed) object.__setattr__(self, "codecs", codecs_parsed) @@ -410,7 +547,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: storage_transformers=_data_typed.get("storage_transformers", ()), # type: ignore[arg-type] ) - def to_dict(self) -> dict[str, JSON]: + def to_json(self) -> ArrayMetadataJSON_V3: out_dict = super().to_dict() extra_fields = out_dict.pop("extra_fields") out_dict = out_dict | extra_fields # type: ignore[operator] @@ -426,14 +563,44 @@ def to_dict(self) -> dict[str, JSON]: if out_dict["dimension_names"] is None: out_dict.pop("dimension_names") - # TODO: replace the `to_dict` / `from_dict` on the `Metadata`` class with - # to_json, from_json, and have ZDType inherit from `Metadata` - # until then, we have this hack here, which relies on the fact that to_dict will pass through - # any non-`Metadata` fields as-is. + # TODO: have ZDType inherit from JSONSerializable so we can remove this hack dtype_meta = out_dict["data_type"] if isinstance(dtype_meta, ZDType): out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable] - return out_dict + return cast(ArrayMetadataJSON_V3, out_dict) + + def to_dict(self) -> dict[str, JSON]: + return dict(self.to_json()) # type: ignore[arg-type] + + @classmethod + def from_json(cls, obj: ArrayMetadataJSONLike_V3) -> Self: + """ + Construct from a trusted, typed input. No validation of the input structure + is performed beyond what ``__init__`` already does. + """ + _known_keys = set(ArrayMetadataJSONLike_V3.__annotations__) + extra_fields = {k: v for k, v in obj.items() if k not in _known_keys} + return cls( + shape=obj["shape"], + data_type=obj["data_type"], + chunk_grid=obj["chunk_grid"], + chunk_key_encoding=obj["chunk_key_encoding"], + codecs=obj["codecs"], + fill_value=obj["fill_value"], + attributes=obj.get("attributes"), + dimension_names=obj.get("dimension_names"), + storage_transformers=obj.get("storage_transformers"), + extra_fields=extra_fields or None, # type: ignore[arg-type] + ) + + @classmethod + def try_from_json(cls, obj: object) -> Self: + """ + Construct from an untrusted input (e.g. JSON read from disk). + Validates the structure and raises a ``MetadataValidationError`` + listing all problems found. + """ + return cls.from_json(narrow_array_metadata_json(obj)) def update_shape(self, shape: tuple[int, ...]) -> Self: return replace(self, shape=shape) diff --git a/tests/conftest.py b/tests/conftest.py index 23a1e87d0a..86db02f6bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition from zarr.core.common import ( JSON, - DimensionNames, + DimensionNamesLike, MemoryOrder, ShapeLike, ZarrFormat, @@ -313,7 +313,7 @@ def create_array_metadata( zarr_format: ZarrFormat, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, ) -> ArrayV2Metadata | ArrayV3Metadata: """ Create array metadata @@ -452,7 +452,7 @@ def meta_from_array( zarr_format: ZarrFormat = 3, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, - dimension_names: DimensionNames = None, + dimension_names: DimensionNamesLike = None, ) -> ArrayV3Metadata | ArrayV2Metadata: """ Create array metadata from an array diff --git a/tests/test_metadata/test_v3.py b/tests/test_metadata/test_v3.py index 01ed921053..713239e8a8 100644 --- a/tests/test_metadata/test_v3.py +++ b/tests/test_metadata/test_v3.py @@ -2,7 +2,8 @@ import json import re -from typing import TYPE_CHECKING, Literal +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import numpy as np import pytest @@ -10,6 +11,7 @@ from zarr import consolidate_metadata, create_group from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import default_buffer_prototype +from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.core.config import config from zarr.core.dtype import UInt8, get_data_type_from_native_dtype @@ -19,6 +21,7 @@ from zarr.core.metadata.v3 import ( ArrayMetadataJSON_V3, ArrayV3Metadata, + ChunkGridLike, parse_codecs, parse_dimension_names, parse_zarr_format, @@ -31,12 +34,14 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from typing import Any from zarr.core.types import JSON from zarr.abc.codec import Codec + from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike + from zarr.core.common import DimensionNamesLike from zarr.core.metadata.v3 import ( @@ -461,3 +466,139 @@ def test_group_to_dict(use_consolidated: bool, attributes: None | dict[str, Any] expect = {"node_type": "group", "zarr_format": 3, "attributes": expect_attributes} assert meta.to_dict() == expect + + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +class _Unset: + """Sentinel for 'key not provided in the input dict'.""" + + +UNSET = _Unset() + + +@dataclass(frozen=True) +class Expect(Generic[TIn, TOut]): + """An (input, expected output) pair for parametrized tests.""" + + input: TIn + expected: TOut + + +@pytest.mark.parametrize("shape", [Expect((10,), (10,)), Expect([5], (5,))]) +@pytest.mark.parametrize("data_type", [Expect(UInt8(), UInt8())]) +@pytest.mark.parametrize( + "chunk_grid", + [ + Expect( + {"name": "regular", "configuration": {"chunk_shape": (10,)}}, + RegularChunkGrid(chunk_shape=(10,)), + ), + Expect(RegularChunkGrid(chunk_shape=(10,)), RegularChunkGrid(chunk_shape=(10,))), + ], +) +@pytest.mark.parametrize( + "chunk_key_encoding", + [ + Expect( + {"name": "default", "configuration": {"separator": "/"}}, + DefaultChunkKeyEncoding(separator="/"), + ), + Expect(DefaultChunkKeyEncoding(separator="."), DefaultChunkKeyEncoding(separator=".")), + ], +) +@pytest.mark.parametrize("codecs", [Expect((BytesCodec(),), (BytesCodec(endian=None),))]) +@pytest.mark.parametrize( + "fill_value", + [Expect(0, np.uint8(0)), Expect(42, np.uint8(42))], +) +@pytest.mark.parametrize( + "attributes", + [Expect({"key": "val"}, {"key": "val"}), Expect(None, {}), Expect(UNSET, {})], +) +@pytest.mark.parametrize( + "dimension_names", + [Expect(("x",), ("x",)), Expect(None, None), Expect(UNSET, None)], +) +@pytest.mark.parametrize( + "extra_fields", + [ + Expect(UNSET, {}), + Expect({"ext": {"must_understand": False}}, {"ext": {"must_understand": False}}), + ], +) +def test_from_json( + shape: Expect[Iterable[int], tuple[int, ...]], + data_type: Expect[UInt8, UInt8], + chunk_grid: Expect[ChunkGridLike, RegularChunkGrid], + chunk_key_encoding: Expect[ChunkKeyEncodingLike, DefaultChunkKeyEncoding], + codecs: Expect[tuple[Codec, ...], tuple[Codec, ...]], + fill_value: Expect[object, np.uint8], + attributes: Expect[dict[str, JSON] | _Unset, dict[str, JSON]], + dimension_names: Expect[DimensionNamesLike | _Unset, tuple[str | None, ...] | None], + extra_fields: Expect[dict[str, object] | _Unset, dict[str, object]], +) -> None: + """ + Test that ArrayV3Metadata.from_json correctly parses each field. + """ + data: dict[str, object] = { + "shape": shape.input, + "data_type": data_type.input, + "chunk_grid": chunk_grid.input, + "chunk_key_encoding": chunk_key_encoding.input, + "codecs": codecs.input, + "fill_value": fill_value.input, + } + if not isinstance(attributes.input, _Unset): + data["attributes"] = attributes.input + if not isinstance(dimension_names.input, _Unset): + data["dimension_names"] = dimension_names.input + if not isinstance(extra_fields.input, _Unset): + data.update(extra_fields.input) + + result = ArrayV3Metadata.from_json(data) # type: ignore[arg-type] + assert result.shape == shape.expected + assert result.data_type == data_type.expected + assert result.chunk_grid == chunk_grid.expected + assert result.chunk_key_encoding == chunk_key_encoding.expected + assert result.codecs == codecs.expected + assert result.fill_value == fill_value.expected + assert result.attributes == attributes.expected + assert result.dimension_names == dimension_names.expected + assert result.extra_fields == extra_fields.expected + + # try_from_json should produce the same result for the same input + assert ArrayV3Metadata.try_from_json(data) == result + + +_VALID_TRY_FROM_JSON_INPUT: dict[str, object] = { + "shape": (10,), + "data_type": "uint8", + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (10,)}}, + "chunk_key_encoding": {"name": "default", "configuration": {"separator": "/"}}, + "codecs": ({"name": "bytes"},), + "fill_value": 0, +} + + +@pytest.mark.parametrize( + ("data", "error_match"), + [ + ("not a dict", "Expected a mapping"), + ({}, "Missing required keys"), + ({**_VALID_TRY_FROM_JSON_INPUT, "shape": "not a shape"}, "shape"), + ({**_VALID_TRY_FROM_JSON_INPUT, "data_type": 12345}, "data_type"), + ({**_VALID_TRY_FROM_JSON_INPUT, "chunk_grid": "not a mapping"}, "chunk_grid"), + ( + {**_VALID_TRY_FROM_JSON_INPUT, "chunk_key_encoding": "not a mapping"}, + "chunk_key_encoding", + ), + ({**_VALID_TRY_FROM_JSON_INPUT, "codecs": "not iterable"}, "codecs"), + ({**_VALID_TRY_FROM_JSON_INPUT, "unknown_field": "value"}, "Disallowed extra fields"), + ], +) +def test_try_from_json_invalid(data: object, error_match: str) -> None: + with pytest.raises(MetadataValidationError, match=error_match): + ArrayV3Metadata.try_from_json(data)