Skip to content

Commit df947c3

Browse files
refactor disallow_str_iteration checks and add check to is_overlapping_types
1 parent 919a42a commit df947c3

File tree

7 files changed

+81
-29
lines changed

7 files changed

+81
-29
lines changed

mypy/checker.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
156156
from mypy.semanal_shared import SemanticAnalyzerCoreInterface
157157
from mypy.sharedparse import BINARY_MAGIC_METHODS
158+
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
158159
from mypy.state import state
159160
from mypy.subtypes import (
160161
find_member,
@@ -514,7 +515,11 @@ def check_first_pass(self) -> None:
514515
Deferred functions will be processed by check_second_pass().
515516
"""
516517
self.recurse_into_functions = True
517-
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
518+
with (
519+
state.strict_optional_set(self.options.strict_optional),
520+
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
521+
checker_state.set(self),
522+
):
518523
self.errors.set_file(
519524
self.path, self.tree.fullname, scope=self.tscope, options=self.options
520525
)
@@ -559,7 +564,11 @@ def check_second_pass(
559564
"""
560565
self.allow_constructor_cache = allow_constructor_cache
561566
self.recurse_into_functions = True
562-
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
567+
with (
568+
state.strict_optional_set(self.options.strict_optional),
569+
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
570+
checker_state.set(self),
571+
):
563572
if not todo and not self.deferred_nodes:
564573
return False
565574
self.errors.set_file(
@@ -5382,7 +5391,9 @@ def analyze_iterable_item_type_without_expression(
53825391
iterable: Type
53835392
iterable = get_proper_type(type)
53845393

5385-
if self.options.disallow_str_iteration and self.is_str_iteration_type(iterable):
5394+
if disallow_str_iteration_state.disallow_str_iteration and self.is_str_iteration_type(
5395+
iterable
5396+
):
53865397
self.msg.str_iteration_disallowed(context, iterable)
53875398

53885399
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from contextlib import contextmanager
5+
from typing import Final
6+
7+
8+
class DisallowStrIterationState:
9+
# Wrap this in a class since it's faster that using a module-level attribute.
10+
11+
def __init__(self, disallow_str_iteration: bool) -> None:
12+
# Value varies by file being processed
13+
self.disallow_str_iteration = disallow_str_iteration
14+
15+
@contextmanager
16+
def set(self, value: bool) -> Iterator[None]:
17+
saved = self.disallow_str_iteration
18+
self.disallow_str_iteration = value
19+
try:
20+
yield
21+
finally:
22+
self.disallow_str_iteration = saved
23+
24+
25+
disallow_str_iteration_state: Final = DisallowStrIterationState(disallow_str_iteration=False)

mypy/meet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mypy import join
66
from mypy.erasetype import erase_type
77
from mypy.maptype import map_instance_to_supertype
8+
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
89
from mypy.state import state
910
from mypy.subtypes import (
1011
are_parameters_compatible,
@@ -14,6 +15,7 @@
1415
is_proper_subtype,
1516
is_same_type,
1617
is_subtype,
18+
is_subtype_relation_ignored_to_disallow_str_iteration,
1719
)
1820
from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback
1921
from mypy.types import (
@@ -596,6 +598,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
596598
if right.type.fullname == "builtins.int" and left.type.fullname in MYPYC_NATIVE_INT_NAMES:
597599
return True
598600

601+
if disallow_str_iteration_state.disallow_str_iteration:
602+
if is_subtype_relation_ignored_to_disallow_str_iteration(left, right):
603+
return False
604+
elif is_subtype_relation_ignored_to_disallow_str_iteration(right, left):
605+
return False
606+
599607
# Two unrelated types cannot be partially overlapping: they're disjoint.
600608
if left.type.has_base(right.type.fullname):
601609
left = map_instance_to_supertype(left, right.type)

mypy/subtypes.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Var,
3535
)
3636
from mypy.options import Options
37+
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
3738
from mypy.state import state
3839
from mypy.types import (
3940
MYPYC_NATIVE_INT_NAMES,
@@ -481,22 +482,9 @@ def visit_instance(self, left: Instance) -> bool:
481482
right = self.right
482483

483484
if (
484-
self.options
485-
and self.options.disallow_str_iteration
486-
and left.type.has_base("builtins.str")
485+
disallow_str_iteration_state.disallow_str_iteration
487486
and isinstance(right, Instance)
488-
and not right.type.has_base("builtins.str")
489-
and any(
490-
right.type.has_base(base)
491-
for base in (
492-
"collections.abc.Collection",
493-
"collections.abc.Iterable",
494-
"collections.abc.Sequence",
495-
"typing.Collection",
496-
"typing.Iterable",
497-
"typing.Sequence",
498-
)
499-
)
487+
and is_subtype_relation_ignored_to_disallow_str_iteration(left, right)
500488
):
501489
return False
502490
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
@@ -2331,3 +2319,21 @@ def is_erased_instance(t: Instance) -> bool:
23312319
elif not isinstance(get_proper_type(arg), AnyType):
23322320
return False
23332321
return True
2322+
2323+
2324+
def is_subtype_relation_ignored_to_disallow_str_iteration(left: Instance, right: Instance) -> bool:
2325+
return (
2326+
left.type.has_base("builtins.str")
2327+
and not right.type.has_base("builtins.str")
2328+
and any(
2329+
right.type.has_base(base)
2330+
for base in (
2331+
"collections.abc.Collection",
2332+
"collections.abc.Iterable",
2333+
"collections.abc.Sequence",
2334+
"typing.Collection",
2335+
"typing.Iterable",
2336+
"typing.Sequence",
2337+
)
2338+
)
2339+
)

mypy/typeops.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -638,15 +638,7 @@ def make_simplified_union(
638638

639639

640640
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
641-
from mypy.subtypes import SubtypeContext, is_proper_subtype
642-
643-
subtype_context = SubtypeContext(
644-
ignore_promotions=True,
645-
keep_erased_types=keep_erased,
646-
options=(
647-
checker_state.type_checker.options if checker_state.type_checker is not None else None
648-
),
649-
)
641+
from mypy.subtypes import is_proper_subtype
650642

651643
# The first pass through this loop, we check if later items are subtypes of earlier items.
652644
# The second pass through this loop, we check if earlier items are subtypes of later items
@@ -695,7 +687,9 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[
695687
):
696688
continue
697689

698-
if is_proper_subtype(ti, tj, subtype_context=subtype_context):
690+
if is_proper_subtype(
691+
ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
692+
):
699693
duplicate_index = j
700694
break
701695
if duplicate_index != -1:

test-data/unit/check-flags.test

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2508,9 +2508,15 @@ takes_collection_subclass(StrSubclass()) # E: Argument 1 to "takes_collection_s
25082508
# N: "StrSubclass" is missing following "CollectionSubclass" protocol member: \
25092509
# N: __missing_impl__
25102510

2511-
def repro(x: Mapping[str, Union[str, Sequence[str]]]) -> None:
2511+
def dict_unpacking_unaffected_by_union_simplification(x: Mapping[str, Union[str, Sequence[str]]]) -> None:
25122512
x = {**x}
25132513

2514+
def narrowing(x: "str | Sequence[str]"):
2515+
if isinstance(x, str):
2516+
reveal_type(x) # N: Revealed type is "builtins.str"
2517+
else:
2518+
reveal_type(x) # N: Revealed type is "typing.Sequence[builtins.str]"
2519+
25142520
[builtins fixtures/str-iter.pyi]
25152521
[typing fixtures/typing-str-iter.pyi]
25162522

test-data/unit/fixtures/str-iter.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,5 @@ class dict(Mapping[_KT, _VT], Generic[_KT, _VT]):
4848
def __len__(self) -> int: pass
4949
def __contains__(self, item: object) -> bool: pass
5050
def __getitem__(self, key: _KT) -> _VT: pass
51+
52+
def isinstance(x: object, t: type) -> bool: pass

0 commit comments

Comments
 (0)