diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 0165020..7de2c21 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -5,6 +5,7 @@ import inspect import itertools import warnings +from copy import copy from functools import lru_cache, wraps from types import EllipsisType from typing import ( @@ -124,6 +125,7 @@ def from_hint( # noqa: PLR0911 msg = f"Invalid base type=<{tensor_type}> in DLType hint, expected a subtype of {_dtypes.SUPPORTED_TENSOR_TYPES}" raise TypeError(msg) + dltype_hint = copy(dltype_hint) if optional else dltype_hint dltype_hint.optional = optional return (cls(tensor_type_hint=tensor_type, dltype_annotation=dltype_hint),) diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index dcf62ac..7f240e5 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -1711,3 +1711,24 @@ def tuple_function( ) -> Self: """A function that takes a tensor and returns a tensor.""" return self + + +SomeTensorT: TypeAlias = Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]] + + +def test_aliased_optional() -> None: + + @dltype.dltyped() + def func(non_optional_tensor: SomeTensorT, optional_tensor: SomeTensorT | None = None) -> None: + pass + + func(torch.zeros((1, 2, 3))) + + func(torch.zeros((1, 2, 3)), None) + + func(torch.zeros((1, 2, 3)), torch.zeros((1, 2, 3))) + + func(torch.zeros((1, 2, 3))) + + with pytest.raises(dltype.DLTypeUnsupportedTensorTypeError): + func(None, torch.zeros((1, 2, 3))) # pyright: ignore[reportArgumentType]