Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/metrics/test_ssim_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class TestSSIMMetric(unittest.TestCase):

def test2d_gaussian(self):
def test_2d_gaussian(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16))
target = torch.abs(torch.randn(2, 3, 16, 16))
Expand All @@ -34,7 +34,7 @@ def test2d_gaussian(self):
expected_value = 0.045415
self.assertTrue(expected_value - result.item() < 0.000001)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

One-sided tolerance check — upper-bound errors pass silently.

expected_value - result.item() < 0.000001 only fails when the result is too low; any result above expected_value trivially satisfies it. Use abs() or assertAlmostEqual.

🐛 Proposed fix (shown for line 35; apply identically to lines 48 and 61)
-        self.assertTrue(expected_value - result.item() < 0.000001)
+        self.assertAlmostEqual(result.item(), expected_value, places=5)

Also applies to: 48-48, 61-61

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_ssim_metric.py` at line 35, The assertion uses a one-sided
check (expected_value - result.item() < 0.000001) which allows results greater
than expected_value to pass; replace it with a symmetric tolerance check by
using abs(expected_value - result.item()) < 1e-6 or unittest's
self.assertAlmostEqual(result.item(), expected_value, places=6) for the
assertions that reference expected_value and result.item() (apply the same
change to the other two occurrences mentioned).


def test2d_uniform(self):
def test_2d_uniform(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16))
target = torch.abs(torch.randn(2, 3, 16, 16))
Expand All @@ -47,7 +47,7 @@ def test2d_uniform(self):
expected_value = 0.050103
self.assertTrue(expected_value - result.item() < 0.000001)

def test3d_gaussian(self):
def test_3d_gaussian(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16, 16))
target = torch.abs(torch.randn(2, 3, 16, 16, 16))
Expand Down
Loading