Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dltype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from dltype._lib._symbolic_expressions import (
AnonymousAxis,
ConstantAxis,
Group,
ISqrt,
LiteralAxis,
Max,
Min,
Expand Down Expand Up @@ -133,7 +135,9 @@
"Float32Tensor",
"Float64Tensor",
"FloatTensor",
"Group",
"IEEE754HalfFloatTensor",
"ISqrt",
"Int8Tensor",
"Int16Tensor",
"Int32Tensor",
Expand Down
157 changes: 90 additions & 67 deletions dltype/_lib/_symbolic_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,14 @@

from __future__ import annotations

import math
import typing
from abc import ABC, abstractmethod
from types import EllipsisType

from typing_extensions import override


class AxisOperationBase(ABC):
def __init__(
self,
lhs: OperableAxis | ComputedAxis | int,
rhs: OperableAxis | ComputedAxis | int,
) -> None:
self._lhs = lhs if isinstance(lhs, OperableAxis | ComputedAxis) else LiteralAxis(lhs)
self._rhs = rhs if isinstance(rhs, OperableAxis | ComputedAxis) else LiteralAxis(rhs)

@abstractmethod
def __str__(self) -> str:
pass

def __repr__(self) -> str:
return self.__str__()


class Add(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value + self._rhs.value}"
return f"({self._lhs}+{self._rhs})"


class Subtract(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value - self._rhs.value}"
return f"({self._lhs}-{self._rhs})"


class Divide(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value // self._rhs.value}"
return f"({self._lhs}/{self._rhs})"


class Multiply(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value * self._rhs.value}"
return f"({self._lhs}*{self._rhs})"


class Exp(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value**self._rhs.value}"
return f"({self._lhs}^{self._rhs})"


class Max(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{max(self._lhs.value, self._rhs.value)}"
return f"max({self._lhs},{self._rhs})"


class Min(AxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{min(self._lhs.value, self._rhs.value)}"
return f"min({self._lhs},{self._rhs})"


class OperableAxis(ABC):
@abstractmethod
def __str__(self) -> str: ...
Expand Down Expand Up @@ -124,6 +59,94 @@ def __rpow__(self, other: OperableAxisT) -> ComputedAxis:
return ComputedAxis(Exp(*self.__resolve_expr_sides(other, reverse=True)))


class AxisOperationBase(OperableAxis, ABC):
@abstractmethod
def __init__(self) -> None:
pass


class UnaryAxisOperationBase(AxisOperationBase):
def __init__(self, axis: Group | OperableAxis | ComputedAxis | int) -> None:
self._axis = axis if isinstance(axis, OperableAxis | ComputedAxis | Group) else LiteralAxis(axis)


class ISqrt(UnaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._axis, LiteralAxis):
return f"{math.isqrt(self._axis.value)}"
return f"isqrt({self._axis})"


class Group(UnaryAxisOperationBase):
def __init__(
self,
grouped_op: OperableAxis | ComputedAxis | int,
) -> None:
self._operators = grouped_op

def __str__(self) -> str:
return f"({self._operators})"


class BinaryAxisOperationBase(AxisOperationBase):
def __init__(
self,
lhs: Group | OperableAxis | ComputedAxis | int,
rhs: Group | OperableAxis | ComputedAxis | int,
) -> None:
self._lhs = lhs if isinstance(lhs, OperableAxis | ComputedAxis | Group) else LiteralAxis(lhs)
self._rhs = rhs if isinstance(rhs, OperableAxis | ComputedAxis | Group) else LiteralAxis(rhs)


class Add(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value + self._rhs.value}"
return f"{self._lhs}+{self._rhs}"


class Subtract(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value - self._rhs.value}"
return f"{self._lhs}-{self._rhs}"


class Divide(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value // self._rhs.value}"
return f"{self._lhs}/{self._rhs}"


class Multiply(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value * self._rhs.value}"
return f"{self._lhs}*{self._rhs}"


class Exp(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{self._lhs.value**self._rhs.value}"
return f"{self._lhs}^{self._rhs}"


class Max(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{max(self._lhs.value, self._rhs.value)}"
return f"max({self._lhs},{self._rhs})"


class Min(BinaryAxisOperationBase):
def __str__(self) -> str:
if isinstance(self._lhs, LiteralAxis) and isinstance(self._rhs, LiteralAxis):
return f"{min(self._lhs.value, self._rhs.value)}"
return f"min({self._lhs},{self._rhs})"


class LiteralAxis(OperableAxis):
def __init__(self, value: int) -> None:
"""Initialize an axis with a literal integer value."""
Expand Down Expand Up @@ -171,7 +194,7 @@ def __str__(self) -> str:
return f"{self._identifier}={self._computation}"


OperableAxisT: typing.TypeAlias = LiteralAxis | VariableAxis | ComputedAxis | int
OperableAxisT: typing.TypeAlias = Group | LiteralAxis | VariableAxis | ComputedAxis | int


class ConstantAxis:
Expand Down
42 changes: 31 additions & 11 deletions dltype/tests/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from dltype import AnonymousAxis, ConstantAxis, LiteralAxis, Max, Min, Shape, VariableAxis
from dltype import AnonymousAxis, ConstantAxis, Group, ISqrt, LiteralAxis, Max, Min, Shape, VariableAxis
from dltype._lib import _parser


Expand Down Expand Up @@ -99,21 +99,41 @@ def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> Non
"*batch c h w",
),
(Shape[LiteralAxis(4), VariableAxis("r")], "4 r"),
(Shape[Min(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "min((4+image_w),imageh)"),
(Shape[VariableAxis("a") - VariableAxis("b")], "(a-b)"),
(Shape[VariableAxis("a") + VariableAxis("b")], "(a+b)"),
(Shape[VariableAxis("a") * VariableAxis("b")], "(a*b)"),
(Shape[VariableAxis("a") ** VariableAxis("b")], "(a^b)"),
(Shape[VariableAxis("a") // VariableAxis("b")], "(a/b)"),
(Shape[Min(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "min(4+image_w,imageh)"),
(Shape[4 // Max(4 + VariableAxis("image_w"), VariableAxis("imageh"))], "4/max(4+image_w,imageh)"),
(Shape[VariableAxis("a") - VariableAxis("b")], "a-b"),
(Shape[VariableAxis("a") + VariableAxis("b")], "a+b"),
(Shape[VariableAxis("a") * VariableAxis("b")], "a*b"),
(Shape[VariableAxis("a") ** VariableAxis("b")], "a^b"),
(Shape[VariableAxis("a") // VariableAxis("b")], "a/b"),
(Shape[LiteralAxis(99) - LiteralAxis(97)], "2"),
(Shape[LiteralAxis(12) + LiteralAxis(1)], "13"),
(Shape[LiteralAxis(10) * LiteralAxis(2)], "20"),
(Shape[LiteralAxis(3) ** LiteralAxis(3)], "27"),
(Shape[LiteralAxis(10) // LiteralAxis(2)], "5"),
(Shape[10 - VariableAxis("b")], "(10-b)"),
(Shape[10 // VariableAxis("b")], "(10/b)"),
(Shape[10 * VariableAxis("b")], "(10*b)"),
(Shape[10 ** VariableAxis("b")], "(10^b)"),
(Shape[10 - VariableAxis("b")], "10-b"),
(Shape[10 // VariableAxis("b")], "10/b"),
(Shape[10 * VariableAxis("b")], "10*b"),
(Shape[10 ** VariableAxis("b")], "10^b"),
(Shape[ISqrt(16)], "4"),
(Shape[Group(VariableAxis("a") - VariableAxis("b"))], "(a-b)"),
(
Shape[ISqrt(VariableAxis("a") - VariableAxis("b"))],
"isqrt(a-b)",
),
(
Shape[
Group(VariableAxis("a") - VariableAxis("b")) // ISqrt(VariableAxis("b") - VariableAxis("c"))
],
"(a-b)/isqrt(b-c)",
),
(
Shape[
Group(Group(VariableAxis("a") - VariableAxis("b") ** Group(4 - VariableAxis("z"))))
// ISqrt(VariableAxis("b") - VariableAxis("c"))
],
"((a-b^(4-z)))/isqrt(b-c)",
),
],
)
def test_parse_symbolic(expression: Shape, expected: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.10.0"
version = "0.11.0"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
Expand Down
6 changes: 5 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.