Skip to content

Commit 66f83ea

Browse files
authored
Refactor equality and identity narrowing for clarity (#20595)
This is a follow up to #20492 that is pure refactoring. I separated it out to make that change easier to review. We inline `narrow_identity_equality_comparison`, improve comments, etc. There is some further refactoring later in my commit stack too.
1 parent ad30be8 commit 66f83ea

File tree

3 files changed

+126
-169
lines changed

3 files changed

+126
-169
lines changed

mypy/checker.py

Lines changed: 105 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@
183183
fixup_partial_type,
184184
function_type,
185185
is_literal_type_like,
186-
is_singleton_type,
186+
is_singleton_equality_type,
187+
is_singleton_identity_type,
187188
make_simplified_union,
188189
true_only,
189190
try_expanding_sum_type_to_union,
@@ -6676,30 +6677,57 @@ def narrow_type_by_equality(
66766677
expr_indices: list[int],
66776678
narrowable_indices: AbstractSet[int],
66786679
) -> tuple[TypeMap, TypeMap]:
6679-
"""Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks."""
6680-
# is_valid_target:
6681-
# Controls which types we're allowed to narrow exprs to. Note that
6682-
# we cannot use 'is_literal_type_like' in both cases since doing
6683-
# 'x = 10000 + 1; x is 10001' is not always True in all Python
6684-
# implementations.
6685-
#
6686-
# coerce_only_in_literal_context:
6687-
# If true, coerce types into literal types only if one or more of
6688-
# the provided exprs contains an explicit Literal type. This could
6689-
# technically be set to any arbitrary value, but it seems being liberal
6690-
# with narrowing when using 'is' and conservative when using '==' seems
6691-
# to break the least amount of real-world code.
6692-
#
6680+
"""
6681+
Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks.
6682+
6683+
The 'operands' and 'operand_types' lists should be the full list of operands used
6684+
in the overall comparison expression. The 'chain_indices' list is the list of indices
6685+
actually used within this identity comparison chain.
6686+
6687+
So if we have the expression:
6688+
6689+
a <= b is c is d <= e
6690+
6691+
...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
6692+
would be the list [1, 2, 3].
6693+
6694+
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
6695+
to refine the types of: that is, all operands that will potentially be a part of
6696+
the output TypeMaps.
6697+
6698+
"""
66936699
# should_narrow_by_identity_equality:
6694-
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
6695-
# that could cause identity-based narrowing to produce invalid results.
6700+
# If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined
6701+
# custom __eq__ or __ne__ method
6702+
should_narrow_by_identity_equality: bool
6703+
6704+
# is_target_for_value_narrowing:
6705+
# If the operator returns True when compared to this target, do we narrow in else branch?
6706+
# E.g. if operator is "==", then:
6707+
# - is_target_for_value_narrowing(str) == False
6708+
# - is_target_for_value_narrowing(Literal["asdf"]) == True
6709+
is_target_for_value_narrowing: Callable[[ProperType], bool]
6710+
6711+
# should_coerce_literals:
6712+
# Ideally, we should always attempt to have this set to True. Unfortunately, for now,
6713+
# performing this coercion can sometimes result in overly aggressive narrowing when taking
6714+
# in the context of other type checker behaviour.
6715+
should_coerce_literals: bool
6716+
66966717
if operator in {"is", "is not"}:
6697-
is_valid_target: Callable[[Type], bool] = is_singleton_type
6698-
coerce_only_in_literal_context = False
6718+
is_target_for_value_narrowing = is_singleton_identity_type
6719+
should_coerce_literals = True
66996720
should_narrow_by_identity_equality = True
6721+
67006722
elif operator in {"==", "!="}:
6701-
is_valid_target = is_singleton_value
6702-
coerce_only_in_literal_context = True
6723+
is_target_for_value_narrowing = is_singleton_equality_type
6724+
6725+
should_coerce_literals = False
6726+
for i in expr_indices:
6727+
typ = get_proper_type(operand_types[i])
6728+
if is_literal_type_like(typ) or (isinstance(typ, Instance) and typ.type.is_enum):
6729+
should_coerce_literals = True
6730+
break
67036731

67046732
expr_types = [operand_types[i] for i in expr_indices]
67056733
should_narrow_by_identity_equality = not any(
@@ -6708,21 +6736,63 @@ def narrow_type_by_equality(
67086736
else:
67096737
raise AssertionError
67106738

6711-
if should_narrow_by_identity_equality:
6712-
return self.narrow_identity_equality_comparison(
6713-
operands,
6714-
operand_types,
6715-
expr_indices,
6716-
narrowable_indices,
6717-
is_valid_target,
6718-
coerce_only_in_literal_context,
6739+
if not should_narrow_by_identity_equality:
6740+
# This is a bit of a legacy code path that might be a little unsound since it ignores
6741+
# custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
6742+
return self.refine_away_none_in_comparison(
6743+
operands, operand_types, expr_indices, narrowable_indices
67196744
)
67206745

6721-
# This is a bit of a legacy code path that might be a little unsound since it ignores
6722-
# custom __eq__. We should see if we can get rid of it.
6723-
return self.refine_away_none_in_comparison(
6724-
operands, operand_types, expr_indices, narrowable_indices
6725-
)
6746+
value_targets = []
6747+
type_targets = []
6748+
for i in expr_indices:
6749+
expr_type = operand_types[i]
6750+
if should_coerce_literals:
6751+
# TODO: doing this prevents narrowing a single-member Enum to literal
6752+
# of its member, because we expand it here and then refuse to add equal
6753+
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6754+
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6755+
# See testMatchEnumSingleChoice
6756+
expr_type = coerce_to_literal(expr_type)
6757+
if is_target_for_value_narrowing(get_proper_type(expr_type)):
6758+
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
6759+
else:
6760+
type_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
6761+
6762+
partial_type_maps = []
6763+
6764+
if value_targets:
6765+
for i in expr_indices:
6766+
if i not in narrowable_indices:
6767+
continue
6768+
for j, target in value_targets:
6769+
if i == j:
6770+
continue
6771+
expr_type = coerce_to_literal(operand_types[i])
6772+
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6773+
if_map, else_map = conditional_types_to_typemaps(
6774+
operands[i], *conditional_types(expr_type, [target])
6775+
)
6776+
partial_type_maps.append((if_map, else_map))
6777+
6778+
if type_targets:
6779+
for i in expr_indices:
6780+
if i not in narrowable_indices:
6781+
continue
6782+
for j, target in type_targets:
6783+
if i == j:
6784+
continue
6785+
expr_type = operand_types[i]
6786+
if_map, else_map = conditional_types_to_typemaps(
6787+
operands[i], *conditional_types(expr_type, [target])
6788+
)
6789+
if if_map:
6790+
else_map = {} # this is the big difference compared to the above
6791+
partial_type_maps.append((if_map, else_map))
6792+
6793+
# We will not have duplicate entries in our type maps if we only have two operands,
6794+
# so we can skip running meets on the intersections
6795+
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)
67266796

67276797
def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
67286798
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -6905,103 +6975,6 @@ def _propagate_walrus_assignments(
69056975
return parent_expr
69066976
return expr
69076977

6908-
def narrow_identity_equality_comparison(
6909-
self,
6910-
operands: list[Expression],
6911-
operand_types: list[Type],
6912-
chain_indices: list[int],
6913-
narrowable_operand_indices: AbstractSet[int],
6914-
is_valid_target: Callable[[ProperType], bool],
6915-
coerce_only_in_literal_context: bool,
6916-
) -> tuple[TypeMap, TypeMap]:
6917-
"""Produce conditional type maps refining expressions by an identity/equality comparison.
6918-
6919-
The 'operands' and 'operand_types' lists should be the full list of operands used
6920-
in the overall comparison expression. The 'chain_indices' list is the list of indices
6921-
actually used within this identity comparison chain.
6922-
6923-
So if we have the expression:
6924-
6925-
a <= b is c is d <= e
6926-
6927-
...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
6928-
would be the list [1, 2, 3].
6929-
6930-
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
6931-
to refine the types of: that is, all operands that will potentially be a part of
6932-
the output TypeMaps.
6933-
6934-
Although this function could theoretically try setting the types of the operands
6935-
in the chains to the meet, doing that causes too many issues in real-world code.
6936-
Instead, we use 'is_valid_target' to identify which of the given chain types
6937-
we could plausibly use as the refined type for the expressions in the chain.
6938-
6939-
Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
6940-
expressions in the chain to a Literal type. Performing this coercion is sometimes
6941-
too aggressive of a narrowing, depending on context.
6942-
"""
6943-
6944-
if coerce_only_in_literal_context:
6945-
should_coerce = False
6946-
for i in chain_indices:
6947-
typ = get_proper_type(operand_types[i])
6948-
if is_literal_type_like(typ) or (isinstance(typ, Instance) and typ.type.is_enum):
6949-
should_coerce = True
6950-
break
6951-
else:
6952-
should_coerce = True
6953-
6954-
value_targets = []
6955-
type_targets = []
6956-
for i in chain_indices:
6957-
expr_type = operand_types[i]
6958-
if should_coerce:
6959-
# TODO: doing this prevents narrowing a single-member Enum to literal
6960-
# of its member, because we expand it here and then refuse to add equal
6961-
# types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6962-
# `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6963-
# See testMatchEnumSingleChoice
6964-
expr_type = coerce_to_literal(expr_type)
6965-
if is_valid_target(get_proper_type(expr_type)):
6966-
value_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
6967-
else:
6968-
type_targets.append((i, TypeRange(expr_type, is_upper_bound=False)))
6969-
6970-
partial_type_maps = []
6971-
6972-
if value_targets:
6973-
for i in chain_indices:
6974-
if i not in narrowable_operand_indices:
6975-
continue
6976-
for j, target in value_targets:
6977-
if i == j:
6978-
continue
6979-
expr_type = coerce_to_literal(operand_types[i])
6980-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6981-
if_map, else_map = conditional_types_to_typemaps(
6982-
operands[i], *conditional_types(expr_type, [target])
6983-
)
6984-
partial_type_maps.append((if_map, else_map))
6985-
6986-
if type_targets:
6987-
for i in chain_indices:
6988-
if i not in narrowable_operand_indices:
6989-
continue
6990-
for j, target in type_targets:
6991-
if i == j:
6992-
continue
6993-
expr_type = operand_types[i]
6994-
if_map, else_map = conditional_types_to_typemaps(
6995-
operands[i], *conditional_types(expr_type, [target])
6996-
)
6997-
if if_map:
6998-
else_map = {}
6999-
partial_type_maps.append((if_map, else_map))
7000-
7001-
# We will not have duplicate entries in our type maps if we only have two operands,
7002-
# so we can skip running meets on the intersections
7003-
return reduce_conditional_maps(partial_type_maps, use_meet=len(operands) > 2)
7004-
70056978
def refine_away_none_in_comparison(
70066979
self,
70076980
operands: list[Expression],
@@ -7012,7 +6985,7 @@ def refine_away_none_in_comparison(
70126985
"""Produces conditional type maps refining away None in an identity/equality chain.
70136986
70146987
For more details about what the different arguments mean, see the
7015-
docstring of 'refine_identity_comparison_expression' up above.
6988+
docstring of 'narrow_type_by_equality' up above.
70166989
"""
70176990

70186991
non_optional_types = []
@@ -8596,11 +8569,6 @@ def reduce_conditional_maps(
85968569
return final_if_map, final_else_map
85978570

85988571

8599-
def is_singleton_value(t: Type) -> bool:
8600-
t = get_proper_type(t)
8601-
return isinstance(t, LiteralType) or t.is_singleton_type()
8602-
8603-
86048572
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
86058573
"builtins.bytes",
86068574
"builtins.bytearray",

mypy/typeops.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from mypy.state import state
3535
from mypy.types import (
36+
ELLIPSIS_TYPE_NAMES,
3637
AnyType,
3738
CallableType,
3839
ExtraAttrs,
@@ -985,24 +986,30 @@ def is_literal_type_like(t: Type | None) -> bool:
985986
return False
986987

987988

988-
def is_singleton_type(typ: Type) -> bool:
989-
"""Returns 'true' if this type is a "singleton type" -- if there exists
990-
exactly only one runtime value associated with this type.
989+
def is_singleton_identity_type(typ: ProperType) -> bool:
990+
"""
991+
Returns True if every value of this type is identical to every other value of this type,
992+
as judged by the `is` operator.
991993
992-
That is, given two values 'a' and 'b' that have the same type 't',
993-
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
994-
always true.
994+
Note that this is not true of certain LiteralType, such as Literal[100001] or Literal["string"]
995+
"""
996+
if isinstance(typ, NoneType):
997+
return True
998+
if isinstance(typ, Instance):
999+
return (typ.type.is_enum and len(typ.type.enum_members) == 1) or (
1000+
typ.type.fullname in ELLIPSIS_TYPE_NAMES
1001+
)
1002+
if isinstance(typ, LiteralType):
1003+
return typ.is_enum_literal() or isinstance(typ.value, bool)
1004+
return False
9951005

996-
Currently, this returns True when given NoneTypes, enum LiteralTypes,
997-
enum types with a single value and ... (Ellipses).
9981006

999-
Note that other kinds of LiteralTypes cannot count as singleton types. For
1000-
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
1001-
that 'a is b' will always be true -- some implementations of Python will end up
1002-
constructing two distinct instances of 100001.
1007+
def is_singleton_equality_type(typ: ProperType) -> bool:
10031008
"""
1004-
typ = get_proper_type(typ)
1005-
return typ.is_singleton_type()
1009+
Returns True if every value of this type compares equal to every other value of this type,
1010+
as judged by the `==` operator.
1011+
"""
1012+
return isinstance(typ, LiteralType) or is_singleton_identity_type(typ)
10061013

10071014

10081015
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str | None) -> Type:

mypy/types.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,6 @@ def write(self, data: WriteBuffer) -> None:
331331
def read(cls, data: ReadBuffer) -> Type:
332332
raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance")
333333

334-
def is_singleton_type(self) -> bool:
335-
return False
336-
337334

338335
class TypeAliasType(Type):
339336
"""A type alias to another type.
@@ -1479,9 +1476,6 @@ def read(cls, data: ReadBuffer) -> NoneType:
14791476
assert read_tag(data) == END_TAG
14801477
return NoneType()
14811478

1482-
def is_singleton_type(self) -> bool:
1483-
return True
1484-
14851479

14861480
# NoneType used to be called NoneTyp so to avoid needlessly breaking
14871481
# external plugins we keep that alias here.
@@ -1848,15 +1842,6 @@ def copy_with_extra_attr(self, name: str, typ: Type) -> Instance:
18481842
new.extra_attrs = existing_attrs
18491843
return new
18501844

1851-
def is_singleton_type(self) -> bool:
1852-
# TODO:
1853-
# Also make this return True if the type corresponds to NotImplemented?
1854-
return (
1855-
self.type.is_enum
1856-
and len(self.type.enum_members) == 1
1857-
or self.type.fullname in ELLIPSIS_TYPE_NAMES
1858-
)
1859-
18601845

18611846
class InstanceCache:
18621847
def __init__(self) -> None:
@@ -3332,9 +3317,6 @@ def read(cls, data: ReadBuffer) -> LiteralType:
33323317
assert read_tag(data) == END_TAG
33333318
return ret
33343319

3335-
def is_singleton_type(self) -> bool:
3336-
return self.is_enum_literal() or isinstance(self.value, bool)
3337-
33383320

33393321
class UnionType(ProperType):
33403322
"""The union type Union[T1, ..., Tn] (at least one type argument)."""

0 commit comments

Comments
 (0)