diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5ed2c53..4ad6b3b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -50,5 +50,52 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@v5.4.0 with: + name: zarrv3 token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true + + - name: Run tests (zarr v2) + if: ${{ matrix.python == '3.13' }} + run: | + uv pip install "zarr>=2.17,<3" + uv run --no-sync python -m pytest -xv --cov=tszip --cov-report=xml --cov-branch -n2 + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5.4.0 + with: + name: zarrv2 + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + + zarr-compatibility: + name: Zarr v2/v3 Cross-compatibility + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4.2.2 + + - name: Install uv and set python version + uses: astral-sh/setup-uv@v6 + with: + python-version: 3.13 + version: "0.8.15" + + - name: Install dependencies (zarr v3) + run: | + uv venv + uv pip install -r pyproject.toml --extra test + + - name: Write test file with zarr v3 + run: uv run --no-sync python tests/zarr_cross_version_helper.py write test_v3.tsz + + - name: Switch to zarr v2 and test reading + run: | + uv pip install "zarr>=2.17,<3" + uv run --no-sync python tests/zarr_cross_version_helper.py read test_v3.tsz + uv run --no-sync python tests/zarr_cross_version_helper.py write test_v2.tsz + + - name: Switch back to zarr v3 and test reading both files + run: | + uv pip install "zarr>=3.0,<4" + uv run --no-sync python tests/zarr_cross_version_helper.py read test_v3.tsz + uv run --no-sync python tests/zarr_cross_version_helper.py read test_v2.tsz \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9c84866..936ecdc 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,6 +3,7 @@ -------------------- - Drop Python 3.9 support, require Python >= 3.10 (#112, benjeffery) +- Support zarr v3 (#114, benjeffery) -------------------- [0.2.4] - 2024-07-10 diff --git a/pyproject.toml b/pyproject.toml index a00ed06..83b5886 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "humanize", "tskit>=0.3.3", "numcodecs>=0.6.4", - "zarr<3", + "zarr>=2.17,<4", ] dynamic = ["version"] @@ -59,7 +59,8 @@ test = [ "pytest-cov==6.3.0", "pytest-xdist==3.8.0", "tskit==0.6.4", - "zarr==2.17.2", + "zarr==2.18.3; python_version == '3.10'", + "zarr==3.1.2; python_version >= '3.11'", "numcodecs>=0.6,<0.15.1", #Pinned due to https://github.com/zarr-developers/numcodecs/issues/733 ] docs = [ @@ -69,7 +70,7 @@ docs = [ "sphinx-argparse==0.5.2", "setuptools_scm==9.2.0", "tskit==0.6.4", - "zarr==2.18.7", + "zarr==3.1.2", "numcodecs>=0.6,<0.15.1", #Pinned due to https://github.com/zarr-developers/numcodecs/issues/733 ] dev = [ @@ -87,7 +88,7 @@ dev = [ "sphinx-issues", "setuptools_scm", "tskit", - "zarr<3", + "zarr", "msprime", "humanize", ] diff --git a/tests/test_compression.py b/tests/test_compression.py index 4f6dff7..f26520e 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -30,12 +30,12 @@ import numpy as np import pytest import tskit -import zarr import tszip import tszip.compression as compression import tszip.exceptions as exceptions import tszip.provenance as provenance +from tszip import compat class TestMinimalDtype(unittest.TestCase): @@ -294,8 +294,8 @@ def tearDown(self): def test_format_written(self): ts = msprime.simulate(10, random_seed=1) tszip.compress(ts, self.path) - with zarr.ZipStore(str(self.path), mode="r") as store: - root = zarr.group(store=store) + with compat.create_zip_store(str(self.path), mode="r") as store: + root = compat.create_zarr_group(store=store) self.assertEqual(root.attrs["format_name"], compression.FORMAT_NAME) self.assertEqual(root.attrs["format_version"], compression.FORMAT_VERSION) @@ -303,16 +303,16 @@ def test_provenance(self): ts = msprime.simulate(10, random_seed=1) for variants_only in [True, False]: tszip.compress(ts, self.path, variants_only=variants_only) - with zarr.ZipStore(str(self.path), mode="r") as store: - root = zarr.group(store=store) + with compat.create_zip_store(str(self.path), mode="r") as store: + root = compat.create_zarr_group(store=store) self.assertEqual( root.attrs["provenance"], provenance.get_provenance_dict({"variants_only": variants_only}), ) def write_file(self, attrs, path): - with zarr.ZipStore(str(path), mode="w") as store: - root = zarr.group(store=store) + with compat.create_zip_store(str(path), mode="w") as store: + root = compat.create_zarr_group(store=store) root.attrs.update(attrs) def test_missing_format_keys(self): diff --git a/tests/zarr_cross_version_helper.py b/tests/zarr_cross_version_helper.py new file mode 100644 index 0000000..9ab72ce --- /dev/null +++ b/tests/zarr_cross_version_helper.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +""" +Script to test zarr cross-version compatibility. +Usage: python test_zarr_cross_version.py [write|read] +""" +import pathlib +import sys + +import msprime +import tskit + +# Add parent directory to path so we can import tszip +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +import tszip # noqa: E402 + + +def all_fields_ts(edge_metadata=True, migrations=True): + """ + A tree sequence with data in all fields (except edge metadata is not set if + edge_metadata is False and migrations are not defined if migrations is False + (this is needed to test simplify, which doesn't allow either) + + """ + demography = msprime.Demography() + demography.add_population(name="A", initial_size=10_000) + demography.add_population(name="B", initial_size=5_000) + demography.add_population(name="C", initial_size=1_000) + demography.add_population(name="D", initial_size=500) + demography.add_population(name="E", initial_size=100) + demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") + ts = msprime.sim_ancestry( + samples={"A": 10, "B": 10}, + demography=demography, + sequence_length=5, + random_seed=42, + recombination_rate=1, + record_migrations=migrations, + record_provenance=True, + ) + ts = msprime.sim_mutations(ts, rate=0.001, random_seed=42) + tables = ts.dump_tables() + # Add locations to individuals + individuals_copy = tables.individuals.copy() + tables.individuals.clear() + for i, individual in enumerate(individuals_copy): + tables.individuals.append( + individual.replace(flags=i, location=[i, i + 1], parents=[i - 1, i - 1]) + ) + # Ensure all columns have unique values + nodes_copy = tables.nodes.copy() + tables.nodes.clear() + for i, node in enumerate(nodes_copy): + tables.nodes.append( + node.replace( + flags=i, + time=node.time + 0.00001 * i, + individual=i % len(tables.individuals), + population=i % len(tables.populations), + ) + ) + if migrations: + tables.migrations.add_row(left=0, right=1, node=21, source=1, dest=3, time=1001) + + # Add metadata + for name, table in tables.table_name_map.items(): + if name == "provenances": + continue + if name == "migrations" and not migrations: + continue + if name == "edges" and not edge_metadata: + continue + table.metadata_schema = tskit.MetadataSchema.permissive_json() + metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] + metadata, metadata_offset = tskit.pack_strings(metadatas) + table.set_columns( + **{ + **table.asdict(), + "metadata": metadata, + "metadata_offset": metadata_offset, + } + ) + tables.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.metadata = "Test metadata" + tables.time_units = "Test time units" + + tables.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.reference_sequence.metadata = "Test reference metadata" + tables.reference_sequence.data = "A" * int(ts.sequence_length) + tables.reference_sequence.url = "http://example.com/a_reference" + + # Add some more rows to provenance to have enough for testing. + for i in range(3): + tables.provenances.add_row(record="A", timestamp=str(i)) + + return tables.tree_sequence() + + +def write_test_file(filename): + """Write a test file with current zarr version""" + ts = all_fields_ts() + tszip.compress(ts, filename) + ts2 = tszip.decompress(filename) + ts.tables.assert_equals(ts2.tables) + + +def read_test_file(filename): + """Read and verify a test file with current zarr version""" + try: + tszip.decompress(filename) + except Exception: + sys.exit(1) + + +if __name__ == "__main__": + action = sys.argv[1] + filename = sys.argv[2] + if action == "write": + write_test_file(filename) + elif action == "read": + read_test_file(filename) diff --git a/tszip/compat.py b/tszip/compat.py new file mode 100644 index 0000000..62e859f --- /dev/null +++ b/tszip/compat.py @@ -0,0 +1,93 @@ +# MIT License +# +# Copyright (c) 2025 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Compatibility layer for zarr v2/v3 API differences +""" +import zarr + +ZARR_V3 = zarr.__version__.startswith("3.") + + +if ZARR_V3: + from zarr.storage import ZipStore + + def create_zip_store(path, mode="r"): + return ZipStore(path, mode=mode) + + def create_zarr_group(store=None): + if store is None: + return zarr.create_group(zarr_format=2) + else: + mode = "r" if getattr(store, "read_only", False) else "a" + return zarr.open_group(store=store, zarr_format=2, mode=mode) + + def create_empty_array( + group, name, shape, dtype, chunks=None, filters=None, compressor=None + ): + return group.empty( + name=name, + shape=shape, + dtype=dtype, + chunks=chunks, + zarr_format=2, + filters=filters, + compressor=compressor, + ) + + def get_nbytes_stored(array): + return array.nbytes_stored() + + def group_items(group): + return group.members() + + def visit_arrays(group, visitor): + for array in group.array_values(): + visitor(array) + +else: + + def create_zip_store(path, mode="r"): + return zarr.ZipStore(path, mode=mode) + + def create_zarr_group(store=None): + return zarr.group(store=store) + + def create_empty_array( + group, name, shape, dtype, chunks=None, filters=None, compressor=None + ): + return group.empty( + name, + shape=shape, + dtype=dtype, + chunks=chunks, + filters=filters, + compressor=compressor, + ) + + def get_nbytes_stored(array): + return array.nbytes_stored + + def group_items(group): + return group.items() + + def visit_arrays(group, visitor): + group.visitvalues(visitor) diff --git a/tszip/compression.py b/tszip/compression.py index 6bea6fd..e34806c 100644 --- a/tszip/compression.py +++ b/tszip/compression.py @@ -31,7 +31,6 @@ import tempfile import warnings import zipfile -from typing import Mapping import humanize import numcodecs @@ -39,6 +38,7 @@ import tskit import zarr +from . import compat from . import exceptions from . import provenance @@ -98,8 +98,8 @@ def compress(ts, destination, variants_only=False): with tempfile.TemporaryDirectory(dir=destdir, prefix=".tszip_work_") as tmpdir: filename = pathlib.Path(tmpdir, "tmp.trees.tgz") logging.debug(f"Writing to temporary file {filename}") - with zarr.ZipStore(filename, mode="w") as store: - root = zarr.group(store=store) + with compat.create_zip_store(filename, mode="w") as store: + root = compat.create_zarr_group(store=store) compress_zarr(ts, root, variants_only=variants_only) if is_path: os.replace(filename, destination) @@ -145,22 +145,26 @@ def compress(self, root, compressor): filters = None if self.delta_filter: filters = [numcodecs.Delta(dtype=dtype)] - compressed_array = root.empty( + compressed_array = compat.create_empty_array( + root, self.name, - chunks=chunks, shape=shape, dtype=dtype, + chunks=chunks, filters=filters, compressor=compressor, ) compressed_array[:] = self.array ratio = 0 if compressed_array.nbytes > 0: - ratio = compressed_array.nbytes / compressed_array.nbytes_stored + nbytes_stored = compat.get_nbytes_stored(compressed_array) + ratio = compressed_array.nbytes / nbytes_stored logger.debug( "{}: output={} compression={:.1f}".format( self.name, - humanize.naturalsize(compressed_array.nbytes_stored, binary=True), + humanize.naturalsize( + compat.get_nbytes_stored(compressed_array), binary=True + ), ratio, ) ) @@ -285,15 +289,17 @@ def check_format(root): def load_zarr(path): path = str(path) try: - store = zarr.ZipStore(path, mode="r") + store = compat.create_zip_store(path, mode="r") + root = compat.create_zarr_group(store=store) except zipfile.BadZipFile as bzf: raise exceptions.FileFormatError("File is not in tszip format") from bzf - root = zarr.group(store=store) + try: check_format(root) yield root finally: - store.close() + if hasattr(store, "close"): + store.close() def decompress_zarr(root): @@ -307,9 +313,10 @@ def decompress_zarr(root): "migrations/right", "sites/position", ] - for key, value in root.items(): - if isinstance(value, Mapping): - for sub_key, sub_value in value.items(): + for key, value in compat.group_items(root): + if hasattr(value, "members") or hasattr(value, "items"): + # This is a zarr Group, iterate over its contents + for sub_key, sub_value in compat.group_items(value): if f"{key}/{sub_key}" in quantised_arrays: dict_repr.setdefault(key, {})[sub_key] = coordinates[sub_value] elif sub_key.endswith("metadata_schema") or (key, sub_key) in [ @@ -323,12 +330,14 @@ def decompress_zarr(root): dict_repr.setdefault(key, {})[sub_key] = bytes(sub_value) else: dict_repr.setdefault(key, {})[sub_key] = sub_value - elif key.endswith("metadata_schema") or key == "time_units": - dict_repr[key] = bytes(value).decode("utf-8") - elif key.endswith("metadata"): - dict_repr[key] = bytes(value) else: - dict_repr[key] = value + # This is an array + if key.endswith("metadata_schema") or key == "time_units": + dict_repr[key] = bytes(value).decode("utf-8") + elif key.endswith("metadata"): + dict_repr[key] = bytes(value) + else: + dict_repr[key] = value return tskit.TableCollection.fromdict(dict_repr).tree_sequence() @@ -336,16 +345,17 @@ def print_summary(path, verbosity=0): arrays = [] def visitor(array): - if isinstance(array, zarr.core.Array): + if isinstance(array, zarr.Array): arrays.append(array) with load_zarr(path) as root: - root.visitvalues(visitor) + compat.visit_arrays(root, visitor) - arrays.sort(key=lambda x: x.nbytes_stored) + arrays.sort(key=lambda x: compat.get_nbytes_stored(x)) max_name_len = max(len(array.name) for array in arrays) storeds = [ - humanize.naturalsize(array.nbytes_stored, binary=True) for array in arrays + humanize.naturalsize(compat.get_nbytes_stored(array), binary=True) + for array in arrays ] max_stored_len = max(len(size) for size in storeds) actuals = [humanize.naturalsize(array.nbytes, binary=True) for array in arrays] @@ -374,7 +384,7 @@ def visitor(array): for array, stored, actual in zip(arrays, storeds, actuals): ratio = 0 if array.nbytes > 0: - ratio = array.nbytes_stored / array.nbytes + ratio = compat.get_nbytes_stored(array) / array.nbytes line = fmt.format( array.name, max_name_len,