Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions comtypes/hints.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,35 @@ class _GetSetNormalProperty(Generic[_T_Inst, _R_Get, _T_SetVal]):
fset: Callable[[_T_Inst, _T_SetVal], Any]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(self, instance: _T_Inst, owner: Optional[type[_T_Inst]]) -> _R_Get: ...
def __set__(self, instance: _T_Inst, value: _T_SetVal) -> None: ...
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _R_Get: ...
def __set__(self, instance: _T_Inst, value: _T_SetVal, /) -> None: ...

class _GetOnlyNormalProperty(Generic[_T_Inst, _R_Get]):
fget: Callable[[_T_Inst], Any]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(self, instance: _T_Inst, owner: Optional[type[_T_Inst]]) -> _R_Get: ...
def __set__(self, instance: _T_Inst, value: Any) -> NoReturn: ...
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _R_Get: ...
def __set__(self, instance: _T_Inst, value: Any, /) -> NoReturn: ...

class _SetOnlyNormalProperty(Generic[_T_Inst, _T_SetVal]):
fget: Callable[[_T_Inst], Any]
fset: Callable[[_T_Inst, _T_SetVal], Any]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]]
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> NoReturn: ...
def __set__(self, instance: _T_Inst, value: _T_SetVal) -> None: ...
def __set__(self, instance: _T_Inst, value: _T_SetVal, /) -> None: ...

@overload
def normal_property(
Expand All @@ -149,9 +153,9 @@ class _GetSetBoundNamedProperty(Generic[_T_Inst, _P_Get, _R_Get, _P_Set]):
fget: Callable[Concatenate[_T_Inst, _P_Get], _R_Get]
fset: Callable[Concatenate[_T_Inst, _P_Set], Any]
__doc__: Optional[str]
def __getitem__(self, index: Any) -> _R_Get: ...
def __getitem__(self, index: Any, /) -> _R_Get: ...
def __call__(self, *args: _P_Get.args, **kwargs: _P_Get.kwargs) -> _R_Get: ...
def __setitem__(self, index: Any, value: Any) -> None: ...
def __setitem__(self, index: Any, value: Any, /) -> None: ...
def __iter__(self) -> NoReturn: ...

class _GetSetNamedProperty(Generic[_T_Inst, _P_Get, _R_Get, _P_Set]):
Expand All @@ -161,20 +165,20 @@ class _GetSetNamedProperty(Generic[_T_Inst, _P_Get, _R_Get, _P_Set]):
__doc__: Optional[str]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]]
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _GetSetBoundNamedProperty[_T_Inst, _P_Get, _R_Get, _P_Set]: ...
def __set__(self, instance: _T_Inst, value: Any) -> NoReturn: ...
def __set__(self, instance: _T_Inst, value: Any, /) -> NoReturn: ...

class _GetOnlyBoundNamedProperty(Generic[_T_Inst, _P_Get, _R_Get]):
name: str
fget: Callable[Concatenate[_T_Inst, _P_Get], _R_Get]
__doc__: Optional[str]
def __getitem__(self, index: Any) -> _R_Get: ...
def __getitem__(self, index: Any, /) -> _R_Get: ...
def __call__(self, *args: _P_Get.args, **kwargs: _P_Get.kwargs) -> _R_Get: ...
def __setitem__(self, index: Any, value: Any) -> NoReturn: ...
def __setitem__(self, index: Any, value: Any, /) -> NoReturn: ...
def __iter__(self) -> NoReturn: ...

class _GetOnlyNamedProperty(Generic[_T_Inst, _P_Get, _R_Get]):
Expand All @@ -183,20 +187,20 @@ class _GetOnlyNamedProperty(Generic[_T_Inst, _P_Get, _R_Get]):
__doc__: Optional[str]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]]
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _GetOnlyBoundNamedProperty[_T_Inst, _P_Get, _R_Get]: ...
def __set__(self, instance: _T_Inst, value: Any) -> NoReturn: ...
def __set__(self, instance: _T_Inst, value: Any, /) -> NoReturn: ...

class _SetOnlyBoundNamedProperty(Generic[_T_Inst, _P_Set]):
name: str
fset: Callable[Concatenate[_T_Inst, _P_Set], Any]
__doc__: Optional[str]
def __getitem__(self, index: Any) -> NoReturn: ...
def __getitem__(self, index: Any, /) -> NoReturn: ...
def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: ...
def __setitem__(self, index: Any, value: Any) -> None: ...
def __setitem__(self, index: Any, value: Any, /) -> None: ...
def __iter__(self) -> NoReturn: ...

class _SetOnlyNamedProperty(Generic[_T_Inst, _P_Set]):
Expand All @@ -205,12 +209,12 @@ class _SetOnlyNamedProperty(Generic[_T_Inst, _P_Set]):
__doc__: Optional[str]

