diff --git a/dltype/__init__.py b/dltype/__init__.py index 39fd0c3..c92225d 100644 --- a/dltype/__init__.py +++ b/dltype/__init__.py @@ -29,6 +29,8 @@ from dltype._lib._symbolic_expressions import ( AnonymousAxis, ConstantAxis, + Group, + ISqrt, LiteralAxis, Max, Min, @@ -133,7 +135,9 @@ "Float32Tensor", "Float64Tensor", "FloatTensor", + "Group", "IEEE754HalfFloatTensor", + "ISqrt", "Int8Tensor", "Int16Tensor", "Int32Tensor", diff --git a/dltype/_lib/_symbolic_expressions.py b/dltype/_lib/_symbolic_expressions.py index 0660eb4..82b4a0e 100644 --- a/dltype/_lib/_symbolic_expressions.py +++ b/dltype/_lib/_symbolic_expressions.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import typing from abc import ABC, abstractmethod from types import EllipsisType @@ -9,72 +10,6 @@ from typing_extensions import override -class AxisOperationBase(ABC): - def __init__( - self, - lhs: OperableAxis | ComputedAxis | int, - rhs: OperableAxis | ComputedAxis | int, - ) -> None: - self._lhs = lhs if isinstance(lhs, OperableAxis | ComputedAxis) else LiteralAxis(lhs) - self._rhs = rhs if isinstance(rhs, OperableAxis | ComputedAxis) else LiteralAxis(rhs) - - @abstractmethod - def __str__(self) -> str: - pass - - def __repr__(self) -> str: - return self.__str__() - - -class Add(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{self._lhs.value + self._rhs.value}" - return f"({self._lhs}+{self._rhs})" - - -class Subtract(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{self._lhs.value - self._rhs.value}" - return f"({self._lhs}-{self._rhs})" - - -class Divide(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{self._lhs.value // self._rhs.value}" - return f"({self._lhs}/{self._rhs})" - - -class Multiply(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{self._lhs.value * self._rhs.value}" - return f"({self._lhs}*{self._rhs})" - - -class Exp(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{self._lhs.value**self._rhs.value}" - return f"({self._lhs}^{self._rhs})" - - -class Max(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{max(self._lhs.value, self._rhs.value)}" - return f"max({self._lhs},{self._rhs})" - - -class Min(AxisOperationBase): - def __str__(self) -> str: - if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): - return f"{min(self._lhs.value, self._rhs.value)}" - return f"min({self._lhs},{self._rhs})" - - class OperableAxis(ABC): @abstractmethod def __str__(self) -> str: ... @@ -124,6 +59,94 @@ def __rpow__(self, other: OperableAxisT) -> ComputedAxis: return ComputedAxis(Exp(*self.__resolve_expr_sides(other, reverse=True))) +class AxisOperationBase(OperableAxis, ABC): + @abstractmethod + def __init__(self) -> None: + pass + + +class UnaryAxisOperationBase(AxisOperationBase): + def __init__(self, axis: Group | OperableAxis | ComputedAxis | int) -> None: + self._axis = axis if isinstance(axis, OperableAxis | ComputedAxis | Group) else LiteralAxis(axis) + + +class ISqrt(UnaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._axis, LiteralAxis): + return f"{math.isqrt(self._axis.value)}" + return f"isqrt({self._axis})" + + +class Group(UnaryAxisOperationBase): + def __init__( + self, + grouped_op: OperableAxis | ComputedAxis | int, + ) -> None: + self._operators = grouped_op + + def __str__(self) -> str: + return f"({self._operators})" + + +class BinaryAxisOperationBase(AxisOperationBase): + def __init__( + self, + lhs: Group | OperableAxis | ComputedAxis | int, + rhs: Group | OperableAxis | ComputedAxis | int, + ) -> None: + self._lhs = lhs if isinstance(lhs, OperableAxis | ComputedAxis | Group) else LiteralAxis(lhs) + self._rhs = rhs if isinstance(rhs, OperableAxis | ComputedAxis | Group) else LiteralAxis(rhs) + + +class Add(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{self._lhs.value + self._rhs.value}" + return f"{self._lhs}+{self._rhs}" + + +class Subtract(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{self._lhs.value - self._rhs.value}" + return f"{self._lhs}-{self._rhs}" + + +class Divide(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{self._lhs.value // self._rhs.value}" + return f"{self._lhs}/{self._rhs}" + + +class Multiply(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{self._lhs.value * self._rhs.value}" + return f"{self._lhs}*{self._rhs}" + + +class Exp(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{self._lhs.value**self._rhs.value}" + return f"{self._lhs}^{self._rhs}" + + +class Max(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{max(self._lhs.value, self._rhs.value)}" + return f"max({self._lhs},{self._rhs})" + + +class Min(BinaryAxisOperationBase): + def __str__(self) -> str: + if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis): + return f"{min(self._lhs.value, self._rhs.value)}" + return f"min({self._lhs},{self._rhs})" + + class LiteralAxis(OperableAxis): def __init__(self, value: int) -> None: """Initialize an axis with a literal integer value.""" @@ -171,7 +194,7 @@ def __str__(self) -> str: return f"{self._identifier}={self._computation}" -OperableAxisT: typing.TypeAlias = LiteralAxis | VariableAxis | ComputedAxis | int +OperableAxisT: typing.TypeAlias = Group | LiteralAxis | VariableAxis | ComputedAxis | int class ConstantAxis: diff --git a/dltype/tests/parser_test.py b/dltype/tests/parser_test.py index 256e6da..765d977 100644 --- a/dltype/tests/parser_test.py +++ b/dltype/tests/parser_test.py @@ -3,7 +3,7 @@ import pytest -from dltype import AnonymousAxis, ConstantAxis, LiteralAxis, Max, Min, Shape, VariableAxis +from dltype import AnonymousAxis, ConstantAxis, Group, ISqrt, LiteralAxis, Max, Min, Shape, VariableAxis from dltype._lib import _parser @@ -99,21 +99,41 @@ def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> Non "*batch c h w", ), (Shape[LiteralAxis(4), VariableAxis("r")], "4 r"), - (Shape[Min(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "min((4+image_w),imageh)"), - (Shape[VariableAxis("a") - VariableAxis("b")], "(a-b)"), - (Shape[VariableAxis("a") + VariableAxis("b")], "(a+b)"), - (Shape[VariableAxis("a") * VariableAxis("b")], "(a*b)"), - (Shape[VariableAxis("a") ** VariableAxis("b")], "(a^b)"), - (Shape[VariableAxis("a") // VariableAxis("b")], "(a/b)"), + (Shape[Min(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "min(4+image_w,imageh)"), + (Shape[4 // Max(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "4/max(4+image_w,imageh)"), + (Shape[VariableAxis("a") - VariableAxis("b")], "a-b"), + (Shape[VariableAxis("a") + VariableAxis("b")], "a+b"), + (Shape[VariableAxis("a") * VariableAxis("b")], "a*b"), + (Shape[VariableAxis("a") ** VariableAxis("b")], "a^b"), + (Shape[VariableAxis("a") // VariableAxis("b")], "a/b"), (Shape[LiteralAxis(99) - LiteralAxis(97)], "2"), (Shape[LiteralAxis(12) + LiteralAxis(1)], "13"), (Shape[LiteralAxis(10) * LiteralAxis(2)], "20"), (Shape[LiteralAxis(3) ** LiteralAxis(3)], "27"), (Shape[LiteralAxis(10) // LiteralAxis(2)], "5"), - (Shape[10 - VariableAxis("b")], "(10-b)"), - (Shape[10 // VariableAxis("b")], "(10/b)"), - (Shape[10 * VariableAxis("b")], "(10*b)"), - (Shape[10 ** VariableAxis("b")], "(10^b)"), + (Shape[10 - VariableAxis("b")], "10-b"), + (Shape[10 // VariableAxis("b")], "10/b"), + (Shape[10 * VariableAxis("b")], "10*b"), + (Shape[10 ** VariableAxis("b")], "10^b"), + (Shape[ISqrt(16)], "4"), + (Shape[Group(VariableAxis("a") - VariableAxis("b"))], "(a-b)"), + ( + Shape[ISqrt(VariableAxis("a") - VariableAxis("b"))], + "isqrt(a-b)", + ), + ( + Shape[ + Group(VariableAxis("a") - VariableAxis("b")) // ISqrt(VariableAxis("b") - VariableAxis("c")) + ], + "(a-b)/isqrt(b-c)", + ), + ( + Shape[ + Group(Group(VariableAxis("a") - VariableAxis("b") ** Group(4 - VariableAxis("z")))) + // ISqrt(VariableAxis("b") - VariableAxis("c")) + ], + "((a-b^(4-z)))/isqrt(b-c)", + ), ], ) def test_parse_symbolic(expression: Shape, expected: str) -> None: diff --git a/pyproject.toml b/pyproject.toml index 973a740..5e2f2ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license-files = ["LICENSE"] name = "dltype" readme = "README.md" requires-python = ">=3.10" -version = "0.10.0" +version = "0.11.0" [project.optional-dependencies] jax = ["jax>=0.6.2"] diff --git a/uv.lock b/uv.lock index 049da9b..90406c2 100644 --- a/uv.lock +++ b/uv.lock @@ -158,7 +158,7 @@ wheels = [ [[package]] name = "dltype" -version = "0.10.0" +version = "0.11.0" source = { virtual = "." } dependencies = [ { name = "pydantic" }, @@ -1436,6 +1436,10 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/ea/304cf7afb744aa626fa9855245526484ee55aba610d9973a0521c552a843/torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63", size = 79411450, upload-time = "2026-02-06T17:37:35.75Z" }, + { url = "https://files.pythonhosted.org/packages/25/d8/9e6b8e7df981a1e3ea3907fd5a74673e791da483e8c307f0b6ff012626d0/torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464", size = 79423460, upload-time = "2026-02-06T17:37:39.657Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/0b295dd8d199ef71e6f176f576473d645d41357b7b8aa978cc6b042575df/torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c", size = 79498197, upload-time = "2026-02-06T17:37:27.627Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1b/af5fccb50c341bd69dc016769503cb0857c1423fbe9343410dfeb65240f2/torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe", size = 79498248, upload-time = "2026-02-06T17:37:31.982Z" }, { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" },