Skip to content

Commit be5eaf4

Browse files
committed
Make warnings a little nicer and don't check for annotations each call
1 parent 831122d commit be5eaf4

File tree

4 files changed

+70
-58
lines changed

4 files changed

+70
-58
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
default_language_version:
22
python: python3.14
33

4+
exclude: ^uv[.]lock$
5+
46
repos:
57
- repo: https://github.com/pre-commit/pre-commit-hooks
68
rev: v6.0.0
@@ -37,7 +39,7 @@ repos:
3739
- id: pretty-format-yaml
3840
args: [--autofix, --indent, '2', --offset, '2', --line-width, '80']
3941
- id: pretty-format-toml
40-
args: [--autofix, --indent, '2']
42+
args: [--autofix, --indent, '2', --no-sort, pyproject.toml]
4143
- repo: https://github.com/astral-sh/uv-pre-commit
4244
# uv version.
4345
rev: 0.10.0

dltype/_lib/_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class _Env(BaseSettings):
2929
MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms
3030
GLOBAL_DISABLE: typing.Final = __env.DISABLE
3131

32+
if DEBUG_MODE:
33+
warnings.warn("DLType debug mode enabled", UserWarning, stacklevel=1)
3234

3335
if GLOBAL_DISABLE:
3436
warnings.warn(

dltype/_lib/_core.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def from_hint(
6060
) -> tuple[DLTypeAnnotation | None, ...]:
6161
"""Create a new _DLTypeAnnotation from a type hint."""
6262
if hint is None:
63-
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=3)
63+
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
6464
return (None,)
6565

6666
_logger.debug("Creating DLType from hint %r", hint)
@@ -85,19 +85,18 @@ def from_hint(
8585
if origin is tuple:
8686
return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args]))
8787

88-
# Only process Annotated types
88+
# Only process Annotated types, warn if the annotated type is a tensor
8989
if origin is not Annotated:
90+
if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False:
91+
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
9092
return (None,)
9193

9294
# Ensure the annotation is a TensorTypeBase
9395
if len(args) < n_expected_args or not isinstance(
9496
args[1],
9597
_tensor_type_base.TensorTypeBase,
9698
):
97-
_logger.warning(
98-
"Invalid annotated dltype hint: %r",
99-
args[1:] if len(args) >= n_expected_args else None,
100-
)
99+
warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4)
101100
return (None,)
102101

