diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 80ade20..24c1eb5 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -6,6 +6,7 @@ import itertools import warnings from functools import lru_cache, wraps +from types import EllipsisType from typing import ( TYPE_CHECKING, Annotated, @@ -51,16 +52,22 @@ class DLTypeAnnotation(NamedTuple): dltype_annotation: _tensor_type_base.TensorTypeBase | None @classmethod - def from_hint( + def from_hint( # noqa: PLR0911 cls, - hint: type | None, + hint: type | EllipsisType | None, name: str, *, optional: bool = False, + stack_offset: int = 0, ) -> tuple[DLTypeAnnotation | None, ...]: """Create a new _DLTypeAnnotation from a type hint.""" + if isinstance(hint, EllipsisType): + return (None,) + if hint is None: - warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4) + warnings.warn( + f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset + ) return (None,) _logger.debug("Creating DLType from hint %r", hint) @@ -83,12 +90,18 @@ def from_hint( # tuple handling special case if origin is tuple: - return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args])) + return tuple( + itertools.chain( + *[cls.from_hint(inner_hint, name, stack_offset=stack_offset + 1) for inner_hint in args] + ) + ) # Only process Annotated types, warn if the annotated type is a tensor if origin is not Annotated: if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False: - warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4) + warnings.warn( + f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset + ) return (None,) # Ensure the annotation is a TensorTypeBase @@ -96,7 +109,9 @@ def from_hint( args[1], _tensor_type_base.TensorTypeBase, ): - warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4) + warnings.warn( + f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4 + stack_offset + ) return (None,) # Ensure the base type is a supported tensor type @@ -130,13 +145,14 @@ def get_dltype_scope(self) -> _dltype_context.EvaluatedDimensionT: def _maybe_get_type_hints( existing_hints: dict[str, tuple[DLTypeAnnotation | None, ...]] | None, func: Callable[P, R], + stack_offset: int = 0, ) -> dict[str, tuple[DLTypeAnnotation | None, ...]] | None: """Get the type hints for a function, or return an empty dict if not available.""" if existing_hints is not None: return existing_hints try: return { - name: DLTypeAnnotation.from_hint(hint, name) + name: DLTypeAnnotation.from_hint(hint, name, stack_offset=stack_offset) for name, hint in get_type_hints(func, include_extras=True).items() } except NameError: diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index d32befe..98f8d0d 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -1694,3 +1694,15 @@ class CheckedNT(NamedTuple): checked(bad_arr) Checked(arg=bad_arr) CheckedNT(arg=bad_arr) + + +def test_tuple_ellipsis() -> None: + + with pytest.warns(UserWarning, match="is missing a DLType hint"): + + @dltype.dltyped() + def tuple_function( # pyright: ignore[reportUnusedFunction] + tensor: tuple[torch.Tensor, ...], + tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]], + ) -> None: + """A function that takes a tensor and returns a tensor.""" diff --git a/pyproject.toml b/pyproject.toml index 81a4bde..f7b4674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ license-files = ["LICENSE"] name = "dltype" readme = "README.md" requires-python = ">=3.10" -version = "0.13.0" +version = "0.13.1" [project.optional-dependencies] jax = ["jax>=0.6.2"] diff --git a/uv.lock b/uv.lock index 8a66bd5..ebcda54 100644 --- a/uv.lock +++ b/uv.lock @@ -158,7 +158,7 @@ wheels = [ [[package]] name = "dltype" -version = "0.13.0" +version = "0.13.1" source = { virtual = "." } dependencies = [ { name = "pydantic" },