Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
import warnings
from functools import lru_cache, wraps
from types import EllipsisType
from typing import (
TYPE_CHECKING,
Annotated,
Expand Down Expand Up @@ -51,16 +52,22 @@ class DLTypeAnnotation(NamedTuple):
dltype_annotation: _tensor_type_base.TensorTypeBase | None

@classmethod
def from_hint(
def from_hint( # noqa: PLR0911
cls,
hint: type | None,
hint: type | EllipsisType | None,
name: str,
*,
optional: bool = False,
stack_offset: int = 0,
) -> tuple[DLTypeAnnotation | None, ...]:
"""Create a new _DLTypeAnnotation from a type hint."""
if isinstance(hint, EllipsisType):
return (None,)

if hint is None:
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
warnings.warn(
f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset
)
return (None,)

_logger.debug("Creating DLType from hint %r", hint)
Expand All @@ -83,20 +90,28 @@ def from_hint(

# tuple handling special case
if origin is tuple:
return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args]))
return tuple(
itertools.chain(
*[cls.from_hint(inner_hint, name, stack_offset=stack_offset + 1) for inner_hint in args]
)
)

# Only process Annotated types, warn if the annotated type is a tensor
if origin is not Annotated:
if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False:
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
warnings.warn(
f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset
)
return (None,)

# Ensure the annotation is a TensorTypeBase
if len(args) < n_expected_args or not isinstance(
args[1],
_tensor_type_base.TensorTypeBase,
):
warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4)
warnings.warn(
f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4 + stack_offset
)
return (None,)

# Ensure the base type is a supported tensor type
Expand Down Expand Up @@ -130,13 +145,14 @@ def get_dltype_scope(self) -> _dltype_context.EvaluatedDimensionT:
def _maybe_get_type_hints(
existing_hints: dict[str, tuple[DLTypeAnnotation | None, ...]] | None,
func: Callable[P, R],
stack_offset: int = 0,
) -> dict[str, tuple[DLTypeAnnotation | None, ...]] | None:
"""Get the type hints for a function, or return an empty dict if not available."""
if existing_hints is not None:
return existing_hints
try:
return {
name: DLTypeAnnotation.from_hint(hint, name)
name: DLTypeAnnotation.from_hint(hint, name, stack_offset=stack_offset)
for name, hint in get_type_hints(func, include_extras=True).items()
}
except NameError:
Expand Down
12 changes: 12 additions & 0 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,3 +1694,15 @@ class CheckedNT(NamedTuple):
checked(bad_arr)
Checked(arg=bad_arr)
CheckedNT(arg=bad_arr)


def test_tuple_ellipsis() -> None:

with pytest.warns(UserWarning, match="is missing a DLType hint"):

@dltype.dltyped()
def tuple_function( # pyright: ignore[reportUnusedFunction]
tensor: tuple[torch.Tensor, ...],
tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]],
) -> None:
"""A function that takes a tensor and returns a tensor."""
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.13.0"
version = "0.13.1"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading