Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def from_hint( # noqa: PLR0911
raise TypeError(msg)

# Recursively process the non-None type with optional=True
return cls.from_hint(non_none_types[0], name, optional=True)
return cls.from_hint(non_none_types[0], name, optional=True, stack_offset=stack_offset + 1)

# tuple handling special case
if origin is tuple:
Expand All @@ -98,7 +98,11 @@ def from_hint( # noqa: PLR0911

# Only process Annotated types, warn if the annotated type is a tensor
if origin is not Annotated:
if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False:
if (
any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES)
if hint and hasattr(hint, "mro")
else False
):
warnings.warn(
f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset
)
Expand Down Expand Up @@ -351,7 +355,7 @@ def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]:
for field_name in cls._fields:
if field_name in field_hints:
hint = field_hints[field_name]
dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint, field_name)
dltype_fields[field_name] = DLTypeAnnotation.from_hint(hint, field_name, stack_offset=-1)

# If no fields need validation, return the original class
if not dltype_fields:
Expand Down Expand Up @@ -421,7 +425,10 @@ def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]:
original_init = cls.__init__
# Get field annotations
field_hints = get_type_hints(cls, include_extras=True)
dltype_hints = {name: DLTypeAnnotation.from_hint(hint, name) for name, hint in field_hints.items()}
dltype_hints = {
name: DLTypeAnnotation.from_hint(hint, name, stack_offset=-1)
for name, hint in field_hints.items()
}

def new_init(self: DataclassT, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
"""A new __init__ method that validates the fields after initialization."""
Expand Down
17 changes: 11 additions & 6 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from pydantic import BaseModel
from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Self

import dltype

Expand Down Expand Up @@ -1700,9 +1701,13 @@ def test_tuple_ellipsis() -> None:

with pytest.warns(UserWarning, match="is missing a DLType hint"):

@dltype.dltyped()
def tuple_function( # pyright: ignore[reportUnusedFunction]
tensor: tuple[torch.Tensor, ...],
tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]],
) -> None:
"""A function that takes a tensor and returns a tensor."""
class MyClass: # pyright: ignore[reportUnusedClass]
@dltype.dltyped()
def tuple_function(
self,
tensor: tuple[torch.Tensor, ...],
tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]],
arg1: int,
) -> Self:
"""A function that takes a tensor and returns a tensor."""
return self
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.13.1"
version = "0.13.2"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading