Skip to content

Commit 5ba0eaa

Browse files
authored
Improve CallableType join in simple cases (#18406)
Fixes #17479 , although as you can see in the test case the logic still remains far from perfect
1 parent dfc1fda commit 5ba0eaa

File tree

2 files changed

+77
-17
lines changed

2 files changed

+77
-17
lines changed

mypy/join.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -377,24 +377,34 @@ def visit_instance(self, t: Instance) -> ProperType:
377377
return self.default(self.s)
378378

379379
def visit_callable_type(self, t: CallableType) -> ProperType:
380-
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
381-
if is_equivalent(t, self.s):
382-
return combine_similar_callables(t, self.s)
383-
result = join_similar_callables(t, self.s)
384-
# We set the from_type_type flag to suppress error when a collection of
385-
# concrete class objects gets inferred as their common abstract superclass.
386-
if not (
387-
(t.is_type_obj() and t.type_object().is_abstract)
388-
or (self.s.is_type_obj() and self.s.type_object().is_abstract)
389-
):
390-
result.from_type_type = True
391-
if any(
392-
isinstance(tp, (NoneType, UninhabitedType))
393-
for tp in get_proper_types(result.arg_types)
394-
):
395-
# We don't want to return unusable Callable, attempt fallback instead.
380+
if isinstance(self.s, CallableType):
381+
if is_similar_callables(t, self.s):
382+
if is_equivalent(t, self.s):
383+
return combine_similar_callables(t, self.s)
384+
result = join_similar_callables(t, self.s)
385+
if any(
386+
isinstance(tp, (NoneType, UninhabitedType))
387+
for tp in get_proper_types(result.arg_types)
388+
):
389+
# We don't want to return unusable Callable, attempt fallback instead.
390+
return join_types(t.fallback, self.s)
391+
# We set the from_type_type flag to suppress error when a collection of
392+
# concrete class objects gets inferred as their common abstract superclass.
393+
if not (
394+
(t.is_type_obj() and t.type_object().is_abstract)
395+
or (self.s.is_type_obj() and self.s.type_object().is_abstract)
396+
):
397+
result.from_type_type = True
398+
return result
399+
else:
400+
s2, t2 = self.s, t
401+
if t2.is_var_arg:
402+
s2, t2 = t2, s2
403+
if is_subtype(s2, t2):
404+
return t2.copy_modified()
405+
elif is_subtype(t2, s2):
406+
return s2.copy_modified()
396407
return join_types(t.fallback, self.s)
397-
return result
398408
elif isinstance(self.s, Overloaded):
399409
# Switch the order of arguments to that we'll get to visit_overloaded.
400410
return join_types(t, self.s)

test-data/unit/check-functions.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,6 +3512,56 @@ class Qux(Bar):
35123512
pass
35133513
[builtins fixtures/tuple.pyi]
35143514

3515+
[case testCallableJoinWithDefaults]
3516+
from typing import Callable, TypeVar
3517+
3518+
T = TypeVar("T")
3519+
3520+
def join(t1: T, t2: T) -> T: ...
3521+
3522+
def f1() -> None: ...
3523+
def f2(i: int = 0) -> None: ...
3524+
def f3(i: str = "") -> None: ...
3525+
3526+
reveal_type(join(f1, f2)) # N: Revealed type is "def ()"
3527+
reveal_type(join(f1, f3)) # N: Revealed type is "def ()"
3528+
reveal_type(join(f2, f3)) # N: Revealed type is "builtins.function" # TODO: this could be better
3529+
[builtins fixtures/tuple.pyi]
3530+
3531+
[case testCallableJoinWithDefaultsMultiple]
3532+
from typing import TypeVar
3533+
T = TypeVar("T")
3534+
def join(t1: T, t2: T, t3: T) -> T: ...
3535+
3536+
def f_1(common, a=None): ...
3537+
def f_any(*_, **__): ...
3538+
def f_3(common, b=None, x=None): ...
3539+
3540+
fdict = {
3541+
"f_1": f_1,
3542+
"f_any": f_any,
3543+
"f_3": f_3,
3544+
}
3545+
reveal_type(fdict) # N: Revealed type is "builtins.dict[builtins.str, def (common: Any, a: Any =) -> Any]"
3546+
3547+
reveal_type(join(f_1, f_any, f_3)) # N: Revealed type is "def (common: Any, a: Any =) -> Any"
3548+
3549+
[builtins fixtures/tuple.pyi]
3550+
3551+
[case testCallableJoinWithType]
3552+
from __future__ import annotations
3553+
class Exc: ...
3554+
class AttributeErr(Exc):
3555+
def __init__(self, *args: object) -> None: ...
3556+
class FnfErr(Exc): ...
3557+
3558+
x = [
3559+
FnfErr,
3560+
AttributeErr,
3561+
]
3562+
reveal_type(x) # N: Revealed type is "builtins.list[builtins.type]"
3563+
[builtins fixtures/type.pyi]
3564+
35153565
[case testDistinctFormatting]
35163566
from typing import Awaitable, Callable, ParamSpec
35173567

0 commit comments

Comments
 (0)