Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @dlangerm-stackav
* @dlangerm-stackav @dlangerm
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
default_language_version:
python: python3.14


repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
Expand Down Expand Up @@ -38,6 +39,7 @@ repos:
args: [--autofix, --indent, '2', --offset, '2', --line-width, '80']
- id: pretty-format-toml
args: [--autofix, --indent, '2']
exclude: uv.lock
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.10.0
Expand Down
2 changes: 2 additions & 0 deletions dltype/_lib/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class _Env(BaseSettings):
MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms
GLOBAL_DISABLE: typing.Final = __env.DISABLE

if DEBUG_MODE:
warnings.warn("DLType debug mode enabled", UserWarning, stacklevel=1)

if GLOBAL_DISABLE:
warnings.warn(
Expand Down
26 changes: 11 additions & 15 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def from_hint(
) -> tuple[DLTypeAnnotation | None, ...]:
"""Create a new _DLTypeAnnotation from a type hint."""
if hint is None:
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=3)
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
return (None,)

_logger.debug("Creating DLType from hint %r", hint)
Expand All @@ -85,19 +85,18 @@ def from_hint(
if origin is tuple:
return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args]))

# Only process Annotated types
# Only process Annotated types, warn if the annotated type is a tensor
if origin is not Annotated:
if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False:
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
return (None,)

# Ensure the annotation is a TensorTypeBase
if len(args) < n_expected_args or not isinstance(
args[1],
_tensor_type_base.TensorTypeBase,
):
_logger.warning(
"Invalid annotated dltype hint: %r",
args[1:] if len(args) >= n_expected_args else None,
)
warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4)
return (None,)

