From 823af24a2cd21f9802e62e4474ca115fbcf49a19 Mon Sep 17 00:00:00 2001 From: Benedikt Seidl Date: Sat, 12 Jul 2025 21:48:27 +0200 Subject: [PATCH 1/3] Improve typing: add References to Repository pygit2 was removed from the python typeshed, but the built in types of pygit2 are not complete. This change tries to improve this situation, by fixing type errors that occur when activating typing for the tests. So first the test functions (for the public API) in test/test_refs.py were adapted, in order to activate mypy type checking. Then _pygit2.pyi was adapted to fix those mypy type checking errors. The skeleton was provided by: stubgen -m pygit2.references -o /tmp/pygit2 The following command mypy test returned the following summary before this change: Found 289 errors in 10 files (checked 50 source files) and after this change: Found 289 errors in 10 files (checked 50 source files) --- pygit2/_pygit2.pyi | 35 +++++++++++++++++++++++++++++------ test/test_refs.py | 46 +++++++++++++++++++++++----------------------- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/pygit2/_pygit2.pyi b/pygit2/_pygit2.pyi index 488beabd..1ef3e18a 100644 --- a/pygit2/_pygit2.pyi +++ b/pygit2/_pygit2.pyi @@ -1,4 +1,4 @@ -from typing import Iterator, Literal, Optional, overload +from typing import Iterator, Literal, Optional, overload, Type from io import IOBase from . import Index from .enums import ( @@ -20,6 +20,8 @@ from .enums import ( SortMode, ) +from .repository import BaseRepository + GIT_OBJ_BLOB = Literal[3] GIT_OBJ_COMMIT = Literal[1] GIT_OBJ_TAG = Literal[4] @@ -73,15 +75,15 @@ class Reference: def delete(self) -> None: ... def log(self) -> Iterator[RefLogEntry]: ... @overload - def peel(self, type: 'Literal[GIT_OBJ_COMMIT]') -> 'Commit': ... + def peel(self, type: 'Literal[GIT_OBJ_COMMIT] | Type[Commit]') -> 'Commit': ... @overload - def peel(self, type: 'Literal[GIT_OBJ_TREE]') -> 'Tree': ... + def peel(self, type: 'Literal[GIT_OBJ_TREE] | Type[Tree]') -> 'Tree': ... @overload - def peel(self, type: 'Literal[GIT_OBJ_TAG]') -> 'Tag': ... + def peel(self, type: 'Literal[GIT_OBJ_TAG] | Type[Tag]') -> 'Tag': ... @overload - def peel(self, type: 'Literal[GIT_OBJ_BLOB]') -> 'Blob': ... + def peel(self, type: 'Literal[GIT_OBJ_BLOB] | Type[Blob]') -> 'Blob': ... @overload - def peel(self, type: 'None') -> 'Commit|Tree|Blob': ... + def peel(self, type: 'None' = None) -> 'Commit|Tree|Tag|Blob': ... def rename(self, new_name: str) -> None: ... def resolve(self) -> Reference: ... def set_target(self, target: _OidArg, message: str = ...) -> None: ... @@ -329,6 +331,21 @@ class RefdbBackend: class RefdbFsBackend(RefdbBackend): def __init__(self, *args, **kwargs) -> None: ... +class References: + def __init__(self, repository: BaseRepository) -> None: ... + def __getitem__(self, name: str) -> Reference: ... + def get(self, key: str) -> Reference: ... + def __iter__(self) -> Iterator[str]: ... + def iterator( + self, references_return_type: ReferenceFilter = ... + ) -> Iterator[Reference]: ... + def create(self, name: str, target: _OidArg, force: bool = False) -> Reference: ... + def delete(self, name: str) -> None: ... + def __contains__(self, name: str) -> bool: ... + @property + def objects(self) -> list[Reference]: ... + def compress(self) -> None: ... + class Repository: _pointer: bytes default_signature: Signature @@ -342,10 +359,12 @@ class Repository: path: str refdb: Refdb workdir: str + references: References def __init__(self, *args, **kwargs) -> None: ... def TreeBuilder(self, src: Tree | _OidArg = ...) -> TreeBuilder: ... def _disown(self, *args, **kwargs) -> None: ... def _from_c(self, *args, **kwargs) -> None: ... + def __getitem__(self, key: str | bytes | Oid | Reference) -> Commit: ... def add_worktree(self, name: str, path: str, ref: Reference = ...) -> Worktree: ... def applies( self, @@ -394,6 +413,9 @@ class Repository: ref: str = 'refs/notes/commits', force: bool = False, ) -> Oid: ... + def create_reference( + self, name: str, target: _OidArg, force: bool = False + ) -> Reference: ... def create_reference_direct( self, name: str, target: _OidArg, force: bool, message: Optional[str] = None ) -> Reference: ... @@ -443,6 +465,7 @@ class Repository: def revparse(self, revspec: str) -> RevSpec: ... def revparse_ext(self, revision: str) -> tuple[Object, Reference]: ... def revparse_single(self, revision: str) -> Object: ... + def set_ident(self, name: str, email: str) -> None: ... def set_odb(self, odb: Odb) -> None: ... def set_refdb(self, refdb: Refdb) -> None: ... def status( diff --git a/test/test_refs.py b/test/test_refs.py index cddfa038..50dfa96e 100644 --- a/test/test_refs.py +++ b/test/test_refs.py @@ -29,7 +29,7 @@ import pytest -from pygit2 import Commit, Signature, Tree, reference_is_valid_name +from pygit2 import Commit, Signature, Tree, reference_is_valid_name, Repository from pygit2 import AlreadyExistsError, GitError, InvalidSpecError from pygit2.enums import ReferenceType @@ -45,7 +45,7 @@ def test_refs_list_objects(testrepo): ] -def test_refs_list(testrepo): +def test_refs_list(testrepo: Repository) -> None: # Without argument assert sorted(testrepo.references) == ['refs/heads/i18n', 'refs/heads/master'] @@ -58,13 +58,13 @@ def test_refs_list(testrepo): ] -def test_head(testrepo): +def test_head(testrepo: Repository) -> None: head = testrepo.head assert LAST_COMMIT == testrepo[head.target].id assert LAST_COMMIT == testrepo[head.raw_target].id -def test_refs_getitem(testrepo): +def test_refs_getitem(testrepo: Repository) -> None: refname = 'refs/foo' # Raise KeyError ? with pytest.raises(KeyError): @@ -78,37 +78,37 @@ def test_refs_getitem(testrepo): assert reference.name == 'refs/heads/master' -def test_refs_get_sha(testrepo): +def test_refs_get_sha(testrepo: Repository) -> None: reference = testrepo.references['refs/heads/master'] assert reference.target == LAST_COMMIT -def test_refs_set_sha(testrepo): +def test_refs_set_sha(testrepo: Repository) -> None: NEW_COMMIT = '5ebeeebb320790caf276b9fc8b24546d63316533' reference = testrepo.references.get('refs/heads/master') reference.set_target(NEW_COMMIT) assert reference.target == NEW_COMMIT -def test_refs_set_sha_prefix(testrepo): +def test_refs_set_sha_prefix(testrepo: Repository) -> None: NEW_COMMIT = '5ebeeebb320790caf276b9fc8b24546d63316533' reference = testrepo.references.get('refs/heads/master') reference.set_target(NEW_COMMIT[0:6]) assert reference.target == NEW_COMMIT -def test_refs_get_type(testrepo): +def test_refs_get_type(testrepo: Repository) -> None: reference = testrepo.references.get('refs/heads/master') assert reference.type == ReferenceType.DIRECT -def test_refs_get_target(testrepo): +def test_refs_get_target(testrepo: Repository) -> None: reference = testrepo.references.get('HEAD') assert reference.target == 'refs/heads/master' assert reference.raw_target == b'refs/heads/master' -def test_refs_set_target(testrepo): +def test_refs_set_target(testrepo: Repository) -> None: reference = testrepo.references.get('HEAD') assert reference.target == 'refs/heads/master' assert reference.raw_target == b'refs/heads/master' @@ -117,14 +117,14 @@ def test_refs_set_target(testrepo): assert reference.raw_target == b'refs/heads/i18n' -def test_refs_get_shorthand(testrepo): +def test_refs_get_shorthand(testrepo: Repository) -> None: reference = testrepo.references.get('refs/heads/master') assert reference.shorthand == 'master' reference = testrepo.references.create('refs/remotes/origin/master', LAST_COMMIT) assert reference.shorthand == 'origin/master' -def test_refs_set_target_with_message(testrepo): +def test_refs_set_target_with_message(testrepo: Repository) -> None: reference = testrepo.references.get('HEAD') assert reference.target == 'refs/heads/master' assert reference.raw_target == b'refs/heads/master' @@ -139,7 +139,7 @@ def test_refs_set_target_with_message(testrepo): assert first.committer == sig -def test_refs_delete(testrepo): +def test_refs_delete(testrepo: Repository) -> None: # We add a tag as a new reference that points to "origin/master" reference = testrepo.references.create('refs/tags/version1', LAST_COMMIT) assert 'refs/tags/version1' in testrepo.references @@ -163,7 +163,7 @@ def test_refs_delete(testrepo): reference.rename('refs/tags/version2') -def test_refs_rename(testrepo): +def test_refs_rename(testrepo: Repository) -> None: # We add a tag as a new reference that points to "origin/master" reference = testrepo.references.create('refs/tags/version1', LAST_COMMIT) assert reference.name == 'refs/tags/version1' @@ -177,7 +177,7 @@ def test_refs_rename(testrepo): reference.rename('b1') -# def test_reload(testrepo): +# def test_reload(testrepo: Repository) -> None: # name = 'refs/tags/version1' # ref = testrepo.create_reference(name, "refs/heads/master", symbolic=True) # ref2 = testrepo.lookup_reference(name) @@ -187,7 +187,7 @@ def test_refs_rename(testrepo): # with pytest.raises(GitError): getattr(ref2, 'name') -def test_refs_resolve(testrepo): +def test_refs_resolve(testrepo: Repository) -> None: reference = testrepo.references.get('HEAD') assert reference.type == ReferenceType.SYMBOLIC reference = reference.resolve() @@ -195,13 +195,13 @@ def test_refs_resolve(testrepo): assert reference.target == LAST_COMMIT -def test_refs_resolve_identity(testrepo): +def test_refs_resolve_identity(testrepo: Repository) -> None: head = testrepo.references.get('HEAD') ref = head.resolve() assert ref.resolve() is ref -def test_refs_create(testrepo): +def test_refs_create(testrepo: Repository) -> None: # We add a tag as a new reference that points to "origin/master" reference = testrepo.references.create('refs/tags/version1', LAST_COMMIT) refs = testrepo.references @@ -220,7 +220,7 @@ def test_refs_create(testrepo): assert reference.target == LAST_COMMIT -def test_refs_create_symbolic(testrepo): +def test_refs_create_symbolic(testrepo: Repository) -> None: # We add a tag as a new symbolic reference that always points to # "refs/heads/master" reference = testrepo.references.create('refs/tags/beta', 'refs/heads/master') @@ -241,11 +241,11 @@ def test_refs_create_symbolic(testrepo): assert reference.raw_target == b'refs/heads/master' -# def test_packall_references(testrepo): +# def test_packall_references(testrepo: Repository) -> None: # testrepo.packall_references() -def test_refs_peel(testrepo): +def test_refs_peel(testrepo: Repository) -> None: ref = testrepo.references.get('refs/heads/master') assert testrepo[ref.target].id == ref.peel().id assert testrepo[ref.raw_target].id == ref.peel().id @@ -254,7 +254,7 @@ def test_refs_peel(testrepo): assert commit.tree.id == ref.peel(Tree).id -def test_refs_equality(testrepo): +def test_refs_equality(testrepo: Repository) -> None: ref1 = testrepo.references.get('refs/heads/master') ref2 = testrepo.references.get('refs/heads/master') ref3 = testrepo.references.get('refs/heads/i18n') @@ -267,7 +267,7 @@ def test_refs_equality(testrepo): assert not ref1 == ref3 -def test_refs_compress(testrepo): +def test_refs_compress(testrepo: Repository) -> None: packed_refs_file = Path(testrepo.path) / 'packed-refs' assert not packed_refs_file.exists() old_refs = [(ref.name, ref.target) for ref in testrepo.references.objects] From 4b3f85d402012dc826c63bd7d8b4a2e806eb8695 Mon Sep 17 00:00:00 2001 From: Benedikt Seidl Date: Sun, 13 Jul 2025 13:41:48 +0200 Subject: [PATCH 2/3] Improve typing: add RemotesCollection to Repository The following command mypy test reports Found 284 errors in 9 files (checked 50 source files) after this change. --- pygit2/_pygit2.pyi | 47 ++++++++++++++++++++++- pygit2/callbacks.py | 51 +++++++++++++++---------- pygit2/credentials.py | 14 +++++-- pygit2/remotes.py | 12 +++++- test/test_remote.py | 86 ++++++++++++++++++++++++++----------------- 5 files changed, 151 insertions(+), 59 deletions(-) diff --git a/pygit2/_pygit2.pyi b/pygit2/_pygit2.pyi index 1ef3e18a..0d4e8f31 100644 --- a/pygit2/_pygit2.pyi +++ b/pygit2/_pygit2.pyi @@ -1,4 +1,4 @@ -from typing import Iterator, Literal, Optional, overload, Type +from typing import Iterator, Literal, Optional, overload, Type, TypedDict from io import IOBase from . import Index from .enums import ( @@ -19,8 +19,10 @@ from .enums import ( ResetMode, SortMode, ) +from collections.abc import Generator from .repository import BaseRepository +from .remotes import Remote GIT_OBJ_BLOB = Literal[3] GIT_OBJ_COMMIT = Literal[1] @@ -346,6 +348,48 @@ class References: def objects(self) -> list[Reference]: ... def compress(self) -> None: ... +_Proxy = None | Literal[True] | str + +class _StrArray: + # incomplete + count: int + +class ProxyOpts: + # incomplete + type: object + url: str + +class PushOptions: + version: int + pb_parallelism: int + callbacks: object # TODO + proxy_opts: ProxyOpts + follow_redirects: object # TODO + custom_headers: _StrArray + remote_push_options: _StrArray + +class _LsRemotesDict(TypedDict): + local: bool + loid: Oid | None + name: str | None + symref_target: str | None + oid: Oid + +class RemoteCollection: + def __init__(self, repo: BaseRepository) -> None: ... + def __len__(self) -> int: ... + def __iter__(self): ... + def __getitem__(self, name: str | int) -> Remote: ... + def names(self) -> Generator[str, None, None]: ... + def create(self, name: str, url: str, fetch: str | None = None) -> Remote: ... + def create_anonymous(self, url: str) -> Remote: ... + def rename(self, name: str, new_name: str) -> list[str]: ... + def delete(self, name: str) -> None: ... + def set_url(self, name: str, url: str) -> None: ... + def set_push_url(self, name: str, url: str) -> None: ... + def add_fetch(self, name: str, refspec: str) -> None: ... + def add_push(self, name: str, refspec: str) -> None: ... + class Repository: _pointer: bytes default_signature: Signature @@ -360,6 +404,7 @@ class Repository: refdb: Refdb workdir: str references: References + remotes: RemoteCollection def __init__(self, *args, **kwargs) -> None: ... def TreeBuilder(self, src: Tree | _OidArg = ...) -> TreeBuilder: ... def _disown(self, *args, **kwargs) -> None: ... diff --git a/pygit2/callbacks.py b/pygit2/callbacks.py index cd7d1c50..c0c3249f 100644 --- a/pygit2/callbacks.py +++ b/pygit2/callbacks.py @@ -65,7 +65,7 @@ # Standard Library from contextlib import contextmanager from functools import wraps -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING, Callable, Generator # pygit2 from ._pygit2 import Oid, DiffFile @@ -73,8 +73,13 @@ from .errors import check_error, Passthrough from .ffi import ffi, C from .utils import maybe_string, to_bytes, ptr_to_bytes, StrArray +from .credentials import Username, UserPass, Keypair +_Credentials = Username | UserPass | Keypair +if TYPE_CHECKING: + from .remotes import TransferProgress + from ._pygit2 import ProxyOpts, PushOptions # # The payload is the way to pass information from the pygit2 API, through # libgit2, to the Python callbacks. And back. @@ -82,7 +87,7 @@ class Payload: - def __init__(self, **kw: object): + def __init__(self, **kw: object) -> None: for key, value in kw.items(): setattr(self, key, value) self._stored_exception = None @@ -113,12 +118,18 @@ class RemoteCallbacks(Payload): RemoteCallbacks(certificate=certificate). """ - def __init__(self, credentials=None, certificate_check=None): + push_options: 'PushOptions' + + def __init__( + self, + credentials: _Credentials | None = None, + certificate_check: Callable[[None, bool, bytes], bool] | None = None, + ) -> None: super().__init__() if credentials is not None: - self.credentials = credentials + self.credentials = credentials # type: ignore[method-assign, assignment] if certificate_check is not None: - self.certificate_check = certificate_check + self.certificate_check = certificate_check # type: ignore[method-assign, assignment] def sideband_progress(self, string: str) -> None: """ @@ -136,7 +147,7 @@ def credentials( url: str, username_from_url: Union[str, None], allowed_types: CredentialType, - ): + ) -> _Credentials: """ Credentials callback. If the remote server requires authentication, this function will be called and its return value used for @@ -159,7 +170,7 @@ def credentials( """ raise Passthrough - def certificate_check(self, certificate: None, valid: bool, host: str) -> bool: + def certificate_check(self, certificate: None, valid: bool, host: bytes) -> bool: """ Certificate callback. Override with your own function to determine whether to accept the server's certificate. @@ -181,7 +192,7 @@ def certificate_check(self, certificate: None, valid: bool, host: str) -> bool: raise Passthrough - def transfer_progress(self, stats): + def transfer_progress(self, stats: 'TransferProgress') -> None: """ During the download of new data, this will be regularly called with the indexer's progress. @@ -196,7 +207,7 @@ def transfer_progress(self, stats): def push_transfer_progress( self, objects_pushed: int, total_objects: int, bytes_pushed: int - ): + ) -> None: """ During the upload portion of a push, this will be regularly called with progress information. @@ -207,7 +218,7 @@ def push_transfer_progress( Override with your own function to report push transfer progress. """ - def update_tips(self, refname, old, new): + def update_tips(self, refname: str, old: Oid, new: Oid) -> None: """ Update tips callback. Override with your own function to report reference updates. @@ -224,7 +235,7 @@ def update_tips(self, refname, old, new): The reference's new value. """ - def push_update_reference(self, refname, message): + def push_update_reference(self, refname: str, message: str) -> None: """ Push update reference callback. Override with your own function to report the remote's acceptance or rejection of reference updates. @@ -244,7 +255,7 @@ class CheckoutCallbacks(Payload): in your class, which you can then pass to checkout operations. """ - def __init__(self): + def __init__(self) -> None: super().__init__() def checkout_notify_flags(self) -> CheckoutNotify: @@ -275,7 +286,7 @@ def checkout_notify( baseline: Optional[DiffFile], target: Optional[DiffFile], workdir: Optional[DiffFile], - ): + ) -> None: """ Checkout will invoke an optional notification callback for certain cases - you pick which ones via `checkout_notify_flags`. @@ -290,7 +301,9 @@ def checkout_notify( """ pass - def checkout_progress(self, path: str, completed_steps: int, total_steps: int): + def checkout_progress( + self, path: str, completed_steps: int, total_steps: int + ) -> None: """ Optional callback to notify the consumer of checkout progress. """ @@ -304,7 +317,7 @@ class StashApplyCallbacks(CheckoutCallbacks): in your class, which you can then pass to stash apply or pop operations. """ - def stash_apply_progress(self, progress: StashApplyProgress): + def stash_apply_progress(self, progress: StashApplyProgress) -> None: """ Stash application progress notification function. @@ -373,9 +386,9 @@ def git_fetch_options(payload, opts=None): @contextmanager def git_proxy_options( payload: object, - opts: object | None = None, + opts: Optional['ProxyOpts'] = None, proxy: None | bool | str = None, -): +) -> Generator['ProxyOpts', None, None]: if opts is None: opts = ffi.new('git_proxy_options *') C.git_proxy_options_init(opts, C.GIT_PROXY_OPTIONS_VERSION) @@ -386,8 +399,8 @@ def git_proxy_options( elif type(proxy) is str: opts.type = C.GIT_PROXY_SPECIFIED # Keep url in memory, otherwise memory is freed and bad things happen - payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) - opts.url = payload.__proxy_url + payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[attr-defined, no-untyped-call] + opts.url = payload.__proxy_url # type: ignore[attr-defined] else: raise TypeError('Proxy must be None, True, or a string') yield opts diff --git a/pygit2/credentials.py b/pygit2/credentials.py index 9a09db76..2b307f94 100644 --- a/pygit2/credentials.py +++ b/pygit2/credentials.py @@ -49,7 +49,7 @@ def credential_type(self) -> CredentialType: return CredentialType.USERNAME @property - def credential_tuple(self): + def credential_tuple(self) -> tuple[str]: return (self._username,) def __call__( @@ -74,7 +74,7 @@ def credential_type(self) -> CredentialType: return CredentialType.USERPASS_PLAINTEXT @property - def credential_tuple(self): + def credential_tuple(self) -> tuple[str, str]: return (self._username, self._password) def __call__( @@ -107,7 +107,11 @@ class Keypair: """ def __init__( - self, username: str, pubkey: str | Path, privkey: str | Path, passphrase: str + self, + username: str, + pubkey: str | Path | None, + privkey: str | Path | None, + passphrase: str | None, ): self._username = username self._pubkey = pubkey @@ -119,7 +123,9 @@ def credential_type(self) -> CredentialType: return CredentialType.SSH_KEY @property - def credential_tuple(self): + def credential_tuple( + self, + ) -> tuple[str, str | Path | None, str | Path | None, str | None]: return (self._username, self._pubkey, self._privkey, self._passphrase) def __call__( diff --git a/pygit2/remotes.py b/pygit2/remotes.py index 1451dbf6..7e91f6ae 100644 --- a/pygit2/remotes.py +++ b/pygit2/remotes.py @@ -24,7 +24,7 @@ # Boston, MA 02110-1301, USA. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any # Import from pygit2 from ._pygit2 import Oid @@ -49,7 +49,15 @@ class TransferProgress: """Progress downloading and indexing data during a fetch.""" - def __init__(self, tp): + total_objects: int + indexed_objects: int + received_objects: int + local_objects: int + total_deltas: int + indexed_deltas: int + received_bytes: int + + def __init__(self, tp: Any) -> None: self.total_objects = tp.total_objects """Total number of objects to download""" diff --git a/test/test_remote.py b/test/test_remote.py index 1d414479..e3cc2146 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -24,10 +24,14 @@ # Boston, MA 02110-1301, USA. import sys +from pathlib import Path +from collections.abc import Generator import pytest import pygit2 +from pygit2 import Repository, Remote +from pygit2.remotes import TransferProgress from . import utils @@ -43,7 +47,7 @@ ORIGIN_REFSPEC = '+refs/heads/*:refs/remotes/origin/*' -def test_remote_create(testrepo): +def test_remote_create(testrepo: Repository) -> None: name = 'upstream' url = 'https://github.com/libgit2/pygit2.git' @@ -58,7 +62,7 @@ def test_remote_create(testrepo): testrepo.remotes.create(*(name, url)) -def test_remote_create_with_refspec(testrepo): +def test_remote_create_with_refspec(testrepo: Repository) -> None: name = 'upstream' url = 'https://github.com/libgit2/pygit2.git' fetch = '+refs/*:refs/*' @@ -72,7 +76,7 @@ def test_remote_create_with_refspec(testrepo): assert remote.push_url is None -def test_remote_create_anonymous(testrepo): +def test_remote_create_anonymous(testrepo: Repository) -> None: url = 'https://github.com/libgit2/pygit2.git' remote = testrepo.remotes.create_anonymous(url) @@ -83,7 +87,7 @@ def test_remote_create_anonymous(testrepo): assert [] == remote.push_refspecs -def test_remote_delete(testrepo): +def test_remote_delete(testrepo: Repository) -> None: name = 'upstream' url = 'https://github.com/libgit2/pygit2.git' @@ -96,7 +100,7 @@ def test_remote_delete(testrepo): assert 1 == len(testrepo.remotes) -def test_remote_rename(testrepo): +def test_remote_rename(testrepo: Repository) -> None: remote = testrepo.remotes[0] assert REMOTE_NAME == remote.name @@ -107,10 +111,10 @@ def test_remote_rename(testrepo): with pytest.raises(ValueError): testrepo.remotes.rename('', '') with pytest.raises(ValueError): - testrepo.remotes.rename(None, None) + testrepo.remotes.rename(None, None) # type: ignore -def test_remote_set_url(testrepo): +def test_remote_set_url(testrepo: Repository) -> None: remote = testrepo.remotes['origin'] assert REMOTE_URL == remote.url @@ -129,7 +133,7 @@ def test_remote_set_url(testrepo): testrepo.remotes.set_push_url('origin', '') -def test_refspec(testrepo): +def test_refspec(testrepo: Repository) -> None: remote = testrepo.remotes['origin'] assert remote.refspec_count == 1 @@ -169,13 +173,13 @@ def test_refspec(testrepo): testrepo.remotes.add_push('origin', '+refs/test/*:refs/test/remotes/*') with pytest.raises(TypeError): - testrepo.remotes.add_fetch(['+refs/*:refs/*', 5]) + testrepo.remotes.add_fetch(['+refs/*:refs/*', 5]) # type: ignore remote = testrepo.remotes['origin'] assert ['+refs/test/*:refs/test/remotes/*'] == remote.push_refspecs -def test_remote_list(testrepo): +def test_remote_list(testrepo: Repository) -> None: assert 1 == len(testrepo.remotes) remote = testrepo.remotes[0] assert REMOTE_NAME == remote.name @@ -189,7 +193,7 @@ def test_remote_list(testrepo): @utils.requires_network -def test_ls_remotes(testrepo): +def test_ls_remotes(testrepo: Repository) -> None: assert 1 == len(testrepo.remotes) remote = testrepo.remotes[0] @@ -200,7 +204,7 @@ def test_ls_remotes(testrepo): assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2')) -def test_remote_collection(testrepo): +def test_remote_collection(testrepo: Repository) -> None: remote = testrepo.remotes['origin'] assert REMOTE_NAME == remote.name assert REMOTE_URL == remote.url @@ -216,7 +220,7 @@ def test_remote_collection(testrepo): @utils.requires_refcount -def test_remote_refcount(testrepo): +def test_remote_refcount(testrepo: Repository) -> None: start = sys.getrefcount(testrepo) remote = testrepo.remotes[0] del remote @@ -224,7 +228,7 @@ def test_remote_refcount(testrepo): assert start == end -def test_fetch(emptyrepo): +def test_fetch(emptyrepo: Repository) -> None: remote = emptyrepo.remotes[0] stats = remote.fetch() assert stats.received_bytes > 2700 @@ -234,7 +238,7 @@ def test_fetch(emptyrepo): @utils.requires_network -def test_fetch_depth_zero(testrepo): +def test_fetch_depth_zero(testrepo: Repository) -> None: remote = testrepo.remotes[0] stats = remote.fetch(REMOTE_FETCHTEST_FETCHSPECS, depth=0) assert stats.indexed_objects == REMOTE_REPO_FETCH_ALL_OBJECTS @@ -242,16 +246,16 @@ def test_fetch_depth_zero(testrepo): @utils.requires_network -def test_fetch_depth_one(testrepo): +def test_fetch_depth_one(testrepo: Repository) -> None: remote = testrepo.remotes[0] stats = remote.fetch(REMOTE_FETCHTEST_FETCHSPECS, depth=1) assert stats.indexed_objects == REMOTE_REPO_FETCH_HEAD_COMMIT_OBJECTS assert stats.received_objects == REMOTE_REPO_FETCH_HEAD_COMMIT_OBJECTS -def test_transfer_progress(emptyrepo): +def test_transfer_progress(emptyrepo: Repository) -> None: class MyCallbacks(pygit2.RemoteCallbacks): - def transfer_progress(self, stats): + def transfer_progress(self, stats: TransferProgress) -> None: self.tp = stats callbacks = MyCallbacks() @@ -262,7 +266,7 @@ def transfer_progress(self, stats): assert stats.received_objects == callbacks.tp.received_objects -def test_update_tips(emptyrepo): +def test_update_tips(emptyrepo: Repository) -> None: remote = emptyrepo.remotes[0] tips = [ ( @@ -292,14 +296,16 @@ def update_tips(self, name, old, new): @utils.requires_network -def test_ls_remotes_certificate_check(): +def test_ls_remotes_certificate_check() -> None: url = 'https://github.com/pygit2/empty.git' class MyCallbacks(pygit2.RemoteCallbacks): - def __init__(self): + def __init__(self) -> None: self.i = 0 - def certificate_check(self, certificate, valid, host): + def certificate_check( + self, certificate: None, valid: bool, host: str | bytes + ) -> bool: self.i += 1 assert certificate is None @@ -322,13 +328,13 @@ def certificate_check(self, certificate, valid, host): @pytest.fixture -def origin(tmp_path): +def origin(tmp_path: Path) -> Generator[Repository, None, None]: with utils.TemporaryRepository('barerepo.zip', tmp_path) as path: yield pygit2.Repository(path) @pytest.fixture -def clone(tmp_path): +def clone(tmp_path: Path) -> Generator[Repository, None, None]: clone = tmp_path / 'clone' clone.mkdir() with utils.TemporaryRepository('barerepo.zip', clone) as path: @@ -340,7 +346,9 @@ def remote(origin, clone): yield clone.remotes.create('origin', origin.path) -def test_push_fast_forward_commits_to_remote_succeeds(origin, clone, remote): +def test_push_fast_forward_commits_to_remote_succeeds( + origin: Repository, clone: Repository, remote: Remote +) -> None: tip = clone[clone.head.target] oid = clone.create_commit( 'refs/heads/master', @@ -354,14 +362,18 @@ def test_push_fast_forward_commits_to_remote_succeeds(origin, clone, remote): assert origin[origin.head.target].id == oid -def test_push_when_up_to_date_succeeds(origin, clone, remote): +def test_push_when_up_to_date_succeeds( + origin: Repository, clone: Repository, remote: Remote +) -> None: remote.push(['refs/heads/master']) origin_tip = origin[origin.head.target].id clone_tip = clone[clone.head.target].id assert origin_tip == clone_tip -def test_push_transfer_progress(origin, clone, remote): +def test_push_transfer_progress( + origin: Repository, clone: Repository, remote: Remote +) -> None: tip = clone[clone.head.target] new_tip_id = clone.create_commit( 'refs/heads/master', @@ -377,7 +389,9 @@ def test_push_transfer_progress(origin, clone, remote): # on the local filesystem, as is the case in this unit test. (When pushing # to a remote over the network, the value is correct.) class MyCallbacks(pygit2.RemoteCallbacks): - def push_transfer_progress(self, objects_pushed, total_objects, bytes_pushed): + def push_transfer_progress( + self, objects_pushed: int, total_objects: int, bytes_pushed: int + ) -> None: self.objects_pushed = objects_pushed self.total_objects = total_objects @@ -390,7 +404,9 @@ def push_transfer_progress(self, objects_pushed, total_objects, bytes_pushed): assert origin.branches['master'].target == new_tip_id -def test_push_interrupted_from_callbacks(origin, clone, remote): +def test_push_interrupted_from_callbacks( + origin: Repository, clone: Repository, remote: Remote +) -> None: tip = clone[clone.head.target] clone.create_commit( 'refs/heads/master', @@ -402,7 +418,9 @@ def test_push_interrupted_from_callbacks(origin, clone, remote): ) class MyCallbacks(pygit2.RemoteCallbacks): - def push_transfer_progress(self, objects_pushed, total_objects, bytes_pushed): + def push_transfer_progress( + self, objects_pushed: int, total_objects: int, bytes_pushed: int + ) -> None: raise InterruptedError('retreat! retreat!') assert origin.branches['master'].target == tip.id @@ -414,7 +432,9 @@ def push_transfer_progress(self, objects_pushed, total_objects, bytes_pushed): assert origin.branches['master'].target == tip.id -def test_push_non_fast_forward_commits_to_remote_fails(origin, clone, remote): +def test_push_non_fast_forward_commits_to_remote_fails( + origin: Repository, clone: Repository, remote: Remote +) -> None: tip = origin[origin.head.target] origin.create_commit( 'refs/heads/master', @@ -438,7 +458,7 @@ def test_push_non_fast_forward_commits_to_remote_fails(origin, clone, remote): remote.push(['refs/heads/master']) -def test_push_options(origin, clone, remote): +def test_push_options(origin: Repository, clone: Repository, remote: Remote) -> None: from pygit2 import RemoteCallbacks callbacks = RemoteCallbacks() @@ -468,7 +488,7 @@ def test_push_options(origin, clone, remote): # strings pointed to by remote_push_options.strings[] are already freed -def test_push_threads(origin, clone, remote): +def test_push_threads(origin: Repository, clone: Repository, remote: Remote) -> None: from pygit2 import RemoteCallbacks callbacks = RemoteCallbacks() From 375f238ae0ad4b7a6471394d285454ce3795836c Mon Sep 17 00:00:00 2001 From: Benedikt Seidl Date: Sun, 13 Jul 2025 14:00:52 +0200 Subject: [PATCH 3/3] Improve typing: add Branches to Repository The following command mypy test reports Found 280 errors in 8 files (checked 50 source files) after this change. --- pygit2/_pygit2.pyi | 20 +++++++++++++++++++- test/test_branch.py | 27 ++++++++++++++------------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/pygit2/_pygit2.pyi b/pygit2/_pygit2.pyi index 0d4e8f31..2cf8e39e 100644 --- a/pygit2/_pygit2.pyi +++ b/pygit2/_pygit2.pyi @@ -126,7 +126,7 @@ class Branch(Reference): def delete(self) -> None: ... def is_checked_out(self) -> bool: ... def is_head(self) -> bool: ... - def rename(self, name: str, force: bool = False) -> None: ... + def rename(self, name: str, force: bool = False) -> 'Branch': ... # type: ignore[override] class Commit(Object): author: Signature @@ -390,6 +390,23 @@ class RemoteCollection: def add_fetch(self, name: str, refspec: str) -> None: ... def add_push(self, name: str, refspec: str) -> None: ... +class Branches: + local: 'Branches' + remote: 'Branches' + def __init__( + self, + repository: BaseRepository, + flag: BranchType = ..., + commit: Commit | _OidArg | None = None, + ) -> None: ... + def __getitem__(self, name: str) -> Branch: ... + def get(self, key: str) -> Branch: ... + def __iter__(self) -> Iterator[str]: ... + def create(self, name: str, commit: Commit, force: bool = False) -> Branch: ... + def delete(self, name: str) -> None: ... + def with_commit(self, commit: Commit | _OidArg | None) -> 'Branches': ... + def __contains__(self, name: _OidArg) -> bool: ... + class Repository: _pointer: bytes default_signature: Signature @@ -405,6 +422,7 @@ class Repository: workdir: str references: References remotes: RemoteCollection + branches: Branches def __init__(self, *args, **kwargs) -> None: ... def TreeBuilder(self, src: Tree | _OidArg = ...) -> TreeBuilder: ... def _disown(self, *args, **kwargs) -> None: ... diff --git a/test/test_branch.py b/test/test_branch.py index bfc944c0..1128a1b1 100644 --- a/test/test_branch.py +++ b/test/test_branch.py @@ -29,6 +29,7 @@ import pytest import os from pygit2.enums import BranchType +from pygit2 import Repository LAST_COMMIT = '2be5719152d4f82c7302b1c0932d8e5f0a4a0e98' @@ -38,7 +39,7 @@ SHARED_COMMIT = '4ec4389a8068641da2d6578db0419484972284c8' -def test_branches_getitem(testrepo): +def test_branches_getitem(testrepo: Repository) -> None: branch = testrepo.branches['master'] assert branch.target == LAST_COMMIT @@ -49,12 +50,12 @@ def test_branches_getitem(testrepo): testrepo.branches['not-exists'] -def test_branches(testrepo): +def test_branches(testrepo: Repository) -> None: branches = sorted(testrepo.branches) assert branches == ['i18n', 'master'] -def test_branches_create(testrepo): +def test_branches_create(testrepo: Repository) -> None: commit = testrepo[LAST_COMMIT] reference = testrepo.branches.create('version1', commit) assert 'version1' in testrepo.branches @@ -70,27 +71,27 @@ def test_branches_create(testrepo): assert reference.target == LAST_COMMIT -def test_branches_delete(testrepo): +def test_branches_delete(testrepo: Repository) -> None: testrepo.branches.delete('i18n') assert testrepo.branches.get('i18n') is None -def test_branches_delete_error(testrepo): +def test_branches_delete_error(testrepo: Repository) -> None: with pytest.raises(pygit2.GitError): testrepo.branches.delete('master') -def test_branches_is_head(testrepo): +def test_branches_is_head(testrepo: Repository) -> None: branch = testrepo.branches.get('master') assert branch.is_head() -def test_branches_is_not_head(testrepo): +def test_branches_is_not_head(testrepo: Repository) -> None: branch = testrepo.branches.get('i18n') assert not branch.is_head() -def test_branches_rename(testrepo): +def test_branches_rename(testrepo: Repository) -> None: new_branch = testrepo.branches['i18n'].rename('new-branch') assert new_branch.target == I18N_LAST_COMMIT @@ -98,25 +99,25 @@ def test_branches_rename(testrepo): assert new_branch_2.target == I18N_LAST_COMMIT -def test_branches_rename_error(testrepo): +def test_branches_rename_error(testrepo: Repository) -> None: original_branch = testrepo.branches.get('i18n') with pytest.raises(ValueError): original_branch.rename('master') -def test_branches_rename_force(testrepo): +def test_branches_rename_force(testrepo: Repository) -> None: original_branch = testrepo.branches.get('master') new_branch = original_branch.rename('i18n', True) assert new_branch.target == LAST_COMMIT -def test_branches_rename_invalid(testrepo): +def test_branches_rename_invalid(testrepo: Repository) -> None: original_branch = testrepo.branches.get('i18n') with pytest.raises(ValueError): original_branch.rename('abc@{123') -def test_branches_name(testrepo): +def test_branches_name(testrepo: Repository) -> None: branch = testrepo.branches.get('master') assert branch.branch_name == 'master' assert branch.name == 'refs/heads/master' @@ -128,7 +129,7 @@ def test_branches_name(testrepo): assert branch.raw_branch_name == branch.branch_name.encode('utf-8') -def test_branches_with_commit(testrepo): +def test_branches_with_commit(testrepo: Repository) -> None: branches = testrepo.branches.with_commit(EXCLUSIVE_MASTER_COMMIT) assert sorted(branches) == ['master'] assert branches.get('i18n') is None