@@ -165,7 +165,7 @@ def bad_ndim_error(tensor_name: str, *, expected: int, actual: int) -> str:
165165 torch .ones (1 , 2 , 3 , 4 ),
166166 incomplete_annotated_function ,
167167 _RaisesInfo (value = torch .ones (1 , 2 , 3 , 4 )),
168- _WarnsInfo ( match_text = re . escape ( "[return] is missing a DLType hint" )) ,
168+ None ,
169169 id = "incomplete_annotated_4D" ,
170170 ),
171171 pytest .param (
@@ -1064,33 +1064,38 @@ def good_function( # pyright: ignore[reportUnusedFunction]
10641064
10651065
10661066def test_dimension_with_external_scope () -> None :
1067- class Provider :
1068- def get_dltype_scope (self ) -> dict [str , int ]:
1069- return {"channels_in" : 3 , "channels_out" : 4 }
1067+ with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
10701068
1071- @dltype .dltyped (scope_provider = "self" )
1072- def forward (
1073- self ,
1069+ class Provider :
1070+ def get_dltype_scope (self ) -> dict [str , int ]:
1071+ return {"channels_in" : 3 , "channels_out" : 4 }
1072+
1073+ @dltype .dltyped (scope_provider = "self" )
1074+ def forward (
1075+ self ,
1076+ tensor : Annotated [
1077+ torch .Tensor ,
1078+ dltype .FloatTensor ["batch channels_in channels_out" ],
1079+ ],
1080+ ) -> torch .Tensor :
1081+ return tensor
1082+
1083+ with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1084+
1085+ @dltype .dltyped (scope_provider = Provider ())
1086+ def good_function (
10741087 tensor : Annotated [
10751088 torch .Tensor ,
1076- dltype .FloatTensor ["batch channels_in channels_out" ],
1089+ dltype .IntTensor ["batch channels_in channels_out" ],
10771090 ],
10781091 ) -> torch .Tensor :
10791092 return tensor
10801093
1081- @dltype .dltyped (scope_provider = Provider ())
1082- def good_function (
1083- tensor : Annotated [
1084- torch .Tensor ,
1085- dltype .IntTensor ["batch channels_in channels_out" ],
1086- ],
1087- ) -> torch .Tensor :
1088- return tensor
1089-
1090- with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1094+ with pytest .WarningsRecorder () as rec :
10911095 good_function (torch .ones (1 , 3 , 4 ).int ())
1092- with pytest .warns ( UserWarning , match = re . escape ( "[return] is missing a DLType hint" )) :
1096+ with pytest .WarningsRecorder () as rec :
10931097 good_function (torch .ones (4 , 3 , 4 ).int ())
1098+ assert len (rec .list ) == 0
10941099
10951100 with pytest .raises (dltype .DLTypeShapeError ):
10961101 good_function (torch .ones (1 , 3 , 5 ).int ())
@@ -1099,10 +1104,8 @@ def good_function(
10991104
11001105 provider = Provider ()
11011106
1102- with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1103- provider .forward (torch .ones (1 , 3 , 4 ))
1104- with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1105- provider .forward (torch .ones (4 , 3 , 4 ))
1107+ provider .forward (torch .ones (1 , 3 , 4 ))
1108+ provider .forward (torch .ones (4 , 3 , 4 ))
11061109
11071110 with pytest .raises (dltype .DLTypeShapeError ):
11081111 provider .forward (torch .ones (1 , 3 , 5 ))
@@ -1114,23 +1117,23 @@ def test_optional_type_handling() -> None:
11141117 """Test that dltyped correctly handles Optional tensor types."""
11151118
11161119 # Test with a function with optional parameter
1117- @dltype .dltyped ()
1118- def optional_tensor_func (
1119- tensor : Annotated [torch .Tensor , dltype .FloatTensor ["b c h w" ]] | None ,
1120- ) -> torch .Tensor :
1121- if tensor is None :
1122- return torch .zeros (1 , 3 , 5 , 5 )
1123- return tensor
1120+ with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1121+
1122+ @dltype .dltyped ()
1123+ def optional_tensor_func (
1124+ tensor : Annotated [torch .Tensor , dltype .FloatTensor ["b c h w" ]] | None ,
1125+ ) -> torch .Tensor :
1126+ if tensor is None :
1127+ return torch .zeros (1 , 3 , 5 , 5 )
1128+ return tensor
11241129
11251130 # Should work with None
1126- with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1127- result = optional_tensor_func (None )
1131+ result = optional_tensor_func (None )
11281132 assert result .shape == (1 , 3 , 5 , 5 )
11291133
11301134 # Should work with correct tensor
11311135 input_tensor = torch .rand (2 , 3 , 4 , 4 )
1132- with pytest .warns (UserWarning , match = re .escape ("[return] is missing a DLType hint" )):
1133- torch .testing .assert_close (optional_tensor_func (input_tensor ), input_tensor )
1136+ torch .testing .assert_close (optional_tensor_func (input_tensor ), input_tensor )
11341137
11351138 # Should fail with incorrect shape
11361139 with pytest .raises (dltype .DLTypeNDimsError ):
@@ -1359,26 +1362,35 @@ def create(
13591362
13601363def test_warning_if_decorator_has_no_annotations_to_check () -> None :
13611364 with pytest .warns (
1362- UserWarning ,
1363- match = "No DLType hints found, skipping type checking" ,
1365+ UserWarning , match = "No DLType hints found for Function: no_annotations, skipping type checking"
13641366 ):
13651367
13661368 @dltype .dltyped ()
13671369 def no_annotations (tensor : torch .Tensor ) -> torch .Tensor : # pyright: ignore[reportUnusedFunction]
13681370 return tensor
13691371
13701372 # should warn if some tensors are untyped
1371- @dltype .dltyped ()
1372- def some_annotations (
1373- tensor : Annotated [torch .Tensor , dltype .FloatTensor ["1 2 3" ]],
1374- ) -> torch .Tensor :
1375- return tensor
13761373
13771374 with pytest .warns (
13781375 UserWarning ,
13791376 match = re .escape ("[return] is missing a DLType hint" ),
13801377 ):
1381- some_annotations (torch .rand (1 , 2 , 3 ))
1378+
1379+ @dltype .dltyped ()
1380+ def some_annotations (
1381+ tensor : Annotated [torch .Tensor , dltype .FloatTensor ["1 2 3" ]],
1382+ ) -> torch .Tensor :
1383+ return tensor
1384+
1385+ some_annotations (torch .rand (1 , 2 , 3 ))
1386+
1387+ with pytest .warns (UserWarning , match = re .escape ("[tensor] has an invalid DLType hint" )):
1388+
1389+ @dltype .dltyped ()
1390+ def some_annotations (
1391+ tensor : Annotated [torch .Tensor , 5 ],
1392+ ) -> torch .Tensor :
1393+ return tensor
13821394
13831395
13841396def test_scalar () -> None :
0 commit comments