From 4cbccd5562479f59ac9645c62591b3ce257178ff Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Tue, 6 Jan 2026 14:52:30 +0800 Subject: [PATCH 01/11] add typings to expr.py --- python/tvm/tir/expr.py | 68 +++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f5476230c19b..1c4683f64f1b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,7 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import List, Optional, Union +from typing import List, Optional, Self, Union import tvm_ffi import tvm.ir._ffi_api @@ -74,111 +74,111 @@ class ExprOp: # TODO(tkonolige): use inspect to add source information to these objects - def __add__(self, other: PrimExpr) -> PrimExpr: + def __add__(self, other: Union[PrimExpr, float, int]) -> "Add": return _generic.add(self, other) - def __radd__(self, other: PrimExpr) -> PrimExpr: + def __radd__(self, other: Union[PrimExpr, float, int]) -> "Add": return _generic.add(other, self) - def __sub__(self, other: PrimExpr) -> PrimExpr: + def __sub__(self, other: Union[PrimExpr, float, int]) -> "Sub": return _generic.subtract(self, other) - def __rsub__(self, other: PrimExpr) -> PrimExpr: + def __rsub__(self, other: Union[PrimExpr, float, int]) -> "Sub": return _generic.subtract(other, self) - def __mul__(self, other: PrimExpr) -> PrimExpr: + def __mul__(self, other: Union[PrimExpr, float, int]) -> "Mul": return _generic.multiply(self, other) - def __rmul__(self, other: PrimExpr) -> PrimExpr: + def __rmul__(self, other: Union[PrimExpr, float, int]) -> "Mul": return _generic.multiply(other, self) - def __div__(self, other: PrimExpr) -> PrimExpr: + def __div__(self, other: Union[PrimExpr, float, int]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other: PrimExpr) -> PrimExpr: + def __rdiv__(self, other: Union[PrimExpr, float, int]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other: PrimExpr) -> PrimExpr: + def __truediv__(self, other: Union[PrimExpr, float, int]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other: PrimExpr) -> PrimExpr: + def __rtruediv__(self, other: Union[PrimExpr, float, int]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other: PrimExpr) -> PrimExpr: + def __floordiv__(self, other: Union[PrimExpr, float, int]) -> "FloorDiv": return _generic.floordiv(self, other) - def __rfloordiv__(self, other: PrimExpr) -> PrimExpr: + def __rfloordiv__(self, other: Union[PrimExpr, float, int]) -> "FloorDiv": return _generic.floordiv(other, self, None) - def __mod__(self, other: PrimExpr) -> PrimExpr: + def __mod__(self, other: Union[PrimExpr, float, int]) -> "Mod": return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other: PrimExpr) -> PrimExpr: + def __rmod__(self, other: Union[PrimExpr, float, int]) -> "Mod": return _ffi_api._OpFloorMod(other, self, None) # type: ignore - def __neg__(self) -> PrimExpr: + def __neg__(self) -> "Mul": neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) - def __lshift__(self, other: PrimExpr) -> PrimExpr: + def __lshift__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.left_shift(self, other, None) # type: ignore - def __rlshift__(self, other: PrimExpr) -> PrimExpr: + def __rlshift__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.left_shift(other, self, None) # type: ignore - def __rshift__(self, other: PrimExpr) -> PrimExpr: + def __rshift__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.right_shift(self, other, None) # type: ignore - def __rrshift__(self, other: PrimExpr) -> PrimExpr: + def __rrshift__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.right_shift(other, self, None) # type: ignore - def __and__(self, other: PrimExpr) -> PrimExpr: + def __and__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_and(self, other, None) # type: ignore - def __rand__(self, other: PrimExpr) -> PrimExpr: + def __rand__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_and(other, self, None) # type: ignore - def __or__(self, other: PrimExpr) -> PrimExpr: + def __or__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_or(self, other, None) # type: ignore - def __ror__(self, other: PrimExpr) -> PrimExpr: + def __ror__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other: PrimExpr) -> PrimExpr: + def __xor__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other: PrimExpr) -> PrimExpr: + def __rxor__(self, other: Union[PrimExpr, int]) -> "Call": return _ffi_api.bitwise_xor(other, self, None) # type: ignore - def __invert__(self) -> PrimExpr: + def __invert__(self) -> "Call": if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other: PrimExpr) -> PrimExpr: + def __lt__(self, other: PrimExpr) -> "LT": return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other: PrimExpr) -> PrimExpr: + def __le__(self, other: PrimExpr) -> "LE": return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other: PrimExpr) -> PrimExpr: + def __eq__(self, other: PrimExpr) -> "EqualOp": return EqualOp(self, other) - def __ne__(self, other: PrimExpr) -> PrimExpr: + def __ne__(self, other: PrimExpr) -> "NotEqualOp": return NotEqualOp(self, other) - def __gt__(self, other: PrimExpr) -> PrimExpr: + def __gt__(self, other: PrimExpr) -> "GT": return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other: PrimExpr) -> PrimExpr: + def __ge__(self, other: PrimExpr) -> "GE": return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): @@ -208,7 +208,7 @@ def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool: """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: + def astype(self, dtype: str, span: Optional[Span] = None) -> Union["Cast", "Self"]: """Cast the expression to other type. Parameters From 0e90827a5b5926875367a5e3caba592794c04171 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Tue, 6 Jan 2026 15:18:59 +0800 Subject: [PATCH 02/11] use numeric --- python/tvm/tir/expr.py | 43 ++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 1c4683f64f1b..451c7660ba4a 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -41,6 +41,9 @@ from .buffer import Buffer, DataProducer +numeric = Union[int, float, complex] + + def convert(expr) -> PrimExpr: return _ffi_api.convert(expr) @@ -74,54 +77,54 @@ class ExprOp: # TODO(tkonolige): use inspect to add source information to these objects - def __add__(self, other: Union[PrimExpr, float, int]) -> "Add": + def __add__(self, other: Union[PrimExpr, numeric]) -> "Add": return _generic.add(self, other) - def __radd__(self, other: Union[PrimExpr, float, int]) -> "Add": + def __radd__(self, other: Union[PrimExpr, numeric]) -> "Add": return _generic.add(other, self) - def __sub__(self, other: Union[PrimExpr, float, int]) -> "Sub": + def __sub__(self, other: Union[PrimExpr, numeric]) -> "Sub": return _generic.subtract(self, other) - def __rsub__(self, other: Union[PrimExpr, float, int]) -> "Sub": + def __rsub__(self, other: Union[PrimExpr, numeric]) -> "Sub": return _generic.subtract(other, self) - def __mul__(self, other: Union[PrimExpr, float, int]) -> "Mul": + def __mul__(self, other: Union[PrimExpr, numeric]) -> "Mul": return _generic.multiply(self, other) - def __rmul__(self, other: Union[PrimExpr, float, int]) -> "Mul": + def __rmul__(self, other: Union[PrimExpr, numeric]) -> "Mul": return _generic.multiply(other, self) - def __div__(self, other: Union[PrimExpr, float, int]) -> "Div": + def __div__(self, other: Union[PrimExpr, numeric]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other: Union[PrimExpr, float, int]) -> "Div": + def __rdiv__(self, other: Union[PrimExpr, numeric]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other: Union[PrimExpr, float, int]) -> "Div": + def __truediv__(self, other: Union[PrimExpr, numeric]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other: Union[PrimExpr, float, int]) -> "Div": + def __rtruediv__(self, other: Union[PrimExpr, numeric]) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other: Union[PrimExpr, float, int]) -> "FloorDiv": + def __floordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv": return _generic.floordiv(self, other) - def __rfloordiv__(self, other: Union[PrimExpr, float, int]) -> "FloorDiv": + def __rfloordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv": return _generic.floordiv(other, self, None) - def __mod__(self, other: Union[PrimExpr, float, int]) -> "Mod": + def __mod__(self, other: Union[PrimExpr, numeric]) -> "Mod": return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other: Union[PrimExpr, float, int]) -> "Mod": + def __rmod__(self, other: Union[PrimExpr, numeric]) -> "Mod": return _ffi_api._OpFloorMod(other, self, None) # type: ignore def __neg__(self) -> "Mul": @@ -163,22 +166,22 @@ def __invert__(self) -> "Call": raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other: PrimExpr) -> "LT": + def __lt__(self, other: Union[PrimExpr, numeric]) -> "LT": return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other: PrimExpr) -> "LE": + def __le__(self, other: Union[PrimExpr, numeric]) -> "LE": return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other: PrimExpr) -> "EqualOp": + def __eq__(self, other: Union[PrimExpr, numeric]) -> "EqualOp": return EqualOp(self, other) - def __ne__(self, other: PrimExpr) -> "NotEqualOp": + def __ne__(self, other: Union[PrimExpr, numeric]) -> "NotEqualOp": return NotEqualOp(self, other) - def __gt__(self, other: PrimExpr) -> "GT": + def __gt__(self, other: Union[PrimExpr, numeric]) -> "GT": return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other: PrimExpr) -> "GE": + def __ge__(self, other: Union[PrimExpr, numeric]) -> "GE": return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): From 88c498143ae9506255bf17e74dd34b2196ea99bf Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Tue, 6 Jan 2026 15:26:03 +0800 Subject: [PATCH 03/11] fix comment --- python/tvm/tir/expr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 451c7660ba4a..7c18ac326e67 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -155,10 +155,10 @@ def __or__(self, other: Union[PrimExpr, int, bool]) -> "Call": def __ror__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other: Union[PrimExpr, int]) -> "Call": + def __xor__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other: Union[PrimExpr, int]) -> "Call": + def __rxor__(self, other: Union[PrimExpr, int, bool]) -> "Call": return _ffi_api.bitwise_xor(other, self, None) # type: ignore def __invert__(self) -> "Call": From 632ae1eaf32bc9b6da85e5e5588a90b4904be024 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Tue, 6 Jan 2026 17:29:56 +0800 Subject: [PATCH 04/11] introduce PrimIntExpr, PrimFloatExpr, PrimLogicalExpr --- python/tvm/ir/__init__.py | 2 +- python/tvm/ir/expr.py | 7 ++- python/tvm/script/ir_builder/tir/ir.py | 10 ++-- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/expr.py | 65 ++++++++++++-------------- 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b74e9954d9cf..3d3399139184 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -32,7 +32,7 @@ structural_hash, ) from .container import Array, Map -from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelaxExpr +from .expr import BaseExpr, GlobalVar, PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr, Range, RelaxExpr from .function import BaseFunc, CallingConv from .global_info import GlobalInfo, DummyGlobalInfo, VDevice from .module import IRModule diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 19abb6bd1eae..f6a73ff88919 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -16,7 +16,7 @@ # under the License. """Common expressions data structures in the IR.""" from numbers import Number -from typing import Optional +from typing import Optional, Union import tvm import tvm_ffi @@ -44,6 +44,11 @@ class PrimExpr(BaseExpr): dtype: str +PrimIntExpr = Union[PrimExpr, int] +PrimFloatExpr = Union[PrimExpr, float] +PrimLogicalExpr = Union[PrimExpr, int, bool] + + @tvm_ffi.register_object("ir.RelaxExpr") class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a08e66789fa3..fbde47db22aa 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -39,7 +39,7 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr +from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr, PrimIntExpr from tvm.tir import op as _tir_op from tvm.tir import type_annotation @@ -271,11 +271,11 @@ def func_ret(ret_type: Type) -> Type: def match_buffer( param: Union[Var, BufferLoad, BufferRegion], - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, + shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr] = None, dtype: str = "float32", data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + strides: List[PrimIntExpr] = None, + elem_offset: PrimIntExpr = None, scope: str = "global", align: int = -1, offset_factor: int = 0, @@ -906,7 +906,7 @@ def thread_binding( ) -def grid(*extents: PrimExpr) -> frame.ForFrame: +def grid(*extents: PrimIntExpr) -> frame.ForFrame: """The grid For statement. Parameters diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 0a598e5e9bb9..b685c094f2d4 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-import, redefined-builtin """Namespace for Tensor-level IR""" -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr from tvm.runtime import const from .buffer import Buffer, decl_buffer, DataProducer diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 7c18ac326e67..be8a7909fb3d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -32,7 +32,7 @@ import tvm_ffi import tvm.ir._ffi_api from tvm import ir -from tvm.ir import Op, PrimExpr +from tvm.ir import Op, PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr from tvm.ir.base import Span from tvm.runtime import Object, ObjectConvertible, Scriptable, DataType, DataTypeCode, const @@ -41,9 +41,6 @@ from .buffer import Buffer, DataProducer -numeric = Union[int, float, complex] - - def convert(expr) -> PrimExpr: return _ffi_api.convert(expr) @@ -77,88 +74,88 @@ class ExprOp: # TODO(tkonolige): use inspect to add source information to these objects - def __add__(self, other: Union[PrimExpr, numeric]) -> "Add": + def __add__(self, other: PrimFloatExpr) -> "Add": return _generic.add(self, other) - def __radd__(self, other: Union[PrimExpr, numeric]) -> "Add": + def __radd__(self, other: PrimFloatExpr) -> "Add": return _generic.add(other, self) - def __sub__(self, other: Union[PrimExpr, numeric]) -> "Sub": + def __sub__(self, other: PrimFloatExpr) -> "Sub": return _generic.subtract(self, other) - def __rsub__(self, other: Union[PrimExpr, numeric]) -> "Sub": + def __rsub__(self, other: PrimFloatExpr) -> "Sub": return _generic.subtract(other, self) - def __mul__(self, other: Union[PrimExpr, numeric]) -> "Mul": + def __mul__(self, other: PrimFloatExpr) -> "Mul": return _generic.multiply(self, other) - def __rmul__(self, other: Union[PrimExpr, numeric]) -> "Mul": + def __rmul__(self, other: PrimFloatExpr) -> "Mul": return _generic.multiply(other, self) - def __div__(self, other: Union[PrimExpr, numeric]) -> "Div": + def __div__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rdiv__(self, other: Union[PrimExpr, numeric]) -> "Div": + def __rdiv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __truediv__(self, other: Union[PrimExpr, numeric]) -> "Div": + def __truediv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(self, other) - def __rtruediv__(self, other: Union[PrimExpr, numeric]) -> "Div": + def __rtruediv__(self, other: PrimFloatExpr) -> "Div": if _dtype_is_int(self) and _dtype_is_int(other): raise div_ambiguity_error() return _generic.divide(other, self) - def __floordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv": + def __floordiv__(self, other: PrimFloatExpr) -> "FloorDiv": return _generic.floordiv(self, other) - def __rfloordiv__(self, other: Union[PrimExpr, numeric]) -> "FloorDiv": + def __rfloordiv__(self, other: PrimFloatExpr) -> "FloorDiv": return _generic.floordiv(other, self, None) - def __mod__(self, other: Union[PrimExpr, numeric]) -> "Mod": + def __mod__(self, other: PrimFloatExpr) -> "Mod": return _ffi_api._OpFloorMod(self, other, None) # type: ignore - def __rmod__(self, other: Union[PrimExpr, numeric]) -> "Mod": + def __rmod__(self, other: PrimFloatExpr) -> "Mod": return _ffi_api._OpFloorMod(other, self, None) # type: ignore def __neg__(self) -> "Mul": neg_one = const(-1, self.dtype) # type: ignore return self.__mul__(neg_one) - def __lshift__(self, other: Union[PrimExpr, int]) -> "Call": + def __lshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.left_shift(self, other, None) # type: ignore - def __rlshift__(self, other: Union[PrimExpr, int]) -> "Call": + def __rlshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.left_shift(other, self, None) # type: ignore - def __rshift__(self, other: Union[PrimExpr, int]) -> "Call": + def __rshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.right_shift(self, other, None) # type: ignore - def __rrshift__(self, other: Union[PrimExpr, int]) -> "Call": + def __rrshift__(self, other: PrimIntExpr) -> "Call": return _ffi_api.right_shift(other, self, None) # type: ignore - def __and__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __and__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_and(self, other, None) # type: ignore - def __rand__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __rand__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_and(other, self, None) # type: ignore - def __or__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __or__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_or(self, other, None) # type: ignore - def __ror__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __ror__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_or(other, self, None) # type: ignore - def __xor__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __xor__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_xor(self, other, None) # type: ignore - def __rxor__(self, other: Union[PrimExpr, int, bool]) -> "Call": + def __rxor__(self, other: PrimLogicalExpr) -> "Call": return _ffi_api.bitwise_xor(other, self, None) # type: ignore def __invert__(self) -> "Call": @@ -166,22 +163,22 @@ def __invert__(self) -> "Call": raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.bitwise_not(self, None) # type: ignore - def __lt__(self, other: Union[PrimExpr, numeric]) -> "LT": + def __lt__(self, other: PrimFloatExpr) -> "LT": return _ffi_api._OpLT(self, other, None) # type: ignore - def __le__(self, other: Union[PrimExpr, numeric]) -> "LE": + def __le__(self, other: PrimFloatExpr) -> "LE": return _ffi_api._OpLE(self, other, None) # type: ignore - def __eq__(self, other: Union[PrimExpr, numeric]) -> "EqualOp": + def __eq__(self, other: PrimFloatExpr) -> "EqualOp": return EqualOp(self, other) - def __ne__(self, other: Union[PrimExpr, numeric]) -> "NotEqualOp": + def __ne__(self, other: PrimFloatExpr) -> "NotEqualOp": return NotEqualOp(self, other) - def __gt__(self, other: Union[PrimExpr, numeric]) -> "GT": + def __gt__(self, other: PrimFloatExpr) -> "GT": return _ffi_api._OpGT(self, other, None) # type: ignore - def __ge__(self, other: Union[PrimExpr, numeric]) -> "GE": + def __ge__(self, other: PrimFloatExpr) -> "GE": return _ffi_api._OpGE(self, other, None) # type: ignore def __nonzero__(self): From c4760392cd67272ee6ca2c67ade0464c3bacc897 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 14:57:29 +0800 Subject: [PATCH 05/11] apply PrimIntExpr to shape/range --- python/tvm/script/ir_builder/tir/ir.py | 198 ++++++++++++------------- 1 file changed, 99 insertions(+), 99 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index fbde47db22aa..80d83e40d5a7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -39,7 +39,7 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr, PrimIntExpr +from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr, PrimIntExpr, PrimLogicalExpr from tvm.tir import op as _tir_op from tvm.tir import type_annotation @@ -119,34 +119,34 @@ def block_name_suffix_context(block_suffix: str): def buffer( - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr], dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = 0, offset_factor: int = 0, buffer_type: str = "", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer declaration function. Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -161,7 +161,7 @@ def buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -271,16 +271,16 @@ def func_ret(ret_type: Type) -> Type: def match_buffer( param: Union[Var, BufferLoad, BufferRegion], - shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr] = None, + shape: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr]] = None, dtype: str = "float32", - data: Var = None, - strides: List[PrimIntExpr] = None, - elem_offset: PrimIntExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer match function. @@ -305,19 +305,19 @@ def match_buffer( param : Union[Var, BufferLoad, BufferRegion] The parameter of the PrimFunc to match. - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...], PrimIntExpr]] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -332,7 +332,7 @@ def match_buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -346,7 +346,7 @@ def match_buffer( shape = [region.extent for region in param.region] else: raise ValueError("Shape must be specified when binding input param") - shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + shape = (shape,) if isinstance(shape, (PrimExpr, Integral, int)) else shape if strides is not None: idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] @@ -400,12 +400,12 @@ def init() -> frame.BlockInitFrame: return _ffi_api.Init() # type: ignore[attr-defined] # pylint: disable=no-member -def where(predicate: Union[PrimExpr, int]) -> None: +def where(predicate: PrimLogicalExpr) -> None: """The block predicate statement. Parameters ---------- - predicate : Union[PrimExpr, Literal[0, 1]] + predicate : PrimLogicalExpr The predicate condition. """ if isinstance(predicate, bool): @@ -418,27 +418,27 @@ def where(predicate: Union[PrimExpr, int]) -> None: _ffi_api.Where(predicate) # type: ignore[attr-defined] # pylint: disable=no-member -def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: +def reads(*buffer_slices: Union[BufferRegion, BufferLoad]) -> None: """The block buffer region reading statement. Parameters ---------- - buffer_slices : List[Union[BufferRegion, BufferLoad]] + buffer_slices : Union[BufferRegion, BufferLoad] The array of buffer regions to read. """ if len(buffer_slices) == 1: if isinstance(buffer_slices[0], tuple): - buffer_slices = list(buffer_slices[0]) + buffer_slices = list(buffer_slices[0]) # type: ignore[assignment] elif isinstance(buffer_slices[0], list): buffer_slices = buffer_slices[0] # type: ignore[assignment] else: - buffer_slices = [buffer_slices[0]] + buffer_slices = [buffer_slices[0]] # type: ignore[assignment] else: buffer_slices = list(buffer_slices) # type: ignore[assignment] _ffi_api.Reads(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member -def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: +def writes(*buffer_slices: Union[BufferRegion, BufferLoad]) -> None: """The block buffer region writing statement. Parameters @@ -448,11 +448,11 @@ def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None: """ if len(buffer_slices) == 1: if isinstance(buffer_slices[0], tuple): - buffer_slices = list(buffer_slices[0]) + buffer_slices = list(buffer_slices[0]) # type: ignore[assignment] elif isinstance(buffer_slices[0], list): buffer_slices = buffer_slices[0] # type: ignore[assignment] else: - buffer_slices = [buffer_slices[0]] + buffer_slices = [buffer_slices[0]] # type: ignore[assignment] else: buffer_slices = list(buffer_slices) # type: ignore[assignment] _ffi_api.Writes(buffer_slices) # type: ignore[attr-defined] # pylint: disable=no-member @@ -470,34 +470,34 @@ def block_attr(attrs: Dict[str, Any]) -> None: def alloc_buffer( - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + shape: Union[List[PrimIntExpr], Tuple[PrimIntExpr], PrimIntExpr], dtype: str = "float32", - data: Var = None, - strides: List[PrimExpr] = None, - elem_offset: PrimExpr = None, + data: Optional[Var] = None, + strides: Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] = None, + elem_offset: Optional[PrimIntExpr] = None, scope: str = "global", align: int = -1, offset_factor: int = 0, buffer_type: str = "default", - axis_separators: List[int] = None, + axis_separators: Optional[List[int]] = None, ) -> Buffer: """The buffer alllocation function. Parameters ---------- - shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] + shape : Union[List[PrimIntExpr], Tuple[PrimIntExpr], PrimIntExpr] The type of the buffer prior to flattening. dtype : str The data type in the content of the buffer. - data : Var + data : Optional[Var] The pointer to the head of the data. - strides : List[PrimExpr] + strides : Optional[Union[List[PrimIntExpr], Tuple[PrimIntExpr, ...]]] The strides of each dimension. - elem_offset : PrimExpr + elem_offset : Optional[PrimIntExpr] The offset in terms of number of dtype elements (including lanes). scope : str @@ -512,7 +512,7 @@ def alloc_buffer( buffer_type : str The buffer type. - axis_separators : List[int] + axis_separators : Optional[List[int]] The separators between input axes when generating flattened output axes. Returns @@ -539,12 +539,12 @@ def alloc_buffer( ) -def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: +def _as_range(dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]]) -> ir.Range: """The range constructor. Parameters ---------- - dom : Union[Range, List[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] The domain. Returns @@ -571,7 +571,7 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -579,7 +579,7 @@ def spatial( Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] The domain of the iteration variable. binding : PrimExpr @@ -599,7 +599,7 @@ def spatial( @staticmethod def reduce( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -607,7 +607,7 @@ def reduce( Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] The domain of the iteration variable. binding : PrimExpr @@ -627,7 +627,7 @@ def reduce( @staticmethod def scan( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -635,7 +635,7 @@ def scan( Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] The domain of the iteration variable. binding : PrimExpr @@ -655,7 +655,7 @@ def scan( @staticmethod def opaque( - dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -663,7 +663,7 @@ def opaque( Parameters ---------- - dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] The domain of the iteration variable. binding : PrimExpr @@ -711,26 +711,26 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L def serial( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The serial For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -741,33 +741,33 @@ def serial( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def parallel( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The parallel For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -778,33 +778,33 @@ def parallel( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def vectorized( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The vectorized For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -815,33 +815,33 @@ def vectorized( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def unroll( - start: PrimExpr, - stop: PrimExpr = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, *, - annotations: Dict[str, Any] = None, - step: Optional[PrimExpr] = None, + annotations: Optional[Dict[str, Any]] = None, + step: Optional[PrimIntExpr] = None, ) -> frame.ForFrame: """The unrolled For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. - step : PrimExpr + step : Optional[PrimIntExpr] The optional step value of iteration. Returns @@ -852,33 +852,33 @@ def unroll( if stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] # pylint: disable=no-member else: start = 0 return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member def thread_binding( - start: PrimExpr, - stop: PrimExpr = None, - thread: str = None, + start: PrimIntExpr, + stop: Optional[PrimIntExpr] = None, + thread: Optional[str] = None, *, - annotations: Dict[str, Any] = None, + annotations: Optional[Dict[str, Any]] = None, ) -> frame.ForFrame: """The thread-binding For statement. Parameters ---------- - start : PrimExpr + start : PrimIntExpr The minimum value of iteration. - stop : PrimExpr + stop : Optional[PrimIntExpr] The maximum value of iteration. - thread : str + thread : Optional[str] The thread for loop variable to bind. - annotations : Dict[str, Any] + annotations : Optional[Dict[str, Any]] The optional annotations of the For statement. Returns @@ -892,13 +892,13 @@ def thread_binding( thread = stop stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] else: start = 0 elif stop is None: stop = start if hasattr(start, "dtype"): - start = IntImm(start.dtype, 0) + start = IntImm(start.dtype, 0) # type: ignore[attr-defined] else: start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member @@ -922,12 +922,12 @@ def grid(*extents: PrimIntExpr) -> frame.ForFrame: return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name +def Assert(condition: PrimLogicalExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name """Create an assertion statement. Parameters ---------- - condition : PrimExpr + condition : PrimLogicalExpr The PrimExpr to test. message : str From 6130381e754f655b41ea543a8372bce95158a413 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 15:01:35 +0800 Subject: [PATCH 06/11] improve range --- python/tvm/ir/expr.py | 12 ++++---- python/tvm/script/ir_builder/tir/ir.py | 40 +++++++++++++------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index f6a73ff88919..56045822e239 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -120,11 +120,11 @@ class Range(Node, Scriptable): Parameters ---------- - begin : PrimExpr + begin : PrimIntExpr The begin value of the range when end is None. Otherwise it is the length of the range. - end : Optional[PrimExpr] + end : Optional[PrimIntExpr] The end value of the range. span : Optional[Span] @@ -141,13 +141,13 @@ class Range(Node, Scriptable): span: Optional[Span] def __init__( - self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None + self, begin: PrimIntExpr, end: Optional[PrimIntExpr] = None, span: Optional[Span] = None ) -> None: self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod def from_min_extent( - min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None + min_value: PrimIntExpr, extent: PrimIntExpr, span: Optional[Span] = None ) -> "Range": """Construct a Range by min and extent. @@ -155,10 +155,10 @@ def from_min_extent( Parameters ---------- - min_value : PrimExpr + min_value : PrimIntExpr The minimum value of the range. - extent : PrimExpr + extent : PrimIntExpr The extent of the range. span : Optional[Span] diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 80d83e40d5a7..962aa4e2149b 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -539,12 +539,12 @@ def alloc_buffer( ) -def _as_range(dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]]) -> ir.Range: +def _as_range(dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr]) -> ir.Range: """The range constructor. Parameters ---------- - dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain. Returns @@ -571,18 +571,18 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The spatial block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -599,18 +599,18 @@ def spatial( @staticmethod def reduce( - dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The reduced block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -627,18 +627,18 @@ def reduce( @staticmethod def scan( - dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The scanning block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -655,18 +655,18 @@ def scan( @staticmethod def opaque( - dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]], - binding: PrimExpr, + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], + binding: PrimIntExpr, dtype: str = "int32", ) -> Var: """The opaque block axis defining function. Parameters ---------- - dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr]] + dom : Union[Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr] The domain of the iteration variable. - binding : PrimExpr + binding : PrimIntExpr The binding value of the iteration variable. dtype : str @@ -682,7 +682,7 @@ def opaque( ) @staticmethod - def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: + def remap(kinds: str, bindings: List[PrimIntExpr], dtype: str = "int32") -> Union[List[Var], Var]: """The block axis remapping function. Parameters @@ -690,7 +690,7 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L kinds : str The types of the iteration variables. - bindings : List[PrimExpr] + bindings : List[PrimIntExpr] The binding values of the iteration variables. dtype : str From ae05abe5fbf54b636cfa5c8b03547aad7ff21945 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 15:09:41 +0800 Subject: [PATCH 07/11] fix self not available in py3.9 --- python/tvm/script/ir_builder/tir/ir.py | 2 +- python/tvm/tir/expr.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 962aa4e2149b..302deaf0f1e4 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -562,7 +562,7 @@ def _as_range(dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimInt return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): - return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(IntImm(dom.dtype, 0), dom) # type: ignore[attr-defined] # pylint: disable=no-member return ir.Range(0, dom) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index be8a7909fb3d..77fa47b35017 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,7 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import List, Optional, Self, Union +from typing import List, Optional, TypeVar, Union import tvm_ffi import tvm.ir._ffi_api @@ -69,6 +69,9 @@ def _dtype_is_float(value): ) # type: ignore +Self = TypeVar("Self", bound="ExprOp") + + class ExprOp: """Operator overloading for Expr like expressions.""" @@ -208,7 +211,7 @@ def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool: """ return _ffi_api._OpEQ(self, other, span) # type: ignore - def astype(self, dtype: str, span: Optional[Span] = None) -> Union["Cast", "Self"]: + def astype(self: Self, dtype: str, span: Optional[Span] = None) -> Union["Cast", Self]: """Cast the expression to other type. Parameters From e13213549ac088b57329eb6e6c646c10f9c1418b Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 16:22:21 +0800 Subject: [PATCH 08/11] fix overload for remap --- python/tvm/script/ir_builder/tir/ir.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 302deaf0f1e4..0bcfaef81e96 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -22,7 +22,7 @@ import sys import threading from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload # isort: off from typing_extensions import Literal @@ -681,8 +681,15 @@ def opaque( _as_range(dom), binding, dtype ) + @overload @staticmethod - def remap(kinds: str, bindings: List[PrimIntExpr], dtype: str = "int32") -> Union[List[Var], Var]: + def remap(kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32") -> Var: ... + @overload + @staticmethod + def remap(kinds: str, bindings: Union[List[PrimExpr], Tuple[()], Tuple[PrimExpr, PrimExpr, *tuple[PrimExpr, ...]]], dtype: str = "int32") -> List[Var]: ... + + @staticmethod + def remap(kinds: str, bindings: Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: """The block axis remapping function. Parameters @@ -690,7 +697,7 @@ def remap(kinds: str, bindings: List[PrimIntExpr], dtype: str = "int32") -> Unio kinds : str The types of the iteration variables. - bindings : List[PrimIntExpr] + bindings : Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr] The binding values of the iteration variables. dtype : str @@ -698,9 +705,10 @@ def remap(kinds: str, bindings: List[PrimIntExpr], dtype: str = "int32") -> Unio Returns ------- - res : Var + res : Union[Var, List[Var]] The iteration variables. """ + bindings = (bindings,) if isinstance(bindings, PrimExpr) else bindings iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member kinds, bindings, dtype ) From c16c85ba84c65892e8ffa858d88ab6166efeb852 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 16:22:34 +0800 Subject: [PATCH 09/11] fix format --- python/tvm/ir/__init__.py | 11 ++++++++++- python/tvm/script/ir_builder/tir/ir.py | 27 +++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 3d3399139184..d9d8819813ed 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -32,7 +32,16 @@ structural_hash, ) from .container import Array, Map -from .expr import BaseExpr, GlobalVar, PrimExpr, PrimIntExpr, PrimFloatExpr, PrimLogicalExpr, Range, RelaxExpr +from .expr import ( + BaseExpr, + GlobalVar, + PrimExpr, + PrimIntExpr, + PrimFloatExpr, + PrimLogicalExpr, + Range, + RelaxExpr, +) from .function import BaseFunc, CallingConv from .global_info import GlobalInfo, DummyGlobalInfo, VDevice from .module import IRModule diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 0bcfaef81e96..92a0e1599d42 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -539,7 +539,9 @@ def alloc_buffer( ) -def _as_range(dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr]) -> ir.Range: +def _as_range( + dom: Union[ir.Range, List[PrimIntExpr], Tuple[PrimIntExpr, PrimIntExpr], PrimIntExpr], +) -> ir.Range: """The range constructor. Parameters @@ -683,13 +685,26 @@ def opaque( @overload @staticmethod - def remap(kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32") -> Var: ... + def remap( + kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32" + ) -> Var: ... @overload @staticmethod - def remap(kinds: str, bindings: Union[List[PrimExpr], Tuple[()], Tuple[PrimExpr, PrimExpr, *tuple[PrimExpr, ...]]], dtype: str = "int32") -> List[Var]: ... + def remap( + kinds: str, + bindings: Union[Tuple[()], Tuple[PrimExpr, PrimExpr, *Tuple[PrimExpr, ...]]], + dtype: str = "int32", + ) -> List[Var]: ... + @overload + @staticmethod + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> List[Var]: ... @staticmethod - def remap(kinds: str, bindings: Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: + def remap( + kinds: str, + bindings: Union[List[PrimExpr], Tuple[PrimExpr, ...], PrimExpr], + dtype: str = "int32", + ) -> Union[List[Var], Var]: """The block axis remapping function. Parameters @@ -930,7 +945,9 @@ def grid(*extents: PrimIntExpr) -> frame.ForFrame: return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Assert(condition: PrimLogicalExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name +def Assert( + condition: PrimLogicalExpr, message: str +) -> frame.AssertFrame: # pylint: disable=invalid-name """Create an assertion statement. Parameters From b6a5166716f331deb048dc12500c2db97a8deeb0 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 18:07:16 +0800 Subject: [PATCH 10/11] fix overload error --- python/tvm/script/ir_builder/tir/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 92a0e1599d42..9019e735bba7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -692,7 +692,7 @@ def remap( @staticmethod def remap( kinds: str, - bindings: Union[Tuple[()], Tuple[PrimExpr, PrimExpr, *Tuple[PrimExpr, ...]]], + bindings: Tuple[PrimExpr, ...], dtype: str = "int32", ) -> List[Var]: ... @overload From c78d10411cd33044c029389e760fb63e57264287 Mon Sep 17 00:00:00 2001 From: Clouds Flowing Date: Wed, 7 Jan 2026 18:44:01 +0800 Subject: [PATCH 11/11] fix format --- python/tvm/script/ir_builder/tir/ir.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 9019e735bba7..7defbf1c3708 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -685,19 +685,18 @@ def opaque( @overload @staticmethod - def remap( - kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32" - ) -> Var: ... + def remap(kinds: str, bindings: Union[PrimExpr, Tuple[PrimExpr]], dtype: str = "int32") -> Var: + ... + @overload @staticmethod - def remap( - kinds: str, - bindings: Tuple[PrimExpr, ...], - dtype: str = "int32", - ) -> List[Var]: ... + def remap(kinds: str, bindings: Tuple[PrimExpr, ...], dtype: str = "int32") -> List[Var]: + ... + @overload @staticmethod - def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> List[Var]: ... + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> List[Var]: + ... @staticmethod def remap( @@ -945,9 +944,9 @@ def grid(*extents: PrimIntExpr) -> frame.ForFrame: return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member -def Assert( +def Assert( # pylint: disable=invalid-name condition: PrimLogicalExpr, message: str -) -> frame.AssertFrame: # pylint: disable=invalid-name +) -> frame.AssertFrame: """Create an assertion statement. Parameters