From 5a22c7cb3c78df975d6eaeef4bffb6583018f5e4 Mon Sep 17 00:00:00 2001 From: Rogdham Date: Sun, 7 Dec 2025 12:04:11 +0100 Subject: [PATCH] refactor: embed type hints, check with mypy --- .github/workflows/build.yml | 16 ++ CHANGELOG.md | 1 + pyproject.toml | 48 +++++- requirements-dev.txt | 1 + requirements-type.txt | 1 + src/pyzstd/__init__.py | 283 +++++++++++++++++++++---------- src/pyzstd/__init__.pyi | 274 ------------------------------ src/pyzstd/__main__.py | 76 ++++++--- src/pyzstd/_seekable_zstdfile.py | 220 +++++++++++++----------- tests/test_seekable.py | 2 +- 10 files changed, 432 insertions(+), 490 deletions(-) create mode 100644 requirements-type.txt delete mode 100644 src/pyzstd/__init__.pyi diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3fb04d6..fb02560 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -89,6 +89,22 @@ jobs: - name: ruff format run: ruff format --check + type: + name: Type + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: 3.14 + - name: Install dependencies + run: python -m pip install -r requirements-type.txt + - name: Create _version.py + run: echo '__version__ = ""' > src/pyzstd/_version.py + - name: mypy + run: mypy + publish: name: Publish to PyPI if: startsWith(github.ref, 'refs/tags') diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cb56cd..61120da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. - Remove git submodule usage - Drop support for Python 3.9 and below - Use `ruff` as formatter and linter +- Embed type hints in Python code, and check with `mypy` ## 0.18.0 (October 5, 2025) diff --git a/pyproject.toml b/pyproject.toml index 452c2c8..021c8b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,48 @@ version-file = "src/pyzstd/_version.py" source = "vcs" +# +# mypy +# + +[tool.mypy] +# Import discovery +files = "src" +ignore_missing_imports = false +follow_imports = "normal" +# Platform configuration +python_version = "3.14" +# Disallow dynamic typing +disallow_any_unimported = true +disallow_any_decorated = true +disallow_any_generics = true +disallow_subclassing_any = true +# Untyped definitions and calls +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +# None and Optional handling +no_implicit_optional = true +strict_optional = true +# Configuring warning +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_return_any = true +warn_unreachable = true +# Suppressing errors +ignore_errors = false +# Miscellaneous strictness flags +strict_equality = true +# Configuring error messages +show_error_context = true +show_error_codes = true +# Miscellaneous +warn_unused_configs = true + + # # ruff # @@ -67,15 +109,11 @@ source = "vcs" [tool.ruff] src = ["src"] target-version = "py310" -extend-exclude = [ - "tests", - '*.pyi', # FIXME -] +extend-exclude = ["tests"] [tool.ruff.lint] select = ["ALL"] ignore = [ - "ANN", # FIXME "C901", "COM812", "D", diff --git a/requirements-dev.txt b/requirements-dev.txt index 3d574ce..873a45e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ -e . -r requirements-lint.txt +-r requirements-type.txt diff --git a/requirements-type.txt b/requirements-type.txt new file mode 100644 index 0000000..4f2bbba --- /dev/null +++ b/requirements-type.txt @@ -0,0 +1 @@ +mypy==1.19.0 diff --git a/src/pyzstd/__init__.py b/src/pyzstd/__init__.py index ff98234..4f94298 100644 --- a/src/pyzstd/__init__.py +++ b/src/pyzstd/__init__.py @@ -1,6 +1,18 @@ +from collections.abc import Callable, Mapping from enum import IntEnum +from io import TextIOWrapper +from os import PathLike import sys -from typing import NamedTuple +from typing import ( + BinaryIO, + ClassVar, + Literal, + NamedTuple, + NoReturn, + TypeAlias, + cast, + overload, +) import warnings if sys.version_info < (3, 14): @@ -8,10 +20,15 @@ else: from compression import zstd -try: - from warnings import deprecated -except ImportError: +if sys.version_info < (3, 13): from typing_extensions import deprecated +else: + from warnings import deprecated + +if sys.version_info < (3, 12): + from typing_extensions import Buffer +else: + from collections.abc import Buffer from pyzstd._version import __version__ # noqa: F401 @@ -56,13 +73,24 @@ class _DeprecatedPlaceholder: - def __repr__(self): + def __repr__(self) -> str: return "" _DEPRECATED_PLACEHOLDER = _DeprecatedPlaceholder() +Strategy = zstd.Strategy +ZstdError = zstd.ZstdError +ZstdDict = zstd.ZstdDict +train_dict = zstd.train_dict +finalize_dict = zstd.finalize_dict +get_frame_info = zstd.get_frame_info +get_frame_size = zstd.get_frame_size +zstd_version = zstd.zstd_version +zstd_version_info = zstd.zstd_version_info + + class CParameter(IntEnum): """Compression parameters""" @@ -90,7 +118,7 @@ class CParameter(IntEnum): jobSize = zstd.CompressionParameter.job_size # noqa: N815 overlapLog = zstd.CompressionParameter.overlap_log # noqa: N815 - def bounds(self): + def bounds(self) -> tuple[int, int]: """Return lower and upper bounds of a compression parameter, both inclusive.""" return zstd.CompressionParameter(self).bounds() @@ -100,12 +128,20 @@ class DParameter(IntEnum): windowLogMax = zstd.DecompressionParameter.window_log_max # noqa: N815 - def bounds(self): + def bounds(self) -> tuple[int, int]: """Return lower and upper bounds of a decompression parameter, both inclusive.""" return zstd.DecompressionParameter(self).bounds() -def _convert_level_or_option(level_or_option, mode): +_LevelOrOption: TypeAlias = int | Mapping[int, int] | None +_Option: TypeAlias = Mapping[int, int] | None +_ZstdDict: TypeAlias = ZstdDict | tuple[ZstdDict, int] | None +_StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes] + + +def _convert_level_or_option( + level_or_option: _LevelOrOption | _Option, mode: str +) -> Mapping[int, int] | None: """Transform pyzstd params into PEP-784 `options` param""" if not isinstance(mode, str): raise TypeError(f"Invalid mode type: {mode}") @@ -135,14 +171,14 @@ def _convert_level_or_option(level_or_option, mode): class ZstdCompressor: """A streaming compressor. Thread-safe at method level.""" - CONTINUE = zstd.ZstdCompressor.CONTINUE + CONTINUE: ClassVar[Literal[0]] = zstd.ZstdCompressor.CONTINUE """Used for mode parameter in .compress() method. Collect more data, encoder decides when to output compressed result, for optimal compression ratio. Usually used for traditional streaming compression. """ - FLUSH_BLOCK = zstd.ZstdCompressor.FLUSH_BLOCK + FLUSH_BLOCK: ClassVar[Literal[1]] = zstd.ZstdCompressor.FLUSH_BLOCK """Used for mode parameter in .compress(), .flush() methods. Flush any remaining data, but don't close the current frame. Usually used for @@ -155,7 +191,7 @@ class ZstdCompressor: necessary. """ - FLUSH_FRAME = zstd.ZstdCompressor.FLUSH_FRAME + FLUSH_FRAME: ClassVar[Literal[2]] = zstd.ZstdCompressor.FLUSH_FRAME """Used for mode parameter in .compress(), .flush() methods. Flush any remaining data, and close the current frame. Usually used for @@ -168,7 +204,9 @@ class ZstdCompressor: only decompress single frame data. Use it only when necessary. """ - def __init__(self, level_or_option=None, zstd_dict=None): + def __init__( + self, level_or_option: _LevelOrOption = None, zstd_dict: _ZstdDict = None + ) -> None: """Initialize a ZstdCompressor object. Parameters @@ -177,11 +215,16 @@ def __init__(self, level_or_option=None, zstd_dict=None): parameters. zstd_dict: A ZstdDict object, pre-trained zstd dictionary. """ + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 self._compressor = zstd.ZstdCompressor( options=_convert_level_or_option(level_or_option, "w"), zstd_dict=zstd_dict ) - def compress(self, data, mode=zstd.ZstdCompressor.CONTINUE): + def compress( + self, data: Buffer, mode: Literal[0, 1, 2] = zstd.ZstdCompressor.CONTINUE + ) -> bytes: """Provide data to the compressor object. Return a chunk of compressed data if possible, or b'' otherwise. @@ -191,7 +234,7 @@ def compress(self, data, mode=zstd.ZstdCompressor.CONTINUE): """ return self._compressor.compress(data, mode) - def flush(self, mode=zstd.ZstdCompressor.FLUSH_FRAME): + def flush(self, mode: Literal[1, 2] = zstd.ZstdCompressor.FLUSH_FRAME) -> bytes: """Flush any remaining data in internal buffer. Since zstd data consists of one or more independent frames, the compressor @@ -202,7 +245,7 @@ def flush(self, mode=zstd.ZstdCompressor.FLUSH_FRAME): """ return self._compressor.flush(mode) - def _set_pledged_input_size(self, size): + def _set_pledged_input_size(self, size: int | None) -> None: """*This is an undocumented method, because it may be used incorrectly.* Set uncompressed content size of a frame, the size will be written into the @@ -218,7 +261,7 @@ def _set_pledged_input_size(self, size): return self._compressor.set_pledged_input_size(size) @property - def last_mode(self): + def last_mode(self) -> Literal[0, 1, 2]: """The last mode used to this compressor object, its value can be .CONTINUE, .FLUSH_BLOCK, .FLUSH_FRAME. Initialized to .FLUSH_FRAME. @@ -227,7 +270,7 @@ def last_mode(self): """ return self._compressor.last_mode - def __reduce__(self): + def __reduce__(self) -> NoReturn: raise TypeError(f"Cannot pickle {type(self)} object.") @@ -235,18 +278,21 @@ class ZstdDecompressor: """A streaming decompressor, it stops after a frame is decompressed. Thread-safe at method level.""" - def __init__(self, zstd_dict=None, option=None): + def __init__(self, zstd_dict: _ZstdDict = None, option: _Option = None) -> None: """Initialize a ZstdDecompressor object. Parameters zstd_dict: A ZstdDict object, pre-trained zstd dictionary. option: A dict object that contains advanced decompression parameters. """ + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 self._decompressor = zstd.ZstdDecompressor( zstd_dict=zstd_dict, options=_convert_level_or_option(option, "r") ) - def decompress(self, data, max_length=-1): + def decompress(self, data: Buffer, max_length: int = -1) -> bytes: """Decompress data, return a chunk of decompressed data if possible, or b'' otherwise. @@ -261,13 +307,13 @@ def decompress(self, data, max_length=-1): return self._decompressor.decompress(data, max_length) @property - def eof(self): + def eof(self) -> bool: """True means the end of the first frame has been reached. If decompress data after that, an EOFError exception will be raised.""" return self._decompressor.eof @property - def needs_input(self): + def needs_input(self) -> bool: """If the max_length output limit in .decompress() method has been reached, and the decompressor has (or may has) unconsumed input data, it will be set to False. In this case, pass b'' to .decompress() method may output further data. @@ -275,12 +321,12 @@ def needs_input(self): return self._decompressor.needs_input @property - def unused_data(self): + def unused_data(self) -> bytes: """A bytes object. When ZstdDecompressor object stops after a frame is decompressed, unused input data after the frame. Otherwise this will be b''.""" return self._decompressor.unused_data - def __reduce__(self): + def __reduce__(self) -> NoReturn: raise TypeError(f"Cannot pickle {type(self)} object.") @@ -288,25 +334,27 @@ class EndlessZstdDecompressor: """A streaming decompressor, accepts multiple concatenated frames. Thread-safe at method level.""" - def __init__(self, zstd_dict=None, option=None): + def __init__(self, zstd_dict: _ZstdDict = None, option: _Option = None) -> None: """Initialize an EndlessZstdDecompressor object. Parameters zstd_dict: A ZstdDict object, pre-trained zstd dictionary. option: A dict object that contains advanced decompression parameters. """ - self._zstd_dict = zstd_dict + self._zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 self._options = _convert_level_or_option(option, "r") self._reset() - def _reset(self, data=b""): + def _reset(self, data: bytes = b"") -> None: self._decompressor = zstd.ZstdDecompressor( zstd_dict=self._zstd_dict, options=self._options ) self._buffer = data self._at_frame_edge = not data - def decompress(self, data, max_length=-1): + def decompress(self, data: Buffer, max_length: int = -1) -> bytes: """Decompress data, return a chunk of decompressed data if possible, or b'' otherwise. @@ -336,7 +384,7 @@ def decompress(self, data, max_length=-1): return out @property - def at_frame_edge(self): + def at_frame_edge(self) -> bool: """True when both the input and output streams are at a frame edge, means a frame is completely decoded and fully flushed, or the decompressor just be initialized. @@ -346,7 +394,7 @@ def at_frame_edge(self): return self._at_frame_edge @property - def needs_input(self): + def needs_input(self) -> bool: """If the max_length output limit in .decompress() method has been reached, and the decompressor has (or may has) unconsumed input data, it will be set to False. In this case, pass b'' to .decompress() method may output further data. @@ -355,11 +403,13 @@ def needs_input(self): self._at_frame_edge or self._decompressor.needs_input ) - def __reduce__(self): + def __reduce__(self) -> NoReturn: raise TypeError(f"Cannot pickle {type(self)} object.") -def compress(data, level_or_option=None, zstd_dict=None): +def compress( + data: Buffer, level_or_option: _LevelOrOption = None, zstd_dict: _ZstdDict = None +) -> bytes: """Compress a block of data, return a bytes object. Compressing b'' will get an empty content frame (9 bytes or more). @@ -371,6 +421,9 @@ def compress(data, level_or_option=None, zstd_dict=None): parameters. zstd_dict: A ZstdDict object, pre-trained dictionary for compression. """ + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 return zstd.compress( data, options=_convert_level_or_option(level_or_option, "w"), @@ -378,7 +431,9 @@ def compress(data, level_or_option=None, zstd_dict=None): ) -def decompress(data, zstd_dict=None, option=None): +def decompress( + data: Buffer, zstd_dict: _ZstdDict = None, option: _Option = None +) -> bytes: """Decompress a zstd data, return a bytes object. Support multiple concatenated frames. @@ -388,6 +443,9 @@ def decompress(data, zstd_dict=None, option=None): zstd_dict: A ZstdDict object, pre-trained zstd dictionary. option: A dict object, contains advanced decompression parameters. """ + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 return zstd.decompress( data, options=_convert_level_or_option(option, "r"), zstd_dict=zstd_dict ) @@ -397,19 +455,26 @@ def decompress(data, zstd_dict=None, option=None): "See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.RichMemZstdCompressor" ) class RichMemZstdCompressor: - def __init__(self, level_or_option=None, zstd_dict=None): - self._compress_kwargs = { - "options": _convert_level_or_option(level_or_option, "w"), - "zstd_dict": zstd_dict, - } + def __init__( + self, level_or_option: _LevelOrOption = None, zstd_dict: _ZstdDict = None + ) -> None: + self._options = _convert_level_or_option(level_or_option, "w") + self._zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 - def compress(self, data): - return zstd.compress(data, **self._compress_kwargs) + def compress(self, data: Buffer) -> bytes: + return zstd.compress(data, options=self._options, zstd_dict=self._zstd_dict) - def __reduce__(self): + def __reduce__(self) -> NoReturn: raise TypeError(f"Cannot pickle {type(self)} object.") +richmem_compress = deprecated( + "See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.richmem_compress" +)(compress) + + class ZstdFile(zstd.ZstdFile): """A file object providing transparent zstd (de)compression. @@ -421,19 +486,16 @@ class ZstdFile(zstd.ZstdFile): supports the Buffer Protocol. """ - FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK - FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME - def __init__( self, - filename, - mode="r", + filename: _StrOrBytesPath | BinaryIO, + mode: Literal["r", "rb", "w", "wb", "x", "xb", "a", "ab"] = "r", *, - level_or_option=None, - zstd_dict=None, - read_size=_DEPRECATED_PLACEHOLDER, - write_size=_DEPRECATED_PLACEHOLDER, - ): + level_or_option: _LevelOrOption | _Option = None, + zstd_dict: _ZstdDict = None, + read_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, + write_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, + ) -> None: """Open a zstd compressed file in binary mode. filename can be either an actual file name (given as a str, bytes, or @@ -465,6 +527,9 @@ def __init__( DeprecationWarning, stacklevel=2, ) + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 super().__init__( filename, mode, @@ -473,16 +538,44 @@ def __init__( ) +@overload +def open( # noqa: A001 + filename: _StrOrBytesPath | BinaryIO, + mode: Literal["r", "rb", "w", "wb", "a", "ab", "x", "xb"] = "rb", + *, + level_or_option: _LevelOrOption | _Option = None, + zstd_dict: _ZstdDict = None, + encoding: None = None, + errors: None = None, + newline: None = None, +) -> zstd.ZstdFile: ... + + +@overload +def open( # noqa: A001 + filename: _StrOrBytesPath | BinaryIO, + mode: Literal["rt", "wt", "at", "xt"], + *, + level_or_option: _LevelOrOption | _Option = None, + zstd_dict: _ZstdDict = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, +) -> TextIOWrapper: ... + + def open( # noqa: A001 - filename, - mode="rb", + filename: _StrOrBytesPath | BinaryIO, + mode: Literal[ + "r", "rb", "w", "wb", "a", "ab", "x", "xb", "rt", "wt", "at", "xt" + ] = "rb", *, - level_or_option=None, - zstd_dict=None, - encoding=None, - errors=None, - newline=None, -): + level_or_option: _LevelOrOption | _Option = None, + zstd_dict: _ZstdDict = None, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, +) -> zstd.ZstdFile | TextIOWrapper: """Open a zstd compressed file in binary or text mode. filename can be either an actual file name (given as a str, bytes, or @@ -507,6 +600,9 @@ def open( # noqa: A001 io.TextIOWrapper instance with the specified encoding, error handling behavior, and line ending(s). """ + zstd_dict = cast( + "ZstdDict | None", zstd_dict + ) # https://github.com/python/typeshed/pull/15113 return zstd.open( filename, mode, @@ -518,26 +614,38 @@ def open( # noqa: A001 ) -def _create_callback(output_stream, callback): +def _create_callback( + output_stream: BinaryIO | None, + callback: Callable[[int, int, memoryview, memoryview], None] | None, +) -> Callable[[int, int, bytes, bytes], None]: if output_stream is None: if callback is None: raise TypeError( "At least one of output_stream argument and callback argument should be non-None." ) - def cb(total_input, total_output, data_in, data_out): + def cb( + total_input: int, total_output: int, data_in: bytes, data_out: bytes + ) -> None: callback( total_input, total_output, memoryview(data_in), memoryview(data_out) ) elif callback is None: - def cb(total_input, total_output, data_in, data_out): # noqa: ARG001 + def cb( + total_input: int, # noqa: ARG001 + total_output: int, # noqa: ARG001 + data_in: bytes, # noqa: ARG001 + data_out: bytes, + ) -> None: output_stream.write(data_out) else: - def cb(total_input, total_output, data_in, data_out): + def cb( + total_input: int, total_output: int, data_in: bytes, data_out: bytes + ) -> None: output_stream.write(data_out) callback( total_input, total_output, memoryview(data_in), memoryview(data_out) @@ -550,16 +658,16 @@ def cb(total_input, total_output, data_in, data_out): "See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.compress_stream" ) def compress_stream( - input_stream, - output_stream, + input_stream: BinaryIO, + output_stream: BinaryIO | None, *, - level_or_option=None, - zstd_dict=None, - pledged_input_size=None, - read_size=131_072, - write_size=_DEPRECATED_PLACEHOLDER, # noqa: ARG001 - callback=None, -): + level_or_option: _LevelOrOption = None, + zstd_dict: _ZstdDict = None, + pledged_input_size: int | None = None, + read_size: int = 131_072, + write_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, # noqa: ARG001 + callback: Callable[[int, int, memoryview, memoryview], None] | None = None, +) -> tuple[int, int]: """Compresses input_stream and writes the compressed data to output_stream, it doesn't close the streams. @@ -618,15 +726,15 @@ def compress_stream( "See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.decompress_stream" ) def decompress_stream( - input_stream, - output_stream, + input_stream: BinaryIO, + output_stream: BinaryIO | None, *, - zstd_dict=None, - option=None, - read_size=131_075, - write_size=131_072, - callback=None, -): + zstd_dict: _ZstdDict = None, + option: _Option = None, + read_size: int = 131_075, + write_size: int = 131_072, + callback: Callable[[int, int, memoryview, memoryview], None] | None = None, +) -> tuple[int, int]: """Decompresses input_stream and writes the decompressed data to output_stream, it doesn't close the streams. @@ -684,21 +792,6 @@ def decompress_stream( return total_input, total_output -Strategy = zstd.Strategy -ZstdError = zstd.ZstdError -richmem_compress = deprecated( - "See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.richmem_compress" -)(compress) -ZstdDict = zstd.ZstdDict -train_dict = zstd.train_dict -finalize_dict = zstd.finalize_dict -get_frame_info = zstd.get_frame_info -get_frame_size = zstd.get_frame_size -zstd_version = zstd.zstd_version -zstd_version_info = zstd.zstd_version_info -zstd_support_multithread = CParameter.nbWorkers.bounds() != (0, 0) - - class CompressionValues(NamedTuple): default: int min: int @@ -708,6 +801,8 @@ class CompressionValues(NamedTuple): compressionLevel_values = CompressionValues( # noqa: N816 zstd.COMPRESSION_LEVEL_DEFAULT, *CParameter.compressionLevel.bounds() ) +zstd_support_multithread = CParameter.nbWorkers.bounds() != (0, 0) + # import here to avoid circular dependency issues from ._seekable_zstdfile import SeekableFormatError, SeekableZstdFile # noqa: E402 diff --git a/src/pyzstd/__init__.pyi b/src/pyzstd/__init__.pyi deleted file mode 100644 index d8e4d17..0000000 --- a/src/pyzstd/__init__.pyi +++ /dev/null @@ -1,274 +0,0 @@ -from collections.abc import ByteString, Callable, Iterable -from enum import IntEnum -import io -from os import PathLike -from typing import ( - BinaryIO, - ClassVar, - Literal, - NamedTuple, - TextIO, - TypeAlias, - overload, -) - -try: - from warnings import deprecated -except ImportError: - from typing_extensions import deprecated - -__version__: str -zstd_version: str -zstd_version_info: tuple[int, int, int] -zstd_support_multithread: bool - -class CompressionValues(NamedTuple): - default: int - min: int - max: int - -compressionLevel_values: CompressionValues - -class Strategy(IntEnum): - fast: int - dfast: int - greedy: int - lazy: int - lazy2: int - btlazy2: int - btopt: int - btultra: int - btultra2: int - -class CParameter(IntEnum): - compressionLevel: int - windowLog: int - hashLog: int - chainLog: int - searchLog: int - minMatch: int - targetLength: int - strategy: int - targetCBlockSize: int - - enableLongDistanceMatching: int - ldmHashLog: int - ldmMinMatch: int - ldmBucketSizeLog: int - ldmHashRateLog: int - - contentSizeFlag: int - checksumFlag: int - dictIDFlag: int - - nbWorkers: int - jobSize: int - overlapLog: int - - def bounds(self) -> tuple[int, int]: ... - -class DParameter(IntEnum): - windowLogMax: int - - def bounds(self) -> tuple[int, int]: ... - -ZstdDictInfo: TypeAlias = tuple[ZstdDict, int] -class ZstdDict: - dict_content: bytes - dict_id: int - - as_digested_dict: ZstdDictInfo - as_undigested_dict: ZstdDictInfo - as_prefix: ZstdDictInfo - - def __init__(self, - dict_content, - is_raw: bool = False) -> None: ... - - def __len__(self) -> int: ... - -class ZstdCompressor: - CONTINUE: ClassVar[Literal[0]] - FLUSH_BLOCK: ClassVar[Literal[1]] - FLUSH_FRAME: ClassVar[Literal[2]] - - last_mode: Literal[0, 1, 2] - - def __init__(self, - level_or_option: None | int | dict[CParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None) -> None: ... - - def compress(self, - data, - mode: Literal[0, 1, 2] = ...) -> bytes: ... - - def flush(self, - mode: Literal[1, 2] = ...) -> bytes: ... - - def _set_pledged_input_size(self, size: int | None) -> None: ... - -@deprecated("See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.RichMemZstdCompressor") -class RichMemZstdCompressor: - def __init__(self, - level_or_option: None | int | dict[CParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None) -> None: ... - - def compress(self, data) -> bytes: ... - -class ZstdDecompressor: - needs_input: bool - eof: bool - unused_data: bytes - - def __init__(self, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - option: dict[DParameter, int] | None = None) -> None: ... - - def decompress(self, - data: ByteString, - max_length: int = -1) -> bytes: ... - -class EndlessZstdDecompressor: - needs_input: bool - at_frame_edge: bool - - def __init__(self, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - option: dict[DParameter, int] | None = None) -> None: ... - - def decompress(self, - data: ByteString, - max_length: int = -1) -> bytes: ... - -class ZstdError(Exception): - ... - -def compress(data, - level_or_option: None | int | dict[CParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None) -> bytes: ... - -@deprecated("See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.richmem_compress") -def richmem_compress(data, - level_or_option: None | int | dict[CParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None) -> bytes: ... - -def decompress(data: ByteString, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - option: dict[DParameter, int] | None = None) -> bytes: ... - -@deprecated("See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.compress_stream") -def compress_stream(input_stream: BinaryIO, output_stream: BinaryIO | None, *, - level_or_option: None | int | dict[CParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - pledged_input_size: int | None = None, - read_size: int = 131_072, write_size: int = 131_591, - callback: Callable[[int, int, memoryview, memoryview], None] | None = None) -> tuple[int, int]: ... - -@deprecated("See https://pyzstd.readthedocs.io/en/stable/deprecated.html for alternatives to pyzstd.decompress_stream") -def decompress_stream(input_stream: BinaryIO, output_stream: BinaryIO | None, *, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - option: dict[DParameter, int] | None = None, - read_size: int = 131_075, write_size: int = 131_072, - callback: Callable[[int, int, memoryview, memoryview], None] | None = None) -> tuple[int, int]: ... - -def train_dict(samples: Iterable, - dict_size: int) -> ZstdDict: ... - -def finalize_dict(zstd_dict: ZstdDict, - samples: Iterable, - dict_size: int, - level: int) -> ZstdDict: ... - -class frame_info(NamedTuple): - decompressed_size: int | None - dictionary_id: int - -def get_frame_info(frame_buffer: ByteString) -> frame_info: ... - -def get_frame_size(frame_buffer: ByteString) -> int: ... - -class ZstdFile(io.BufferedIOBase): - FLUSH_BLOCK: ClassVar[Literal[1]] - FLUSH_FRAME: ClassVar[Literal[2]] - - def __init__(self, - filename: str | bytes | PathLike | BinaryIO, - mode: str = "r", - *, - level_or_option: None | int | dict[CParameter, int] | dict[DParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - read_size: int = 131_075, write_size: int = 131_591) -> None: ... - def close(self) -> None: ... - - def write(self, data) -> int: ... - def flush(self, - mode: Literal[1, 2] = ...) -> None: ... - - def read(self, size: int | None = -1) -> bytes: ... - def read1(self, size: int = -1) -> bytes: ... - def readinto(self, b) -> int: ... - def readinto1(self, b) -> int: ... - def readline(self, size: int | None = -1) -> bytes: ... - def seek(self, - offset: int, - whence: int = 0) -> int: ... - def peek(self, size: int = -1) -> bytes: ... - - def tell(self) -> int: ... - def fileno(self) -> int: ... - @property - def closed(self) -> bool: ... - def writable(self) -> bool: ... - def readable(self) -> bool: ... - def seekable(self) -> bool: ... - -class SeekableZstdFile(ZstdFile): - def __init__(self, - filename: str | bytes | PathLike | BinaryIO, - mode: str = "r", - *, - level_or_option: None | int | dict[CParameter, int] | dict[DParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - read_size: int = 131_075, write_size: int = 131_591, - max_frame_content_size: int = ...) -> None: ... - - @property - def seek_table_info(self) -> tuple[int, int, int]: ... - - @staticmethod - def is_seekable_format_file(filename: str | bytes | PathLike | BinaryIO) -> bool: ... - -_BinaryMode: TypeAlias = Literal["r", "rb", # read - "w", "wb", "a", "ab", "x", "xb"] # write -_TextMode: TypeAlias = Literal["rt", # read - "wt", "at", "xt"] # write - -@overload -def open(filename: str | bytes | PathLike | BinaryIO, - mode: _BinaryMode = "rb", - *, - level_or_option: None | int | dict[CParameter, int] | dict[DParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - encoding: None = None, - errors: None = None, - newline: None = None) -> ZstdFile: ... - -@overload -def open(filename: str | bytes | PathLike | BinaryIO, - mode: _TextMode, - *, - level_or_option: None | int | dict[CParameter, int] | dict[DParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - encoding: str | None = None, - errors: str | None = None, - newline: str | None = None) -> TextIO: ... - -@overload -def open(filename: str | bytes | PathLike | BinaryIO, - mode: str, - *, - level_or_option: None | int | dict[CParameter, int] | dict[DParameter, int] = None, - zstd_dict: None | ZstdDict | ZstdDictInfo = None, - encoding: str | None = None, - errors: str | None = None, - newline: str | None = None) -> ZstdFile | TextIO: ... diff --git a/src/pyzstd/__main__.py b/src/pyzstd/__main__.py index 4d7af14..6f9acf7 100644 --- a/src/pyzstd/__main__.py +++ b/src/pyzstd/__main__.py @@ -1,8 +1,10 @@ # CLI of pyzstd module: python -m pyzstd --help import argparse +from collections.abc import Mapping, Sequence import os from shutil import copyfileobj from time import time +from typing import Any, BinaryIO, Protocol, cast from pyzstd import ( CParameter, @@ -15,13 +17,36 @@ ) from pyzstd import __version__ as pyzstd_version + +class Args(Protocol): + dict: str + f: bool + compress: str + tar_input_dir: str + level: int + threads: int + long: int + checksum: bool + write_dictID: bool # noqa: N815 + decompress: str + tar_output_dir: str + test: str | None + windowLogMax: int # noqa: N815 + train: str + maxdict: int + dictID: int # noqa: N815 + output: BinaryIO | None + input: BinaryIO | None + zd: ZstdDict | None + + # buffer sizes recommended by zstd C_READ_BUFFER = 131072 D_READ_BUFFER = 131075 # open output file and assign to args.output -def open_output(args, path): +def open_output(args: Args, path: str) -> None: if not args.f and os.path.isfile(path): answer = input(f"output file already exists:\n{path}\noverwrite? (y/n) ") print() @@ -32,7 +57,7 @@ def open_output(args, path): args.output = open(path, "wb") # noqa: SIM115 -def close_files(args): +def close_files(args: Args) -> None: if args.input is not None: args.input.close() @@ -40,7 +65,7 @@ def close_files(args): args.output.close() -def compress_option(args): +def compress_option(args: Args) -> Mapping[int, int]: # threads message if args.threads == 0: threads_msg = "single-thread mode" @@ -58,7 +83,7 @@ def compress_option(args): long_msg = "no" # option - option = { + option: Mapping[int, int] = { CParameter.compressionLevel: args.level, CParameter.nbWorkers: args.threads, CParameter.enableLongDistanceMatching: use_long, @@ -80,10 +105,13 @@ def compress_option(args): return option -def compress(args): +def compress(args: Args) -> None: + args.input = cast("BinaryIO", args.input) + # output file if args.output is None: open_output(args, args.input.name + ".zst") + args.output = cast("BinaryIO", args.output) # pre-compress message msg = ( @@ -114,7 +142,9 @@ def compress(args): print(msg) -def decompress(args): +def decompress(args: Args) -> None: + args.input = cast("BinaryIO", args.input) + # output file if args.output is None: if args.test is None: @@ -126,9 +156,10 @@ def decompress(args): else: out_path = os.devnull open_output(args, out_path) + args.output = cast("BinaryIO", args.output) # option - option = {DParameter.windowLogMax: args.windowLogMax} + option: Mapping[int, int] = {DParameter.windowLogMax: args.windowLogMax} # pre-decompress message output_name = args.output.name @@ -159,7 +190,7 @@ def decompress(args): print(msg) -def train(args): +def train(args: Args) -> None: from glob import glob # check output file @@ -225,7 +256,7 @@ def train(args): print(msg) -def tarfile_create(args): +def tarfile_create(args: Args) -> None: import sys if sys.version_info < (3, 14): @@ -244,6 +275,7 @@ def tarfile_create(args): if args.output is None: out_path = os.path.join(dirname, basename + ".tar.zst") open_output(args, out_path) + args.output = cast("BinaryIO", args.output) # pre-compress message msg = ( @@ -263,7 +295,7 @@ def tarfile_create(args): None, fileobj=args.output, mode="w", options=option, zstd_dict=args.zd ) as f: f.add(args.tar_input_dir, basename) - uncompressed_size = f.fileobj.tell() + uncompressed_size = f.fileobj.tell() # type: ignore[union-attr] t2 = time() output_file_size = args.output.tell() @@ -282,7 +314,7 @@ def tarfile_create(args): print(msg) -def tarfile_extract(args): +def tarfile_extract(args: Args) -> None: import sys if sys.version_info < (3, 14): @@ -302,7 +334,7 @@ def tarfile_extract(args): raise NotADirectoryError(msg) # option - option = {DParameter.windowLogMax: args.windowLogMax} + option: Mapping[int, int] = {DParameter.windowLogMax: args.windowLogMax} # pre-extract message msg = ( @@ -320,7 +352,7 @@ def tarfile_extract(args): None, fileobj=args.input, mode="r", zstd_dict=args.zd, options=option ) as f: f.extractall(args.tar_output_dir, filter="data") - uncompressed_size = f.fileobj.tell() + uncompressed_size = f.fileobj.tell() # type: ignore[union-attr] t2 = time() close_files(args) @@ -336,12 +368,18 @@ def tarfile_extract(args): print(msg) -def range_action(start, end): +def range_action(start: int, end: int) -> type[argparse.Action]: class RangeAction(argparse.Action): - def __call__(self, parser, args, values, option_string=None): # noqa: ARG002 + def __call__( + self, + _: object, + namespace: object, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ) -> None: # convert to int try: - v = int(values) + v = int(values) # type: ignore[arg-type] except ValueError: raise TypeError(f"{option_string} should be an integer") from None @@ -354,12 +392,12 @@ def __call__(self, parser, args, values, option_string=None): # noqa: ARG002 ) raise ValueError(msg) - setattr(args, self.dest, v) + setattr(namespace, self.dest, v) return RangeAction -def parse_arg(): +def parse_arg() -> Args: p = argparse.ArgumentParser( prog="CLI of pyzstd module", description=( @@ -546,7 +584,7 @@ def parse_arg(): return args -def main(): +def main() -> None: print(f"*** pyzstd module v{pyzstd_version}, zstd library v{zstd_version}. ***\n") args = parse_arg() diff --git a/src/pyzstd/_seekable_zstdfile.py b/src/pyzstd/_seekable_zstdfile.py index c4f6653..c91fffc 100644 --- a/src/pyzstd/_seekable_zstdfile.py +++ b/src/pyzstd/_seekable_zstdfile.py @@ -4,9 +4,30 @@ from os import PathLike from os.path import isfile from struct import Struct +import sys +from typing import BinaryIO, ClassVar, Literal, cast import warnings -from pyzstd import _DEPRECATED_PLACEHOLDER, ZstdCompressor, ZstdDecompressor +from pyzstd import ( + _DEPRECATED_PLACEHOLDER, + ZstdCompressor, + ZstdDecompressor, + _DeprecatedPlaceholder, + _LevelOrOption, + _Option, + _StrOrBytesPath, + _ZstdDict, +) + +if sys.version_info < (3, 12): + from typing_extensions import Buffer +else: + from collections.abc import Buffer + +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self __all__ = ("SeekableFormatError", "SeekableZstdFile") @@ -18,7 +39,7 @@ class SeekableFormatError(Exception): "An error related to Zstandard Seekable Format." - def __init__(self, msg): + def __init__(self, msg: str) -> None: super().__init__("Zstandard Seekable Format error: " + msg) @@ -52,11 +73,11 @@ class _SeekTable: _s_footer = Struct(" None: self._read_mode = read_mode self._clear_seek_table() - def _clear_seek_table(self): + def _clear_seek_table(self) -> None: self._has_checksum = False # The seek table frame size, used for append mode. self._seek_frame_size = 0 @@ -83,7 +104,7 @@ def _clear_seek_table(self): # I is uint32_t. self._frames = array("I") - def append_entry(self, compressed_size, decompressed_size): + def append_entry(self, compressed_size: int, decompressed_size: int) -> None: if compressed_size == 0: if decompressed_size == 0: # (0, 0) frame is no sense @@ -104,7 +125,7 @@ def append_entry(self, compressed_size, decompressed_size): # seek_to_0 is True or False. # In read mode, seeking to 0 is necessary. - def load_seek_table(self, fp, seek_to_0): + def load_seek_table(self, fp: BinaryIO, seek_to_0: bool) -> None: # noqa: FBT001 # Get file size fsize = fp.seek(0, 2) # 2 is SEEK_END if fsize == 0: @@ -211,7 +232,7 @@ def load_seek_table(self, fp, seek_to_0): self._file_size = fsize # Find frame index by decompressed position - def index_by_dpos(self, pos): + def index_by_dpos(self, pos: int) -> int | None: # Array's first item is 0, so need this. pos = max(pos, 0) @@ -221,19 +242,19 @@ def index_by_dpos(self, pos): # None means >= EOF return None - def get_frame_sizes(self, i): + def get_frame_sizes(self, i: int) -> tuple[int, int]: return (self._cumulated_c_size[i - 1], self._cumulated_d_size[i - 1]) - def get_full_c_size(self): + def get_full_c_size(self) -> int: return self._full_c_size - def get_full_d_size(self): + def get_full_d_size(self) -> int: return self._full_d_size # Merge the seek table to max_frames frames. # The format allows up to 0xFFFF_FFFF frames. When frames # number exceeds, use this method to merge. - def _merge_frames(self, max_frames): + def _merge_frames(self, max_frames: int) -> None: if self._frames_count <= max_frames: return @@ -258,7 +279,7 @@ def _merge_frames(self, max_frames): pos += length - def write_seek_table(self, fp): + def write_seek_table(self, fp: BinaryIO) -> None: # Exceeded format limit if self._frames_count > 0xFFFFFFFF: # Emit a warning @@ -294,17 +315,17 @@ def write_seek_table(self, fp): fp.write(ba) @property - def seek_frame_size(self): + def seek_frame_size(self) -> int: return self._seek_frame_size @property - def file_size(self): + def file_size(self) -> int: return self._file_size - def __len__(self): + def __len__(self) -> int: return self._frames_count - def get_info(self): + def get_info(self) -> tuple[int, int, int]: return (self._frames_count, self._full_c_size, self._full_d_size) @@ -313,7 +334,9 @@ class _EOFSuccess(EOFError): # noqa: N818 class _SeekableDecompressReader(io.RawIOBase): - def __init__(self, fp, zstd_dict, option, read_size): + def __init__( + self, fp: BinaryIO, zstd_dict: _ZstdDict, option: _Option, read_size: int + ) -> None: # Check fp readable/seekable if not hasattr(fp, "readable") or not hasattr(fp, "seekable"): raise TypeError( @@ -343,22 +366,24 @@ def __init__(self, fp, zstd_dict, option, read_size): self._size = self._seek_table.get_full_d_size() self._pos = 0 - self._decompressor = ZstdDecompressor(self._zstd_dict, self._option) + self._decompressor: ZstdDecompressor | None = ZstdDecompressor( + self._zstd_dict, self._option + ) - def close(self): + def close(self) -> None: self._decompressor = None return super().close() - def readable(self): + def readable(self) -> bool: return True - def seekable(self): + def seekable(self) -> bool: return True - def tell(self): + def tell(self) -> int: return self._pos - def _decompress(self, size): + def _decompress(self, size: int) -> bytes: """ Decompress up to size bytes. May return b"", in which case try again. @@ -387,7 +412,7 @@ def _decompress(self, size): self._pos += len(out) return out - def readinto(self, b): + def readinto(self, b: Buffer) -> int: with memoryview(b) as view, view.cast("B") as byte_view: try: while True: @@ -399,7 +424,7 @@ def readinto(self, b): # If the new position is within BufferedReader's buffer, # this method may not be called. - def seek(self, offset, whence=0): + def seek(self, offset: int, whence: int = 0) -> int: # offset is absolute file position if whence == 0: # SEEK_SET pass @@ -443,7 +468,7 @@ def seek(self, offset, whence=0): return self._pos - def get_seek_table_info(self): + def get_seek_table_info(self) -> tuple[int, int, int]: return self._seek_table.get_info() @@ -459,24 +484,24 @@ class SeekableZstdFile(io.BufferedIOBase): # The format uses uint32_t for compressed/decompressed sizes. If flush # block a lot, compressed_size may exceed the limit, so set a max size. - FRAME_MAX_C_SIZE = 2 * 1024 * 1024 * 1024 + FRAME_MAX_C_SIZE: ClassVar[int] = 2 * 1024 * 1024 * 1024 # Zstd seekable format's example code also use 1GiB as max content size. - FRAME_MAX_D_SIZE = 1 * 1024 * 1024 * 1024 + FRAME_MAX_D_SIZE: ClassVar[int] = 1 * 1024 * 1024 * 1024 - FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK - FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME + FLUSH_BLOCK: ClassVar[Literal[1]] = ZstdCompressor.FLUSH_BLOCK + FLUSH_FRAME: ClassVar[Literal[2]] = ZstdCompressor.FLUSH_FRAME def __init__( self, - filename, - mode="r", + filename: _StrOrBytesPath | BinaryIO, + mode: Literal["r", "rb", "w", "wb", "a", "ab", "x", "xb"] = "r", *, - level_or_option=None, - zstd_dict=None, - read_size=_DEPRECATED_PLACEHOLDER, - write_size=_DEPRECATED_PLACEHOLDER, - max_frame_content_size=1024 * 1024 * 1024, - ): + level_or_option: _LevelOrOption | _Option = None, + zstd_dict: _ZstdDict = None, + read_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, # type: ignore[has-type] + write_size: int | _DeprecatedPlaceholder = _DEPRECATED_PLACEHOLDER, # type: ignore[has-type] + max_frame_content_size: int = 1024 * 1024 * 1024, + ) -> None: """Open a Zstandard Seekable Format file in binary mode. In read mode, the file can be 0-size file. @@ -514,6 +539,7 @@ def __init__( DeprecationWarning, stacklevel=2, ) + read_size = cast("int", read_size) if write_size == _DEPRECATED_PLACEHOLDER: write_size = 131591 else: @@ -522,15 +548,16 @@ def __init__( DeprecationWarning, stacklevel=2, ) + write_size = cast("int", write_size) - self._fp = None + self._fp: BinaryIO | None = None self._close_fp = False self._mode = _MODE_CLOSED self._buffer = None if not isinstance(mode, str): raise TypeError("mode must be a str") - mode = mode.removesuffix("b") # handle rb, wb, xb, ab + mode = mode.removesuffix("b") # type: ignore[assignment] # handle rb, wb, xb, ab # Read or write mode if mode == "r": @@ -572,10 +599,10 @@ def __init__( # For seekable format self._max_frame_content_size = max_frame_content_size self._reset_frame_sizes() - self._seek_table = _SeekTable(read_mode=False) + self._seek_table: _SeekTable | None = _SeekTable(read_mode=False) mode_code = _MODE_WRITE - self._compressor = ZstdCompressor( + self._compressor: ZstdCompressor | None = ZstdCompressor( level_or_option=level_or_option, zstd_dict=zstd_dict ) self._pos = 0 @@ -606,7 +633,7 @@ def __init__( # File object if isinstance(filename, (str, bytes, PathLike)): - self._fp = open(filename, mode + "b") # noqa: SIM115 + self._fp = cast("BinaryIO", open(filename, mode + "b")) # noqa: SIM115 self._close_fp = True elif hasattr(filename, "read") or hasattr(filename, "write"): self._fp = filename @@ -619,19 +646,19 @@ def __init__( raw = _SeekableDecompressReader( self._fp, zstd_dict=zstd_dict, - option=level_or_option, + option=cast("_Option", level_or_option), # checked earlier on read_size=read_size, ) self._buffer = io.BufferedReader(raw) elif mode == "a": if self._fp.seekable(): - self._fp.seek(self._seek_table.get_full_c_size()) + self._fp.seek(self._seek_table.get_full_c_size()) # type: ignore[union-attr] # Necessary if the current table has many (0, 0) entries self._fp.truncate() else: # Add the seek table frame - self._seek_table.append_entry(self._seek_table.seek_frame_size, 0) + self._seek_table.append_entry(self._seek_table.seek_frame_size, 0) # type: ignore[union-attr] # Emit a warning warnings.warn( ( @@ -641,31 +668,31 @@ def __init__( "zstd skippable frame) at the end of the file " "can't be overwritten. Each time open such file " "in append mode, it will waste some storage " - f"space. {self._seek_table.seek_frame_size} bytes " + f"space. {self._seek_table.seek_frame_size} bytes " # type: ignore[union-attr] "were wasted this time." ), RuntimeWarning, 2, ) - def _reset_frame_sizes(self): + def _reset_frame_sizes(self) -> None: self._current_c_size = 0 self._current_d_size = 0 self._left_d_size = self._max_frame_content_size - def _check_not_closed(self): + def _check_not_closed(self) -> None: if self.closed: raise ValueError("I/O operation on closed file") - def _check_can_read(self): + def _check_can_read(self) -> None: if not self.readable(): raise io.UnsupportedOperation("File not open for reading") - def _check_can_write(self): + def _check_can_write(self) -> None: if not self.writable(): raise io.UnsupportedOperation("File not open for writing") - def close(self): + def close(self) -> None: """Flush and close the file. May be called more than once without error. Once the file is @@ -679,11 +706,11 @@ def close(self): try: if self._mode == _MODE_READ: if getattr(self, "_buffer", None): - self._buffer.close() + self._buffer.close() # type: ignore[union-attr] self._buffer = None elif self._mode == _MODE_WRITE: self.flush(self.FLUSH_FRAME) - self._seek_table.write_seek_table(self._fp) + self._seek_table.write_seek_table(self._fp) # type: ignore[union-attr] self._compressor = None finally: self._mode = _MODE_CLOSED @@ -695,7 +722,7 @@ def close(self): self._fp = None self._close_fp = False - def write(self, data): + def write(self, data: Buffer) -> int: """Write a bytes-like object to the file. Returns the number of uncompressed bytes written, which is @@ -715,10 +742,10 @@ def write(self, data): write_size = min(nbytes, self._left_d_size) # Compress & write - compressed = self._compressor.compress( + compressed = self._compressor.compress( # type: ignore[union-attr] byte_view[pos : pos + write_size] ) - output_size = self._fp.write(compressed) + output_size = self._fp.write(compressed) # type: ignore[union-attr] self._pos += write_size pos += write_size @@ -738,7 +765,7 @@ def write(self, data): return pos - def flush(self, mode=ZstdCompressor.FLUSH_BLOCK): + def flush(self, mode: Literal[1, 2] = ZstdCompressor.FLUSH_BLOCK) -> None: """Flush remaining data to the underlying stream. The mode argument can be ZstdFile.FLUSH_BLOCK, ZstdFile.FLUSH_FRAME. @@ -761,12 +788,12 @@ def flush(self, mode=ZstdCompressor.FLUSH_BLOCK): "ZstdFile.FLUSH_BLOCK" ) - if self._compressor.last_mode != mode: + if self._compressor.last_mode != mode: # type: ignore[union-attr] # Flush zstd block/frame, and write. - compressed = self._compressor.flush(mode) - output_size = self._fp.write(compressed) + compressed = self._compressor.flush(mode) # type: ignore[union-attr] + output_size = self._fp.write(compressed) # type: ignore[union-attr] if hasattr(self._fp, "flush"): - self._fp.flush() + self._fp.flush() # type: ignore[union-attr] # Cumulate self._current_c_size += output_size @@ -775,10 +802,10 @@ def flush(self, mode=ZstdCompressor.FLUSH_BLOCK): if mode == self.FLUSH_FRAME and self._current_c_size != 0: # Add an entry to seek table - self._seek_table.append_entry(self._current_c_size, self._current_d_size) + self._seek_table.append_entry(self._current_c_size, self._current_d_size) # type: ignore[union-attr] self._reset_frame_sizes() - def read(self, size=-1): + def read(self, size: int | None = -1) -> bytes: """Read up to size uncompressed bytes from the file. If size is negative or omitted, read until EOF is reached. @@ -787,9 +814,9 @@ def read(self, size=-1): if size is None: size = -1 self._check_can_read() - return self._buffer.read(size) + return self._buffer.read(size) # type: ignore[union-attr] - def read1(self, size=-1): + def read1(self, size: int = -1) -> bytes: """Read up to size uncompressed bytes, while trying to avoid making multiple reads from the underlying stream. Reads up to a buffer's worth of data if size is negative. @@ -799,26 +826,26 @@ def read1(self, size=-1): self._check_can_read() if size < 0: size = io.DEFAULT_BUFFER_SIZE - return self._buffer.read1(size) + return self._buffer.read1(size) # type: ignore[union-attr] - def readinto(self, b): + def readinto(self, b: Buffer) -> int: """Read bytes into b. Returns the number of bytes read (0 for EOF). """ self._check_can_read() - return self._buffer.readinto(b) + return self._buffer.readinto(b) # type: ignore[union-attr] - def readinto1(self, b): + def readinto1(self, b: Buffer) -> int: """Read bytes into b, while trying to avoid making multiple reads from the underlying stream. Returns the number of bytes read (0 for EOF). """ self._check_can_read() - return self._buffer.readinto1(b) + return self._buffer.readinto1(b) # type: ignore[union-attr] - def readline(self, size=-1): + def readline(self, size: int | None = -1) -> bytes: """Read a line of uncompressed bytes from the file. The terminating newline (if present) is retained. If size is @@ -826,9 +853,9 @@ def readline(self, size=-1): case the line may be incomplete). Returns b'' if already at EOF. """ self._check_can_read() - return self._buffer.readline(size) + return self._buffer.readline(size) # type: ignore[union-attr] - def seek(self, offset, whence=io.SEEK_SET): + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: """Change the file position. The new position is specified by offset, relative to the @@ -844,68 +871,68 @@ def seek(self, offset, whence=io.SEEK_SET): this operation may be extremely slow. """ self._check_can_read() - return self._buffer.seek(offset, whence) + return self._buffer.seek(offset, whence) # type: ignore[union-attr] - def peek(self, size=-1): + def peek(self, size: int = -1) -> bytes: """Return buffered data without advancing the file position. Always returns at least one byte of data, unless at EOF. The exact number of bytes returned is unspecified. """ self._check_can_read() - return self._buffer.peek(size) + return self._buffer.peek(size) # type: ignore[union-attr] - def __iter__(self): + def __iter__(self) -> Self: self._check_can_read() return self - def __next__(self): + def __next__(self) -> bytes: self._check_can_read() - if ret := self._buffer.readline(): + if ret := self._buffer.readline(): # type: ignore[union-attr] return ret raise StopIteration - def tell(self): + def tell(self) -> int: """Return the current file position.""" self._check_not_closed() if self._mode == _MODE_READ: - return self._buffer.tell() + return self._buffer.tell() # type: ignore[union-attr] if self._mode == _MODE_WRITE: return self._pos raise RuntimeError # impossible code path - def fileno(self): + def fileno(self) -> int: """Return the file descriptor for the underlying file.""" self._check_not_closed() - return self._fp.fileno() + return self._fp.fileno() # type: ignore[union-attr] @property - def name(self): + def name(self) -> str: """Return the file name for the underlying file.""" self._check_not_closed() - return self._fp.name + return self._fp.name # type: ignore[union-attr] @property - def closed(self): + def closed(self) -> bool: """True if this file is closed.""" return self._mode == _MODE_CLOSED - def writable(self): + def writable(self) -> bool: """Return whether the file was opened for writing.""" self._check_not_closed() return self._mode == _MODE_WRITE - def readable(self): + def readable(self) -> bool: """Return whether the file was opened for reading.""" self._check_not_closed() return self._mode == _MODE_READ - def seekable(self): + def seekable(self) -> bool: """Return whether the file supports seeking.""" - return self.readable() and self._buffer.seekable() + return self.readable() and self._buffer.seekable() # type: ignore[union-attr] @property - def seek_table_info(self): + def seek_table_info(self) -> tuple[int, int, int] | None: """A tuple: (frames_number, compressed_size, decompressed_size) 1, Frames_number and compressed_size don't count the seek table frame (a zstd skippable frame at the end of the file). @@ -914,14 +941,13 @@ def seek_table_info(self): 3, If the SeekableZstdFile object is closed, it's None. """ if self._mode == _MODE_WRITE: - return self._seek_table.get_info() + return self._seek_table.get_info() # type: ignore[union-attr] if self._mode == _MODE_READ: - return self._buffer.raw.get_seek_table_info() - # Closed + return self._buffer.raw.get_seek_table_info() # type: ignore[union-attr] return None @staticmethod - def is_seekable_format_file(filename): + def is_seekable_format_file(filename: _StrOrBytesPath | BinaryIO) -> bool: """Check if a file is Zstandard Seekable Format file or 0-size file. It parses the seek table at the end of the file, returns True if no @@ -932,7 +958,7 @@ def is_seekable_format_file(filename): """ # Check argument if isinstance(filename, (str, bytes, PathLike)): - fp = open(filename, "rb") # noqa: SIM115 + fp: BinaryIO = open(filename, "rb") # noqa: SIM115 is_file_path = True elif ( hasattr(filename, "readable") diff --git a/tests/test_seekable.py b/tests/test_seekable.py index bbc90a5..3dd56b1 100644 --- a/tests/test_seekable.py +++ b/tests/test_seekable.py @@ -51,7 +51,7 @@ def _check_deprecated(testcase): class SeekTableCase(unittest.TestCase): def create_table(self, sizes_lst, read_mode=True): - table = _SeekTable(read_mode) + table = _SeekTable(read_mode=read_mode) for item in sizes_lst: table.append_entry(*item) return table