From 5b0c7456f3c21f535b375c9abe192f9345ecc14f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Mar 2026 12:03:03 +0530 Subject: [PATCH] move test_hooks.py to pytest --- tests/hooks/test_hooks.py | 59 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 8a83f60ff278..bb444e58e280 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -13,8 +13,8 @@ # limitations under the License. import gc -import unittest +import pytest import torch from diffusers.hooks import HookRegistry, ModelHook @@ -134,20 +134,18 @@ def post_forward(self, module, output): return output -class HookTests(unittest.TestCase): +class TestHooks: in_features = 4 hidden_features = 8 out_features = 4 num_layers = 2 - def setUp(self): + def setup_method(self): params = self.get_module_parameters() self.model = DummyModel(**params) self.model.to(torch_device) - def tearDown(self): - super().tearDown() - + def teardown_method(self): del self.model gc.collect() free_memory() @@ -171,20 +169,20 @@ def test_hook_registry(self): registry_repr = repr(registry) expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)" - self.assertEqual(len(registry.hooks), 2) - self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) - self.assertEqual(registry_repr, expected_repr) + assert len(registry.hooks) == 2 + assert registry._hook_order == ["add_hook", "multiply_hook"] + assert registry_repr == expected_repr registry.remove_hook("add_hook") - self.assertEqual(len(registry.hooks), 1) - self.assertEqual(registry._hook_order, ["multiply_hook"]) + assert len(registry.hooks) == 1 + assert registry._hook_order == ["multiply_hook"] def test_stateful_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) registry.register_hook(StatefulAddHook(1), "stateful_add_hook") - self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) + assert registry.hooks["stateful_add_hook"].increment == 0 input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) num_repeats = 3 @@ -194,13 +192,13 @@ def test_stateful_hook(self): if i == 0: output1 = result - self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) + assert registry.get_hook("stateful_add_hook").increment == num_repeats registry.reset_stateful_hooks() output2 = self.model(input) - self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) - self.assertTrue(torch.allclose(output1, output2)) + assert registry.get_hook("stateful_add_hook").increment == 1 + assert torch.allclose(output1, output2) def test_inference(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) @@ -218,9 +216,9 @@ def test_inference(self): new_input = input * 2 + 1 output3 = self.model(new_input).mean().detach().cpu().item() - self.assertAlmostEqual(output1, output2, places=5) - self.assertAlmostEqual(output1, output3, places=5) - self.assertAlmostEqual(output2, output3, places=5) + assert output1 == pytest.approx(output2, abs=5e-6) + assert output1 == pytest.approx(output3, abs=5e-6) + assert output2 == pytest.approx(output3, abs=5e-6) def test_skip_layer_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) @@ -228,30 +226,29 @@ def test_skip_layer_hook(self): input = torch.zeros(1, 4, device=torch_device) output = self.model(input).mean().detach().cpu().item() - self.assertEqual(output, 0.0) + assert output == 0.0 registry.remove_hook("skip_layer_hook") registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") output = self.model(input).mean().detach().cpu().item() - self.assertNotEqual(output, 0.0) + assert output != 0.0 def test_skip_layer_internal_block(self): registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) input = torch.zeros(1, 4, device=torch_device) registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") - with self.assertRaises(RuntimeError) as cm: + with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"): self.model(input).mean().detach().cpu().item() - self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) registry.remove_hook("skip_layer_hook") output = self.model(input).mean().detach().cpu().item() - self.assertNotEqual(output, 0.0) + assert output != 0.0 registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") output = self.model(input).mean().detach().cpu().item() - self.assertNotEqual(output, 0.0) + assert output != 0.0 def test_invocation_order_stateful_first(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) @@ -278,7 +275,7 @@ def test_invocation_order_stateful_first(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: @@ -289,7 +286,7 @@ def test_invocation_order_stateful_first(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log def test_invocation_order_stateful_middle(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) @@ -316,7 +313,7 @@ def test_invocation_order_stateful_middle(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: @@ -327,7 +324,7 @@ def test_invocation_order_stateful_middle(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log registry.remove_hook("add_hook_2") with CaptureLogger(logger) as cap_logger: @@ -336,7 +333,7 @@ def test_invocation_order_stateful_middle(self): expected_invocation_order_log = ( ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log def test_invocation_order_stateful_last(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) @@ -363,7 +360,7 @@ def test_invocation_order_stateful_last(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log registry.remove_hook("add_hook") with CaptureLogger(logger) as cap_logger: @@ -374,4 +371,4 @@ def test_invocation_order_stateful_last(self): .replace(" ", "") .replace("\n", "") ) - self.assertEqual(output, expected_invocation_order_log) + assert output == expected_invocation_order_log