103102
# Ensure the base type is a supported tensor type
@@ -165,6 +164,10 @@ def _resolve_value(
165164
return cast("tuple[Any]", value) if len(type_hint) > 1 else (value,)
166165

167166

167+
def _get_func_lineref(func: Callable[P, R]) -> str:
168+
return f"Function: {func.__name__}"
169+
170+
168171
def dltyped( # noqa: C901, PLR0915
169172
scope_provider: DLTypeScopeProvider | Literal["self"] | None = None,
170173
*,
@@ -183,7 +186,7 @@ def dltyped( # noqa: C901, PLR0915
183186
184187
"""
185188

186-
def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915
189+
def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901
187190
if _dependency_utilities.is_torch_scripting() or not enabled:
188191
# jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking
189192
return func
@@ -204,9 +207,8 @@ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR09
204207

205208
# if we added dltype to a method where it will have no effect, warn the user
206209
if dltype_hints is not None and all(all(vv is None for vv in v) for v in dltype_hints.values()):
207-
_logger.warning("dltype_hints=%r", dltype_hints)
208210
warnings.warn(
209-
"No DLType hints found, skipping type checking",
211+
f"No DLType hints found for {_get_func_lineref(func)}, skipping type checking",
210212
UserWarning,
211213
stacklevel=2,
212214
)
@@ -284,12 +286,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
284286
maybe_return_annotation,
285287
)
286288
ctx.assert_context()
287-
elif any(isinstance(retval, T) for T in _dtypes.SUPPORTED_TENSOR_TYPES):
288-
warnings.warn(
289-
f"[{return_key}] is missing a DLType hint",
290-
UserWarning,
291-
stacklevel=2,
292-
)
293289
except _errors.DLTypeError as e:
294290
# include the full function signature in the error message
295291
e.set_context(f"{func.__name__}{signature}")

dltype/tests/dltype_test.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def bad_ndim_error(tensor_name: str, *, expected: int, actual: int) -> str:
165165
torch.ones(1, 2, 3, 4),
166166
incomplete_annotated_function,
167167
_RaisesInfo(value=torch.ones(1, 2, 3, 4)),
168-
_WarnsInfo(match_text=re.escape("[return] is missing a DLType hint")),
168+
None,
169169
id="incomplete_annotated_4D",
170170
),
171171
pytest.param(
@@ -1064,33 +1064,38 @@ def good_function( # pyright: ignore[reportUnusedFunction]
10641064

10651065

10661066
def test_dimension_with_external_scope() -> None:
1067-
class Provider:
1068-
def get_dltype_scope(self) -> dict[str, int]:
1069-
return {"channels_in": 3, "channels_out": 4}
1067+
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
10701068

1071-
@dltype.dltyped(scope_provider="self")
1072-
def forward(
1073-
self,
1069+
class Provider:
1070+
def get_dltype_scope(self) -> dict[str, int]:
1071+
return {"channels_in": 3, "channels_out": 4}
1072+
1073+
@dltype.dltyped(scope_provider="self")
1074+
def forward(
1075+
self,
1076+
tensor: Annotated[
1077+
torch.Tensor,
1078+
dltype.FloatTensor["batch channels_in channels_out"],
1079+
],
1080+
) -> torch.Tensor:
1081+
return tensor
1082+
1083+
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1084+
1085+
@dltype.dltyped(scope_provider=Provider())
1086+
def good_function(
10741087
tensor: Annotated[
10751088
torch.Tensor,
1076-
dltype.FloatTensor["batch channels_in channels_out"],
1089+
dltype.IntTensor["batch channels_in channels_out"],
10771090
],
10781091
) -> torch.Tensor:
10791092
return tensor
10801093

1081-
@dltype.dltyped(scope_provider=Provider())
1082-
def good_function(
1083-
tensor: Annotated[
1084-
torch.Tensor,
1085-
dltype.IntTensor["batch channels_in channels_out"],
1086-
],
1087-
) -> torch.Tensor:
1088-
return tensor
1089-
1090-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1094+
with pytest.WarningsRecorder() as rec:
10911095
good_function(torch.ones(1, 3, 4).int())
1092-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1096+
with pytest.WarningsRecorder() as rec:
10931097
good_function(torch.ones(4, 3, 4).int())
1098+
assert len(rec.list) == 0
10941099

10951100
with pytest.raises(dltype.DLTypeShapeError):
10961101
good_function(torch.ones(1, 3, 5).int())
@@ -1099,10 +1104,8 @@ def good_function(
10991104

11001105
provider = Provider()
11011106

1102-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1103-
provider.forward(torch.ones(1, 3, 4))
1104-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1105-
provider.forward(torch.ones(4, 3, 4))
1107+
provider.forward(torch.ones(1, 3, 4))
1108+
provider.forward(torch.ones(4, 3, 4))
11061109

11071110
with pytest.raises(dltype.DLTypeShapeError):
11081111
provider.forward(torch.ones(1, 3, 5))
@@ -1114,23 +1117,23 @@ def test_optional_type_handling() -> None:
11141117
"""Test that dltyped correctly handles Optional tensor types."""
11151118

11161119
# Test with a function with optional parameter
1117-
@dltype.dltyped()
1118-
def optional_tensor_func(
1119-
tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None,
1120-
) -> torch.Tensor:
1121-
if tensor is None:
1122-
return torch.zeros(1, 3, 5, 5)
1123-
return tensor
1120+
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1121+
1122+
@dltype.dltyped()
1123+
def optional_tensor_func(
1124+
tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] | None,
1125+
) -> torch.Tensor:
1126+
if tensor is None:
1127+
return torch.zeros(1, 3, 5, 5)
1128+
return tensor
11241129

11251130
# Should work with None
1126-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1127-
result = optional_tensor_func(None)
1131+
result = optional_tensor_func(None)
11281132
assert result.shape == (1, 3, 5, 5)
11291133

11301134
# Should work with correct tensor
11311135
input_tensor = torch.rand(2, 3, 4, 4)
1132-
with pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")):
1133-
torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor)
1136+
torch.testing.assert_close(optional_tensor_func(input_tensor), input_tensor)
11341137

11351138
# Should fail with incorrect shape
11361139
with pytest.raises(dltype.DLTypeNDimsError):
@@ -1359,26 +1362,35 @@ def create(
13591362

13601363
def test_warning_if_decorator_has_no_annotations_to_check() -> None:
13611364
with pytest.warns(
1362-
UserWarning,
1363-
match="No DLType hints found, skipping type checking",
1365+
UserWarning, match="No DLType hints found for Function: no_annotations, skipping type checking"
13641366
):
13651367

13661368
@dltype.dltyped()
13671369
def no_annotations(tensor: torch.Tensor) -> torch.Tensor: # pyright: ignore[reportUnusedFunction]
13681370
return tensor
13691371

13701372
# should warn if some tensors are untyped
1371-
@dltype.dltyped()
1372-
def some_annotations(
1373-
tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]],
1374-
) -> torch.Tensor:
1375-
return tensor
13761373

13771374
with pytest.warns(
13781375
UserWarning,
13791376
match=re.escape("[return] is missing a DLType hint"),
13801377
):
1381-
some_annotations(torch.rand(1, 2, 3))
1378+
1379+
@dltype.dltyped()
1380+
def some_annotations(
1381+
tensor: Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]],
1382+
) -> torch.Tensor:
1383+
return tensor
1384+
1385+
some_annotations(torch.rand(1, 2, 3))
1386+
1387+
with pytest.warns(UserWarning, match=re.escape("[tensor] has an invalid DLType hint")):
1388+
1389+
@dltype.dltyped()
1390+
def some_annotations(
1391+
tensor: Annotated[torch.Tensor, 5],
1392+
) -> torch.Tensor:
1393+
return tensor
13821394

13831395

13841396
def test_scalar() -> None:

0 commit comments

Comments
 (0)