183183 fixup_partial_type ,
184184 function_type ,
185185 is_literal_type_like ,
186- is_singleton_type ,
186+ is_singleton_equality_type ,
187+ is_singleton_identity_type ,
187188 make_simplified_union ,
188189 true_only ,
189190 try_expanding_sum_type_to_union ,
@@ -6676,30 +6677,57 @@ def narrow_type_by_equality(
66766677 expr_indices : list [int ],
66776678 narrowable_indices : AbstractSet [int ],
66786679 ) -> tuple [TypeMap , TypeMap ]:
6679- """Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks."""
6680- # is_valid_target:
6681- # Controls which types we're allowed to narrow exprs to. Note that
6682- # we cannot use 'is_literal_type_like' in both cases since doing
6683- # 'x = 10000 + 1; x is 10001' is not always True in all Python
6684- # implementations.
6685- #
6686- # coerce_only_in_literal_context:
6687- # If true, coerce types into literal types only if one or more of
6688- # the provided exprs contains an explicit Literal type. This could
6689- # technically be set to any arbitrary value, but it seems being liberal
6690- # with narrowing when using 'is' and conservative when using '==' seems
6691- # to break the least amount of real-world code.
6692- #
6680+ """
6681+ Calculate type maps for '==', '!=', 'is' or 'is not' expression, ignoring `type(x)` checks.
6682+
6683+ The 'operands' and 'operand_types' lists should be the full list of operands used
6684+ in the overall comparison expression. The 'chain_indices' list is the list of indices
6685+ actually used within this identity comparison chain.
6686+
6687+ So if we have the expression:
6688+
6689+ a <= b is c is d <= e
6690+
6691+ ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
6692+ would be the list [1, 2, 3].
6693+
6694+ The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
6695+ to refine the types of: that is, all operands that will potentially be a part of
6696+ the output TypeMaps.
6697+
6698+ """
66936699 # should_narrow_by_identity_equality:
6694- # Set to 'false' only if the user defines custom __eq__ or __ne__ methods
6695- # that could cause identity-based narrowing to produce invalid results.
6700+ # If operator is "==" or "!=", we cannot narrow if we detect the presence of a user defined
6701+ # custom __eq__ or __ne__ method
6702+ should_narrow_by_identity_equality : bool
6703+
6704+ # is_target_for_value_narrowing:
6705+ # If the operator returns True when compared to this target, do we narrow in else branch?
6706+ # E.g. if operator is "==", then:
6707+ # - is_target_for_value_narrowing(str) == False
6708+ # - is_target_for_value_narrowing(Literal["asdf"]) == True
6709+ is_target_for_value_narrowing : Callable [[ProperType ], bool ]
6710+
6711+ # should_coerce_literals:
6712+ # Ideally, we should always attempt to have this set to True. Unfortunately, for now,
6713+ # performing this coercion can sometimes result in overly aggressive narrowing when taking
6714+ # in the context of other type checker behaviour.
6715+ should_coerce_literals : bool
6716+
66966717 if operator in {"is" , "is not" }:
6697- is_valid_target : Callable [[ Type ], bool ] = is_singleton_type
6698- coerce_only_in_literal_context = False
6718+ is_target_for_value_narrowing = is_singleton_identity_type
6719+ should_coerce_literals = True
66996720 should_narrow_by_identity_equality = True
6721+
67006722 elif operator in {"==" , "!=" }:
6701- is_valid_target = is_singleton_value
6702- coerce_only_in_literal_context = True
6723+ is_target_for_value_narrowing = is_singleton_equality_type
6724+
6725+ should_coerce_literals = False
6726+ for i in expr_indices :
6727+ typ = get_proper_type (operand_types [i ])
6728+ if is_literal_type_like (typ ) or (isinstance (typ , Instance ) and typ .type .is_enum ):
6729+ should_coerce_literals = True
6730+ break
67036731
67046732 expr_types = [operand_types [i ] for i in expr_indices ]
67056733 should_narrow_by_identity_equality = not any (
@@ -6708,21 +6736,63 @@ def narrow_type_by_equality(
67086736 else :
67096737 raise AssertionError
67106738
6711- if should_narrow_by_identity_equality :
6712- return self .narrow_identity_equality_comparison (
6713- operands ,
6714- operand_types ,
6715- expr_indices ,
6716- narrowable_indices ,
6717- is_valid_target ,
6718- coerce_only_in_literal_context ,
6739+ if not should_narrow_by_identity_equality :
6740+ # This is a bit of a legacy code path that might be a little unsound since it ignores
6741+ # custom __eq__. We should see if we can get rid of it in favour of `return {}, {}`
6742+ return self .refine_away_none_in_comparison (
6743+ operands , operand_types , expr_indices , narrowable_indices
67196744 )
67206745
6721- # This is a bit of a legacy code path that might be a little unsound since it ignores
6722- # custom __eq__. We should see if we can get rid of it.
6723- return self .refine_away_none_in_comparison (
6724- operands , operand_types , expr_indices , narrowable_indices
6725- )
6746+ value_targets = []
6747+ type_targets = []
6748+ for i in expr_indices :
6749+ expr_type = operand_types [i ]
6750+ if should_coerce_literals :
6751+ # TODO: doing this prevents narrowing a single-member Enum to literal
6752+ # of its member, because we expand it here and then refuse to add equal
6753+ # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6754+ # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6755+ # See testMatchEnumSingleChoice
6756+ expr_type = coerce_to_literal (expr_type )
6757+ if is_target_for_value_narrowing (get_proper_type (expr_type )):
6758+ value_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6759+ else :
6760+ type_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6761+
6762+ partial_type_maps = []
6763+
6764+ if value_targets :
6765+ for i in expr_indices :
6766+ if i not in narrowable_indices :
6767+ continue
6768+ for j , target in value_targets :
6769+ if i == j :
6770+ continue
6771+ expr_type = coerce_to_literal (operand_types [i ])
6772+ expr_type = try_expanding_sum_type_to_union (expr_type , None )
6773+ if_map , else_map = conditional_types_to_typemaps (
6774+ operands [i ], * conditional_types (expr_type , [target ])
6775+ )
6776+ partial_type_maps .append ((if_map , else_map ))
6777+
6778+ if type_targets :
6779+ for i in expr_indices :
6780+ if i not in narrowable_indices :
6781+ continue
6782+ for j , target in type_targets :
6783+ if i == j :
6784+ continue
6785+ expr_type = operand_types [i ]
6786+ if_map , else_map = conditional_types_to_typemaps (
6787+ operands [i ], * conditional_types (expr_type , [target ])
6788+ )
6789+ if if_map :
6790+ else_map = {} # this is the big difference compared to the above
6791+ partial_type_maps .append ((if_map , else_map ))
6792+
6793+ # We will not have duplicate entries in our type maps if we only have two operands,
6794+ # so we can skip running meets on the intersections
6795+ return reduce_conditional_maps (partial_type_maps , use_meet = len (operands ) > 2 )
67266796
67276797 def propagate_up_typemap_info (self , new_types : TypeMap ) -> TypeMap :
67286798 """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -6905,103 +6975,6 @@ def _propagate_walrus_assignments(
69056975 return parent_expr
69066976 return expr
69076977
6908- def narrow_identity_equality_comparison (
6909- self ,
6910- operands : list [Expression ],
6911- operand_types : list [Type ],
6912- chain_indices : list [int ],
6913- narrowable_operand_indices : AbstractSet [int ],
6914- is_valid_target : Callable [[ProperType ], bool ],
6915- coerce_only_in_literal_context : bool ,
6916- ) -> tuple [TypeMap , TypeMap ]:
6917- """Produce conditional type maps refining expressions by an identity/equality comparison.
6918-
6919- The 'operands' and 'operand_types' lists should be the full list of operands used
6920- in the overall comparison expression. The 'chain_indices' list is the list of indices
6921- actually used within this identity comparison chain.
6922-
6923- So if we have the expression:
6924-
6925- a <= b is c is d <= e
6926-
6927- ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
6928- would be the list [1, 2, 3].
6929-
6930- The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
6931- to refine the types of: that is, all operands that will potentially be a part of
6932- the output TypeMaps.
6933-
6934- Although this function could theoretically try setting the types of the operands
6935- in the chains to the meet, doing that causes too many issues in real-world code.
6936- Instead, we use 'is_valid_target' to identify which of the given chain types
6937- we could plausibly use as the refined type for the expressions in the chain.
6938-
6939- Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
6940- expressions in the chain to a Literal type. Performing this coercion is sometimes
6941- too aggressive of a narrowing, depending on context.
6942- """
6943-
6944- if coerce_only_in_literal_context :
6945- should_coerce = False
6946- for i in chain_indices :
6947- typ = get_proper_type (operand_types [i ])
6948- if is_literal_type_like (typ ) or (isinstance (typ , Instance ) and typ .type .is_enum ):
6949- should_coerce = True
6950- break
6951- else :
6952- should_coerce = True
6953-
6954- value_targets = []
6955- type_targets = []
6956- for i in chain_indices :
6957- expr_type = operand_types [i ]
6958- if should_coerce :
6959- # TODO: doing this prevents narrowing a single-member Enum to literal
6960- # of its member, because we expand it here and then refuse to add equal
6961- # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6962- # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6963- # See testMatchEnumSingleChoice
6964- expr_type = coerce_to_literal (expr_type )
6965- if is_valid_target (get_proper_type (expr_type )):
6966- value_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6967- else :
6968- type_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6969-
6970- partial_type_maps = []
6971-
6972- if value_targets :
6973- for i in chain_indices :
6974- if i not in narrowable_operand_indices :
6975- continue
6976- for j , target in value_targets :
6977- if i == j :
6978- continue
6979- expr_type = coerce_to_literal (operand_types [i ])
6980- expr_type = try_expanding_sum_type_to_union (expr_type , None )
6981- if_map , else_map = conditional_types_to_typemaps (
6982- operands [i ], * conditional_types (expr_type , [target ])
6983- )
6984- partial_type_maps .append ((if_map , else_map ))
6985-
6986- if type_targets :
6987- for i in chain_indices :
6988- if i not in narrowable_operand_indices :
6989- continue
6990- for j , target in type_targets :
6991- if i == j :
6992- continue
6993- expr_type = operand_types [i ]
6994- if_map , else_map = conditional_types_to_typemaps (
6995- operands [i ], * conditional_types (expr_type , [target ])
6996- )
6997- if if_map :
6998- else_map = {}
6999- partial_type_maps .append ((if_map , else_map ))
7000-
7001- # We will not have duplicate entries in our type maps if we only have two operands,
7002- # so we can skip running meets on the intersections
7003- return reduce_conditional_maps (partial_type_maps , use_meet = len (operands ) > 2 )
7004-
70056978 def refine_away_none_in_comparison (
70066979 self ,
70076980 operands : list [Expression ],
@@ -7012,7 +6985,7 @@ def refine_away_none_in_comparison(
70126985 """Produces conditional type maps refining away None in an identity/equality chain.
70136986
70146987 For more details about what the different arguments mean, see the
7015- docstring of 'refine_identity_comparison_expression ' up above.
6988+ docstring of 'narrow_type_by_equality ' up above.
70166989 """
70176990
70186991 non_optional_types = []
@@ -8596,11 +8569,6 @@ def reduce_conditional_maps(
85968569 return final_if_map , final_else_map
85978570
85988571
8599- def is_singleton_value (t : Type ) -> bool :
8600- t = get_proper_type (t )
8601- return isinstance (t , LiteralType ) or t .is_singleton_type ()
8602-
8603-
86048572BUILTINS_CUSTOM_EQ_CHECKS : Final = {
86058573 "builtins.bytes" ,
86068574 "builtins.bytearray" ,
0 commit comments