Skip to content

Commit d9b73ff

Browse files
committed
update
1 parent dcd6026 commit d9b73ff

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
import torch
20+
import torch.multiprocessing as mp
21+
22+
from diffusers.models._modeling_parallel import ContextParallelConfig
23+
24+
from ...testing_utils import (
25+
is_context_parallel,
26+
require_torch_multi_accelerator,
27+
)
28+
29+
30+
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
31+
try:
32+
# Setup distributed environment
33+
os.environ["MASTER_ADDR"] = "localhost"
34+
os.environ["MASTER_PORT"] = "12355"
35+
36+
torch.distributed.init_process_group(
37+
backend="nccl",
38+
init_method="env://",
39+
world_size=world_size,
40+
rank=rank,
41+
)
42+
torch.cuda.set_device(rank)
43+
device = torch.device(f"cuda:{rank}")
44+
45+
model = model_class(**init_dict)
46+
model.to(device)
47+
model.eval()
48+
49+
inputs_on_device = {}
50+
for key, value in inputs_dict.items():
51+
if isinstance(value, torch.Tensor):
52+
inputs_on_device[key] = value.to(device)
53+
else:
54+
inputs_on_device[key] = value
55+
56+
cp_config = ContextParallelConfig(**cp_dict)
57+
model.enable_parallelism(config=cp_config)
58+
59+
with torch.no_grad():
60+
output = model(**inputs_on_device, return_dict=False)[0]
61+
62+
if rank == 0:
63+
result_queue.put(("success", output.shape))
64+
65+
except Exception as e:
66+
if rank == 0:
67+
result_queue.put(("error", str(e)))
68+
finally:
69+
if torch.distributed.is_initialized():
70+
torch.distributed.destroy_process_group()
71+
72+
73+
@is_context_parallel
74+
@require_torch_multi_accelerator
75+
class ContextParallelTesterMixin:
76+
base_precision = 1e-3
77+
78+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
79+
def test_context_parallel_inference(self, cp_type):
80+
if not torch.distributed.is_available():
81+
pytest.skip("torch.distributed is not available.")
82+
83+
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
84+
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
85+
86+
world_size = 2
87+
init_dict = self.get_init_dict()
88+
inputs_dict = self.get_dummy_inputs()
89+
cp_dict = {cp_type: world_size}
90+
91+
ctx = mp.get_context("spawn")
92+
result_queue = ctx.Queue()
93+
94+
mp.spawn(
95+
_context_parallel_worker,
96+
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
97+
nprocs=world_size,
98+
join=True,
99+
)
100+
101+
status, result = result_queue.get(timeout=60)
102+
assert status == "success", f"Context parallel inference failed: {result}"

0 commit comments

Comments
 (0)