diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 785a0a7ac61f..d75c9144dcda 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -95,7 +95,12 @@ tokenizer_format_call, ) from mypyc.primitives.bytearray_ops import isinstance_bytearray -from mypyc.primitives.bytes_ops import isinstance_bytes +from mypyc.primitives.bytes_ops import ( + bytes_adjust_index_op, + bytes_get_item_unsafe_op, + bytes_range_check_op, + isinstance_bytes, +) from mypyc.primitives.dict_ops import ( dict_items_op, dict_keys_op, @@ -1207,30 +1212,50 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line) -@specialize_dunder("__getitem__", bytes_writer_rprimitive) -def translate_bytes_writer_get_item( - builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression +def translate_getitem_with_bounds_check( + builder: IRBuilder, + base_expr: Expression, + args: list[Expression], + ctx_expr: Expression, + adjust_index_op: PrimitiveDescription, + range_check_op: PrimitiveDescription, + get_item_unsafe_op: PrimitiveDescription, ) -> Value | None: - """Optimized BytesWriter.__getitem__ implementation with bounds checking.""" + """Shared helper for optimized __getitem__ with bounds checking. + + This implements the common pattern of: + 1. Adjusting negative indices + 2. Checking if index is in valid range + 3. Raising IndexError if out of range + 4. Getting the item if in range + + Args: + builder: The IR builder + base_expr: The base object expression + args: The arguments to __getitem__ (should be length 1) + ctx_expr: The context expression for line numbers + adjust_index_op: Primitive op to adjust negative indices + range_check_op: Primitive op to check if index is in valid range + get_item_unsafe_op: Primitive op to get item (no bounds checking) + + Returns: + The result value, or None if optimization doesn't apply + """ # Check that we have exactly one argument if len(args) != 1: return None - # Get the BytesWriter object + # Get the object obj = builder.accept(base_expr) # Get the index argument index = builder.accept(args[0]) # Adjust the index (handle negative indices) - adjusted_index = builder.primitive_op( - bytes_writer_adjust_index_op, [obj, index], ctx_expr.line - ) + adjusted_index = builder.primitive_op(adjust_index_op, [obj, index], ctx_expr.line) # Check if the adjusted index is in valid range - range_check = builder.primitive_op( - bytes_writer_range_check_op, [obj, adjusted_index], ctx_expr.line - ) + range_check = builder.primitive_op(range_check_op, [obj, adjusted_index], ctx_expr.line) # Create blocks for branching valid_block = BasicBlock() @@ -1247,13 +1272,27 @@ def translate_bytes_writer_get_item( # Handle valid index - get the item builder.activate_block(valid_block) - result = builder.primitive_op( - bytes_writer_get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line - ) + result = builder.primitive_op(get_item_unsafe_op, [obj, adjusted_index], ctx_expr.line) return result +@specialize_dunder("__getitem__", bytes_writer_rprimitive) +def translate_bytes_writer_get_item( + builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression +) -> Value | None: + """Optimized BytesWriter.__getitem__ implementation with bounds checking.""" + return translate_getitem_with_bounds_check( + builder, + base_expr, + args, + ctx_expr, + bytes_writer_adjust_index_op, + bytes_writer_range_check_op, + bytes_writer_get_item_unsafe_op, + ) + + @specialize_dunder("__setitem__", bytes_writer_rprimitive) def translate_bytes_writer_set_item( builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression @@ -1300,3 +1339,19 @@ def translate_bytes_writer_set_item( ) return builder.none() + + +@specialize_dunder("__getitem__", bytes_rprimitive) +def translate_bytes_get_item( + builder: IRBuilder, base_expr: Expression, args: list[Expression], ctx_expr: Expression +) -> Value | None: + """Optimized bytes.__getitem__ implementation with bounds checking.""" + return translate_getitem_with_bounds_check( + builder, + base_expr, + args, + ctx_expr, + bytes_adjust_index_op, + bytes_range_check_op, + bytes_get_item_unsafe_op, + ) diff --git a/mypyc/lib-rt/bytes_extra_ops.h b/mypyc/lib-rt/bytes_extra_ops.h index eebb5a345438..0f2917764ba0 100644 --- a/mypyc/lib-rt/bytes_extra_ops.h +++ b/mypyc/lib-rt/bytes_extra_ops.h @@ -2,9 +2,30 @@ #define MYPYC_BYTES_EXTRA_OPS_H #include +#include #include "CPy.h" // Optimized bytes translate operation PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table); +// Optimized bytes.__getitem__ operations + +// If index is negative, convert to non-negative index (no range checking) +static inline int64_t CPyBytes_AdjustIndex(PyObject *obj, int64_t index) { + if (index < 0) { + return index + Py_SIZE(obj); + } + return index; +} + +// Check if index is in valid range [0, len) +static inline bool CPyBytes_RangeCheck(PyObject *obj, int64_t index) { + return index >= 0 && index < Py_SIZE(obj); +} + +// Get byte at index (no bounds checking) - returns as CPyTagged +static inline CPyTagged CPyBytes_GetItemUnsafe(PyObject *obj, int64_t index) { + return ((CPyTagged)(uint8_t)(PyBytes_AS_STRING(obj))[index]) << 1; +} + #endif diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 53e7832c3998..0b32c7937ba1 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -12,6 +12,7 @@ c_int_rprimitive, c_pyssize_t_rprimitive, dict_rprimitive, + int64_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, @@ -21,6 +22,7 @@ ERR_NEG_INT, binary_op, custom_op, + custom_primitive_op, function_op, load_address_op, method_op, @@ -148,3 +150,38 @@ c_function_name="CPyBytes_Ord", error_kind=ERR_MAGIC, ) + +# Optimized bytes.__getitem__ operations + +# bytes index adjustment - convert negative index to positive +bytes_adjust_index_op = custom_primitive_op( + name="bytes_adjust_index", + arg_types=[bytes_rprimitive, int64_rprimitive], + return_type=int64_rprimitive, + c_function_name="CPyBytes_AdjustIndex", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[BYTES_EXTRA_OPS], +) + +# bytes range check - check if index is in valid range +bytes_range_check_op = custom_primitive_op( + name="bytes_range_check", + arg_types=[bytes_rprimitive, int64_rprimitive], + return_type=bool_rprimitive, + c_function_name="CPyBytes_RangeCheck", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[BYTES_EXTRA_OPS], +) + +# bytes.__getitem__() - get byte at index (no bounds checking) +bytes_get_item_unsafe_op = custom_primitive_op( + name="bytes_get_item_unsafe", + arg_types=[bytes_rprimitive, int64_rprimitive], + return_type=int_rprimitive, + c_function_name="CPyBytes_GetItemUnsafe", + error_kind=ERR_NEVER, + experimental=True, + dependencies=[BYTES_EXTRA_OPS], +) diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 6ec09bedad48..391be56f00d2 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -100,15 +100,26 @@ L0: return r0 [case testBytesIndex] -def f(a: bytes, i: int) -> int: +from mypy_extensions import i64 + +def f(a: bytes, i: i64) -> int: return a[i] [out] def f(a, i): a :: bytes - i, r0 :: int + i, r0 :: i64 + r1, r2 :: bool + r3 :: int L0: - r0 = CPyBytes_GetItem(a, i) - return r0 + r0 = CPyBytes_AdjustIndex(a, i) + r1 = CPyBytes_RangeCheck(a, r0) + if r1 goto L2 else goto L1 :: bool +L1: + r2 = raise IndexError('index out of range') + unreachable +L2: + r3 = CPyBytes_GetItemUnsafe(a, r0) + return r3 [case testBytesConcat] def f(a: bytes, b: bytes) -> bytes: