@@ -199,6 +199,48 @@ def test_affine(self, input_param, input_data, expected_val):
199199 )
200200
201201
202+ class TestComputeWAffine (unittest .TestCase ):
203+ def test_identity_2d (self ):
204+ """Identity matrix with same input/output size should produce pure translation to/from center."""
205+ mat = np .eye (3 )
206+ img_size = (4 , 4 )
207+ sp_size = (4 , 4 )
208+ result = Affine .compute_w_affine (2 , mat , img_size , sp_size )
209+ # For identity transform with same sizes, result should be identity
210+ assert_allclose (result , np .eye (3 ), atol = 1e-6 )
211+
212+ def test_identity_3d (self ):
213+ """Identity matrix in 3D with same input/output size."""
214+ mat = np .eye (4 )
215+ img_size = (6 , 6 , 6 )
216+ sp_size = (6 , 6 , 6 )
217+ result = Affine .compute_w_affine (3 , mat , img_size , sp_size )
218+ assert_allclose (result , np .eye (4 ), atol = 1e-6 )
219+
220+ def test_different_sizes (self ):
221+ """When img_size != sp_size, result should include net translation."""
222+ mat = np .eye (3 )
223+ img_size = (4 , 4 )
224+ sp_size = (8 , 8 )
225+ result = Affine .compute_w_affine (2 , mat , img_size , sp_size )
226+ # Translation should account for the shift: (4-1)/2 - (8-1)/2 = 1.5 - 3.5 = -2.0
227+ expected_translation = np .array ([(d1 - 1 ) / 2 - (d2 - 1 ) / 2 for d1 , d2 in zip (img_size , sp_size )])
228+ assert_allclose (result [:2 , 2 ], expected_translation , atol = 1e-6 )
229+
230+ def test_output_shape (self ):
231+ """Output should be (r+1) x (r+1) matrix."""
232+ for r in [2 , 3 ]:
233+ mat = np .eye (r + 1 )
234+ result = Affine .compute_w_affine (r , mat , (4 ,) * r , (4 ,) * r )
235+ self .assertEqual (result .shape , (r + 1 , r + 1 ))
236+
237+ def test_torch_input (self ):
238+ """Method should accept torch tensor input."""
239+ mat = torch .eye (3 )
240+ result = Affine .compute_w_affine (2 , mat , (4 , 4 ), (4 , 4 ))
241+ assert_allclose (result , np .eye (3 ), atol = 1e-6 )
242+
243+
202244@unittest .skipUnless (optional_import ("scipy" )[1 ], "Requires scipy library." )
203245class TestAffineConsistency (unittest .TestCase ):
204246 @parameterized .expand ([[7 ], [8 ], [9 ]])
0 commit comments