@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]]
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _SetOnlyBoundNamedProperty[_T_Inst, _P_Set]: ...
def __set__(self, instance: _T_Inst, value: Any) -> NoReturn: ...
def __set__(self, instance: _T_Inst, value: Any, /) -> NoReturn: ...

@overload
def named_property(
Expand All @@ -231,9 +235,11 @@ def named_property(

class _Descriptor(Protocol[_T_Inst, _R_Get]):
@overload
def __get__(self, instance: None, owner: type[_T_Inst]) -> Self: ...
def __get__(self, instance: None, owner: type[_T_Inst], /) -> Self: ...
@overload
def __get__(self, instance: _T_Inst, owner: Optional[type[_T_Inst]]) -> _R_Get: ...
def __get__(
self, instance: _T_Inst, owner: Optional[type[_T_Inst]], /
) -> _R_Get: ...

# `__len__` for objects with `Count`
@overload
Expand Down
39 changes: 33 additions & 6 deletions comtypes/test/test_typeannotator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import unittest

from comtypes.tools import typedesc
Expand Down Expand Up @@ -72,18 +73,29 @@ def test_disp_interface(self):
" pass # avoid using a keyword for def except(self) -> hints.Incomplete: ...\n" # noqa
" def bacon(self, *args: hints.Any, **kwargs: hints.Any) -> hints.Incomplete: ...\n" # noqa
" def _get_spam(self, arg1: hints.Incomplete = ..., /) -> hints.Incomplete: ...\n" # noqa
" def _set_spam(self, arg1: hints.Incomplete = ..., /, **kwargs: hints.Any) -> hints.Incomplete: ...\n" # noqa
" def _set_spam(self, arg1: hints.Incomplete = ..., /, *args: hints.Unpack[tuple[hints.Incomplete]]) -> hints.Incomplete: ...\n" # noqa
" spam = hints.named_property('spam', _get_spam, _set_spam)\n"
" pass # avoid using a keyword for def raise(self, foo: hints.Incomplete, bar: hints.Incomplete = ..., /) -> hints.Incomplete: ...\n" # noqa
" def _get_def(self, arg1: hints.Incomplete = ..., /) -> hints.Incomplete: ...\n" # noqa
" def _set_def(self, arg1: hints.Incomplete = ..., /, **kwargs: hints.Any) -> hints.Incomplete: ...\n" # noqa
" def _set_def(self, arg1: hints.Incomplete = ..., /, *args: hints.Unpack[tuple[hints.Incomplete]]) -> hints.Incomplete: ...\n" # noqa
" pass # avoid using a keyword for def = hints.named_property('def', _get_def, _set_def)\n" # noqa
" def egg(self) -> hints.Incomplete: ..." # noqa
)
self.assertEqual(
expected, typeannotator.DispInterfaceMembersAnnotator(itf).generate()
)

def test_valid_syntax_dispmethods(self):
itf = self._create_typedesc_disp_interface()
definition = "\n".join(
(
"class ISomeInterface(IDispatch):",
" if TYPE_CHECKING:",
f"{typeannotator.DispInterfaceMembersAnnotator(itf).generate()}",
)
)
ast.parse(definition, mode="exec")

def _create_typedesc_com_interface(self) -> typedesc.ComInterface:
guid = "{00000000-0000-0000-0000-000000000000}"
itf = typedesc.ComInterface(
Expand All @@ -98,6 +110,7 @@ def _create_typedesc_com_interface(self) -> typedesc.ComInterface:
put_ham = typedesc.ComMethod(
4, 1610678270, "ham", HRESULT_type, ["propput"], None
)
put_ham.add_argument(VARIANT_type, "arg1", ["in"], None)
bacon = typedesc.ComMethod(1, 1610678271, "bacon", HRESULT_type, [], None)
bacon.add_argument(VARIANT_type, "foo", ["in"], None)
bacon.add_argument(VARIANT_type, "or", ["in"], None)
Expand All @@ -107,9 +120,12 @@ def _create_typedesc_com_interface(self) -> typedesc.ComInterface:
get_class = typedesc.ComMethod(
2, 1610678273, "class", HRESULT_type, ["propget"], None
)
get_class.add_argument(VARIANT_type, "arg1", ["in"], None)
put_class = typedesc.ComMethod(
4, 1610678273, "class", HRESULT_type, ["propput"], None
)
put_class.add_argument(VARIANT_type, "arg1", ["in", "optional"], None)
put_class.add_argument(VARIANT_type, "arg2", ["in"], None)
pass_ = typedesc.ComMethod(1, 1610678274, "pass", HRESULT_type, [], None)
pass_.add_argument(VARIANT_type, "foo", ["in"], None)
pass_.add_argument(VARIANT_type, "bar", ["in", "optional"], None)
Expand All @@ -123,16 +139,27 @@ def test_com_interface(self):
" def _get_spam(self) -> hints.Hresult: ...\n"
" spam = hints.normal_property(_get_spam)\n"
" def _get_ham(self) -> hints.Hresult: ...\n"
" def _set_ham(self) -> hints.Hresult: ...\n"
" def _set_ham(self, arg1: hints.Incomplete) -> hints.Hresult: ...\n"
" ham = hints.normal_property(_get_ham, _set_ham)\n"
" def bacon(self, *args: hints.Any, **kwargs: hints.Any) -> hints.Hresult: ...\n" # noqa
" def _get_global(self) -> hints.Hresult: ...\n"
" pass # avoid using a keyword for global = hints.normal_property(_get_global)\n" # noqa
" def _get_class(self) -> hints.Hresult: ...\n"
" def _set_class(self) -> hints.Hresult: ...\n"
" pass # avoid using a keyword for class = hints.normal_property(_get_class, _set_class)\n" # noqa
" def _get_class(self, arg1: hints.Incomplete) -> hints.Hresult: ...\n"
" def _set_class(self, arg1: hints.Incomplete = ..., /, *args: hints.Unpack[tuple[hints.Incomplete]]) -> hints.Hresult: ...\n" # noqa
" pass # avoid using a keyword for class = hints.named_property('class', _get_class, _set_class)\n" # noqa
" pass # avoid using a keyword for def pass(self, foo: hints.Incomplete, bar: hints.Incomplete = ...) -> hints.Hresult: ..." # noqa
)
self.assertEqual(
expected, typeannotator.ComInterfaceMembersAnnotator(itf).generate()
)

def test_valid_syntax_commethods(self):
itf = self._create_typedesc_com_interface()
definition = "\n".join(
(
"class ISomeInterface(IUnknown):",
" if TYPE_CHECKING:",
f"{typeannotator.ComInterfaceMembersAnnotator(itf).generate()}",
)
)
ast.parse(definition, mode="exec")
35 changes: 25 additions & 10 deletions comtypes/tools/codegenerator/typeannotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,41 @@ def _to_outtype(typ: Any) -> str:
return "hints.Incomplete"


def _generate_trailing_params(specs: Sequence[tuple[Any, str, Optional[Any]]]) -> str:
"""Generates a type hint for variadic positional arguments.

This is for cases where required parameters follow optional ones, which is
not directly representable in Python's syntax. This pattern typically
occurs in COM `propput` or `propputref` methods that take multiple
arguments, corresponding to assignments like `obj.prop[a, b] = value`.
"""
params = f"tuple[{', '.join(('hints.Incomplete',) * len(specs))}]"
return f"*args: hints.Unpack[{params}]"


class ComMethodAnnotator(_MethodAnnotator[typedesc.ComMethod]):
def _iter_outarg_specs(self) -> Iterator[tuple[Any, str]]:
for typ, name, flags, _ in self.method.arguments:
if "out" in flags:
yield typ, name

def getvalue(self, name: str) -> str:
specs = self.inarg_specs
inargs = []
has_optional = False
for _, argname, default in self.inarg_specs:
for i, (_, argname, default) in enumerate(specs):
if keyword.iskeyword(argname):
inargs = ["*args: hints.Any", "**kwargs: hints.Any"]
break
if default is None:
if has_optional:
# probably propput or propputref
# Required parameters are positioned after optional ones.
# This likely indicates a named propput or named propputref
# assignment in the form of `obj.prop[...] = ...`.
# HACK: Something that goes into this conditional branch
# should be a special callback.
inargs.append("**kwargs: hints.Any")
inargs.append("/")
inargs.append(_generate_trailing_params(specs[i:]))
break
inargs.append(f"{argname}: hints.Incomplete")
else:
Expand Down Expand Up @@ -275,28 +291,27 @@ def generate(self) -> str:

class DispMethodAnnotator(_MethodAnnotator[typedesc.DispMethod]):
def getvalue(self, name: str) -> str:
specs = self.inarg_specs
inargs = []
has_optional = False
# NOTE: Since named parameters are not yet implemented, all arguments
# for the dispmethod (called via `Invoke`) are marked as
# positional-only parameters, introduced in PEP570.
# See also `automation.IDispatch.Invoke`.
# See https://github.com/enthought/comtypes/issues/371
for _, argname, default in self.inarg_specs:
for i, (_, argname, default) in enumerate(specs):
if keyword.iskeyword(argname):
inargs = ["*args: hints.Any", "**kwargs: hints.Any"]
break
if default is None:
if has_optional:
# Required parameter follows an optional one.
# probably propput or propputref
# TODO: After named parameters are supported,
# the positional-only parameter markers
# will be removed.
# Required parameters are positioned after optional ones.
# This likely indicates a named propput or named propputref
# assignment in the form of `obj.prop[...] = ...`.
inargs.append("/")
# HACK: Something that goes into this conditional branch
# should be a special callback.
inargs.append("**kwargs: hints.Any")
inargs.append(_generate_trailing_params(specs[i:]))
break
inargs.append(f"{argname}: hints.Incomplete")
else:
Expand Down