From cf79ba5ea031b4a4e58b85dd3b80b6f932a42197 Mon Sep 17 00:00:00 2001 From: David Langerman Date: Wed, 25 Mar 2026 14:15:22 +0000 Subject: [PATCH] Fix non-class methods erroring on mro --- dltype/_lib/_core.py | 15 +++++++++++---- dltype/tests/dltype_test.py | 17 +++++++++++------ pyproject.toml | 2 +- uv.lock | 2 +- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 24c1eb5..0165020 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -86,7 +86,7 @@ def from_hint( # noqa: PLR0911 raise TypeError(msg) # Recursively process the non-None type with optional=True - return cls.from_hint(non_none_types[0], name, optional=True) + return cls.from_hint(non_none_types[0], name, optional=True, stack_offset=stack_offset + 1) # tuple handling special case if origin is tuple: @@ -98,7 +98,11 @@ def from_hint( # noqa: PLR0911 # Only process Annotated types, warn if the annotated type is a tensor if origin is not Annotated: - if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False: + if ( + any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) + if hint and hasattr(hint, "mro") + else False + ): warnings.warn( f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset ) @@ -351,7 +355,7 @@ def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]: for field_name in cls._fields: if field_name in field_hints: hint = field_hints[field_name] - dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint, field_name) + dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint, field_name, stack_offset=-1) # If no fields need validation, return the original class if not dltype_fields: @@ -421,7 +425,10 @@ def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]: original_init = cls.__init__ # Get field annotations field_hints = get_type_hints(cls, include_extras=True) - dltype_hints = {name: DLTypeAnnotation.from_hint(hint, name) for name, hint in field_hints.items()} + dltype_hints = { + name: DLTypeAnnotation.from_hint(hint, name, stack_offset=-1) + for name, hint in field_hints.items() + } def new_init(self: DataclassT, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 """A new __init__ method that validates the fields after initialization.""" diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index 98f8d0d..dcf62ac 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -19,6 +19,7 @@ import torch from pydantic import BaseModel from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Self import dltype @@ -1700,9 +1701,13 @@ def test_tuple_ellipsis() -> None: with pytest.warns(UserWarning, match="is missing a DLType hint"): - @dltype.dltyped() - def tuple_function( # pyright: ignore[reportUnusedFunction] - tensor: tuple[torch.Tensor, ...], - tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]], - ) -> None: - """A function that takes a tensor and returns a tensor.""" + class MyClass: # pyright: ignore[reportUnusedClass] + @dltype.dltyped() + def tuple_function( + self, + tensor: tuple[torch.Tensor, ...], + tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]], + arg1: int, + ) -> Self: + """A function that takes a tensor and returns a tensor.""" + return self diff --git a/pyproject.toml b/pyproject.toml index f7b4674..8c5615c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ license-files = ["LICENSE"] name = "dltype" readme = "README.md" requires-python = ">=3.10" -version = "0.13.1" +version = "0.13.2" [project.optional-dependencies] jax = ["jax>=0.6.2"] diff --git a/uv.lock b/uv.lock index ebcda54..9420626 100644 --- a/uv.lock +++ b/uv.lock @@ -158,7 +158,7 @@ wheels = [ [[package]] name = "dltype" -version = "0.13.1" +version = "0.13.2" source = { virtual = "." } dependencies = [ { name = "pydantic" },