Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,6 +2397,8 @@ def parse_file(self, *, temporary: bool = False) -> None:
self.source_hash = compute_hash(source)

self.parse_inline_configuration(source)
self.check_for_invalid_options()

self.size_hint = len(source)
if not cached:
self.tree = manager.parse_file(
Expand Down Expand Up @@ -2447,6 +2449,13 @@ def parse_inline_configuration(self, source: str) -> None:
for lineno, error in config_errors:
self.manager.errors.report(lineno, 0, error)

def check_for_invalid_options(self) -> None:
if self.options.mypyc and not self.options.strict_bytes:
self.manager.errors.set_file(self.xpath, self.id, options=self.options)
self.manager.errors.report(
1, 0, "Option --strict-bytes cannot be disabled when using mypyc", blocker=True
)

def semantic_analysis_pass1(self) -> None:
"""Perform pass 1 of semantic analysis, which happens immediately after parsing.

Expand Down
4 changes: 4 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@ def process_options(
fscache: FileSystemCache | None = None,
program: str = "mypy",
header: str = HEADER,
mypyc: bool = False,
) -> tuple[list[BuildSource], Options]:
"""Parse command line arguments.

Expand Down Expand Up @@ -1398,6 +1399,9 @@ def process_options(

options = Options()
strict_option_set = False
if mypyc:
# Mypyc has strict_bytes enabled by default
options.strict_bytes = True

def set_strict_flags() -> None:
nonlocal strict_option_set
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeshed/stubs/librt/librt/strings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from mypy_extensions import i64, u8
@final
class BytesWriter:
def append(self, /, x: int) -> None: ...
def write(self, /, b: bytes) -> None: ...
def write(self, /, b: bytes | bytearray) -> None: ...
def getvalue(self) -> bytes: ...
def truncate(self, /, size: i64) -> None: ...
def __len__(self) -> i64: ...
Expand Down
2 changes: 1 addition & 1 deletion mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_mypy_config(
fscache: FileSystemCache | None,
) -> tuple[list[BuildSource], list[BuildSource], Options]:
"""Construct mypy BuildSources and Options from file and options lists"""
all_sources, options = process_options(mypy_options, fscache=fscache)
all_sources, options = process_options(mypy_options, fscache=fscache, mypyc=True)
if only_compile_paths is not None:
paths_set = set(only_compile_paths)
mypyc_sources = [s for s in all_sources if s.path in paths_set]
Expand Down
2 changes: 1 addition & 1 deletion mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def emit_cast(
elif is_bytes_rprimitive(typ):
if declare_dest:
self.emit_line(f"PyObject *{dest};")
check = "(PyBytes_Check({}) || PyByteArray_Check({}))"
check = "(PyBytes_Check({}))"
if likely:
check = f"(likely{check})"
self.emit_arg_check(src, dest, typ, check.format(src, src), optional)
Expand Down
10 changes: 10 additions & 0 deletions mypyc/test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,13 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__)
[out]
B
pkg2.mod2

[case testStrictBytesRequired]
# cmd: --no-strict-bytes a.py

[file a.py]
def f(b: bytes) -> None: pass
f(bytearray())

[out]
a.py:1: error: Option --strict-bytes cannot be disabled when using mypyc
11 changes: 7 additions & 4 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class bytes:
def __init__(self) -> None: ...
@overload
def __init__(self, x: object) -> None: ...
def __add__(self, x: bytes) -> bytes: ...
def __add__(self, x: bytes | bytearray) -> bytes: ...
def __mul__(self, x: int) -> bytes: ...
def __rmul__(self, x: int) -> bytes: ...
def __eq__(self, x: object) -> bool: ...
Expand All @@ -178,8 +178,8 @@ def __getitem__(self, i: int) -> int: ...
def __getitem__(self, i: slice) -> bytes: ...
def join(self, x: Iterable[object]) -> bytes: ...
def decode(self, encoding: str=..., errors: str=...) -> str: ...
def translate(self, t: bytes) -> bytes: ...
def startswith(self, t: bytes) -> bool: ...
def translate(self, t: bytes | bytearray) -> bytes: ...
def startswith(self, t: bytes | bytearray) -> bool: ...
def __iter__(self) -> Iterator[int]: ...

class bytearray:
Expand All @@ -189,9 +189,12 @@ def __init__(self) -> None: pass
def __init__(self, x: object) -> None: pass
@overload
def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass
def __add__(self, s: bytes) -> bytearray: ...
def __add__(self, s: bytes | bytearray) -> bytearray: ...
def __setitem__(self, i: int, o: int) -> None: ...
@overload
def __getitem__(self, i: int) -> int: ...
@overload
def __getitem__(self, i: slice) -> bytearray: ...
def decode(self, x: str = ..., y: str = ...) -> str: ...
def startswith(self, t: bytes) -> bool: ...

Expand Down
13 changes: 13 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,16 @@ L0:
r0 = CPyBytes_Startswith(a, b)
r1 = truncate r0: i32 to builtins.bool
return r1

[case testBytesVsBytearray]
def bytes_func(b: bytes) -> None: pass
def bytearray_func(ba: bytearray) -> None: pass

def foo(b: bytes, ba: bytearray) -> None:
bytes_func(b)
bytearray_func(ba)
bytes_func(ba)
bytearray_func(b)
[out]
main:7: error: Argument 1 to "bytes_func" has incompatible type "bytearray"; expected "bytes"
main:8: error: Argument 1 to "bytearray_func" has incompatible type "bytes"; expected "bytearray"
6 changes: 3 additions & 3 deletions mypyc/test-data/run-base64.test
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[case testAllBase64Features_librt_experimental]
from typing import Any
from typing import Any, cast
import base64
import binascii
import random
Expand All @@ -14,7 +14,7 @@ def test_encode_basic() -> None:
assert b64encode(b"x") == b"eA=="

with assertRaises(TypeError):
b64encode(bytearray(b"x"))
b64encode(cast(Any, bytearray(b"x")))

def check_encode(b: bytes) -> None:
assert b64encode(b) == getattr(base64, "b64encode")(b)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_decode_basic() -> None:
assert b64decode(b"eA==") == b"x"

with assertRaises(TypeError):
b64decode(bytearray(b"eA=="))
b64decode(cast(Any, bytearray(b"eA==")))

for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar":
with assertRaises(ValueError):
Expand Down
44 changes: 22 additions & 22 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_concat() -> None:
assert type(b1) == bytes
assert type(b2) == bytes
assert type(b3) == bytes
brr1: bytes = bytearray(3)
brr2: bytes = bytearray(range(5))
brr1 = bytearray(3)
brr2 = bytearray(range(5))
b4 = b1 + brr1
assert b4 == b'123\x00\x00\x00'
assert type(brr1) == bytearray
Expand All @@ -94,9 +94,9 @@ def test_concat() -> None:
b5 = brr2 + b2
assert b5 == bytearray(b'\x00\x01\x02\x03\x04456')
assert type(b5) == bytearray
b5 = b2 + brr2
assert b5 == b'456\x00\x01\x02\x03\x04'
assert type(b5) == bytes
b6 = b2 + brr2
assert b6 == b'456\x00\x01\x02\x03\x04'
assert type(b6) == bytes

def test_join() -> None:
seq = (b'1', b'"', b'\xf0')
Expand Down Expand Up @@ -217,9 +217,9 @@ def test_startswith() -> None:
assert test.startswith(bytearray(b'some'))
assert not test.startswith(bytearray(b'other'))

test = bytearray(b'some string')
assert test.startswith(b'some')
assert not test.startswith(b'other')
test2 = bytearray(b'some string')
assert test2.startswith(b'some')
assert not test2.startswith(b'other')

[case testBytesSlicing]
def test_bytes_slicing() -> None:
Expand Down Expand Up @@ -257,34 +257,38 @@ def test_bytes_slicing() -> None:
[case testBytearrayBasics]
from typing import Any

from testutil import assertRaises

def test_basics() -> None:
brr1: bytes = bytearray(3)
brr1 = bytearray(3)
assert brr1 == bytearray(b'\x00\x00\x00')
assert brr1 == b'\x00\x00\x00'
l = [10, 20, 30, 40]
brr2: bytes = bytearray(l)
brr2 = bytearray(l)
assert brr2 == bytearray(b'\n\x14\x1e(')
assert brr2 == b'\n\x14\x1e('
brr3: bytes = bytearray(range(5))
brr3 = bytearray(range(5))
assert brr3 == bytearray(b'\x00\x01\x02\x03\x04')
assert brr3 == b'\x00\x01\x02\x03\x04'
brr4: bytes = bytearray('string', 'utf-8')
brr4 = bytearray('string', 'utf-8')
assert brr4 == bytearray(b'string')
assert brr4 == b'string'
assert len(brr1) == 3
assert len(brr2) == 4

def f(b: bytes) -> bool:
return True
def f(b: bytes) -> str:
return "xy"

def test_bytearray_passed_into_bytes() -> None:
assert f(bytearray(3))
brr1: Any = bytearray()
assert f(brr1)
with assertRaises(TypeError, "bytes object expected; got bytearray"):
f(brr1)
with assertRaises(TypeError, "bytes object expected; got bytearray"):
b: bytes = brr1

[case testBytearraySlicing]
def test_bytearray_slicing() -> None:
b: bytes = bytearray(b'abcdefg')
b = bytearray(b'abcdefg')
zero = int()
ten = 10 + zero
two = 2 + zero
Expand Down Expand Up @@ -318,7 +322,7 @@ def test_bytearray_slicing() -> None:
from testutil import assertRaises

def test_bytearray_indexing() -> None:
b: bytes = bytearray(b'\xae\x80\xfe\x15')
b = bytearray(b'\xae\x80\xfe\x15')
assert b[0] == 174
assert b[1] == 128
assert b[2] == 254
Expand Down Expand Up @@ -347,10 +351,6 @@ def test_bytes_join() -> None:
assert b' '.join([b'a', b'b']) == b'a b'
assert b' '.join([]) == b''

x: bytes = bytearray(b' ')
assert x.join([b'a', b'b']) == b'a b'
assert type(x.join([b'a', b'b'])) == bytearray

y: bytes = bytes_subclass()
assert y.join([]) == b'spook'

Expand Down
6 changes: 3 additions & 3 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -892,15 +892,15 @@ def test_decode_error() -> None:
pass

def test_decode_bytearray() -> None:
b: bytes = bytearray(b'foo\x00bar')
b = bytearray(b'foo\x00bar')
assert b.decode() == 'foo\x00bar'
assert b.decode('utf-8') == 'foo\x00bar'
assert b.decode('latin-1') == 'foo\x00bar'
assert b.decode('ascii') == 'foo\x00bar'
assert b.decode('utf-8' + str()) == 'foo\x00bar'
assert b.decode('latin-1' + str()) == 'foo\x00bar'
assert b.decode('ascii' + str()) == 'foo\x00bar'
b2: bytes = bytearray(b'foo\x00bar\xbe')
b2 = bytearray(b'foo\x00bar\xbe')
assert b2.decode('latin-1') == 'foo\x00bar\xbe'
with assertRaises(UnicodeDecodeError):
b2.decode('ascii')
Expand All @@ -910,7 +910,7 @@ def test_decode_bytearray() -> None:
b2.decode('utf-8')
with assertRaises(UnicodeDecodeError):
b2.decode('utf-8' + str())
b3: bytes = bytearray(b'Z\xc3\xbcrich')
b3 = bytearray(b'Z\xc3\xbcrich')
assert b3.decode("utf-8") == 'Zürich'

def test_invalid_encoding() -> None:
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) ->
options.use_builtins_fixtures = True
options.show_traceback = True
options.strict_optional = True
options.strict_bytes = True
options.disable_bytearray_promotion = True
options.disable_memoryview_promotion = True
options.python_version = sys.version_info[:2]
options.export_types = True
options.preserve_asts = True
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def build_ir_for_single_file2(
options.export_types = True
options.preserve_asts = True
options.allow_empty_bodies = True
options.strict_bytes = True
options.disable_bytearray_promotion = True
options.disable_memoryview_promotion = True
options.per_module_options["__main__"] = {"mypyc": True}

source = build.BuildSource("main", "__main__", program_text)
Expand Down