From d2d443bb811f1952226a7578858c557797284f7b Mon Sep 17 00:00:00 2001 From: David Langerman Date: Sun, 22 Mar 2026 13:46:29 -0400 Subject: [PATCH] Make warnings a little nicer and don't check for annotations each call --- .github/CODEOWNERS | 2 +- .pre-commit-config.yaml | 2 + dltype/_lib/_constants.py | 2 + dltype/_lib/_core.py | 26 +++++----- dltype/tests/dltype_test.py | 96 +++++++++++++++++++++---------------- 5 files changed, 70 insertions(+), 58 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 32bd756..88691a7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @dlangerm-stackav +* @dlangerm-stackav @dlangerm diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15c0022..329b3c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,7 @@ default_language_version: python: python3.14 + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 @@ -38,6 +39,7 @@ repos: args: [--autofix, --indent, '2', --offset, '2', --line-width, '80'] - id: pretty-format-toml args: [--autofix, --indent, '2'] + exclude: uv.lock - repo: https://github.com/astral-sh/uv-pre-commit # uv version. rev: 0.10.0 diff --git a/dltype/_lib/_constants.py b/dltype/_lib/_constants.py index 5d256fd..ae6a4c8 100644 --- a/dltype/_lib/_constants.py +++ b/dltype/_lib/_constants.py @@ -29,6 +29,8 @@ class _Env(BaseSettings): MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms GLOBAL_DISABLE: typing.Final = __env.DISABLE +if DEBUG_MODE: + warnings.warn("DLType debug mode enabled", UserWarning, stacklevel=1) if GLOBAL_DISABLE: warnings.warn( diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 8b5b08b..80ade20 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -60,7 +60,7 @@ def from_hint( ) -> tuple[DLTypeAnnotation | None, ...]: """Create a new _DLTypeAnnotation from a type hint.""" if hint is None: - warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=3) + warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4) return (None,) _logger.debug("Creating DLType from hint %r", hint) @@ -85,8 +85,10 @@ def from_hint( if origin is tuple: return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args])) - # Only process Annotated types + # Only process Annotated types, warn if the annotated type is a tensor if origin is not Annotated: + if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False: + warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4) return (None,) # Ensure the annotation is a TensorTypeBase @@ -94,10 +96,7 @@ def from_hint( args[1], _tensor_type_base.TensorTypeBase, ): - _logger.warning( - "Invalid annotated dltype hint: %r", - args[1:] if len(args) >= n_expected_args else None, - ) + warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4) return (None,) # Ensure the base type is a supported tensor type @@ -165,6 +164,10 @@ def _resolve_value( return cast("tuple[Any]", value) if len(type_hint) > 1 else (value,) +def _get_func_lineref(func: Callable[P, R]) -> str: + return f"Function: {func.__name__}" + + def dltyped( # noqa: C901, PLR0915 scope_provider: DLTypeScopeProvider | Literal["self"] | None = None, *, @@ -183,7 +186,7 @@ def dltyped( # noqa: C901, PLR0915 """ - def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915 + def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901 if _dependency_utilities.is_torch_scripting() or not enabled: # jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking return func @@ -204,9 +207,8 @@ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR09 # if we added dltype to a method where it will have no effect, warn the user if dltype_hints is not None and all(all(vv is None for vv in v) for v in dltype_hints.values()): - _logger.warning("dltype_hints=%r", dltype_hints) warnings.warn( - "No DLType hints found, skipping type checking", + f"No DLType hints found for {_get_func_lineref(func)}, skipping type checking", UserWarning, stacklevel=2, ) @@ -284,12 +286,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901 maybe_return_annotation, ) ctx.assert_context() - elif any(isinstance(retval, T) for T in _dtypes.SUPPORTED_TENSOR_TYPES): - warnings.warn( - f"[{return_key}] is missing a DLType hint", - UserWarning, - stacklevel=2, - ) except _errors.DLTypeError as e: # include the full function signature in the error message e.set_context(f"{func.__name__}{signature}") diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index 99a3a9c..d32befe 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -165,7 +165,7 @@ def bad_ndim_error(tensor_name: str, *, expected: int, actual: int) -> str: torch.ones(1, 2, 3, 4), incomplete_annotated_function, _RaisesInfo(value=torch.ones(1, 2, 3, 4)), - _WarnsInfo(match_text=re.escape("[return] is missing a DLType hint")), + None, id="incomplete_annotated_4D", ), pytest.param( @@ -1064,33 +1064,38 @@ def good_function( # pyright: ignore[reportUnusedFunction] def test_dimension_with_external_scope() -> None: - class Provider: - def get_dltype_scope(self) -> dict[str, int]: - return {"channels_in": 3, "channels_out": 4} + with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): - @dltype.dltyped(scope_provider="self") - def forward( - self, + class Provider: + def get_dltype_scope(self) -> dict[str, int]: + return {"channels_in": 3, "channels_out": 4} + + @dltype.dltyped(scope_provider="self") + def forward( + self, + tensor: Annotated[ + torch.Tensor, + dltype.FloatTensor["batch channels_in channels_out"], + ], + ) -> torch.Tensor: + return tensor + + with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): + + @dltype.dltyped(scope_provider=Provider()) + def good_function( tensor: Annotated[ torch.Tensor, - dltype.FloatTensor["batch channels_in channels_out"], + dltype.IntTensor["batch channels_in channels_out"], ], ) -> torch.Tensor: return tensor - @dltype.dltyped(scope_provider=Provider()) - def good_function( - tensor: Annotated[ - torch.Tensor, - dltype.IntTensor["batch channels_in channels_out"], - ], - ) -> torch.Tensor: - return tensor - - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): + with pytest.WarningsRecorder() as rec: good_function(torch.ones(1, 3, 4).int()) - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): + with pytest.WarningsRecorder() as rec: good_function(torch.ones(4, 3, 4).int()) + assert len(rec.list) == 0 with pytest.raises(dltype.DLTypeShapeError): good_function(torch.ones(1, 3, 5).int()) @@ -1099,10 +1104,8 @@ def good_function( provider = Provider() - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): - provider.forward(torch.ones(1, 3, 4)) - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): - provider.forward(torch.ones(4, 3, 4)) + provider.forward(torch.ones(1, 3, 4)) + provider.forward(torch.ones(4, 3, 4)) with pytest.raises(dltype.DLTypeShapeError): provider.forward(torch.ones(1, 3, 5)) @@ -1114,23 +1117,23 @@ def test_optional_type_handling() -> None: """Test that dltyped correctly handles Optional tensor types.""" # Test with a function with optional parameter - @dltype.dltyped() - def optional_tensor_func( - tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None, - ) -> torch.Tensor: - if tensor is None: - return torch.zeros(1, 3, 5, 5) - return tensor + with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): + + @dltype.dltyped() + def optional_tensor_func( + tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None, + ) -> torch.Tensor: + if tensor is None: + return torch.zeros(1, 3, 5, 5) + return tensor # Should work with None - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): - result = optional_tensor_func(None) + result = optional_tensor_func(None) assert result.shape == (1, 3, 5, 5) # Should work with correct tensor input_tensor = torch.rand(2, 3, 4, 4) - with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")): - torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor) + torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor) # Should fail with incorrect shape with pytest.raises(dltype.DLTypeNDimsError): @@ -1359,8 +1362,7 @@ def create( def test_warning_if_decorator_has_no_annotations_to_check() -> None: with pytest.warns( - UserWarning, - match="No DLType hints found, skipping type checking", + UserWarning, match="No DLType hints found for Function: no_annotations, skipping type checking" ): @dltype.dltyped() @@ -1368,17 +1370,27 @@ def no_annotations(tensor: torch.Tensor) -> torch.Tensor: # pyright: ignore[rep return tensor # should warn if some tensors are untyped - @dltype.dltyped() - def some_annotations( - tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]], - ) -> torch.Tensor: - return tensor with pytest.warns( UserWarning, match=re.escape("[return] is missing a DLType hint"), ): - some_annotations(torch.rand(1, 2, 3)) + + @dltype.dltyped() + def some_annotations( + tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]], + ) -> torch.Tensor: + return tensor + + some_annotations(torch.rand(1, 2, 3)) + + with pytest.warns(UserWarning, match=re.escape("[tensor] has an invalid DLType hint")): + + @dltype.dltyped() + def some_annotations( + tensor: Annotated[torch.Tensor, 5], + ) -> torch.Tensor: + return tensor def test_scalar() -> None: