Skip to content

Commit 1e1058a

Browse files
Move the top level 'tests/' into src/maxdiffusion/tests/ as a legacy (#322)
* Move the top level 'tests/' into `src/maxdiffusion/tests/` as a legacy * Format by pyink
1 parent 9622341 commit 1e1058a

25 files changed

Lines changed: 2008 additions & 2025 deletions
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

tests/conftest.py renamed to src/maxdiffusion/tests/legacy_hf_tests/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@
3131

3232

3333
def pytest_addoption(parser):
34-
from maxdiffusion.utils.testing_utils import pytest_addoption_shared
34+
from maxdiffusion.utils.testing_utils import pytest_addoption_shared
3535

36-
pytest_addoption_shared(parser)
36+
pytest_addoption_shared(parser)
3737

3838

3939
def pytest_terminal_summary(terminalreporter):
40-
from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main
40+
from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main
4141

42-
make_reports = terminalreporter.config.getoption("--make-reports")
43-
if make_reports:
44-
pytest_terminal_summary_main(terminalreporter, id=make_reports)
42+
make_reports = terminalreporter.config.getoption("--make-reports")
43+
if make_reports:
44+
pytest_terminal_summary_main(terminalreporter, id=make_reports)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import inspect
18+
19+
from maxdiffusion.utils import is_flax_available
20+
from maxdiffusion.utils.testing_utils import require_flax
21+
22+
23+
if is_flax_available():
24+
import jax
25+
26+
27+
@require_flax
28+
class FlaxModelTesterMixin:
29+
30+
def test_output(self):
31+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
32+
33+
model = self.model_class(**init_dict)
34+
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
35+
jax.lax.stop_gradient(variables)
36+
37+
output = model.apply(variables, inputs_dict["sample"])
38+
39+
if isinstance(output, dict):
40+
output = output.sample
41+
42+
self.assertIsNotNone(output)
43+
expected_shape = inputs_dict["sample"].shape
44+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
45+
46+
def test_forward_with_norm_groups(self):
47+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
48+
49+
init_dict["norm_num_groups"] = 16
50+
init_dict["block_out_channels"] = (16, 32)
51+
52+
model = self.model_class(**init_dict)
53+
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
54+
jax.lax.stop_gradient(variables)
55+
56+
output = model.apply(variables, inputs_dict["sample"])
57+
58+
if isinstance(output, dict):
59+
output = output.sample
60+
61+
self.assertIsNotNone(output)
62+
expected_shape = inputs_dict["sample"].shape
63+
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
64+
65+
def test_deprecated_kwargs(self):
66+
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
67+
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
68+
69+
if has_kwarg_in_model_class and not has_deprecated_kwarg:
70+
raise ValueError(
71+
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
72+
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
73+
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
74+
" [<deprecated_argument>]`"
75+
)
76+
77+
if not has_kwarg_in_model_class and has_deprecated_kwarg:
78+
raise ValueError(
79+
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
80+
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
81+
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
82+
" from `_deprecated_kwargs = [<deprecated_argument>]`"
83+
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import gc
18+
import unittest
19+
20+
from maxdiffusion import FlaxUNet2DConditionModel
21+
from maxdiffusion.utils import is_flax_available
22+
from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow
23+
from parameterized import parameterized
24+
25+
26+
if is_flax_available():
27+
import jax
28+
import jax.numpy as jnp
29+
30+
31+
@slow
32+
@require_flax
33+
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
34+
35+
def get_file_format(self, seed, shape):
36+
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
37+
38+
def tearDown(self):
39+
# clean up the VRAM after each test
40+
super().tearDown()
41+
gc.collect()
42+
43+
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
44+
dtype = jnp.bfloat16 if fp16 else jnp.float32
45+
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
46+
return image
47+
48+
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
49+
dtype = jnp.bfloat16 if fp16 else jnp.float32
50+
revision = "bf16" if fp16 else None
51+
52+
model, params = FlaxUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", dtype=dtype, revision=revision)
53+
return model, params
54+
55+
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
56+
dtype = jnp.bfloat16 if fp16 else jnp.float32
57+
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
58+
return hidden_states
59+
60+
@parameterized.expand(
61+
[
62+
# fmt: off
63+
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
64+
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
65+
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
66+
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
67+
# fmt: on
68+
]
69+
)
70+
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
71+
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
72+
latents = self.get_latents(seed, fp16=True)
73+
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
74+
75+
sample = model.apply(
76+
{"params": params},
77+
latents,
78+
jnp.array(timestep, dtype=jnp.int32),
79+
encoder_hidden_states=encoder_hidden_states,
80+
).sample
81+
82+
assert sample.shape == latents.shape
83+
84+
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
85+
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
86+
87+
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
88+
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
89+
90+
@parameterized.expand(
91+
[
92+
# fmt: off
93+
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
94+
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
95+
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
96+
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
97+
# fmt: on
98+
]
99+
)
100+
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
101+
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
102+
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
103+
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
104+
105+
sample = model.apply(
106+
{"params": params},
107+
latents,
108+
jnp.array(timestep, dtype=jnp.int32),
109+
encoder_hidden_states=encoder_hidden_states,
110+
).sample
111+
112+
assert sample.shape == latents.shape
113+
114+
output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
115+
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
116+
117+
# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
118+
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
19+
from maxdiffusion import FlaxAutoencoderKL
20+
from maxdiffusion.utils import is_flax_available
21+
from maxdiffusion.utils.testing_utils import require_flax
22+
23+
from .test_modeling_common_flax import FlaxModelTesterMixin
24+
25+
26+
if is_flax_available():
27+
import jax
28+
29+
30+
@require_flax
31+
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
32+
model_class = FlaxAutoencoderKL
33+
34+
@property
35+
def dummy_input(self):
36+
batch_size = 4
37+
num_channels = 3
38+
sizes = (32, 32)
39+
40+
prng_key = jax.random.PRNGKey(0)
41+
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
42+
43+
return {"sample": image, "prng_key": prng_key}
44+
45+
def prepare_init_args_and_inputs_for_common(self):
46+
init_dict = {
47+
"block_out_channels": [32, 64],
48+
"in_channels": 3,
49+
"out_channels": 3,
50+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
51+
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
52+
"latent_channels": 4,
53+
}
54+
inputs_dict = self.dummy_input
55+
return init_dict, inputs_dict
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""

tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy renamed to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy

File renamed without changes.

tests/schedulers/rf_scheduler_test_ref/step_01.npy renamed to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy

File renamed without changes.

tests/schedulers/rf_scheduler_test_ref/step_02.npy renamed to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy

File renamed without changes.

0 commit comments

Comments
 (0)