Skip to content

Commit eae7543

Browse files
committed
update
1 parent d08e0bb commit eae7543

File tree

10 files changed

+140
-383
lines changed

10 files changed

+140
-383
lines changed

tests/models/testing_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
1+
from .attention import AttentionTesterMixin
22
from .common import BaseModelTesterConfig, ModelTesterMixin
33
from .compile import TorchCompileTesterMixin
44
from .ip_adapter import IPAdapterTesterMixin
55
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
66
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
7+
from .parallelism import ContextParallelTesterMixin
78
from .quantization import (
89
BitsAndBytesTesterMixin,
910
GGUFTesterMixin,

tests/models/testing_utils/attention.py

Lines changed: 3 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
17-
1816
import pytest
1917
import torch
20-
import torch.multiprocessing as mp
2118

22-
from diffusers.models._modeling_parallel import ContextParallelConfig
2319
from diffusers.models.attention import AttentionModuleMixin
2420
from diffusers.models.attention_processor import (
2521
AttnProcessor,
@@ -28,8 +24,6 @@
2824
from ...testing_utils import (
2925
assert_tensors_close,
3026
is_attention,
31-
is_context_parallel,
32-
require_torch_multi_accelerator,
3327
torch_device,
3428
)
3529

@@ -71,9 +65,7 @@ def test_fuse_unfuse_qkv_projections(self):
7165

7266
# Get output before fusion
7367
with torch.no_grad():
74-
output_before_fusion = model(**inputs_dict)
75-
if isinstance(output_before_fusion, dict):
76-
output_before_fusion = output_before_fusion.to_tuple()[0]
68+
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
7769

7870
# Fuse projections
7971
model.fuse_qkv_projections()
@@ -90,9 +82,7 @@ def test_fuse_unfuse_qkv_projections(self):
9082
if has_fused_projections:
9183
# Get output after fusion
9284
with torch.no_grad():
93-
output_after_fusion = model(**inputs_dict)
94-
if isinstance(output_after_fusion, dict):
95-
output_after_fusion = output_after_fusion.to_tuple()[0]
85+
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
9686

9787
# Verify outputs match
9888
assert_tensors_close(
@@ -115,9 +105,7 @@ def test_fuse_unfuse_qkv_projections(self):
115105

116106
# Get output after unfusion
117107
with torch.no_grad():
118-
output_after_unfusion = model(**inputs_dict)
119-
if isinstance(output_after_unfusion, dict):
120-
output_after_unfusion = output_after_unfusion.to_tuple()[0]
108+
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
121109

122110
# Verify outputs still match
123111
assert_tensors_close(
@@ -195,80 +183,3 @@ def test_attention_processor_count_mismatch_raises_error(self):
195183
model.set_attn_processor(wrong_processors)
196184

197185
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
198-
199-
200-
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
201-
try:
202-
# Setup distributed environment
203-
os.environ["MASTER_ADDR"] = "localhost"
204-
os.environ["MASTER_PORT"] = "12355"
205-
206-
torch.distributed.init_process_group(
207-
backend="nccl",
208-
init_method="env://",
209-
world_size=world_size,
210-
rank=rank,
211-
)
212-
torch.cuda.set_device(rank)
213-
device = torch.device(f"cuda:{rank}")
214-
215-
model = model_class(**init_dict)
216-
model.to(device)
217-
model.eval()
218-
219-
inputs_on_device = {}
220-
for key, value in inputs_dict.items():
221-
if isinstance(value, torch.Tensor):
222-
inputs_on_device[key] = value.to(device)
223-
else:
224-
inputs_on_device[key] = value
225-
226-
cp_config = ContextParallelConfig(**cp_dict)
227-
model.enable_parallelism(config=cp_config)
228-
229-
with torch.no_grad():
230-
output = model(**inputs_on_device)
231-
if isinstance(output, dict):
232-
output = output.to_tuple()[0]
233-
234-
if rank == 0:
235-
result_queue.put(("success", output.shape))
236-
237-
except Exception as e:
238-
if rank == 0:
239-
result_queue.put(("error", str(e)))
240-
finally:
241-
if torch.distributed.is_initialized():
242-
torch.distributed.destroy_process_group()
243-
244-
245-
@is_context_parallel
246-
@require_torch_multi_accelerator
247-
class ContextParallelTesterMixin:
248-
base_precision = 1e-3
249-
250-
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
251-
def test_context_parallel_inference(self, cp_type):
252-
if not torch.distributed.is_available():
253-
pytest.skip("torch.distributed is not available.")
254-
255-
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
256-
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
257-
258-
world_size = 2
259-
init_dict = self.get_init_dict()
260-
inputs_dict = self.get_dummy_inputs()
261-
cp_dict = {cp_type: world_size}
262-
263-
ctx = mp.get_context("spawn")
264-
result_queue = ctx.Queue()
265-
266-
mp.spawn(
267-
_context_parallel_worker,
268-
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
269-
nprocs=world_size,
270-
join=True,
271-
)
272-
273-
status, result = result_queue.get(timeout=60)
274-
assert status == "success", f"Context parallel inference failed: {result}"

tests/models/testing_utils/common.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin):
259259
pass
260260
"""
261261

262-
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
262+
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
263263
torch.manual_seed(0)
264264
model = self.model_class(**self.get_init_dict())
265265
model.to(torch_device)
@@ -278,15 +278,8 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
278278
)
279279

280280
with torch.no_grad():
281-
image = model(**self.get_dummy_inputs())
282-
283-
if isinstance(image, dict):
284-
image = image.to_tuple()[0]
285-
286-
new_image = new_model(**self.get_dummy_inputs())
287-
288-
if isinstance(new_image, dict):
289-
new_image = new_image.to_tuple()[0]
281+
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
282+
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
290283

291284
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
292285

@@ -308,14 +301,8 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
308301
new_model.to(torch_device)
309302

310303
with torch.no_grad():
311-
image = model(**self.get_dummy_inputs())
312-
if isinstance(image, dict):
313-
image = image.to_tuple()[0]
314-
315-
new_image = new_model(**self.get_dummy_inputs())
316-
317-
if isinstance(new_image, dict):
318-
new_image = new_image.to_tuple()[0]
304+
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
305+
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
319306

320307
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
321308

@@ -343,13 +330,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
343330
model.eval()
344331

345332
with torch.no_grad():
346-
first = model(**self.get_dummy_inputs())
347-
if isinstance(first, dict):
348-
first = first.to_tuple()[0]
349-
350-
second = model(**self.get_dummy_inputs())
351-
if isinstance(second, dict):
352-
second = second.to_tuple()[0]
333+
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
334+
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
353335

354336
# Filter out NaN values before comparison
355337
first_flat = first.flatten()
@@ -369,10 +351,7 @@ def test_output(self, expected_output_shape=None):
369351

370352
inputs_dict = self.get_dummy_inputs()
371353
with torch.no_grad():
372-
output = model(**inputs_dict)
373-
374-
if isinstance(output, dict):
375-
output = output.to_tuple()[0]
354+
output = model(**inputs_dict, return_dict=False)[0]
376355

377356
assert output is not None, "Model output is None"
378357
assert output[0].shape == expected_output_shape or self.output_shape, (
@@ -501,13 +480,8 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
501480
assert param.data.dtype == dtype
502481

503482
with torch.no_grad():
504-
output = model(**self.get_dummy_inputs())
505-
if isinstance(output, dict):
506-
output = output.to_tuple()[0]
507-
508-
output_loaded = model_loaded(**self.get_dummy_inputs())
509-
if isinstance(output_loaded, dict):
510-
output_loaded = output_loaded.to_tuple()[0]
483+
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
484+
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
511485

512486
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
513487

@@ -519,7 +493,7 @@ def test_sharded_checkpoints(self, tmp_path):
519493
model = self.model_class(**config).eval()
520494
model = model.to(torch_device)
521495

522-
base_output = model(**inputs_dict)
496+
base_output = model(**inputs_dict, return_dict=False)[0]
523497

524498
model_size = compute_module_persistent_sizes(model)[""]
525499
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -539,10 +513,10 @@ def test_sharded_checkpoints(self, tmp_path):
539513

540514
torch.manual_seed(0)
541515
inputs_dict_new = self.get_dummy_inputs()
542-
new_output = new_model(**inputs_dict_new)
516+
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
543517

544518
assert_tensors_close(
545-
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
519+
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
546520
)
547521

548522
@require_accelerator
@@ -553,7 +527,7 @@ def test_sharded_checkpoints_with_variant(self, tmp_path):
553527
model = self.model_class(**config).eval()
554528
model = model.to(torch_device)
555529

556-
base_output = model(**inputs_dict)
530+
base_output = model(**inputs_dict, return_dict=False)[0]
557531

558532
model_size = compute_module_persistent_sizes(model)[""]
559533
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -578,10 +552,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path):
578552

579553
torch.manual_seed(0)
580554
inputs_dict_new = self.get_dummy_inputs()
581-
new_output = new_model(**inputs_dict_new)
555+
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
582556

583557
assert_tensors_close(
584-
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
558+
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
585559
)
586560

587561
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
@@ -593,7 +567,7 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
593567
model = self.model_class(**config).eval()
594568
model = model.to(torch_device)
595569

596-
base_output = model(**inputs_dict)
570+
base_output = model(**inputs_dict, return_dict=False)[0]
597571

598572
model_size = compute_module_persistent_sizes(model)[""]
599573
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
@@ -628,10 +602,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
628602

629603
torch.manual_seed(0)
630604
inputs_dict_parallel = self.get_dummy_inputs()
631-
output_parallel = model_parallel(**inputs_dict_parallel)
605+
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
632606

633607
assert_tensors_close(
634-
base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading"
608+
base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading"
635609
)
636610

637611
finally:
@@ -652,7 +626,7 @@ def test_model_parallelism(self, tmp_path):
652626
model = model.to(torch_device)
653627

654628
torch.manual_seed(0)
655-
base_output = model(**inputs_dict)
629+
base_output = model(**inputs_dict, return_dict=False)[0]
656630

657631
model_size = compute_module_sizes(model)[""]
658632
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
@@ -668,8 +642,8 @@ def test_model_parallelism(self, tmp_path):
668642
check_device_map_is_respected(new_model, new_model.hf_device_map)
669643

670644
torch.manual_seed(0)
671-
new_output = new_model(**inputs_dict)
645+
new_output = new_model(**inputs_dict, return_dict=False)[0]
672646

673647
assert_tensors_close(
674-
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism"
648+
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism"
675649
)

0 commit comments

Comments
 (0)