# Ensure the base type is a supported tensor type
Expand Down Expand Up @@ -165,6 +164,10 @@ def _resolve_value(
return cast("tuple[Any]", value) if len(type_hint) > 1 else (value,)


def _get_func_lineref(func: Callable[P, R]) -> str:
return f"Function: {func.__name__}"


def dltyped( # noqa: C901, PLR0915
scope_provider: DLTypeScopeProvider | Literal["self"] | None = None,
*,
Expand All @@ -183,7 +186,7 @@ def dltyped( # noqa: C901, PLR0915

"""

def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915
def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901
if _dependency_utilities.is_torch_scripting() or not enabled:
# jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking
return func
Expand All @@ -204,9 +207,8 @@ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR09

# if we added dltype to a method where it will have no effect, warn the user
if dltype_hints is not None and all(all(vv is None for vv in v) for v in dltype_hints.values()):
_logger.warning("dltype_hints=%r", dltype_hints)
warnings.warn(
"No DLType hints found, skipping type checking",
f"No DLType hints found for {_get_func_lineref(func)}, skipping type checking",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -284,12 +286,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
maybe_return_annotation,
)
ctx.assert_context()
elif any(isinstance(retval, T) for T in _dtypes.SUPPORTED_TENSOR_TYPES):
warnings.warn(
f"[{return_key}] is missing a DLType hint",
UserWarning,
stacklevel=2,
)
except _errors.DLTypeError as e:
# include the full function signature in the error message
e.set_context(f"{func.__name__}{signature}")
Expand Down
96 changes: 54 additions & 42 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def bad_ndim_error(tensor_name: str, *, expected: int, actual: int) -> str:
torch.ones(1, 2, 3, 4),
incomplete_annotated_function,
_RaisesInfo(value=torch.ones(1, 2, 3, 4)),
_WarnsInfo(match_text=re.escape("[return] is missing a DLType hint")),
None,
id="incomplete_annotated_4D",
),
pytest.param(
Expand Down Expand Up @@ -1064,33 +1064,38 @@ def good_function( # pyright: ignore[reportUnusedFunction]


def test_dimension_with_external_scope() -> None:
class Provider:
def get_dltype_scope(self) -> dict[str, int]:
return {"channels_in": 3, "channels_out": 4}
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):

@dltype.dltyped(scope_provider="self")
def forward(
self,
class Provider:
def get_dltype_scope(self) -> dict[str, int]:
return {"channels_in": 3, "channels_out": 4}

@dltype.dltyped(scope_provider="self")
def forward(
self,
tensor: Annotated[
torch.Tensor,
dltype.FloatTensor["batch channels_in channels_out"],
],
) -> torch.Tensor:
return tensor

with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):

@dltype.dltyped(scope_provider=Provider())
def good_function(
tensor: Annotated[
torch.Tensor,
dltype.FloatTensor["batch channels_in channels_out"],
dltype.IntTensor["batch channels_in channels_out"],
],
) -> torch.Tensor:
return tensor

@dltype.dltyped(scope_provider=Provider())
def good_function(
tensor: Annotated[
torch.Tensor,
dltype.IntTensor["batch channels_in channels_out"],
],
) -> torch.Tensor:
return tensor

with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
with pytest.WarningsRecorder() as rec:
good_function(torch.ones(1, 3, 4).int())
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
with pytest.WarningsRecorder() as rec:
good_function(torch.ones(4, 3, 4).int())
assert len(rec.list) == 0

with pytest.raises(dltype.DLTypeShapeError):
good_function(torch.ones(1, 3, 5).int())
Expand All @@ -1099,10 +1104,8 @@ def good_function(

provider = Provider()

with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
provider.forward(torch.ones(1, 3, 4))
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
provider.forward(torch.ones(4, 3, 4))
provider.forward(torch.ones(1, 3, 4))
provider.forward(torch.ones(4, 3, 4))

with pytest.raises(dltype.DLTypeShapeError):
provider.forward(torch.ones(1, 3, 5))
Expand All @@ -1114,23 +1117,23 @@ def test_optional_type_handling() -> None:
"""Test that dltyped correctly handles Optional tensor types."""

# Test with a function with optional parameter
@dltype.dltyped()
def optional_tensor_func(
tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None,
) -> torch.Tensor:
if tensor is None:
return torch.zeros(1, 3, 5, 5)
return tensor
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):

@dltype.dltyped()
def optional_tensor_func(
tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None,
) -> torch.Tensor:
if tensor is None:
return torch.zeros(1, 3, 5, 5)
return tensor

# Should work with None
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
result = optional_tensor_func(None)
result = optional_tensor_func(None)
assert result.shape == (1, 3, 5, 5)

# Should work with correct tensor
input_tensor = torch.rand(2, 3, 4, 4)
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor)
torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor)

# Should fail with incorrect shape
with pytest.raises(dltype.DLTypeNDimsError):
Expand Down Expand Up @@ -1359,26 +1362,35 @@ def create(

def test_warning_if_decorator_has_no_annotations_to_check() -> None:
with pytest.warns(
UserWarning,
match="No DLType hints found, skipping type checking",
UserWarning, match="No DLType hints found for Function: no_annotations, skipping type checking"
):

@dltype.dltyped()
def no_annotations(tensor: torch.Tensor) -> torch.Tensor: # pyright: ignore[reportUnusedFunction]
return tensor

# should warn if some tensors are untyped
@dltype.dltyped()
def some_annotations(
tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]],
) -> torch.Tensor:
return tensor

with pytest.warns(
UserWarning,
match=re.escape("[return] is missing a DLType hint"),
):
some_annotations(torch.rand(1, 2, 3))

@dltype.dltyped()
def some_annotations(
tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]],
) -> torch.Tensor:
return tensor

some_annotations(torch.rand(1, 2, 3))

with pytest.warns(UserWarning, match=re.escape("[tensor] has an invalid DLType hint")):

@dltype.dltyped()
def some_annotations(
tensor: Annotated[torch.Tensor, 5],
) -> torch.Tensor:
return tensor


def test_scalar() -> None:
Expand Down
Loading