Skip to content

Commit 7bbcf4e

Browse files
Add unit tests for Affine.compute_w_affine method
Add focused tests covering identity transforms (2D/3D), different input/output sizes with expected translation offsets, output shape validation, and torch tensor input compatibility. Signed-off-by: Mohamed Salah <eng.mohamed.tawab@gmail.com>
1 parent bb698d6 commit 7bbcf4e

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/transforms/test_affine.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,48 @@ def test_affine(self, input_param, input_data, expected_val):
199199
)
200200

201201

202+
class TestComputeWAffine(unittest.TestCase):
203+
def test_identity_2d(self):
204+
"""Identity matrix with same input/output size should produce pure translation to/from center."""
205+
mat = np.eye(3)
206+
img_size = (4, 4)
207+
sp_size = (4, 4)
208+
result = Affine.compute_w_affine(2, mat, img_size, sp_size)
209+
# For identity transform with same sizes, result should be identity
210+
assert_allclose(result, np.eye(3), atol=1e-6)
211+
212+
def test_identity_3d(self):
213+
"""Identity matrix in 3D with same input/output size."""
214+
mat = np.eye(4)
215+
img_size = (6, 6, 6)
216+
sp_size = (6, 6, 6)
217+
result = Affine.compute_w_affine(3, mat, img_size, sp_size)
218+
assert_allclose(result, np.eye(4), atol=1e-6)
219+
220+
def test_different_sizes(self):
221+
"""When img_size != sp_size, result should include net translation."""
222+
mat = np.eye(3)
223+
img_size = (4, 4)
224+
sp_size = (8, 8)
225+
result = Affine.compute_w_affine(2, mat, img_size, sp_size)
226+
# Translation should account for the shift: (4-1)/2 - (8-1)/2 = 1.5 - 3.5 = -2.0
227+
expected_translation = np.array([(d1 - 1) / 2 - (d2 - 1) / 2 for d1, d2 in zip(img_size, sp_size)])
228+
assert_allclose(result[:2, 2], expected_translation, atol=1e-6)
229+
230+
def test_output_shape(self):
231+
"""Output should be (r+1) x (r+1) matrix."""
232+
for r in [2, 3]:
233+
mat = np.eye(r + 1)
234+
result = Affine.compute_w_affine(r, mat, (4,) * r, (4,) * r)
235+
self.assertEqual(result.shape, (r + 1, r + 1))
236+
237+
def test_torch_input(self):
238+
"""Method should accept torch tensor input."""
239+
mat = torch.eye(3)
240+
result = Affine.compute_w_affine(2, mat, (4, 4), (4, 4))
241+
assert_allclose(result, np.eye(3), atol=1e-6)
242+
243+
202244
@unittest.skipUnless(optional_import("scipy")[1], "Requires scipy library.")
203245
class TestAffineConsistency(unittest.TestCase):
204246
@parameterized.expand([[7], [8], [9]])

0 commit comments

Comments
 (0)